Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFitDriver.h
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 * Emmanouil Michalainas, CERN 2021
6 *
7 * Copyright (c) 2021, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
14#ifndef RooFit_RooFitDriver_h
15#define RooFit_RooFitDriver_h
16
17#include <RooAbsData.h>
19#include <RooGlobalFunc.h>
20#include <RooHelpers.h>
21#include <RooRealProxy.h>
23
24#include <chrono>
25#include <memory>
26#include <stack>
27
28class RooAbsArg;
29class RooAbsCategory;
30class RooSimultaneous;
31
32namespace ROOT {
33namespace Experimental {
34
35struct NodeInfo;
36
38public:
39 ////////////////////
40 // Enums and aliases
41
42 using DataSpansMap = std::map<RooFit::Detail::DataKey, RooSpan<const double>>;
43
44 //////////////////////////
45 // Public member functions
46
48
49 void setData(RooAbsData const &data, std::string const &rangeName = "", RooSimultaneous const *simPdf = nullptr,
50 bool skipZeroWeights = false, bool takeGlobalObservablesFromData = true);
51 void setData(DataSpansMap const &dataSpans);
52
54
56 std::vector<double> getValues();
57 double getVal();
58 RooAbsReal &topNode() const;
59
60 void print(std::ostream &os) const;
61
62private:
63 ///////////////////////////
64 // Private member functions
65
66 void processVariable(NodeInfo &nodeInfo);
67 void setClientsDirty(NodeInfo &nodeInfo);
68 double getValHeterogeneous();
69 void markGPUNodes();
70 void assignToGPU(NodeInfo &info);
71 void computeCPUNode(const RooAbsArg *node, NodeInfo &info);
72 void setOperMode(RooAbsArg *arg, RooAbsArg::OperMode opMode);
74 void syncDataTokens();
75
76 ///////////////////////////
77 // Private member variables
78
79 Detail::BufferManager _bufferManager; // The object managing the different buffers for the intermediate results
80
84 double *_cudaMemDataset = nullptr;
85
86 // used for preserving static info about the computation graph
89
90 // the ordered computation graph
91 std::vector<NodeInfo> _nodes;
92
93 // used for preserving resources
94 std::stack<std::vector<double>> _vectorBuffers;
95
96 // RAII structures to reset state of computation graph after driver destruction
97 std::stack<RooHelpers::ChangeOperModeRAII> _changeOperModeRAIIs;
98};
99
100class RooAbsRealWrapper final : public RooAbsReal {
101public:
102 RooAbsRealWrapper(std::unique_ptr<RooFitDriver> driver, std::string const &rangeName, RooSimultaneous const *simPdf,
103 bool takeGlobalObservablesFromData);
104
105 RooAbsRealWrapper(const RooAbsRealWrapper &other, const char *name = nullptr);
106
107 TObject *clone(const char *newname) const override { return new RooAbsRealWrapper(*this, newname); }
108
109 double defaultErrorLevel() const override { return _driver->topNode().defaultErrorLevel(); }
110
111 bool getParameters(const RooArgSet *observables, RooArgSet &outputSet, bool stripDisconnected) const override;
112
113 bool setData(RooAbsData &data, bool cloneData) override;
114
115 double getValV(const RooArgSet *) const override { return evaluate(); }
116
117 void applyWeightSquared(bool flag) override
118 {
119 const_cast<RooAbsReal &>(_driver->topNode()).applyWeightSquared(flag);
120 }
121
122 void printMultiline(std::ostream &os, Int_t /*contents*/, bool /*verbose*/ = false,
123 TString /*indent*/ = "") const override
124 {
125 _driver->print(os);
126 }
127
128protected:
129 double evaluate() const override { return _driver ? _driver->getVal() : 0.0; }
130
131private:
132 std::shared_ptr<RooFitDriver> _driver;
134 RooAbsData *_data = nullptr;
136 std::string _rangeName;
137 RooSimultaneous const *_simPdf = nullptr;
139};
140
141} // end namespace Experimental
142} // end namespace ROOT
143
144#endif
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition TGX11.cxx:110
void applyWeightSquared(bool flag) override
Disables or enables the usage of squared weights.
TObject * clone(const char *newname) const override
double getValV(const RooArgSet *) const override
Return value of object.
void printMultiline(std::ostream &os, Int_t, bool=false, TString="") const override
Structure printing.
double defaultErrorLevel() const override
std::shared_ptr< RooFitDriver > _driver
bool setData(RooAbsData &data, bool cloneData) override
bool getParameters(const RooArgSet *observables, RooArgSet &outputSet, bool stripDisconnected) const override
Fills a list with leaf nodes in the arg tree starting with ourself as top node that don't match any o...
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::map< RooFit::Detail::DataKey, RooSpan< const double > > DataSpansMap
const RooFit::BatchModeOption _batchMode
void setOperMode(RooAbsArg *arg, RooAbsArg::OperMode opMode)
Temporarily change the operation mode of a RooAbsArg until the RooFitDriver gets deleted.
void syncDataTokens()
If there are servers with the same name that got de-duplicated in the _nodes list,...
RooFit::Detail::DataMap _dataMapCPU
double getValHeterogeneous()
Returns the value of the top node in the computation graph.
double getVal()
Returns the value of the top node in the computation graph.
RooFit::Detail::DataMap _dataMapCUDA
void setData(RooAbsData const &data, std::string const &rangeName="", RooSimultaneous const *simPdf=nullptr, bool skipZeroWeights=false, bool takeGlobalObservablesFromData=true)
std::vector< NodeInfo > _nodes
std::vector< double > getValues()
void assignToGPU(NodeInfo &info)
Assign a node to be computed in the GPU.
void processVariable(NodeInfo &nodeInfo)
Process a variable in the computation graph.
void computeCPUNode(const RooAbsArg *node, NodeInfo &info)
Detail::BufferManager _bufferManager
std::stack< std::vector< double > > _vectorBuffers
void markGPUNodes()
Decides which nodes are assigned to the GPU in a CUDA fit.
std::stack< RooHelpers::ChangeOperModeRAII > _changeOperModeRAIIs
void setClientsDirty(NodeInfo &nodeInfo)
Flags all the clients of a given node dirty.
void print(std::ostream &os) const
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:74
A space to attach TBranches.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:59
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:62
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
Mother of all ROOT objects.
Definition TObject.h:41
Basic string class.
Definition TString.h:139
This file contains a specialised ROOT message handler to test for diagnostic in unit tests.
BatchModeOption
For setting the batch mode flag with the BatchMode() command argument to RooAbsPdf::fitTo()
A struct used by the RooFitDriver to store information on the RooAbsArgs in the computation graph.