#ifndef ROOT_Rtypes
#include "Rtypes.h"
#endif
#ifndef ROO_GLOBAL_FUNC
#include "RooGlobalFunc.h"
#endif
#ifndef ROO_ABS_REAL
#include "RooAbsReal.h"
#endif
#ifndef ROO_ARG_SET
#include "RooArgSet.h"
#endif
#ifndef ROO_ARG_LIST
#include "RooArgList.h"
#endif
#ifndef ROOSTATS_ModelConfig
#include "RooStats/ModelConfig.h"
#endif
#ifndef RooStats_RooStatsUtils
#include "RooStats/RooStatsUtils.h"
#endif
#ifndef ROOSTATS_MCMCCalculator
#include "RooStats/MCMCCalculator.h"
#endif
#ifndef ROOSTATS_MetropolisHastings
#include "RooStats/MetropolisHastings.h"
#endif
#ifndef ROOSTATS_MarkovChain
#include "RooStats/MarkovChain.h"
#endif
#ifndef RooStats_MCMCInterval
#include "RooStats/MCMCInterval.h"
#endif
#ifndef ROOT_TIterator
#include "TIterator.h"
#endif
#ifndef ROOSTATS_UniformProposal
#include "RooStats/UniformProposal.h"
#endif
#ifndef ROOSTATS_PdfProposal
#include "RooStats/PdfProposal.h"
#endif
#ifndef ROO_PROD_PDF
#include "RooProdPdf.h"
#endif
ClassImp(RooStats::MCMCCalculator);
using namespace RooFit;
using namespace RooStats;
MCMCCalculator::MCMCCalculator() :
fPropFunc(0),
fPdf(0),
fPriorPdf(0),
fData(0),
fAxes(0)
{
fNumIters = 0;
fNumBurnInSteps = 0;
fNumBins = 0;
fUseKeys = kFALSE;
fUseSparseHist = kFALSE;
fSize = -1;
fIntervalType = MCMCInterval::kShortest;
fLeftSideTF = -1;
fEpsilon = -1;
fDelta = -1;
}
MCMCCalculator::MCMCCalculator(RooAbsData& data, const ModelConfig & model) :
fPropFunc(0),
fData(&data),
fAxes(0)
{
SetModel(model);
SetupBasicUsage();
}
void MCMCCalculator::SetModel(const ModelConfig & model) {
fPdf = model.GetPdf();
fPriorPdf = model.GetPriorPdf();
fPOI.removeAll();
fNuisParams.removeAll();
if (model.GetParametersOfInterest())
fPOI.add(*model.GetParametersOfInterest());
if (model.GetNuisanceParameters())
fNuisParams.add(*model.GetNuisanceParameters());
}
void MCMCCalculator::SetupBasicUsage()
{
fPropFunc = 0;
fNumIters = 10000;
fNumBurnInSteps = 40;
fNumBins = 50;
fUseKeys = kFALSE;
fUseSparseHist = kFALSE;
SetTestSize(0.05);
fIntervalType = MCMCInterval::kShortest;
fLeftSideTF = -1;
fEpsilon = -1;
fDelta = -1;
}
void MCMCCalculator::SetLeftSideTailFraction(Double_t a)
{
if (a < 0 || a > 1) {
coutE(InputArguments) << "MCMCCalculator::SetLeftSideTailFraction: "
<< "Fraction must be in the range [0, 1]. "
<< a << "is not allowed." << endl;
return;
}
fLeftSideTF = a;
fIntervalType = MCMCInterval::kTailFraction;
}
MCMCInterval* MCMCCalculator::GetInterval() const
{
if (!fData || !fPdf ) return 0;
if (fPOI.getSize() == 0) return 0;
if (fSize < 0) {
coutE(InputArguments) << "MCMCCalculator::GetInterval: "
<< "Test size/Confidence level not set. Returning NULL." << endl;
return NULL;
}
bool useDefaultPropFunc = (fPropFunc == 0);
bool usePriorPdf = (fPriorPdf != 0);
if (useDefaultPropFunc) fPropFunc = new UniformProposal();
RooAbsPdf * prodPdf = fPdf;
if (usePriorPdf) {
TString prodName = TString("product_") + TString(fPdf->GetName()) + TString("_") + TString(fPriorPdf->GetName() );
prodPdf = new RooProdPdf(prodName,prodName,RooArgList(*fPdf,*fPriorPdf) );
}
RooArgSet* constrainedParams = prodPdf->getParameters(*fData);
RooAbsReal* nll = prodPdf->createNLL(*fData, Constrain(*constrainedParams));
delete constrainedParams;
RooArgSet* params = nll->getParameters(*fData);
RemoveConstantParameters(params);
if (fNumBins > 0) {
SetBins(*params, fNumBins);
SetBins(fPOI, fNumBins);
if (dynamic_cast<PdfProposal*>(fPropFunc)) {
RooArgSet* proposalVars = ((PdfProposal*)fPropFunc)->GetPdf()->
getParameters((RooAbsData*)NULL);
SetBins(*proposalVars, fNumBins);
}
}
MetropolisHastings mh;
mh.SetFunction(*nll);
mh.SetType(MetropolisHastings::kLog);
mh.SetSign(MetropolisHastings::kNegative);
mh.SetParameters(*params);
mh.SetProposalFunction(*fPropFunc);
mh.SetNumIters(fNumIters);
MarkovChain* chain = mh.ConstructChain();
TString name = TString("MCMCInterval_") + TString(GetName() );
MCMCInterval* interval = new MCMCInterval(name, fPOI, *chain);
if (fAxes != NULL)
interval->SetAxes(*fAxes);
if (fNumBurnInSteps > 0)
interval->SetNumBurnInSteps(fNumBurnInSteps);
interval->SetUseKeys(fUseKeys);
interval->SetUseSparseHist(fUseSparseHist);
interval->SetIntervalType(fIntervalType);
if (fIntervalType == MCMCInterval::kTailFraction)
interval->SetLeftSideTailFraction(fLeftSideTF);
if (fEpsilon >= 0)
interval->SetEpsilon(fEpsilon);
if (fDelta >= 0)
interval->SetDelta(fDelta);
interval->SetConfidenceLevel(1.0 - fSize);
if (useDefaultPropFunc) delete fPropFunc;
if (usePriorPdf) delete prodPdf;
delete nll;
delete params;
return interval;
}