#include "RooStats/ToyMCSampler.h"
#ifndef ROO_MSG_SERVICE
#include "RooMsgService.h"
#endif
#ifndef ROO_DATA_HIST
#include "RooDataHist.h"
#endif
#ifndef ROO_REAL_VAR
#include "RooRealVar.h"
#endif
#include "TCanvas.h"
#include "RooPlot.h"
#include "RooRandom.h"
#include "RooStudyManager.h"
#include "RooStats/ToyMCStudy.h"
#include "TMath.h"
ClassImp(RooStats::ToyMCSampler)
namespace RooStats {
class NuisanceParametersSampler {
public:
NuisanceParametersSampler(RooAbsPdf *prior=NULL, const RooArgSet *parameters=NULL, Int_t nToys=1000, Bool_t asimov=kFALSE) :
fPrior(prior),
fParams(parameters),
fNToys(nToys),
fExpected(asimov),
fPoints(NULL),
fIndex(0)
{
if(prior) Refresh();
}
void NextPoint(RooArgSet& nuisPoint, Double_t& weight) {
if (fIndex >= fNToys) {
Refresh();
fIndex = 0;
}
nuisPoint = *fPoints->get(fIndex++);
weight = fPoints->weight();
if(fPoints->weight() == 0.0) {
oocoutI((TObject*)NULL,Generation) << "Weight 0 encountered. Skipping." << endl;
NextPoint(nuisPoint, weight);
}
}
protected:
void Refresh() {
if (!fPrior || !fParams) return;
if (fPoints) delete fPoints;
if (fExpected) {
oocoutI((TObject*)NULL,InputArguments) << "Using expected nuisance parameters." << endl;
int nBins = fNToys;
TIter it2 = fParams->createIterator();
RooRealVar *myarg2;
while ((myarg2 = dynamic_cast<RooRealVar*>(it2.Next()))) {
myarg2->setBins(nBins);
}
fPoints = fPrior->generateBinned(
*fParams,
RooFit::ExpectedData(),
RooFit::NumEvents(1)
);
if(fPoints->numEntries() != fNToys) {
fNToys = fPoints->numEntries();
oocoutI((TObject*)NULL,InputArguments) <<
"Adjusted number of toys to number of bins of nuisance parameters: " << fNToys << endl;
}
}else{
oocoutI((TObject*)NULL,InputArguments) << "Using randomized nuisance parameters." << endl;
fPoints = fPrior->generate(*fParams, fNToys);
}
}
private:
RooAbsPdf *fPrior;
const RooArgSet *fParams;
Int_t fNToys;
Bool_t fExpected;
RooAbsData *fPoints;
Int_t fIndex;
};
Bool_t ToyMCSampler::CheckConfig(void) {
bool goodConfig = true;
if(!fTestStat) { ooccoutE((TObject*)NULL,InputArguments) << "Test statistic not set." << endl; goodConfig = false; }
if(!fObservables) { ooccoutE((TObject*)NULL,InputArguments) << "Observables not set." << endl; goodConfig = false; }
if(!fNullPOI) { ooccoutE((TObject*)NULL,InputArguments) << "Parameter values used to evaluate for test statistic not set." << endl; goodConfig = false; }
if(!fPdf) { ooccoutE((TObject*)NULL,InputArguments) << "Pdf not set." << endl; goodConfig = false; }
return goodConfig;
}
SamplingDistribution* ToyMCSampler::GetSamplingDistribution(RooArgSet& paramPointIn) {
if(!fProofConfig)
return GetSamplingDistributionSingleWorker(paramPointIn);
CheckConfig();
if(fToysInTails) {
fToysInTails = 0;
oocoutW((TObject*)NULL, InputArguments)
<< "Adaptive sampling in ToyMCSampler is not supported for parallel runs."
<< endl;
}
Int_t totToys = fNToys;
fNToys = (int)ceil((double)fNToys / (double)fProofConfig->GetNExperiments());
ToyMCStudy toymcstudy;
toymcstudy.SetToyMCSampler(*this);
toymcstudy.SetParamPointOfInterest(paramPointIn);
RooWorkspace w(fProofConfig->GetWorkspace());
RooStudyManager studymanager(w, toymcstudy);
studymanager.runProof(fProofConfig->GetNExperiments(), fProofConfig->GetHost(), fProofConfig->GetShowGui());
SamplingDistribution *result = new SamplingDistribution(GetSamplingDistName().c_str(), GetSamplingDistName().c_str());
toymcstudy.merge(*result);
fNToys = totToys;
return result;
}
SamplingDistribution* ToyMCSampler::GetSamplingDistributionSingleWorker(RooArgSet& paramPointIn) {
CheckConfig();
std::vector<Double_t> testStatVec;
std::vector<Double_t> testStatWeights;
RooArgSet *paramPoint = (RooArgSet*) paramPointIn.snapshot();
RooArgSet *allVars = fPdf->getVariables();
RooArgSet *saveAll = (RooArgSet*) allVars->snapshot();
NuisanceParametersSampler *np = NULL;
if(fPriorNuisance && fNuisancePars)
np = new NuisanceParametersSampler(fPriorNuisance, fNuisancePars, fNToys, fExpectedNuisancePar);
Double_t toysInTails = 0.0;
for (Int_t i = 0; i < fMaxToys; ++i) {
if ( i% 500 == 0 && i>0 ) {
oocoutP((TObject*)0,Generation) << "generated toys: " << i << " / " << fNToys;
if (fToysInTails) ooccoutP((TObject*)0,Generation) << " (tails: " << toysInTails << " / " << fToysInTails << ")" << std::endl;
else ooccoutP((TObject*)0,Generation) << endl;
}
Double_t value, weight;
if (np) {
*allVars = *paramPoint;
np->NextPoint(*allVars, weight);
RooAbsData* toydata = GenerateToyData(*allVars);
value = fTestStat->Evaluate(*toydata, *fNullPOI);
if(fImportanceDensity) {
*allVars = *fImportanceSnapshot;
RooAbsReal *impNLL = fImportanceDensity->createNLL(*toydata, RooFit::Extended(kFALSE), RooFit::CloneData(kFALSE));
double impNLLVal = impNLL->getVal();
delete impNLL;
*allVars = *paramPoint;
RooAbsReal *pdfNLL = fPdf->createNLL(*toydata, RooFit::Extended(kFALSE), RooFit::CloneData(kFALSE));
double pdfNLLVal = pdfNLL->getVal();
delete pdfNLL;
weight *= exp(impNLLVal - pdfNLLVal);
}
delete toydata;
}else{
*allVars = *paramPoint;
RooAbsData* toydata = GenerateToyData(*allVars);
value = fTestStat->Evaluate(*toydata, *fNullPOI);
weight = -1.;
delete toydata;
}
if(value != value) {
oocoutW((TObject*)NULL, Generation) << "skip: " << value << ", " << weight << endl;
continue;
}
testStatVec.push_back(value);
if(weight >= 0.) testStatWeights.push_back(weight);
if (value <= fAdaptiveLowLimit || value >= fAdaptiveHighLimit) {
if(weight >= 0.) toysInTails += weight;
else toysInTails += 1.;
}
if (toysInTails >= fToysInTails && i+1 >= fNToys) break;
}
*allVars = *saveAll;
delete saveAll;
delete allVars;
if(np) delete np;
if (testStatWeights.size()) {
return new SamplingDistribution(
fSamplingDistName.c_str(),
fSamplingDistName.c_str(),
testStatVec,
testStatWeights,
fTestStat->GetVarName()
);
}
return new SamplingDistribution(
fSamplingDistName.c_str(),
fSamplingDistName.c_str(),
testStatVec,
fTestStat->GetVarName()
);
}
RooAbsData* ToyMCSampler::GenerateToyData(RooArgSet& ) const {
RooArgSet observables(*fObservables);
if(fGlobalObservables && fGlobalObservables->getSize()) {
observables.remove(*fGlobalObservables);
RooDataSet *one = fPdf->generate(*fGlobalObservables, 1);
const RooArgSet *values = one->get();
RooArgSet *allVars = fPdf->getVariables();
*allVars = *values;
delete allVars;
delete values;
delete one;
}
RooAbsData* data = NULL;
if(!fImportanceDensity) {
data = Generate(*fPdf, observables);
}else{
RooArgSet* allVars = fPdf->getVariables();
RooArgSet* allVars2 = fImportanceDensity->getVariables();
allVars->add(*allVars2);
const RooArgSet* saveVars = (const RooArgSet*)allVars->snapshot();
int forceEvents = 0;
if(fNEvents == 0) {
forceEvents = (int)fPdf->expectedEvents(observables);
forceEvents = RooRandom::randomGenerator()->Poisson(forceEvents);
}
if(fImportanceSnapshot) *allVars = *fImportanceSnapshot;
data = Generate(*fImportanceDensity, observables, NULL, forceEvents);
*allVars = *saveVars;
delete allVars;
delete allVars2;
delete saveVars;
}
return data;
}
RooAbsData* ToyMCSampler::Generate(RooAbsPdf &pdf, RooArgSet &observables, const RooDataSet* protoData, int forceEvents) const {
if(fProtoData) {
protoData = fProtoData;
forceEvents = protoData->numEntries();
}
RooAbsData *data = NULL;
int events = forceEvents;
if(events == 0) events = fNEvents;
if(events == 0) {
if( pdf.canBeExtended() && pdf.expectedEvents(observables) > 0) {
if(fGenerateBinned) {
if(protoData) data = pdf.generateBinned(observables, RooFit::Extended(), RooFit::ProtoData(*protoData, true, true));
else data = pdf.generateBinned(observables, RooFit::Extended());
}else{
if(protoData) data = pdf.generate (observables, RooFit::Extended(), RooFit::ProtoData(*protoData, true, true));
else data = pdf.generate (observables, RooFit::Extended());
}
}else{
oocoutE((TObject*)0,InputArguments)
<< "ToyMCSampler: Error : pdf is not extended and number of events per toy is zero"
<< endl;
}
}else{
if(fGenerateBinned) {
if(protoData) data = pdf.generateBinned(observables, events, RooFit::ProtoData(*protoData, true, true));
else data = pdf.generateBinned(observables, events);
}else{
if(protoData) data = pdf.generate (observables, events, RooFit::ProtoData(*protoData, true, true));
else data = pdf.generate (observables, events);
}
}
return data;
}
}