
import ROOT

doBayesian = False
doFeldmanCousins = False
doMCMC = False

# let's time this challenging example
t = ROOT.TStopwatch()
t.Start()

# set RooFit random seed for reproducible results
ROOT.RooRandom.randomGenerator().SetSeed(4357)

# make model
wspace = ROOT.RooWorkspace("wspace")
wspace.factory("Poisson::on(non[0,1000], sum::splusb(s[40,0,100],b[100,0,300]))")
wspace.factory("Poisson::off(noff[0,5000], prod::taub(b,tau[5,3,7],rho[1,0,2]))")
wspace.factory("Poisson::onbar(nonbar[0,10000], bbar[1000,500,2000])")
wspace.factory("Poisson::offbar(noffbar[0,1000000], prod::lambdaoffbar(bbar, tau))")
wspace.factory("Gaussian::mcCons(rhonom[1.,0,2], rho, sigma[.2])")
wspace.factory("PROD::model(on,off,onbar,offbar,mcCons)")
wspace.defineSet("obs", "non,noff,nonbar,noffbar,rhonom")

wspace.factory("Uniform::prior_poi({s})")
wspace.factory("Uniform::prior_nuis({b,bbar,tau, rho})")
wspace.factory("PROD::prior(prior_poi,prior_nuis)")

# ----------------------------------
# Control some interesting variations
# define parameers of interest
# for 1-d plots
wspace.defineSet("poi", "s")
wspace.defineSet("nuis", "b,tau,rho,bbar")

# for 2-d plots to inspect correlations:
# wspace.defineSet("poi","s,rho")

# test simpler cases where parameters are known.
# wspace["tau"].setConstant()
# wspace["rho"].setConstant()
# wspace["b"].setConstant()
# wspace["bbar"].setConstant()

# inspect workspace
# wspace.Print()

# ----------------------------------------------------------
# Generate toy data
# generate toy data assuming current value of the parameters
# import into workspace.
# add Verbose() to see how it's being generated
data = wspace["model"].generate(wspace.set("obs"), 1)
# data.Print("v")
wspace.Import(data)

# ----------------------------------
# Now the statistical tests
# model config
modelConfig = ROOT.RooStats.ModelConfig("FourBins")
modelConfig.SetWorkspace(wspace)
modelConfig.SetPdf(wspace["model"])
modelConfig.SetPriorPdf(wspace["prior"])
modelConfig.SetParametersOfInterest(wspace.set("poi"))
modelConfig.SetNuisanceParameters(wspace.set("nuis"))
wspace.Import(modelConfig)
# wspace.writeToFile("FourBin.root")

# -------------------------------------------------
# If you want to see the covariance matrix uncomment
# wspace["model"].fitTo(data)

# use ProfileLikelihood
plc = ROOT.RooStats.ProfileLikelihoodCalculator(data, modelConfig)
plc.SetConfidenceLevel(0.95)
plInt = plc.GetInterval()
msglevel = ROOT.RooMsgService.instance().globalKillBelow()
ROOT.RooMsgService.instance().setGlobalKillBelow(ROOT.RooFit.FATAL)
plInt.LowerLimit(wspace["s"])  # get ugly print out of the way. Fix.
ROOT.RooMsgService.instance().setGlobalKillBelow(msglevel)

# use FeldmaCousins (takes ~20 min)
fc = ROOT.RooStats.FeldmanCousins(data, modelConfig)
fc.SetConfidenceLevel(0.95)
# number counting: dataset always has 1 entry with N events observed
fc.FluctuateNumDataEntries(False)
fc.UseAdaptiveSampling(True)
fc.SetNBins(40)
fcInt = ROOT.RooStats.PointSetInterval()
if doFeldmanCousins:  # takes 7 minutes
    fcInt = fc.GetInterval()

# use BayesianCalculator (only 1-d parameter of interest, slow for this problem)
bc = ROOT.RooStats.BayesianCalculator(data, modelConfig)
bc.SetConfidenceLevel(0.95)
bInt = ROOT.RooStats.SimpleInterval()
if doBayesian and len(wspace.set("poi")) == 1:
    bInt = bc.GetInterval()
else:
    print("Bayesian Calc. only supports on parameter of interest")

# use MCMCCalculator  (takes about 1 min)
# Want an efficient proposal function, so derive it from covariance
# matrix of fit
fit = wspace["model"].fitTo(data, Save=True)
ph = ROOT.RooStats.ProposalHelper()
ph.SetVariables(fit.floatParsFinal())
ph.SetCovMatrix(fit.covarianceMatrix())
ph.SetUpdateProposalParameters(True)  # auto-create mean vars and add mappings
ph.SetCacheSize(100)
pf = ph.GetProposalFunction()

mc = ROOT.RooStats.MCMCCalculator(data, modelConfig)
mc.SetConfidenceLevel(0.95)
mc.SetProposalFunction(pf)
mc.SetNumBurnInSteps(500)  # first N steps to be ignored as burn-in
mc.SetNumIters(50000)
mc.SetLeftSideTailFraction(0.5)  # make a central interval
mcInt = ROOT.RooStats.MCMCInterval()
if doMCMC:
    mcInt = mc.GetInterval()

# ----------------------------------
# Make some plots
c1 = ROOT.gROOT.Get("c1")
if not c1:
    c1 = ROOT.TCanvas("c1")

if doBayesian and doMCMC:
    c1.Divide(3)
    c1.cd(1)
elif doBayesian or doMCMC:
    c1.Divide(2)
    c1.cd(1)

lrplot = ROOT.RooStats.LikelihoodIntervalPlot(plInt)
lrplot.Draw()

if doBayesian and len(wspace.set("poi")) == 1:
    c1.cd(2)
    # the plot takes a long time and print lots of error
    # using a scan it is better
    bc.SetScanOfPosterior(20)
    bplot = bc.GetPosteriorPlot()
    bplot.Draw()

if doMCMC:
    if doBayesian and len(wspace.set("poi")) == 1:
        c1.cd(3)
    else:
        c1.cd(2)
    mcPlot = ROOT.RooStats.MCMCIntervalPlot(mcInt)
    mcPlot.Draw()

# ----------------------------------
# query intervals
print(
    "Profile Likelihood interval on s = [{}, {}]".format(plInt.LowerLimit(wspace["s"]), plInt.UpperLimit(wspace["s"]))
)
# Profile Likelihood interval on s = [12.1902, 88.6871]

if doBayesian and len(wspace.set("poi")) == 1:
    print("Bayesian interval on s = [{}, {}]".format(bInt.LowerLimit(), bInt.UpperLimit()))

if doFeldmanCousins:
    print(
        "Feldman Cousins interval on s = [{}, {}]".format(fcInt.LowerLimit(wspace["s"]), fcInt.UpperLimit(wspace["s"]))
    )
    # Feldman Cousins interval on s = [18.75 +/- 2.45, 83.75 +/- 2.45]

if doMCMC:
    print("MCMC interval on s = [{}, {}]".format(mcInt.LowerLimit(wspace["s"]), mcInt.UpperLimit(wspace["s"])))
    # MCMC interval on s = [15.7628, 84.7266]

t.Stop()
t.Print()

c1.SaveAs("FourBinInstructional.png")

# TODO: The calculators have to be destructed first. Otherwise, we can get
# segmentation faults depending on the destruction order, which is random in
# Python. Probably the issue is that some object has a non-owning pointer to
# another object, which it uses in its destructor. This should be fixed either
# in the design of RooStats in C++, or with phythonizations.
del plc
del bc
del mc
