
import ROOT

# -------------------------------------------------------
# The actual macro


def OneSidedFrequentistUpperLimitWithBands(
    infile="", workspaceName="combined", modelConfigName="ModelConfig", dataName="obsData"
):

    confidenceLevel = 0.95
    nPointsToScan = 12
    nToyMC = 150

    # -------------------------------------------------------
    # First part is just to access a user-defined file
    # or create the standard example file if it doesn't exist
    filename = ""
    if infile == "":
        filename = "results/example_combined_GaussExample_model.root"
        fileExist = not ROOT.gSystem.AccessPathName(filename)  # note opposite return code
        # if file does not exists generate with histfactory
        if not fileExist:
            # Normally this would be run on the command line
            print(f"will run standard hist2workspace example")
            ROOT.gROOT.ProcessLine(".not  prepareHistFactory .")
            ROOT.gROOT.ProcessLine(".not  hist2workspace config/example.xml")
            print(f"\n\n---------------------")
            print(f"Done creating example input")
            print(f"---------------------\n\n")

    else:
        filename = infile

    # Try to open the file
    file = ROOT.TFile.Open(filename)

    # if input file was specified but not found, quit
    if not file:
        print(f"StandardRooStatsDemoMacro: Input file {filename} is not found")
        return

    # -------------------------------------------------------
    # Now get the data and workspace

    # get the workspace out of the file
    w = file.Get(workspaceName)
    global gw
    gw = w
    global gfile
    gfile = file

    if not w:
        print(f"workspace not found")
        return

    # get the modelConfig out of the file
    mc = w.obj(modelConfigName)

    # get the modelConfig out of the file
    data = w.data(dataName)

    # make sure ingredients are found
    if not data or not mc:
        w.Print()
        print(f"data or ModelConfig was not found")
        return

    # -------------------------------------------------------
    # Now get the POI for convenience
    # you may want to adjust the range of your POI

    firstPOI = mc.GetParametersOfInterest().first()
    #  firstPOI->setMin(0);
    #  firstPOI->setMax(10);

    # --------------------------------------------
    # Create and use the FeldmanCousins tool
    # to find and plot the 95% confidence interval
    # on the parameter of interest as specified
    # in the model config
    # REMEMBER, we will change the test statistic
    # so this is NOT a Feldman-Cousins interval
    fc = ROOT.RooStats.FeldmanCousins(data, mc)
    fc.SetConfidenceLevel(confidenceLevel)
    fc.AdditionalNToysFactor(
        0.5
    )  # degrade/improve sampling that defines confidence belt: in this case makes the example faster
    #  fc.UseAdaptiveSampling(True); # speed it up a bit, don't use for expected limits
    fc.SetNBins(nPointsToScan)  # set how many points per parameter of interest to scan
    fc.CreateConfBelt(True)  # save the information in the belt for plotting

    # -------------------------------------------------------
    # Feldman-Cousins is a unified limit by definition
    # but the tool takes care of a few things for us like which values
    # of the nuisance parameters should be used to generate toys.
    # so let's just change the test statistic and realize this is
    # no longer "Feldman-Cousins" but is a fully frequentist Neyman-Construction.
    #  ProfileLikelihoodTestStatModified onesided(*mc->GetPdf());
    #  fc.GetTestStatSampler()->SetTestStatistic(&onesided);
    # ((ToyMCSampler*) fc.GetTestStatSampler())->SetGenerateBinned(True);
    toymcsampler = fc.GetTestStatSampler()
    testStat = toymcsampler.GetTestStatistic()
    testStat.SetOneSided(True)

    # Since this tool needs to throw toy MC the PDF needs to be
    # extended or the tool needs to know how many entries in a dataset
    # per pseudo experiment.
    # In the 'number counting form' where the entries in the dataset
    # are counts, and not values of discriminating variables, the
    # datasets typically only have one entry and the PDF is not
    # extended.
    if not mc.GetPdf().canBeExtended():
        if data.numEntries() == 1:
            fc.FluctuateNumDataEntries(False)
        else:
            print(f"Not sure what to do about this model")

    if mc.GetGlobalObservables():
        print(f"will use global observables for unconditional ensemble")
        mc.GetGlobalObservables().Print()
        toymcsampler.SetGlobalObservables(mc.GetGlobalObservables())

    # Now get the interval
    interval = fc.GetInterval()
    belt = fc.GetConfidenceBelt()

    # print out the interval on the first Parameter of Interest
    print(
        f"\n95% interval on {firstPOI.GetName()} is : [{interval.LowerLimit(firstPOI)}, {interval.UpperLimit(firstPOI)} ]"
    )

    # get observed UL and value of test statistic evaluated there
    tmpPOI = ROOT.RooArgSet(firstPOI)
    observedUL = interval.UpperLimit(firstPOI)
    firstPOI.setVal(observedUL)
    obsTSatObsUL = fc.GetTestStatSampler().EvaluateTestStatistic(data, tmpPOI)

    # Ask the calculator which points were scanned
    parameterScan = fc.GetPointsToScan()
    tmpPoint = ROOT.RooArgSet()

    # make a histogram of parameter vs. threshold
    histOfThresholds = ROOT.TH1F(
        "histOfThresholds", "", parameterScan.numEntries(), firstPOI.getMin(), firstPOI.getMax()
    )
    histOfThresholds.GetXaxis().SetTitle(firstPOI.GetName())
    histOfThresholds.GetYaxis().SetTitle("Threshold")

    # loop through the points that were tested and ask confidence belt
    # what the upper/lower thresholds were.
    # For FeldmanCousins, the lower cut off is always 0
    for i in range(parameterScan.numEntries()):
        tmpPoint = parameterScan.get(i).clone("temp")
        # cout <<"get threshold"<<endl;
        arMax = belt.GetAcceptanceRegionMax(tmpPoint)
        poiVal = tmpPoint.getRealValue(firstPOI.GetName())
        histOfThresholds.Fill(poiVal, arMax)

    c1 = ROOT.TCanvas()
    c1.Divide(2)
    c1.cd(1)
    histOfThresholds.SetMinimum(0)
    histOfThresholds.Draw()
    c1.Update()
    c1.Draw()
    c1.cd(2)

    # -------------------------------------------------------
    # Now we generate the expected bands and power-constraint

    # First: find parameter point for mu=0, with conditional MLEs for nuisance parameters
    nll = mc.GetPdf().createNLL(data)
    profile = nll.createProfile(mc.GetParametersOfInterest())
    firstPOI.setVal(0.0)
    profile.getVal()  # this will do fit and set nuisance parameters to profiled values
    poiAndNuisance = ROOT.RooArgSet()
    if mc.GetNuisanceParameters():
        poiAndNuisance.add(mc.GetNuisanceParameters())
    poiAndNuisance.add(mc.GetParametersOfInterest())
    w.saveSnapshot("paramsToGenerateData", poiAndNuisance)
    paramsToGenerateData = poiAndNuisance.snapshot()
    print(f"\nWill use these parameter points to generate pseudo data for bkg only")
    paramsToGenerateData.Print("v")

    unconditionalObs = ROOT.RooArgSet()
    unconditionalObs.add(mc.GetObservables())
    unconditionalObs.add(mc.GetGlobalObservables())  # comment this out for the original conditional ensemble

    CLb = 0
    CLbinclusive = 0

    # Now we generate background only and find distribution of upper limits
    histOfUL = ROOT.TH1F("histOfUL", "", 100, 0, firstPOI.getMax())
    histOfUL.GetXaxis().SetTitle("Upper Limit (background only)")
    histOfUL.GetYaxis().SetTitle("Entries")
    for imc in range(nToyMC):

        # set parameters back to values for generating pseudo data
        #    cout << "\n get current nuis, set vals, print again" << endl;
        w.loadSnapshot("paramsToGenerateData")
        #    poiAndNuisance->Print("v");

        toyData = ROOT.RooDataSet()
        # debugging
        global gmc
        gmc = mc
        # return
        # now generate a toy dataset
        if not mc.GetPdf().canBeExtended():
            if data.numEntries() == 1:
                toyData = mc.GetPdf().generate(mc.GetObservables(), 1)
            else:
                print(f"Not sure what to do about this model")
        else:
            # print("generating extended dataset")
            toyData = mc.GetPdf().generate(mc.GetObservables(), Extended=True)

        # generate global observables
        # need to be careful for simpdf
        #    RooDataSet* globalData = mc->GetPdf()->generate(*mc->GetGlobalObservables(),1);

        simPdf = mc.GetPdf()
        if not simPdf:
            one = mc.GetPdf().generate(mc.GetGlobalObservables(), 1)
            values = one.get()
            allVars = mc.GetPdf().getVariables()
            allVars.assign(values)
            # del values
            # del one
        else:

            # try fix for sim pdf
            for tt in simPdf.indexCat():
                catName = tt.first
                # global gcatName
                # gcatName = catName
                # return
                # Get pdf associated with state from simpdf
                pdftmp = simPdf.getPdf(str(catName))

                # Generate only global variables defined by the pdf associated with this state
                globtmp = pdftmp.getObservables(mc.GetGlobalObservables())
                tmp = pdftmp.generate(globtmp, 1)

                # Transfer values to output placeholder
                globtmp.assign(tmp.get(0))

        #    globalData->Print("v");
        #    unconditionalObs = *globalData->get();
        #    mc->GetGlobalObservables()->Print("v");
        #    delete globalData;
        #    cout << "toy data = " << endl;
        #    toyData->get()->Print("v");

        # get test stat at observed UL in observed data
        firstPOI.setVal(observedUL)
        toyTSatObsUL = fc.GetTestStatSampler().EvaluateTestStatistic(toyData, tmpPOI)
        #    toyData->get()->Print("v");
        #    cout <<"obsTSatObsUL " <<obsTSatObsUL << "toyTS " << toyTSatObsUL << endl;
        if obsTSatObsUL < toyTSatObsUL:  # not sure about <= part yet
            CLb += (1.0) / nToyMC
        if obsTSatObsUL <= toyTSatObsUL:  # not sure about <= part yet
            CLbinclusive += (1.0) / nToyMC

        # loop over points in belt to find upper limit for this toy data
        thisUL = ROOT.Double_t(0)
        for i in range(parameterScan.numEntries()):
            tmpPoint = parameterScan.get(i).clone("temp")
            arMax = belt.GetAcceptanceRegionMax(tmpPoint)
            firstPOI.setVal(tmpPoint.getRealValue(firstPOI.GetName()))
            #   double thisTS = profile->getVal();
            thisTS = fc.GetTestStatSampler().EvaluateTestStatistic(toyData, tmpPOI)

            #   cout << "poi = " << firstPOI->getVal()
            # << " max is " << arMax << " this profile = " << thisTS << endl;
            #      cout << "thisTS = " << thisTS<<endl;
            if thisTS <= arMax:
                thisUL = firstPOI.getVal()
            else:
                break

        """
      #
      # loop over points in belt to find upper limit for this toy data
      thisUL = 0
      for i in range(histOfThresholds.GetNbinsX() ++i)
         tmpPoint = (RooArgSet) parameterScan.get(i).clone("temp")
         print("----------------  ", i)
         tmpPoint.Print("v")
         print(f"from hist ", histOfThresholds.GetBinCenter(i+1) )
         arMax = histOfThresholds.GetBinContent(i+1)
         # cout << " threshold from Hist = aMax " << arMax<<endl;
         # double arMax2 = belt->GetAcceptanceRegionMax(*tmpPoint);
         # cout << "from scan arMax2 = "<< arMax2 << endl; # not the same due to TH1F not TH1D
         # cout << "scan - hist" << arMax2-arMax << endl;
         firstPOI.setVal( histOfThresholds.GetBinCenter(i+1))
         #   double thisTS = profile->getVal();
         thisTS = fc.GetTestStatSampler().EvaluateTestStatistic(toyData,tmpPOI)

         #   cout << "poi = " << firstPOI->getVal()
         # = ROOT.Double_t() << " max is " << arMax << " this profile = " << thisTS << endl;
         #      cout << "thisTS = " << thisTS<<endl;

         # NOTE: need to add a small epsilon term for single precision vs. double precision
#         if(thisTS<=arMax + 1e-7){
#            thisUL = firstPOI->getVal();
#         } else{
#            break;
#         }
#      }
#      */
#
      """

        histOfUL.Fill(thisUL)

        # for few events, data is often the same, and UL is often the same
        #    cout << "thisUL = " << thisUL<<endl;

        # delete toyData
    c1.cd(2)
    histOfUL.Draw()
    c1.Update()
    c1.Draw()
    c1.SaveAs("OneSidedFrequentistUpperLimitWithBands.png")

    # if you want to see a plot of the sampling distribution for a particular scan point:
    #
    """
   SamplingDistPlot sampPlot
   indexInScan = 0
   tmpPoint = (RooArgSet) parameterScan.get(indexInScan).clone("temp")
   firstPOI.setVal( tmpPoint.getRealValue(firstPOI.GetName()) )
   toymcsampler.SetParametersForTestStat(tmpPOI)
   samp = toymcsampler.GetSamplingDistribution(tmpPoint)
   sampPlot.AddSamplingDistribution(samp)
   sampPlot.Draw()
   """

    # Now find bands and power constraint
    bins = histOfUL.GetIntegral()
    cumulative = histOfUL.Clone("cumulative")
    cumulative.SetContent(bins)
    band2sigDown = band1sigDown = bandMedian = band1sigUp = band2sigUp = ROOT.Double_t()
    for i in range(cumulative.GetNbinsX()):
        if bins[i] < ROOT.RooStats.SignificanceToPValue(2):
            band2sigDown = cumulative.GetBinCenter(i)
        if bins[i] < ROOT.RooStats.SignificanceToPValue(1):
            band1sigDown = cumulative.GetBinCenter(i)
        if bins[i] < 0.5:
            bandMedian = cumulative.GetBinCenter(i)
        if bins[i] < ROOT.RooStats.SignificanceToPValue(-1):
            band1sigUp = cumulative.GetBinCenter(i)
        if bins[i] < ROOT.RooStats.SignificanceToPValue(-2):
            band2sigUp = cumulative.GetBinCenter(i)

    print(f"-2 sigma  band ", band2sigDown)
    print(f"-1 sigma  band {band1sigDown} [Power Constraint)]")
    print(f"median of band ", bandMedian)
    print(f"+1 sigma  band ", band1sigUp)
    print(f"+2 sigma  band ", band2sigUp)

    # print out the interval on the first Parameter of Interest
    print(f"\nObserved 95% upper-limit ", interval.UpperLimit(firstPOI))
    print(f"CLb strict [P(toy>obs|0)] for observed 95% upper-limit ", CLb)
    print("inclusive [P(toy>=obs|0)] for observed 95% upper-limit ", CLbinclusive)


OneSidedFrequentistUpperLimitWithBands(
    infile="", workspaceName="combined", modelConfigName="ModelConfig", dataName="obsData"
)
