Logo ROOT  
Reference Guide
MetropolisHastings.cxx
Go to the documentation of this file.
1// @(#)root/roostats:$Id$
2// Authors: Kevin Belasco 17/06/2009
3// Authors: Kyle Cranmer 17/06/2009
4/*************************************************************************
5 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/** \class RooStats::MetropolisHastings
13 \ingroup Roostats
14
15This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16of data points using Monte Carlo. In the main algorithm, new points in the
17parameter space are proposed and then visited based on their relative
18likelihoods. This class can use any implementation of the ProposalFunction,
19including non-symmetric proposal functions, to propose parameter points and
20still maintain detailed balance when constructing the chain.
21
22
23
24The "Likelihood" function that is sampled when deciding what steps to take in
25the chain has been given a very generic implementation. The user can create
26any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27object with the method SetFunction(RooAbsReal&). Be sure to tell
28MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29so that it knows what logic to use when sampling your RooAbsReal. For example,
30a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32SetSign(MetropolisHastings::kNegative);
33If you're using a traditional likelihood function:
34SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35You must set these type and sign flags or MetropolisHastings will not construct
36a MarkovChain.
37
38Also note that in ConstructChain(), the values of the variables are randomized
39uniformly over their intervals before construction of the MarkovChain begins.
40
41*/
42
44
49
50#include "Rtypes.h"
51#include "RooRealVar.h"
52#include "RooNLLVar.h"
53#include "RooGlobalFunc.h"
54#include "RooDataSet.h"
55#include "RooArgSet.h"
56#include "RooArgList.h"
57#include "RooMsgService.h"
58#include "RooRandom.h"
59#include "TH1.h"
60#include "TMath.h"
61#include "TFile.h"
62
64
65using namespace RooFit;
66using namespace RooStats;
67using namespace std;
68
69////////////////////////////////////////////////////////////////////////////////
70
71MetropolisHastings::MetropolisHastings()
72{
73 // default constructor
74 fFunction = NULL;
75 fPropFunc = NULL;
76 fNumIters = 0;
80}
81
82////////////////////////////////////////////////////////////////////////////////
83
85 ProposalFunction& proposalFunction, Int_t numIters)
86{
88 SetParameters(paramsOfInterest);
89 SetProposalFunction(proposalFunction);
90 fNumIters = numIters;
94}
95
96////////////////////////////////////////////////////////////////////////////////
97
99{
100 if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
101 coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
102 " function, or (log) likelihood function" << endl;
103 return NULL;
104 }
105 if (fSign == kSignUnset || fType == kTypeUnset) {
106 coutE(Eval) << "Please set type and sign of your function using "
107 << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
108 endl;
109 return NULL;
110 }
111
113
114 RooArgSet x;
115 RooArgSet xPrime;
116 x.addClone(fParameters);
118 xPrime.addClone(fParameters);
119 RandomizeCollection(xPrime);
120
121 MarkovChain* chain = new MarkovChain();
122 // only the POI will be added to the chain
124
125 Int_t weight = 0;
126 Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
127
128 // ibucur: i think the user should have the possibility to display all the message
129 // levels should they want to; maybe a setPrintLevel would be appropriate
130 // (maybe for the other classes that use this approach as well)?
133
134 // We will need to check if log-likelihood evaluation left an error status.
135 // Now using faster eval error logging with CountErrors.
136 if (fType == kLog) {
138 //N.B: need to clear the count in case of previous errors !
139 // the clear needs also to be done after calling setEvalErrorLoggingMode
141 }
142
143 bool hadEvalError = true;
144
145 Int_t i = 0;
146 // get a good starting point for x
147 // for fType == kLog, this means that fFunction->getVal() did not cause
148 // an eval error
149 // for fType == kRegular this means fFunction->getVal() != 0
150 //
151 // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
152 // steps we should have to take for any reasonable (log) likelihood function
153 while (i < 1000 && hadEvalError) {
156 xL = fFunction->getVal();
157
158 if (fType == kLog) {
159 if (RooAbsReal::numEvalErrors() > 0) {
161 hadEvalError = true;
162 } else
163 hadEvalError = false;
164 } else if (fType == kRegular) {
165 if (xL == 0.0)
166 hadEvalError = true;
167 else
168 hadEvalError = false;
169 } else
170 // for now the only 2 types are kLog and kRegular (won't get here)
171 hadEvalError = false;
172 ++i;
173 }
174
175 if(hadEvalError) {
176 coutE(Eval) << "Problem finding a good starting point in " <<
177 "MetropolisHastings::ConstructChain() " << endl;
178 }
179
180
181 ooccoutP((TObject *)0, Generation) << "Metropolis-Hastings progress: ";
182
183 // do main loop
184 for (i = 0; i < fNumIters; i++) {
185 // reset error handling flag
186 hadEvalError = false;
187
188 // print a dot every 1% of the chain construction
189 if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) << ".";
190
191 fPropFunc->Propose(xPrime, x);
192
194 xPrimeL = fFunction->getVal();
195
196 // check if log-likelihood for xprime had an error status
197 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
198 xPrimeL = RooNumber::infinity();
200 hadEvalError = true;
201 }
202
203 // why evaluate the last point again, can't we cache it?
204 // kbelasco: commenting out lines below to add/test caching support
205 //RooStats::SetParameters(&x, &fParameters);
206 //xL = fFunction->getVal();
207
208 if (fType == kLog) {
209 if (fSign == kPositive)
210 a = xL - xPrimeL;
211 else
212 a = xPrimeL - xL;
213 }
214 else
215 a = xPrimeL / xL;
216 //a = xL / xPrimeL;
217
218 if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
219 Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
220 Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
221 if (fType == kRegular)
222 a *= xPD / xPrimePD;
223 else
224 a += TMath::Log(xPrimePD) - TMath::Log(xPD);
225 }
226
227 if (!hadEvalError && ShouldTakeStep(a)) {
228 // go to the proposed point xPrime
229
230 // add the current point with the current weight
231 if (weight != 0.0)
232 chain->Add(x, CalcNLL(xL), (Double_t)weight);
233
234 // reset the weight and go to xPrime
235 weight = 1;
236 RooStats::SetParameters(&xPrime, &x);
237 xL = xPrimeL;
238 } else {
239 // stay at the current point
240 weight++;
241 }
242 }
243
244 // make sure to add the last point
245 if (weight != 0.0)
246 chain->Add(x, CalcNLL(xL), (Double_t)weight);
247 ooccoutP((TObject *)0, Generation) << endl;
248
250
251 Int_t numAccepted = chain->Size();
252 coutI(Eval) << "Proposal acceptance rate: " <<
253 numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
254 coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
255
256 //TFile chainDataFile("chainData.root", "recreate");
257 //chain->GetDataSet()->Write();
258 //chainDataFile.Close();
259
260 return chain;
261}
262
263////////////////////////////////////////////////////////////////////////////////
264
266{
267 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
268 // The proposed point has a higher likelihood than the
269 // current point, so we should go there
270 return kTRUE;
271 }
272 else {
273 // generate numbers on a log distribution to decide
274 // whether to go to xPrime or stay at x
275 //Double_t rand = fGen.Uniform(1.0);
277 if (fType == kLog) {
278 rand = TMath::Log(rand);
279 // kbelasco: should this be changed to just (-rand > a) for logical
280 // consistency with below test when fType == kRegular?
281 if (-1.0 * rand >= a)
282 // we chose to go to the new proposed point
283 // even though it has a lower likelihood than the current one
284 return kTRUE;
285 } else {
286 // fType must be kRegular
287 // kbelasco: ensure that we never visit a point where PDF == 0
288 //if (rand <= a)
289 if (rand < a)
290 // we chose to go to the new proposed point
291 // even though it has a lower likelihood than the current one
292 return kTRUE;
293 }
294 return kFALSE;
295 }
296}
297
298////////////////////////////////////////////////////////////////////////////////
299
301{
302 if (fType == kLog) {
303 if (fSign == kNegative)
304 return xL;
305 else
306 return -xL;
307 } else {
308 if (fSign == kPositive)
309 return -1.0 * TMath::Log(xL);
310 else
311 return -1.0 * TMath::Log(-xL);
312 }
313}
#define coutI(a)
Definition: RooMsgService.h:30
#define coutE(a)
Definition: RooMsgService.h:33
#define ooccoutP(o, a)
Definition: RooMsgService.h:54
const Bool_t kFALSE
Definition: RtypesCore.h:90
float Float_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
Int_t getSize() const
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:60
Double_t getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition: RooAbsReal.h:90
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgSet.h:28
virtual Bool_t add(const RooAbsCollection &col, Bool_t silent=kFALSE)
Add a collection of arguments to this collection by calling add() for each element in the source coll...
Definition: RooArgSet.h:88
virtual void addClone(const RooAbsCollection &col, Bool_t silent=kFALSE)
Add a collection of arguments to this collection by calling addOwned() for each element in the source...
Definition: RooArgSet.h:96
static RooMsgService & instance()
Return reference to singleton instance.
void setGlobalKillBelow(RooFit::MsgLevel level)
RooFit::MsgLevel globalKillBelow() const
static Double_t infinity()
Return internal infinity representation.
Definition: RooNumber.cxx:49
static Double_t uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition: RooRandom.cxx:83
Stores the steps in a Markov Chain of points.
Definition: MarkovChain.h:30
virtual void Add(RooArgSet &entry, Double_t nllValue, Double_t weight=1.0)
safely add an entry to the chain
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
Definition: MarkovChain.cxx:77
virtual Int_t Size() const
get the number of steps in the chain
Definition: MarkovChain.h:49
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
virtual void SetProposalFunction(ProposalFunction &proposalFunction)
virtual Bool_t ShouldTakeStep(Double_t d)
virtual void SetParameters(const RooArgSet &set)
virtual MarkovChain * ConstructChain()
virtual Double_t CalcNLL(Double_t xL)
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
virtual void Propose(RooArgSet &xPrime, RooArgSet &x)=0
Populate xPrime with the new proposed point, possibly based on the current point x.
virtual Double_t GetProposalDensity(RooArgSet &x1, RooArgSet &x2)=0
Return the probability of proposing the point x1 given the starting point x2.
virtual Bool_t IsSymmetric(RooArgSet &x1, RooArgSet &x2)=0
Determine whether or not the proposal density is symmetric for points x1 and x2 - that is,...
Mother of all ROOT objects.
Definition: TObject.h:37
Double_t x[n]
Definition: legend1.C:17
void function(const Char_t *name_, T fun, const Char_t *docstring=0)
Definition: RExports.h:151
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
MsgLevel
Verbosity level for RooMsgService::StreamConfig in RooMsgService.
Definition: RooGlobalFunc.h:65
@ Generation
Definition: RooGlobalFunc.h:67
Namespace for the RooStats classes.
Definition: Asimov.h:19
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
Definition: RooStatsUtils.h:65
void RandomizeCollection(RooAbsCollection &set, Bool_t randomizeConstants=kTRUE)
Double_t Log(Double_t x)
Definition: TMath.h:750
auto * a
Definition: textangle.C:12