Logo ROOT   6.14/05
Reference Guide
MetropolisHastings.cxx
Go to the documentation of this file.
1 // @(#)root/roostats:$Id$
2 // Authors: Kevin Belasco 17/06/2009
3 // Authors: Kyle Cranmer 17/06/2009
4 /*************************************************************************
5  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /** \class RooStats::MetropolisHastings
13  \ingroup Roostats
14 
15 This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16 of data points using Monte Carlo. In the main algorithm, new points in the
17 parameter space are proposed and then visited based on their relative
18 likelihoods. This class can use any implementation of the ProposalFunction,
19 including non-symmetric proposal functions, to propose parameter points and
20 still maintain detailed balance when constructing the chain.
21 
22 
23 
24 The "Likelihood" function that is sampled when deciding what steps to take in
25 the chain has been given a very generic implementation. The user can create
26 any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27 object with the method SetFunction(RooAbsReal&). Be sure to tell
28 MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29 so that it knows what logic to use when sampling your RooAbsReal. For example,
30 a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31 the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32 SetSign(MetropolisHastings::kNegative);
33 If you're using a traditional likelihood function:
34 SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35 You must set these type and sign flags or MetropolisHastings will not construct
36 a MarkovChain.
37 
38 Also note that in ConstructChain(), the values of the variables are randomized
39 uniformly over their intervals before construction of the MarkovChain begins.
40 
41 */
42 
44 
45 #include "RooStats/MarkovChain.h"
46 #include "RooStats/MCMCInterval.h"
47 #include "RooStats/RooStatsUtils.h"
49 
50 #include "Rtypes.h"
51 #include "RooRealVar.h"
52 #include "RooNLLVar.h"
53 #include "RooGlobalFunc.h"
54 #include "RooDataSet.h"
55 #include "RooArgSet.h"
56 #include "RooArgList.h"
57 #include "RooMsgService.h"
58 #include "RooRandom.h"
59 #include "TH1.h"
60 #include "TMath.h"
61 #include "TFile.h"
62 
64 
65 using namespace RooFit;
66 using namespace RooStats;
67 using namespace std;
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 
71 MetropolisHastings::MetropolisHastings()
72 {
73  // default constructor
74  fFunction = NULL;
75  fPropFunc = NULL;
76  fNumIters = 0;
77  fNumBurnInSteps = 0;
78  fSign = kSignUnset;
79  fType = kTypeUnset;
80 }
81 
82 ////////////////////////////////////////////////////////////////////////////////
83 
84 MetropolisHastings::MetropolisHastings(RooAbsReal& function, const RooArgSet& paramsOfInterest,
85  ProposalFunction& proposalFunction, Int_t numIters)
86 {
87  fFunction = &function;
88  SetParameters(paramsOfInterest);
89  SetProposalFunction(proposalFunction);
90  fNumIters = numIters;
91  fNumBurnInSteps = 0;
92  fSign = kSignUnset;
93  fType = kTypeUnset;
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 
98 MarkovChain* MetropolisHastings::ConstructChain()
99 {
100  if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
101  coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
102  " function, or (log) likelihood function" << endl;
103  return NULL;
104  }
105  if (fSign == kSignUnset || fType == kTypeUnset) {
106  coutE(Eval) << "Please set type and sign of your function using "
107  << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
108  endl;
109  return NULL;
110  }
111 
112  if (fChainParams.getSize() == 0) fChainParams.add(fParameters);
113 
114  RooArgSet x;
115  RooArgSet xPrime;
116  x.addClone(fParameters);
118  xPrime.addClone(fParameters);
119  RandomizeCollection(xPrime);
120 
121  MarkovChain* chain = new MarkovChain();
122  // only the POI will be added to the chain
123  chain->SetParameters(fChainParams);
124 
125  Int_t weight = 0;
126  Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
127 
128  // ibucur: i think the user should have the possibility to display all the message
129  // levels should they want to; maybe a setPrintLevel would be appropriate
130  // (maybe for the other classes that use this approach as well)?
133 
134  // We will need to check if log-likelihood evaluation left an error status.
135  // Now using faster eval error logging with CountErrors.
136  if (fType == kLog) {
138  //N.B: need to clear the count in case of previous errors !
139  // the clear needs also to be done after calling setEvalErrorLoggingMode
141  }
142 
143  bool hadEvalError = true;
144 
145  Int_t i = 0;
146  // get a good starting point for x
147  // for fType == kLog, this means that fFunction->getVal() did not cause
148  // an eval error
149  // for fType == kRegular this means fFunction->getVal() != 0
150  //
151  // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
152  // steps we should have to take for any reasonable (log) likelihood function
153  while (i < 1000 && hadEvalError) {
155  RooStats::SetParameters(&x, &fParameters);
156  xL = fFunction->getVal();
157 
158  if (fType == kLog) {
159  if (RooAbsReal::numEvalErrors() > 0) {
161  hadEvalError = true;
162  } else
163  hadEvalError = false;
164  } else if (fType == kRegular) {
165  if (xL == 0.0)
166  hadEvalError = true;
167  else
168  hadEvalError = false;
169  } else
170  // for now the only 2 types are kLog and kRegular (won't get here)
171  hadEvalError = false;
172  ++i;
173  }
174 
175  if(hadEvalError) {
176  coutE(Eval) << "Problem finding a good starting point in " <<
177  "MetropolisHastings::ConstructChain() " << endl;
178  }
179 
180 
181  ooccoutP((TObject *)0, Generation) << "Metropolis-Hastings progress: ";
182 
183  // do main loop
184  for (i = 0; i < fNumIters; i++) {
185  // reset error handling flag
186  hadEvalError = false;
187 
188  // print a dot every 1% of the chain construction
189  if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) << ".";
190 
191  fPropFunc->Propose(xPrime, x);
192 
193  RooStats::SetParameters(&xPrime, &fParameters);
194  xPrimeL = fFunction->getVal();
195 
196  // check if log-likelihood for xprime had an error status
197  if (fFunction->numEvalErrors() > 0 && fType == kLog) {
198  xPrimeL = RooNumber::infinity();
199  fFunction->clearEvalErrorLog();
200  hadEvalError = true;
201  }
202 
203  // why evaluate the last point again, can't we cache it?
204  // kbelasco: commenting out lines below to add/test caching support
205  //RooStats::SetParameters(&x, &fParameters);
206  //xL = fFunction->getVal();
207 
208  if (fType == kLog) {
209  if (fSign == kPositive)
210  a = xL - xPrimeL;
211  else
212  a = xPrimeL - xL;
213  }
214  else
215  a = xPrimeL / xL;
216  //a = xL / xPrimeL;
217 
218  if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
219  Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
220  Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
221  if (fType == kRegular)
222  a *= xPD / xPrimePD;
223  else
224  a += TMath::Log(xPrimePD) - TMath::Log(xPD);
225  }
226 
227  if (!hadEvalError && ShouldTakeStep(a)) {
228  // go to the proposed point xPrime
229 
230  // add the current point with the current weight
231  if (weight != 0.0)
232  chain->Add(x, CalcNLL(xL), (Double_t)weight);
233 
234  // reset the weight and go to xPrime
235  weight = 1;
236  RooStats::SetParameters(&xPrime, &x);
237  xL = xPrimeL;
238  } else {
239  // stay at the current point
240  weight++;
241  }
242  }
243 
244  // make sure to add the last point
245  if (weight != 0.0)
246  chain->Add(x, CalcNLL(xL), (Double_t)weight);
247  ooccoutP((TObject *)0, Generation) << endl;
248 
250 
251  Int_t numAccepted = chain->Size();
252  coutI(Eval) << "Proposal acceptance rate: " <<
253  numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
254  coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
255 
256  //TFile chainDataFile("chainData.root", "recreate");
257  //chain->GetDataSet()->Write();
258  //chainDataFile.Close();
259 
260  return chain;
261 }
262 
263 ////////////////////////////////////////////////////////////////////////////////
264 
265 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
266 {
267  if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
268  // The proposed point has a higher likelihood than the
269  // current point, so we should go there
270  return kTRUE;
271  }
272  else {
273  // generate numbers on a log distribution to decide
274  // whether to go to xPrime or stay at x
275  //Double_t rand = fGen.Uniform(1.0);
276  Double_t rand = RooRandom::uniform();
277  if (fType == kLog) {
278  rand = TMath::Log(rand);
279  // kbelasco: should this be changed to just (-rand > a) for logical
280  // consistency with below test when fType == kRegular?
281  if (-1.0 * rand >= a)
282  // we chose to go to the new proposed point
283  // even though it has a lower likelihood than the current one
284  return kTRUE;
285  } else {
286  // fType must be kRegular
287  // kbelasco: ensure that we never visit a point where PDF == 0
288  //if (rand <= a)
289  if (rand < a)
290  // we chose to go to the new proposed point
291  // even though it has a lower likelihood than the current one
292  return kTRUE;
293  }
294  return kFALSE;
295  }
296 }
297 
298 ////////////////////////////////////////////////////////////////////////////////
299 
300 Double_t MetropolisHastings::CalcNLL(Double_t xL)
301 {
302  if (fType == kLog) {
303  if (fSign == kNegative)
304  return xL;
305  else
306  return -xL;
307  } else {
308  if (fSign == kPositive)
309  return -1.0 * TMath::Log(xL);
310  else
311  return -1.0 * TMath::Log(-xL);
312  }
313 }
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
#define coutE(a)
Definition: RooMsgService.h:34
Double_t Log(Double_t x)
Definition: TMath.h:759
float Float_t
Definition: RtypesCore.h:53
#define coutI(a)
Definition: RooMsgService.h:31
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
Definition: RooStatsUtils.h:58
RooFit::MsgLevel globalKillBelow() const
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
virtual Int_t Size() const
get the number of steps in the chain
Definition: MarkovChain.h:49
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
static RooMsgService & instance()
Return reference to singleton instance.
#define ooccoutP(o, a)
Definition: RooMsgService.h:52
STL namespace.
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
void SetParameters(TFitEditor::FuncParams_t &pars, TF1 *func)
Restore the parameters from pars into the function.
Definition: TFitEditor.cxx:287
Double_t x[n]
Definition: legend1.C:17
virtual void Add(RooArgSet &entry, Double_t nllValue, Double_t weight=1.0)
safely add an entry to the chain
virtual void addClone(const RooAbsCollection &col, Bool_t silent=kFALSE)
Add a collection of arguments to this collection by calling addOwned() for each element in the source...
Definition: RooArgSet.h:94
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
auto * a
Definition: textangle.C:12
static Double_t infinity()
Return internal infinity representation.
Definition: RooNumber.cxx:49
void setGlobalKillBelow(RooFit::MsgLevel level)
void RandomizeCollection(RooAbsCollection &set, Bool_t randomizeConstants=kTRUE)
Definition: RooStatsUtils.h:99
const Bool_t kFALSE
Definition: RtypesCore.h:88
PyObject * fType
Stores the steps in a Markov Chain of points.
Definition: MarkovChain.h:30
Namespace for the RooStats classes.
Definition: Asimov.h:20
#define ClassImp(name)
Definition: Rtypes.h:359
static Double_t uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition: RooRandom.cxx:84
double Double_t
Definition: RtypesCore.h:55
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:53
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
Definition: MarkovChain.cxx:77
Mother of all ROOT objects.
Definition: TObject.h:37
const Bool_t kTRUE
Definition: RtypesCore.h:87