Logo ROOT  
Reference Guide
MetropolisHastings.cxx
Go to the documentation of this file.
1 // @(#)root/roostats:$Id$
2 // Authors: Kevin Belasco 17/06/2009
3 // Authors: Kyle Cranmer 17/06/2009
4 /*************************************************************************
5  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /** \class RooStats::MetropolisHastings
13  \ingroup Roostats
14 
15 This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16 of data points using Monte Carlo. In the main algorithm, new points in the
17 parameter space are proposed and then visited based on their relative
18 likelihoods. This class can use any implementation of the ProposalFunction,
19 including non-symmetric proposal functions, to propose parameter points and
20 still maintain detailed balance when constructing the chain.
21 
22 
23 
24 The "Likelihood" function that is sampled when deciding what steps to take in
25 the chain has been given a very generic implementation. The user can create
26 any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27 object with the method SetFunction(RooAbsReal&). Be sure to tell
28 MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29 so that it knows what logic to use when sampling your RooAbsReal. For example,
30 a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31 the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32 SetSign(MetropolisHastings::kNegative);
33 If you're using a traditional likelihood function:
34 SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35 You must set these type and sign flags or MetropolisHastings will not construct
36 a MarkovChain.
37 
38 Also note that in ConstructChain(), the values of the variables are randomized
39 uniformly over their intervals before construction of the MarkovChain begins.
40 
41 */
42 
44 
45 #include "RooStats/MarkovChain.h"
46 #include "RooStats/MCMCInterval.h"
47 #include "RooStats/RooStatsUtils.h"
49 
50 #include "Rtypes.h"
51 #include "RooRealVar.h"
52 #include "RooNLLVar.h"
53 #include "RooGlobalFunc.h"
54 #include "RooDataSet.h"
55 #include "RooArgSet.h"
56 #include "RooArgList.h"
57 #include "RooMsgService.h"
58 #include "RooRandom.h"
59 #include "TMath.h"
60 #include "TFile.h"
61 
63 
64 using namespace RooFit;
65 using namespace RooStats;
66 using namespace std;
67 
68 ////////////////////////////////////////////////////////////////////////////////
69 
70 MetropolisHastings::MetropolisHastings()
71 {
72  // default constructor
73  fFunction = NULL;
74  fPropFunc = NULL;
75  fNumIters = 0;
76  fNumBurnInSteps = 0;
77  fSign = kSignUnset;
78  fType = kTypeUnset;
79 }
80 
81 ////////////////////////////////////////////////////////////////////////////////
82 
83 MetropolisHastings::MetropolisHastings(RooAbsReal& function, const RooArgSet& paramsOfInterest,
84  ProposalFunction& proposalFunction, Int_t numIters)
85 {
86  fFunction = &function;
87  SetParameters(paramsOfInterest);
88  SetProposalFunction(proposalFunction);
89  fNumIters = numIters;
90  fNumBurnInSteps = 0;
91  fSign = kSignUnset;
92  fType = kTypeUnset;
93 }
94 
95 ////////////////////////////////////////////////////////////////////////////////
96 
97 MarkovChain* MetropolisHastings::ConstructChain()
98 {
99  if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
100  coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
101  " function, or (log) likelihood function" << endl;
102  return NULL;
103  }
104  if (fSign == kSignUnset || fType == kTypeUnset) {
105  coutE(Eval) << "Please set type and sign of your function using "
106  << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
107  endl;
108  return NULL;
109  }
110 
111  if (fChainParams.getSize() == 0) fChainParams.add(fParameters);
112 
113  RooArgSet x;
114  RooArgSet xPrime;
115  x.addClone(fParameters);
117  xPrime.addClone(fParameters);
118  RandomizeCollection(xPrime);
119 
120  MarkovChain* chain = new MarkovChain();
121  // only the POI will be added to the chain
122  chain->SetParameters(fChainParams);
123 
124  Int_t weight = 0;
125  Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
126 
127  // ibucur: i think the user should have the possibility to display all the message
128  // levels should they want to; maybe a setPrintLevel would be appropriate
129  // (maybe for the other classes that use this approach as well)?
132 
133  // We will need to check if log-likelihood evaluation left an error status.
134  // Now using faster eval error logging with CountErrors.
135  if (fType == kLog) {
137  //N.B: need to clear the count in case of previous errors !
138  // the clear needs also to be done after calling setEvalErrorLoggingMode
140  }
141 
142  bool hadEvalError = true;
143 
144  Int_t i = 0;
145  // get a good starting point for x
146  // for fType == kLog, this means that fFunction->getVal() did not cause
147  // an eval error
148  // for fType == kRegular this means fFunction->getVal() != 0
149  //
150  // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
151  // steps we should have to take for any reasonable (log) likelihood function
152  while (i < 1000 && hadEvalError) {
154  RooStats::SetParameters(&x, &fParameters);
155  xL = fFunction->getVal();
156 
157  if (fType == kLog) {
158  if (RooAbsReal::numEvalErrors() > 0) {
160  hadEvalError = true;
161  } else
162  hadEvalError = false;
163  } else if (fType == kRegular) {
164  if (xL == 0.0)
165  hadEvalError = true;
166  else
167  hadEvalError = false;
168  } else
169  // for now the only 2 types are kLog and kRegular (won't get here)
170  hadEvalError = false;
171  ++i;
172  }
173 
174  if(hadEvalError) {
175  coutE(Eval) << "Problem finding a good starting point in " <<
176  "MetropolisHastings::ConstructChain() " << endl;
177  }
178 
179 
180  ooccoutP((TObject *)0, Generation) << "Metropolis-Hastings progress: ";
181 
182  // do main loop
183  for (i = 0; i < fNumIters; i++) {
184  // reset error handling flag
185  hadEvalError = false;
186 
187  // print a dot every 1% of the chain construction
188  if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) << ".";
189 
190  fPropFunc->Propose(xPrime, x);
191 
192  RooStats::SetParameters(&xPrime, &fParameters);
193  xPrimeL = fFunction->getVal();
194 
195  // check if log-likelihood for xprime had an error status
196  if (fFunction->numEvalErrors() > 0 && fType == kLog) {
197  xPrimeL = RooNumber::infinity();
198  fFunction->clearEvalErrorLog();
199  hadEvalError = true;
200  }
201 
202  // why evaluate the last point again, can't we cache it?
203  // kbelasco: commenting out lines below to add/test caching support
204  //RooStats::SetParameters(&x, &fParameters);
205  //xL = fFunction->getVal();
206 
207  if (fType == kLog) {
208  if (fSign == kPositive)
209  a = xL - xPrimeL;
210  else
211  a = xPrimeL - xL;
212  }
213  else
214  a = xPrimeL / xL;
215  //a = xL / xPrimeL;
216 
217  if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
218  Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
219  Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
220  if (fType == kRegular)
221  a *= xPD / xPrimePD;
222  else
223  a += TMath::Log(xPrimePD) - TMath::Log(xPD);
224  }
225 
226  if (!hadEvalError && ShouldTakeStep(a)) {
227  // go to the proposed point xPrime
228 
229  // add the current point with the current weight
230  if (weight != 0.0)
231  chain->Add(x, CalcNLL(xL), (Double_t)weight);
232 
233  // reset the weight and go to xPrime
234  weight = 1;
235  RooStats::SetParameters(&xPrime, &x);
236  xL = xPrimeL;
237  } else {
238  // stay at the current point
239  weight++;
240  }
241  }
242 
243  // make sure to add the last point
244  if (weight != 0.0)
245  chain->Add(x, CalcNLL(xL), (Double_t)weight);
246  ooccoutP((TObject *)0, Generation) << endl;
247 
249 
250  Int_t numAccepted = chain->Size();
251  coutI(Eval) << "Proposal acceptance rate: " <<
252  numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
253  coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
254 
255  //TFile chainDataFile("chainData.root", "recreate");
256  //chain->GetDataSet()->Write();
257  //chainDataFile.Close();
258 
259  return chain;
260 }
261 
262 ////////////////////////////////////////////////////////////////////////////////
263 
264 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
265 {
266  if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
267  // The proposed point has a higher likelihood than the
268  // current point, so we should go there
269  return kTRUE;
270  }
271  else {
272  // generate numbers on a log distribution to decide
273  // whether to go to xPrime or stay at x
274  //Double_t rand = fGen.Uniform(1.0);
275  Double_t rand = RooRandom::uniform();
276  if (fType == kLog) {
277  rand = TMath::Log(rand);
278  // kbelasco: should this be changed to just (-rand > a) for logical
279  // consistency with below test when fType == kRegular?
280  if (-1.0 * rand >= a)
281  // we chose to go to the new proposed point
282  // even though it has a lower likelihood than the current one
283  return kTRUE;
284  } else {
285  // fType must be kRegular
286  // kbelasco: ensure that we never visit a point where PDF == 0
287  //if (rand <= a)
288  if (rand < a)
289  // we chose to go to the new proposed point
290  // even though it has a lower likelihood than the current one
291  return kTRUE;
292  }
293  return kFALSE;
294  }
295 }
296 
297 ////////////////////////////////////////////////////////////////////////////////
298 
299 Double_t MetropolisHastings::CalcNLL(Double_t xL)
300 {
301  if (fType == kLog) {
302  if (fSign == kNegative)
303  return xL;
304  else
305  return -xL;
306  } else {
307  if (fSign == kPositive)
308  return -1.0 * TMath::Log(xL);
309  else
310  return -1.0 * TMath::Log(-xL);
311  }
312 }
RooAbsReal::CountErrors
@ CountErrors
Definition: RooAbsReal.h:298
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
RooAbsReal::setEvalErrorLoggingMode
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
Definition: RooAbsReal.cxx:4838
RooMsgService.h
RooArgSet.h
RooStats::MarkovChain::Size
virtual Int_t Size() const
get the number of steps in the chain
Definition: MarkovChain.h:61
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
coutE
#define coutE(a)
Definition: RooMsgService.h:33
ooccoutP
#define ooccoutP(o, a)
Definition: RooMsgService.h:54
RooStats::MarkovChain
Definition: MarkovChain.h:36
Float_t
float Float_t
Definition: RtypesCore.h:57
x
Double_t x[n]
Definition: legend1.C:17
SetParameters
void SetParameters(TFitEditor::FuncParams_t &pars, TF1 *func)
Restore the parameters from pars into the function.
Definition: TFitEditor.cxx:277
coutI
#define coutI(a)
Definition: RooMsgService.h:30
RooAbsReal
Definition: RooAbsReal.h:61
RooFit::MsgLevel
MsgLevel
Verbosity level for RooMsgService::StreamConfig in RooMsgService.
Definition: RooGlobalFunc.h:65
RooDataSet.h
RooAbsReal::clearEvalErrorLog
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
Definition: RooAbsReal.cxx:3805
TFile.h
RooNLLVar.h
bool
RooStats::MetropolisHastings
Definition: MetropolisHastings.h:30
MarkovChain.h
RooFit
Definition: RooCFunction1Binding.h:29
a
auto * a
Definition: textangle.C:12
RooFit::Generation
@ Generation
Definition: RooGlobalFunc.h:67
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
RooRandom.h
MCMCInterval.h
RooMsgService::setGlobalKillBelow
void setGlobalKillBelow(RooFit::MsgLevel level)
Definition: RooMsgService.h:160
RooStats::MarkovChain::Add
virtual void Add(RooArgSet &entry, Double_t nllValue, Double_t weight=1.0)
safely add an entry to the chain
Definition: MarkovChain.cxx:102
RooRealVar.h
RooFit::PROGRESS
@ PROGRESS
Definition: RooGlobalFunc.h:65
RooGlobalFunc.h
RooStatsUtils.h
MetropolisHastings.h
RooStats::MarkovChain::SetParameters
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
Definition: MarkovChain.cxx:77
RooStats::RandomizeCollection
void RandomizeCollection(RooAbsCollection &set, Bool_t randomizeConstants=kTRUE)
Definition: RooStatsUtils.h:106
RooNumber::infinity
static Double_t infinity()
Return internal infinity representation.
Definition: RooNumber.cxx:49
Double_t
double Double_t
Definition: RtypesCore.h:59
RooStats
Definition: Asimov.h:19
ProposalFunction.h
TObject
Definition: TObject.h:37
RooStats::SetParameters
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
Definition: RooStatsUtils.h:65
RooStats::ProposalFunction
Definition: ProposalFunction.h:48
RooArgSet::addClone
virtual void addClone(const RooAbsCollection &col, Bool_t silent=kFALSE)
Add a collection of arguments to this collection by calling addOwned() for each element in the source...
Definition: RooArgSet.h:96
RooAbsReal::numEvalErrors
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
Definition: RooAbsReal.cxx:3867
RooFit::Eval
@ Eval
Definition: RooGlobalFunc.h:68
RooRandom::uniform
static Double_t uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition: RooRandom.cxx:83
RooArgList.h
RooMsgService::instance
static RooMsgService & instance()
Return reference to singleton instance.
Definition: RooMsgService.cxx:363
Rtypes.h
RooMsgService::globalKillBelow
RooFit::MsgLevel globalKillBelow() const
Definition: RooMsgService.h:161
TMath.h
RooArgSet
Definition: RooArgSet.h:28
int