
import ROOT

# User configuration parameters
ntoys = 6000
nToysRatio = 20  # ratio Ntoys Null/ntoys ALT

# -------------------------------------------------------
# A New Test Statistic Class for this example.
# It simply returns the sum of the values in a particular
# column of a dataset.
# You can ignore this class and focus on the macro below

ROOT.gInterpreter.Declare(
    """
using namespace RooFit;
using namespace RooStats;

class BinCountTestStat : public TestStatistic {
public:
   BinCountTestStat(void) : fColumnName("tmp") {}
   BinCountTestStat(string columnName) : fColumnName(columnName) {}

   virtual Double_t Evaluate(RooAbsData &data, RooArgSet & /*nullPOI*/)
   {
      // This is the main method in the interface
      Double_t value = 0.0;
      for (int i = 0; i < data.numEntries(); i++) {
         value += data.get(i)->getRealValue(fColumnName.c_str());
      }
      return value;
   }
   virtual const TString GetVarName() const { return fColumnName; }

private:
   string fColumnName;

protected:
   ClassDef(BinCountTestStat, 1)
};
"""
)


# -----------------------------
# The Actual Tutorial Macro
# -----------------------------

# This tutorial has 6 parts
# Table of Contents
# Setup
#   1. Make the model for the 'prototype problem'
# Special cases
#   2. NOT RELEVANT HERE
#   3. Use RooStats analytic solution for this problem
# RooStats HybridCalculator -- can be generalized
#   4. RooStats ToyMC version of 2. & 3.
#   5. RooStats ToyMC with an equivalent test statistic
#   6. RooStats ToyMC with simultaneous control & main measurement

t = ROOT.TStopwatch()
t.Start()
c = ROOT.TCanvas()
c.Divide(2, 2)

# -----------------------------------------------------
# P A R T   1  :  D I R E C T   I N T E G R A T I O N
# ====================================================
# Make model for prototype on/off problem
# Pois(x | s+b) * Pois(y | tau b )
# for Z_Gamma, use uniform prior on b.
w = ROOT.RooWorkspace("w")

# replace the pdf in 'number counting form'
# w.factory("Poisson::px(x[150,0,500],sum::splusb(s[0,0,100],b[100,0,300]))")
# with one in standard form.  Now x is encoded in event count
w.factory("Uniform::f(m[0,1])")  # m is a dummy discriminating variable
w.factory("ExtendPdf::px(f,sum::splusb(s[0,0,100],b[100,0.1,300]))")
w.factory("Poisson::py(y[100,0.1,500],prod::taub(tau[1.],b))")
w.factory("PROD::model(px,py)")
w.factory("Uniform::prior_b(b)")

# We will control the output level in a few places to avoid
# verbose progress messages.  We start by keeping track
# of the current threshold on messages.
msglevel = ROOT.RooMsgService.instance().globalKillBelow()

# -----------------------------------------------
# P A R T   3  :  A N A L Y T I C   R E S U L T
# ==============================================
# In this special case, the integrals are known analytically
# and they are implemented in RooStats::NumberCountingUtils

# analytic Z_Bi
p_Bi = ROOT.RooStats.NumberCountingUtils.BinomialWithTauObsP(150, 100, 1)
Z_Bi = ROOT.RooStats.NumberCountingUtils.BinomialWithTauObsZ(150, 100, 1)
print("-----------------------------------------")
print("Part 3")
print(f"Z_Bi p-value (analytic): {p_Bi}")
print(f"Z_Bi significance (analytic) {Z_Bi}")
t.Stop()
t.Print()
t.Reset()
t.Start()

# --------------------------------------------------------------
# P A R T   4  :  U S I N G   H Y B R I D   C A L C U L A T O R
# ==============================================================
# Now we demonstrate the RooStats HybridCalculator.
#
# Like all RooStats calculators it needs the data and a ModelConfig
# for the relevant hypotheses.  Since we are doing hypothesis testing
# we need a ModelConfig for the null (background only) and the alternate
# (signal+background) hypotheses.  We also need to specify the PDF,
# the parameters of interest, and the observables.  Furthermore, since
# the parameter of interest is floating, we need to specify which values
# of the parameter corresponds to the null and alternate (eg. s=0 and s=50)
#
# define some sets of variables obs={x} and poi={s}
# note here, x is the only observable in the main measurement
# and y is treated as a separate measurement, which is used
# to produce the prior that will be used in this calculation
# to randomize the nuisance parameters.
w.defineSet("obs", "m")
w.defineSet("poi", "s")

# create a toy dataset with the x=150
#  data = ROOT.RooDataSet("d", "d", w.set("obs"))
#  data.add(w.set("obs"))
data = w.pdf("px").generate(w.set("obs"), 150)

# Part 3a : Setup ModelConfigs
# -------------------------------------------------------
# create the null (background-only) ModelConfig with s=0
b_model = ROOT.RooStats.ModelConfig("B_model", w)
b_model.SetPdf(w.pdf("px"))
b_model.SetObservables(w.set("obs"))
b_model.SetParametersOfInterest(w.set("poi"))
w.var("s").setVal(0.0)  # important!
b_model.SetSnapshot(w.set("poi"))

# create the alternate (signal+background) ModelConfig with s=50
sb_model = ROOT.RooStats.ModelConfig("S+B_model", w)
sb_model.SetPdf(w.pdf("px"))
sb_model.SetObservables(w.set("obs"))
sb_model.SetParametersOfInterest(w.set("poi"))
w.var("s").setVal(50.0)  # important!
sb_model.SetSnapshot(w.set("poi"))

# Part 3b : Choose Test Statistic
# --------------------------------------------------------------
# To make an equivalent calculation we need to use x as the test
# statistic.  This is not a built-in test statistic in RooStats
# so we define it above.  The new class inherits from the
# RooStats.TestStatistic interface, and simply returns the value
# of x in the dataset.

eventCount = ROOT.RooStats.NumEventsTestStat(w.pdf("px"))

# Part 3c : Define Prior used to randomize nuisance parameters
# -------------------------------------------------------------
#
# The prior used for the hybrid calculator is the posterior
# from the auxiliary measurement y.  The model for the aux.
# measurement is Pois(y|tau*b), thus the likelihood function
# is proportional to (has the form of) a Gamma distribution.
# if the 'original prior' $\eta(b)$ is uniform, then from
# Bayes's theorem we have the posterior:
#  $\pi(b) = Pois(y|tau*b) * \eta(b)$
# If $\eta(b)$ is flat, then we arrive at a Gamma distribution.
# Since RooFit will normalize the PDF we can actually supply
# py=Pois(y,tau*b) that will be equivalent to multiplying by a Uniform.
#
# Alternatively, we could explicitly use a gamma distribution:
#
# `w.factory("Gamma::gamma(b,sum::temp(y,1),1,0)")`
#
# or we can use some other ad hoc prior that do not naturally
# follow from the known form of the auxiliary measurement.
# The common choice is the equivalent Gaussian:
w.factory("Gaussian::gauss_prior(b,y, expr::sqrty('sqrt(y)',y))")
# this corresponds to the "Z_N" calculation.
#
# or one could use the analogous log-normal prior
w.factory("Lognormal::lognorm_prior(b,y, expr::kappa('1+1./sqrt(y)',y))")
#
# Ideally, the HybridCalculator would be able to inspect the full
# model Pois(x | s+b) * Pois(y | tau b ) and be given the original
# prior $\eta(b)$ to form $\pi(b) = Pois(y|tau*b) * \eta(b)$.
# This is not yet implemented because in the general case
# it is not easy to identify the terms in the PDF that correspond
# to the auxiliary measurement.  So for now, it must be set
# explicitly with:
#  - ForcePriorNuisanceNull()
#  - ForcePriorNuisanceAlt()
# the name "ForcePriorNuisance" was chosen because we anticipate
# this to be auto-detected, but will leave the option open
# to force to a different prior for the nuisance parameters.

# Part 3d : Construct and configure the HybridCalculator
# -------------------------------------------------------

hc1 = ROOT.RooStats.HybridCalculator(data, sb_model, b_model)
toymcs1 = hc1.GetTestStatSampler()
#  toymcs1.SetNEventsPerToy(1) # because the model is in number counting form
toymcs1.SetTestStatistic(eventCount)  # set the test statistic
#  toymcs1.SetGenerateBinned()
hc1.SetToys(ntoys, ntoys // nToysRatio)
hc1.ForcePriorNuisanceAlt(w.pdf("py"))
hc1.ForcePriorNuisanceNull(w.pdf("py"))
# if you wanted to use the ad hoc Gaussian prior instead
# ~~~
#  hc1.ForcePriorNuisanceAlt(w.pdf("gauss_prior"))
#  hc1.ForcePriorNuisanceNull(w.pdf("gauss_prior"))
# ~~~
# if you wanted to use the ad hoc log-normal prior instead
# ~~~
#  hc1.ForcePriorNuisanceAlt(w.pdf("lognorm_prior"))
#  hc1.ForcePriorNuisanceNull(w.pdf("lognorm_prior"))
# ~~~

# these lines save current msg level and then kill any messages below ERROR
ROOT.RooMsgService.instance().setGlobalKillBelow(ROOT.RooFit.ERROR)
# Get the result
r1 = hc1.GetHypoTest()
ROOT.RooMsgService.instance().setGlobalKillBelow(msglevel)  # set it back
print("-----------------------------------------")
print("Part 4")
r1.Print()
t.Stop()
t.Print()
t.Reset()
t.Start()

c.cd(2)
p1 = ROOT.RooStats.HypoTestPlot(r1, 30)  # 30 bins, TS is discrete
p1.Draw()

quit()  # keep the running time short by default
# -------------------------------------------------------------------------
# # P A R T   5  :  U S I N G   H Y B R I D   C A L C U L A T O R   W I T H
# # A N   A L T E R N A T I V E   T E S T   S T A T I S T I C
#
# A likelihood ratio test statistics should be 1-to-1 with the count x
# when the value of b is fixed in the likelihood.  This is implemented
# by the SimpleLikelihoodRatioTestStat

slrts = ROOT.RooStats.SimpleLikelihoodRatioTestStat(b_model.GetPdf(), sb_model.GetPdf())
slrts.SetNullParameters(b_model.GetSnapshot())
slrts.SetAltParameters(sb_model.GetSnapshot())

# HYBRID CALCULATOR
hc2 = ROOT.RooStats.HybridCalculator(data, sb_model, b_model)
toymcs2 = ROOT.RooStats.ToyMCSampler()
toymcs2 = hc2.GetTestStatSampler()
#  toymcs2.SetNEventsPerToy(1)
toymcs2.SetTestStatistic(slrts)
#  toymcs2.SetGenerateBinned()
hc2.SetToys(ntoys, ntoys // nToysRatio)
hc2.ForcePriorNuisanceAlt(w.pdf("py"))
hc2.ForcePriorNuisanceNull(w.pdf("py"))
# if you wanted to use the ad hoc Gaussian prior instead
# ~~~
#  hc2.ForcePriorNuisanceAlt(w.pdf("gauss_prior"))
#  hc2.ForcePriorNuisanceNull(w.pdf("gauss_prior"))
# ~~~
# if you wanted to use the ad hoc log-normal prior instead
# ~~~
#  hc2.ForcePriorNuisanceAlt(w.pdf("lognorm_prior"))
#  hc2.ForcePriorNuisanceNull(w.pdf("lognorm_prior"))
# ~~~

# these lines save current msg level and then kill any messages below ERROR
ROOT.RooMsgService.instance().setGlobalKillBelow(ROOT.RooFit.ERROR)
# Get the result
r2 = hc2.GetHypoTest()
print("-----------------------------------------")
print("Part 5")
r2.Print()
t.Stop()
t.Print()
t.Reset()
t.Start()
ROOT.RooMsgService.instance().setGlobalKillBelow(msglevel)

c.cd(3)
p2 = ROOT.RooStats.HypoTestPlot(r2, 30)  # 30 bins
p2.Draw()

quit()  # so standard tutorial runs faster

# ---------------------------------------------
# OUTPUT (2.66 GHz Intel Core i7)
# ============================================

# -----------------------------------------
# Part 3
# Z_Bi p-value (analytic): 0.00094165
# Z_Bi significance (analytic): 3.10804
# Real time 0:00:00, CP time 0.610

# Results HybridCalculator_result:
# - Null p-value = 0.00103333 +/- 0.000179406
# - Significance = 3.08048 sigma
# - Number of S+B toys: 1000
# - Number of B toys: 30000
# - Test statistic evaluated on data: 150
# - CL_b: 0.998967 +/- 0.000185496
# - CL_s+b: 0.495 +/- 0.0158106
# - CL_s: 0.495512 +/- 0.0158272
# Real time 0:04:43, CP time 283.780

# -------------------------------------------------------
# Comparison
# -------------------------------------------------------
# LEPStatToolsForLHC
# https:#plone4.fnal.gov:4430/P0/phystat/packages/0703002
# Uses Gaussian prior
# CL_b = 6.218476e-04, Significance = 3.228665 sigma
#
# -------------------------------------------------------
# Comparison
# -------------------------------------------------------
# Asymptotic
# From the value of the profile likelihood ratio (5.0338)
# The significance can be estimated using Wilks's theorem
# significance = sqrt(2*profileLR) = 3.1729 sigma
