ROOT logo
// @(#)root/roostats:$Id: MCMCCalculator.cxx 28978 2009-06-17 14:33:31Z kbelasco $
// Authors: Kevin Belasco        17/06/2009
// Authors: Kyle Cranmer         17/06/2009
/*************************************************************************
 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers.               *
 * All rights reserved.                                                  *
 *                                                                       *
 * For the licensing terms see $ROOTSYS/LICENSE.                         *
 * For the list of contributors see $ROOTSYS/README/CREDITS.             *
 *************************************************************************/

//_________________________________________________
/*
BEGIN_HTML
<p>
MCMCCalculator is a concrete implementation of IntervalCalculator.
It creates a Markov Chain of data points using Monte Carlo, implementing the
Metropolis algorithm.  From this Markov Chain, this class can generate a
MCMCInterval as per user specification.
</p>

<p>
Note: Currently the Markov Chain is created within this class, but this
feature will be factored out in future implementations so that Markov
Chains can be generated for other purposes.
</p>

<p>
In the main algorithm, new points in the space of parameters
are proposed and then visited based on their relative likelihoods.
This class can accept any implementation of the ProposalFunction interface,
including non-symmetric proposal functions, and still maintain detailed balance.
</p>

<p>
The interface allows one to pass the model, data, and parameters via a workspace and
then specify them with names.
</p>
<p>
After configuring the calculator, one only needs to ask GetInterval(), which will
return an ConfInterval (MCMCInterval in this case) pointer.
</p>
END_HTML
*/
//_________________________________________________

#ifndef RooStats_RooStatsUtils
#include "RooStats/RooStatsUtils.h"
#endif

#ifndef ROOT_Rtypes
#include "Rtypes.h"
#endif

#include "RooRealVar.h"
#include "RooNLLVar.h"
#include "RooGlobalFunc.h"
#include "RooDataSet.h"
#include "RooArgSet.h"
#include "RooArgList.h"
#include "TRandom.h"
#include "TH1.h"
#include "TMath.h"
#include "TFile.h"
#include "RooStats/MCMCCalculator.h"
#include "RooStats/MCMCInterval.h"

ClassImp(RooStats::MCMCCalculator);

using namespace RooFit;
using namespace RooStats;

MCMCCalculator::MCMCCalculator()
{
   // default constructor
   fWS = NULL;
   fPOI = NULL;
   fNuisParams = NULL;
   fOwnsWorkspace = false;
   fPropFunc = NULL;
   fPdfName = NULL;
   fDataName = NULL;
   fNumIters = 0;
   fNumBurnInSteps = 0;
   fNumBins = 0;
   fAxes = NULL;
}

MCMCCalculator::MCMCCalculator(RooWorkspace& ws, RooAbsData& data,
                RooAbsPdf& pdf, RooArgSet& paramsOfInterest,
                ProposalFunction& proposalFunction, Int_t numIters,
                RooArgList* axes, Double_t size)
{
   fOwnsWorkspace = false;
   SetWorkspace(ws);
   SetData(data);
   SetPdf(pdf);
   SetParameters(paramsOfInterest);
   SetTestSize(size);
   SetProposalFunction(proposalFunction);
   fNumIters = numIters;
   fNumBurnInSteps = 0;
   fNumBins = 0;
   fAxes = axes;
}

MCMCCalculator::MCMCCalculator(RooAbsData& data, RooAbsPdf& pdf,
                RooArgSet& paramsOfInterest, ProposalFunction& proposalFunction,
                Int_t numIters, RooArgList* axes, Double_t size)
{
   // alternate constructor
   fWS = new RooWorkspace();
   fOwnsWorkspace = true;
   SetData(data);
   SetPdf(pdf);
   SetParameters(paramsOfInterest);
   SetTestSize(size);
   SetProposalFunction(proposalFunction);
   fNumIters = numIters;
   fNumBurnInSteps = 0;
   fNumBins = 0;
   fAxes = axes;
}

MCMCInterval* MCMCCalculator::GetInterval() const
{
   // Main interface to get a RooStats::ConfInterval.  

   RooAbsPdf* pdf = fWS->pdf(fPdfName);
   RooAbsData* data = fWS->data(fDataName);
   if (!data || !pdf || !fPOI) return 0;

   RooArgSet x;
   RooArgSet xPrime;
   RooRealVar* w = new RooRealVar("w", "weight", 0);
   RooArgSet* parameters = pdf->getParameters(data);
   RemoveConstantParameters(parameters);
   x.addClone(*parameters);
   x.addOwned(*w);
   xPrime.addClone(*parameters);

   RooDataSet* points = new RooDataSet("points", "Markov Chain", x, WeightVar(*w));

   TRandom gen;
   RooArgSet* constrainedParams = pdf->getParameters(*data);
   RooAbsReal* nll = pdf->createNLL(*data, Constrain(*constrainedParams));
   delete constrainedParams;

   RooArgSet* nllParams = nll->getParameters(*data);
   Int_t weight = 0;

   for (int i = 0; i < fNumIters; i++) {
     //       cout << "Iteration # " << i << endl;
     if (i % 100 == 0){
       fprintf(stdout, ".");
       fflush(NULL);
     }

      fPropFunc->Propose(xPrime, x);

      RooStats::SetParameters(&xPrime, nllParams);
      Double_t xPrimeNLL = nll->getVal();
      RooStats::SetParameters(&x, nllParams);
      Double_t xNLL = nll->getVal();
      Double_t diff = xPrimeNLL - xNLL;

      if (!fPropFunc->IsSymmetric(xPrime, x))
         diff += TMath::Log(fPropFunc->GetProposalDensity(xPrime, x)) - 
                 TMath::Log(fPropFunc->GetProposalDensity(x, xPrime));

      if (diff < 0.0) {
         // The proposed point (xPrime) has a higher likelihood than the
         // current (x), so go there

         // add the current point with the current weight
         points->addFast(x, (Double_t)weight);
         // reset the weight and go to xPrime
         weight = 1;
         RooStats::SetParameters(&xPrime, &x);
      }
      else {
         // generate numbers on a log distribution to decide
         // whether to go to xPrime or stay at x
         Double_t rand = TMath::Log(gen.Uniform(1.0));
         if (-1.0 * rand >= diff) {
            // we chose to go to the new proposed point xPrime
            // even though it has a lower likelihood than x

            // add the current point with the current weight
            points->addFast(x, (Double_t)weight);
            // reset the weight and go to xPrime
            weight = 1;
            RooStats::SetParameters(&xPrime, &x);
         } else {
            // stay at current point x
            weight++;
         }
      }
   }
   delete nllParams;
   printf("\n");
   // make sure to add the last point
   points->addFast(x, (Double_t)weight);

   //TFile chainDataFile("chainData.root", "recreate");
   //points->Write();
   //chainDataFile.Close();

   MCMCInterval* interval = new MCMCInterval("mcmcinterval", "MCMCInterval",
                                             *fPOI, *points);
   if (fAxes != NULL)
      interval->SetAxes(*fAxes);
   if (fNumBins > 0)
      interval->SetNumBins(fNumBins);
   if (fNumBurnInSteps > 0)
      interval->SetNumBurnInSteps(fNumBurnInSteps);
   interval->SetConfidenceLevel(1.0 - fSize);
   return interval;
}
 MCMCCalculator.cxx:1
 MCMCCalculator.cxx:2
 MCMCCalculator.cxx:3
 MCMCCalculator.cxx:4
 MCMCCalculator.cxx:5
 MCMCCalculator.cxx:6
 MCMCCalculator.cxx:7
 MCMCCalculator.cxx:8
 MCMCCalculator.cxx:9
 MCMCCalculator.cxx:10
 MCMCCalculator.cxx:11
 MCMCCalculator.cxx:12
 MCMCCalculator.cxx:13
 MCMCCalculator.cxx:14
 MCMCCalculator.cxx:15
 MCMCCalculator.cxx:16
 MCMCCalculator.cxx:17
 MCMCCalculator.cxx:18
 MCMCCalculator.cxx:19
 MCMCCalculator.cxx:20
 MCMCCalculator.cxx:21
 MCMCCalculator.cxx:22
 MCMCCalculator.cxx:23
 MCMCCalculator.cxx:24
 MCMCCalculator.cxx:25
 MCMCCalculator.cxx:26
 MCMCCalculator.cxx:27
 MCMCCalculator.cxx:28
 MCMCCalculator.cxx:29
 MCMCCalculator.cxx:30
 MCMCCalculator.cxx:31
 MCMCCalculator.cxx:32
 MCMCCalculator.cxx:33
 MCMCCalculator.cxx:34
 MCMCCalculator.cxx:35
 MCMCCalculator.cxx:36
 MCMCCalculator.cxx:37
 MCMCCalculator.cxx:38
 MCMCCalculator.cxx:39
 MCMCCalculator.cxx:40
 MCMCCalculator.cxx:41
 MCMCCalculator.cxx:42
 MCMCCalculator.cxx:43
 MCMCCalculator.cxx:44
 MCMCCalculator.cxx:45
 MCMCCalculator.cxx:46
 MCMCCalculator.cxx:47
 MCMCCalculator.cxx:48
 MCMCCalculator.cxx:49
 MCMCCalculator.cxx:50
 MCMCCalculator.cxx:51
 MCMCCalculator.cxx:52
 MCMCCalculator.cxx:53
 MCMCCalculator.cxx:54
 MCMCCalculator.cxx:55
 MCMCCalculator.cxx:56
 MCMCCalculator.cxx:57
 MCMCCalculator.cxx:58
 MCMCCalculator.cxx:59
 MCMCCalculator.cxx:60
 MCMCCalculator.cxx:61
 MCMCCalculator.cxx:62
 MCMCCalculator.cxx:63
 MCMCCalculator.cxx:64
 MCMCCalculator.cxx:65
 MCMCCalculator.cxx:66
 MCMCCalculator.cxx:67
 MCMCCalculator.cxx:68
 MCMCCalculator.cxx:69
 MCMCCalculator.cxx:70
 MCMCCalculator.cxx:71
 MCMCCalculator.cxx:72
 MCMCCalculator.cxx:73
 MCMCCalculator.cxx:74
 MCMCCalculator.cxx:75
 MCMCCalculator.cxx:76
 MCMCCalculator.cxx:77
 MCMCCalculator.cxx:78
 MCMCCalculator.cxx:79
 MCMCCalculator.cxx:80
 MCMCCalculator.cxx:81
 MCMCCalculator.cxx:82
 MCMCCalculator.cxx:83
 MCMCCalculator.cxx:84
 MCMCCalculator.cxx:85
 MCMCCalculator.cxx:86
 MCMCCalculator.cxx:87
 MCMCCalculator.cxx:88
 MCMCCalculator.cxx:89
 MCMCCalculator.cxx:90
 MCMCCalculator.cxx:91
 MCMCCalculator.cxx:92
 MCMCCalculator.cxx:93
 MCMCCalculator.cxx:94
 MCMCCalculator.cxx:95
 MCMCCalculator.cxx:96
 MCMCCalculator.cxx:97
 MCMCCalculator.cxx:98
 MCMCCalculator.cxx:99
 MCMCCalculator.cxx:100
 MCMCCalculator.cxx:101
 MCMCCalculator.cxx:102
 MCMCCalculator.cxx:103
 MCMCCalculator.cxx:104
 MCMCCalculator.cxx:105
 MCMCCalculator.cxx:106
 MCMCCalculator.cxx:107
 MCMCCalculator.cxx:108
 MCMCCalculator.cxx:109
 MCMCCalculator.cxx:110
 MCMCCalculator.cxx:111
 MCMCCalculator.cxx:112
 MCMCCalculator.cxx:113
 MCMCCalculator.cxx:114
 MCMCCalculator.cxx:115
 MCMCCalculator.cxx:116
 MCMCCalculator.cxx:117
 MCMCCalculator.cxx:118
 MCMCCalculator.cxx:119
 MCMCCalculator.cxx:120
 MCMCCalculator.cxx:121
 MCMCCalculator.cxx:122
 MCMCCalculator.cxx:123
 MCMCCalculator.cxx:124
 MCMCCalculator.cxx:125
 MCMCCalculator.cxx:126
 MCMCCalculator.cxx:127
 MCMCCalculator.cxx:128
 MCMCCalculator.cxx:129
 MCMCCalculator.cxx:130
 MCMCCalculator.cxx:131
 MCMCCalculator.cxx:132
 MCMCCalculator.cxx:133
 MCMCCalculator.cxx:134
 MCMCCalculator.cxx:135
 MCMCCalculator.cxx:136
 MCMCCalculator.cxx:137
 MCMCCalculator.cxx:138
 MCMCCalculator.cxx:139
 MCMCCalculator.cxx:140
 MCMCCalculator.cxx:141
 MCMCCalculator.cxx:142
 MCMCCalculator.cxx:143
 MCMCCalculator.cxx:144
 MCMCCalculator.cxx:145
 MCMCCalculator.cxx:146
 MCMCCalculator.cxx:147
 MCMCCalculator.cxx:148
 MCMCCalculator.cxx:149
 MCMCCalculator.cxx:150
 MCMCCalculator.cxx:151
 MCMCCalculator.cxx:152
 MCMCCalculator.cxx:153
 MCMCCalculator.cxx:154
 MCMCCalculator.cxx:155
 MCMCCalculator.cxx:156
 MCMCCalculator.cxx:157
 MCMCCalculator.cxx:158
 MCMCCalculator.cxx:159
 MCMCCalculator.cxx:160
 MCMCCalculator.cxx:161
 MCMCCalculator.cxx:162
 MCMCCalculator.cxx:163
 MCMCCalculator.cxx:164
 MCMCCalculator.cxx:165
 MCMCCalculator.cxx:166
 MCMCCalculator.cxx:167
 MCMCCalculator.cxx:168
 MCMCCalculator.cxx:169
 MCMCCalculator.cxx:170
 MCMCCalculator.cxx:171
 MCMCCalculator.cxx:172
 MCMCCalculator.cxx:173
 MCMCCalculator.cxx:174
 MCMCCalculator.cxx:175
 MCMCCalculator.cxx:176
 MCMCCalculator.cxx:177
 MCMCCalculator.cxx:178
 MCMCCalculator.cxx:179
 MCMCCalculator.cxx:180
 MCMCCalculator.cxx:181
 MCMCCalculator.cxx:182
 MCMCCalculator.cxx:183
 MCMCCalculator.cxx:184
 MCMCCalculator.cxx:185
 MCMCCalculator.cxx:186
 MCMCCalculator.cxx:187
 MCMCCalculator.cxx:188
 MCMCCalculator.cxx:189
 MCMCCalculator.cxx:190
 MCMCCalculator.cxx:191
 MCMCCalculator.cxx:192
 MCMCCalculator.cxx:193
 MCMCCalculator.cxx:194
 MCMCCalculator.cxx:195
 MCMCCalculator.cxx:196
 MCMCCalculator.cxx:197
 MCMCCalculator.cxx:198
 MCMCCalculator.cxx:199
 MCMCCalculator.cxx:200
 MCMCCalculator.cxx:201
 MCMCCalculator.cxx:202
 MCMCCalculator.cxx:203
 MCMCCalculator.cxx:204
 MCMCCalculator.cxx:205
 MCMCCalculator.cxx:206
 MCMCCalculator.cxx:207
 MCMCCalculator.cxx:208
 MCMCCalculator.cxx:209
 MCMCCalculator.cxx:210
 MCMCCalculator.cxx:211
 MCMCCalculator.cxx:212
 MCMCCalculator.cxx:213
 MCMCCalculator.cxx:214
 MCMCCalculator.cxx:215
 MCMCCalculator.cxx:216
 MCMCCalculator.cxx:217
 MCMCCalculator.cxx:218
 MCMCCalculator.cxx:219