Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CodeSquashContext.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2023
5 * Jonas Rembser, CERN 2023
6 *
7 * Copyright (c) 2023, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
15
16#include "RooFuncWrapper.h"
17
18#include "RooFitImplHelpers.h"
19
20#include <algorithm>
21#include <cctype>
22
23namespace RooFit {
24
25namespace Detail {
26
27CodeSquashContext::CodeSquashContext(std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes,
28 std::vector<double> &xlarr, Experimental::RooFuncWrapper &wrapper)
29 : _wrapper{&wrapper}, _nodeOutputSizes(outputSizes), _xlArr(xlarr)
30{
31}
32
33/// @brief Adds (or overwrites) the string representing the result of a node.
34/// @param key The name of the node to add the result for.
35/// @param value The new name to assign/overwrite.
36void CodeSquashContext::addResult(const char *key, std::string const &value)
37{
38 const TNamed *namePtr = RooNameReg::known(key);
39 if (namePtr)
40 addResult(namePtr, value);
41}
42
43void CodeSquashContext::addResult(TNamed const *key, std::string const &value)
44{
45 _nodeNames[key] = value;
46}
47
48/// @brief Gets the result for the given node using the node name. This node also performs the necessary
49/// code generation through recursive calls to 'translate'. A call to this function modifies the already
50/// existing code body.
51/// @param key The node to get the result string for.
52/// @return String representing the result of this node.
53std::string const &CodeSquashContext::getResult(RooAbsArg const &arg)
54{
55 // If the result has already been recorded, just return the result.
56 // It is usually the responsibility of each translate function to assign
57 // the proper result to its class. Hence, if a result has already been recorded
58 // for a particular node, it means the node has already been 'translate'd and we
59 // dont need to visit it again.
60 auto found = _nodeNames.find(arg.namePtr());
61 if (found != _nodeNames.end())
62 return found->second;
63
64 // The result for vector observables should already be in the map if you
65 // opened the loop scope. This is just to check if we did not request the
66 // result of a vector-valued observable outside of the scope of a loop.
67 auto foundVecObs = _vecObsIndices.find(arg.namePtr());
68 if (foundVecObs != _vecObsIndices.end()) {
69 throw std::runtime_error("You requested the result of a vector observable outside a loop scope for it!");
70 }
71
72 // Now, recursively call translate into the current argument to load the correct result.
73 arg.translate(*this);
74
75 return _nodeNames.at(arg.namePtr());
76}
77
78/// @brief Adds the given string to the string block that will be emitted at the top of the squashed function. Useful
79/// for variable declarations.
80/// @param str The string to add to the global scope.
81void CodeSquashContext::addToGlobalScope(std::string const &str)
82{
83 _globalScope += str;
84}
85
86/// @brief Assemble and return the final code with the return expression and global statements.
87/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
88/// @return The final body of the function.
89std::string CodeSquashContext::assembleCode(std::string const &returnExpr)
90{
91 return _globalScope + _code + "\n return " + returnExpr + ";\n";
92}
93
94/// @brief Since the squashed code represents all observables as a single flattened array, it is important
95/// to keep track of the start index for a vector valued observable which can later be expanded to access the correct
96/// element. For example, a vector valued variable x with 10 entries will be squashed to obs[start_idx + i].
97/// @param key The name of the node representing the vector valued observable.
98/// @param idx The start index (or relative position of the observable in the set of all observables).
99void CodeSquashContext::addVecObs(const char *key, int idx)
100{
101 const TNamed *namePtr = RooNameReg::known(key);
102 if (namePtr)
103 _vecObsIndices[namePtr] = idx;
104}
105
106/// @brief Adds the input string to the squashed code body. If a class implements a translate function that wants to
107/// emit something to the squashed code body, it must call this function with the code it wants to emit. In case of
108/// loops, automatically determines if code needs to be stored inside or outside loop scope.
109/// @param klass The class requesting this addition, usually 'this'.
110/// @param in String to add to the squashed code.
111void CodeSquashContext::addToCodeBody(RooAbsArg const *klass, std::string const &in)
112{
113 // If we are in a loop and the value is scope independent, save it at the top of the loop.
114 // else, just save it in the current scope.
116}
117
118/// @brief A variation of the previous addToCodeBody that takes in a bool value that determines
119/// if input is independent. This overload exists because there might other ways to determine if
120/// a value/collection of values is scope independent.
121/// @param in String to add to the squashed code.
122/// @param isScopeIndep The value determining if the input is scope dependent.
123void CodeSquashContext::addToCodeBody(std::string const &in, bool isScopeIndep /* = false */)
124{
125 // If we are in a loop and the value is scope independent, save it at the top of the loop.
126 // else, just save it in the current scope.
127 if (_scopePtr != -1 && isScopeIndep) {
128 _tempScope += in;
129 } else {
130 _code += in;
131 }
132}
133
134/// @brief Create a RAII scope for iterating over vector observables. You can't use the result of vector observables
135/// outside these loop scopes.
136/// @param in A pointer to the calling class, used to determine the loop dependent variables.
137std::unique_ptr<CodeSquashContext::LoopScope> CodeSquashContext::beginLoop(RooAbsArg const *in)
138{
139 std::string idx = "loopIdx" + std::to_string(_loopLevel);
140
141 std::vector<TNamed const *> vars;
142 // set the results of the vector observables
143 for (auto const &it : _vecObsIndices) {
144 if (!in->dependsOn(it.first))
145 continue;
146
147 vars.push_back(it.first);
148 _nodeNames[it.first] = "obs[" + std::to_string(it.second) + " + " + idx + "]";
149 }
150
151 // TODO: we are using the size of the first loop variable to the the number
152 // of iterations, but it should be made sure that all loop vars are either
153 // scalar or have the same size.
154 std::size_t numEntries = 1;
155 for (auto &it : vars) {
156 std::size_t n = outputSize(it);
157 if (n > 1 && numEntries > 1 && n != numEntries) {
158 throw std::runtime_error("Trying to loop over variables with different sizes!");
159 }
160 numEntries = std::max(n, numEntries);
161 }
162
163 // Save the current size of the code array so that we can insert the code at the right position.
164 _scopePtr = _code.size();
165
166 // Make sure that the name of this variable doesn't clash with other stuff
167 addToCodeBody(in, "for(int " + idx + " = 0; " + idx + " < " + std::to_string(numEntries) + "; " + idx + "++) {\n");
168
169 ++_loopLevel;
170 return std::make_unique<LoopScope>(*this, std::move(vars));
171}
172
174{
175 _code += "}\n";
176
177 // Insert the temporary code into the correct code position.
178 _code.insert(_scopePtr, _tempScope);
179 _tempScope.erase();
180 _scopePtr = -1;
181
182 // clear the results of the loop variables if they were vector observables
183 for (auto const &ptr : scope.vars()) {
184 if (_vecObsIndices.find(ptr) != _vecObsIndices.end())
185 _nodeNames.erase(ptr);
186 }
187 --_loopLevel;
188}
189
190/// @brief Get a unique variable name to be used in the generated code.
192{
193 return "t" + std::to_string(_tmpVarIdx++);
194}
195
196/// @brief A function to save an expression that includes/depends on the result of the input node.
197/// @param in The node on which the valueToSave depends on/belongs to.
198/// @param valueToSave The actual string value to save as a temporary.
199void CodeSquashContext::addResult(RooAbsArg const *in, std::string const &valueToSave)
200{
201 // std::string savedName = RooFit::Detail::makeValidVarName(in->GetName());
202 std::string savedName = getTmpVarName();
203
204 // Only save values if they contain operations.
205 bool hasOperations = valueToSave.find_first_of(":-+/*") != std::string::npos;
206
207 // If the name is not empty and this value is worth saving, save it to the correct scope.
208 // otherwise, just return the actual value itself
209 if (hasOperations) {
210 // If this is a scalar result, it will go just outside the loop because
211 // it doesn't need to be recomputed inside loops.
212 std::string outVarDecl = "const double " + savedName + " = " + valueToSave + ";\n";
213 addToCodeBody(in, outVarDecl);
214 } else {
215 savedName = valueToSave;
216 }
217
218 addResult(in->namePtr(), savedName);
219}
220
221/// @brief Function to save a RooListProxy as an array in the squashed code.
222/// @param in The list to convert to array.
223/// @return Name of the array that stores the input list in the squashed code.
225{
226 if (in.empty()) {
227 return "nullptr";
228 }
229
230 auto it = listNames.find(in.uniqueId().value());
231 if (it != listNames.end())
232 return it->second;
233
234 std::string savedName = getTmpVarName();
235 bool canSaveOutside = true;
236
237 std::stringstream declStrm;
238 declStrm << "double " << savedName << "[] = {";
239 for (const auto arg : in) {
240 declStrm << getResult(*arg) << ",";
241 canSaveOutside = canSaveOutside && isScopeIndependent(arg);
242 }
243 declStrm.seekp(-1, declStrm.cur);
244 declStrm << "};\n";
245
246 addToCodeBody(declStrm.str(), canSaveOutside);
247
248 listNames.insert({in.uniqueId().value(), savedName});
249 return savedName;
250}
251
252std::string CodeSquashContext::buildArg(std::span<const double> arr)
253{
254 unsigned int n = arr.size();
255 std::string offset = std::to_string(_xlArr.size());
256 _xlArr.reserve(_xlArr.size() + n);
257 for (unsigned int i = 0; i < n; i++) {
258 _xlArr.push_back(arr[i]);
259 }
260 return "xlArr + " + offset;
261}
262
264{
265 return !in->isReducerNode() && outputSize(in->namePtr()) == 1;
266}
267
268/// @brief Register a function that is only know to the interpreter to the context.
269/// This is useful to dump the standalone C++ code for the computation graph.
271{
273}
274
275} // namespace Detail
276} // namespace RooFit
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
char name[80]
Definition TGX11.cxx:110
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:79
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
const TNamed * namePtr() const
De-duplicated pointer to this object's name.
Definition RooAbsArg.h:535
virtual void translate(RooFit::Detail::CodeSquashContext &ctx) const
This function defines a translation for each RooAbsReal based object that can be used to express the ...
virtual bool isReducerNode() const
Definition RooAbsArg.h:549
Abstract container object that can hold multiple RooAbsArg objects.
RooFit::UniqueId< RooAbsCollection > const & uniqueId() const
Returns a unique ID that is different for every instantiated RooAbsCollection.
A class to manage loop scopes using the RAII technique.
std::vector< TNamed const * > const & vars() const
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
std::string _tempScope
Stores code that eventually gets injected into main code body.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void endLoop(LoopScope const &scope)
std::unordered_map< const TNamed *, int > _vecObsIndices
A map to keep track of the observable indices if they are non scalar.
int _loopLevel
The current number of for loops the started.
int _tmpVarIdx
Index to get unique names for temporary variables.
std::size_t outputSize(RooFit::Detail::DataKey key) const
Figure out the output size of a node.
std::unordered_map< const TNamed *, std::string > _nodeNames
Map of node names to their result strings.
void addToCodeBody(RooAbsArg const *klass, std::string const &in)
Adds the input string to the squashed code body.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
bool isScopeIndependent(RooAbsArg const *in) const
std::string getTmpVarName() const
Get a unique variable name to be used in the generated code.
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
std::string _code
Stores the squashed code body.
void addToGlobalScope(std::string const &str)
Adds the given string to the string block that will be emitted at the top of the squashed function.
std::unordered_map< RooFit::UniqueId< RooAbsCollection >::Value_t, std::string > listNames
A map to keep track of list names as assigned by addResult.
std::string buildArg(RooAbsCollection const &x)
Function to save a RooListProxy as an array in the squashed code.
Experimental::RooFuncWrapper * _wrapper
void collectFunction(std::string const &name)
Register a function that is only know to the interpreter to the context.
CodeSquashContext(std::map< RooFit::Detail::DataKey, std::size_t > const &outputSizes, std::vector< double > &xlarr, Experimental::RooFuncWrapper &wrapper)
int _scopePtr
Keeps track of the position to go back and insert code to.
std::string _globalScope
Block of code that is placed before the rest of the function body.
std::unique_ptr< LoopScope > beginLoop(RooAbsArg const *in)
Create a RAII scope for iterating over vector observables.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
void collectFunction(std::string const &funcName)
static const TNamed * known(const char *stringPtr)
If the name is already known, return its TNamed pointer. Otherwise return 0 (don't register the name)...
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
const Int_t n
Definition legend1.C:16
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
constexpr Value_t value() const
Return numerical value of ID.
Definition UniqueId.h:59