Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CodeSquashContext.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2023
5 * Jonas Rembser, CERN 2023
6 *
7 * Copyright (c) 2023, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
15
16#include "RooFitImplHelpers.h"
17
18#include <algorithm>
19#include <cctype>
20
21namespace RooFit {
22
23namespace Detail {
24
25/// @brief Adds (or overwrites) the string representing the result of a node.
26/// @param key The name of the node to add the result for.
27/// @param value The new name to assign/overwrite.
28void CodeSquashContext::addResult(const char *key, std::string const &value)
29{
30 const TNamed *namePtr = RooNameReg::known(key);
31 if (namePtr)
32 addResult(namePtr, value);
33}
34
35void CodeSquashContext::addResult(TNamed const *key, std::string const &value)
36{
37 _nodeNames[key] = value;
38}
39
40/// @brief Gets the result for the given node using the node name. This node also performs the necessary
41/// code generation through recursive calls to 'translate'. A call to this function modifies the already
42/// existing code body.
43/// @param key The node to get the result string for.
44/// @return String representing the result of this node.
45std::string const &CodeSquashContext::getResult(RooAbsArg const &arg)
46{
47 // If the result has already been recorded, just return the result.
48 // It is usually the responsibility of each translate function to assign
49 // the proper result to its class. Hence, if a result has already been recorded
50 // for a particular node, it means the node has already been 'translate'd and we
51 // dont need to visit it again.
52 auto found = _nodeNames.find(arg.namePtr());
53 if (found != _nodeNames.end())
54 return found->second;
55
56 // The result for vector observables should already be in the map if you
57 // opened the loop scope. This is just to check if we did not request the
58 // result of a vector-valued observable outside of the scope of a loop.
59 auto foundVecObs = _vecObsIndices.find(arg.namePtr());
60 if (foundVecObs != _vecObsIndices.end()) {
61 throw std::runtime_error("You requested the result of a vector observable outside a loop scope for it!");
62 }
63
64 // Now, recursively call translate into the current argument to load the correct result.
65 arg.translate(*this);
66
67 return _nodeNames.at(arg.namePtr());
68}
69
70/// @brief Adds the given string to the string block that will be emitted at the top of the squashed function. Useful
71/// for variable declarations.
72/// @param str The string to add to the global scope.
73void CodeSquashContext::addToGlobalScope(std::string const &str)
74{
75 _globalScope += str;
76}
77
78/// @brief Assemble and return the final code with the return expression and global statements.
79/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
80/// @return The final body of the function.
81std::string CodeSquashContext::assembleCode(std::string const &returnExpr)
82{
83 return _globalScope + _code + "\n return " + returnExpr + ";\n";
84}
85
86/// @brief Since the squashed code represents all observables as a single flattened array, it is important
87/// to keep track of the start index for a vector valued observable which can later be expanded to access the correct
88/// element. For example, a vector valued variable x with 10 entries will be squashed to obs[start_idx + i].
89/// @param key The name of the node representing the vector valued observable.
90/// @param idx The start index (or relative position of the observable in the set of all observables).
91void CodeSquashContext::addVecObs(const char *key, int idx)
92{
93 const TNamed *namePtr = RooNameReg::known(key);
94 if (namePtr)
95 _vecObsIndices[namePtr] = idx;
96}
97
98/// @brief Adds the input string to the squashed code body. If a class implements a translate function that wants to
99/// emit something to the squashed code body, it must call this function with the code it wants to emit. In case of
100/// loops, automatically determines if code needs to be stored inside or outside loop scope.
101/// @param klass The class requesting this addition, usually 'this'.
102/// @param in String to add to the squashed code.
103void CodeSquashContext::addToCodeBody(RooAbsArg const *klass, std::string const &in)
104{
105 // If we are in a loop and the value is scope independent, save it at the top of the loop.
106 // else, just save it in the current scope.
108}
109
110/// @brief A variation of the previous addToCodeBody that takes in a bool value that determines
111/// if input is independent. This overload exists because there might other ways to determine if
112/// a value/collection of values is scope independent.
113/// @param in String to add to the squashed code.
114/// @param isScopeIndep The value determining if the input is scope dependent.
115void CodeSquashContext::addToCodeBody(std::string const &in, bool isScopeIndep /* = false */)
116{
117 // If we are in a loop and the value is scope independent, save it at the top of the loop.
118 // else, just save it in the current scope.
119 if (_scopePtr != -1 && isScopeIndep)
120 _tempScope += in;
121 else
122 _code += in;
123}
124
125/// @brief Create a RAII scope for iterating over vector observables. You can't use the result of vector observables
126/// outside these loop scopes.
127/// @param in A pointer to the calling class, used to determine the loop dependent variables.
128std::unique_ptr<CodeSquashContext::LoopScope> CodeSquashContext::beginLoop(RooAbsArg const *in)
129{
130 std::string idx = "loopIdx" + std::to_string(_loopLevel);
131
132 std::vector<TNamed const *> vars;
133 // set the results of the vector observables
134 for (auto const &it : _vecObsIndices) {
135 if (!in->dependsOn(it.first))
136 continue;
137
138 vars.push_back(it.first);
139 _nodeNames[it.first] = "obs[" + std::to_string(it.second) + " + " + idx + "]";
140 }
141
142 // TODO: we are using the size of the first loop variable to the the number
143 // of iterations, but it should be made sure that all loop vars are either
144 // scalar or have the same size.
145 std::size_t numEntries = 1;
146 for (auto &it : vars) {
147 std::size_t n = outputSize(it);
148 if (n > 1 && numEntries > 1 && n != numEntries) {
149 throw std::runtime_error("Trying to loop over variables with different sizes!");
150 }
151 numEntries = std::max(n, numEntries);
152 }
153
154 // Save the current size of the code array so that we can insert the code at the right position.
155 _scopePtr = _code.size();
156
157 // Make sure that the name of this variable doesn't clash with other stuff
158 addToCodeBody(in, "for(int " + idx + " = 0; " + idx + " < " + std::to_string(numEntries) + "; " + idx + "++) {\n");
159
160 ++_loopLevel;
161 return std::make_unique<LoopScope>(*this, std::move(vars));
162}
163
165{
166 _code += "}\n";
167
168 // Insert the temporary code into the correct code position.
169 _code.insert(_scopePtr, _tempScope);
170 _tempScope.erase();
171 _scopePtr = -1;
172
173 // clear the results of the loop variables if they were vector observables
174 for (auto const &ptr : scope.vars()) {
175 if (_vecObsIndices.find(ptr) != _vecObsIndices.end())
176 _nodeNames.erase(ptr);
177 }
178 --_loopLevel;
179}
180
181/// @brief Get a unique variable name to be used in the generated code.
183{
184 return "tmpVar" + std::to_string(_tmpVarIdx++);
185}
186
187/// @brief A function to save an expression that includes/depends on the result of the input node.
188/// @param in The node on which the valueToSave depends on/belongs to.
189/// @param valueToSave The actual string value to save as a temporary.
190void CodeSquashContext::addResult(RooAbsArg const *in, std::string const &valueToSave)
191{
192 std::string savedName = RooFit::Detail::makeValidVarName(in->GetName());
193
194 // Only save values if they contain operations.
195 bool hasOperations = valueToSave.find_first_of(":-+/*") != std::string::npos;
196
197 // If the name is not empty and this value is worth saving, save it to the correct scope.
198 // otherwise, just return the actual value itself
199 if (hasOperations) {
200 // If this is a scalar result, it will go just outside the loop because
201 // it doesn't need to be recomputed inside loops.
202 std::string outVarDecl = "const double " + savedName + " = " + valueToSave + ";\n";
203 addToCodeBody(in, outVarDecl);
204 } else {
205 savedName = valueToSave;
206 }
207
208 addResult(in->namePtr(), savedName);
209}
210
211/// @brief Function to save a RooListProxy as an array in the squashed code.
212/// @param in The list to convert to array.
213/// @return Name of the array that stores the input list in the squashed code.
215{
216 auto it = listNames.find(in.uniqueId().value());
217 if (it != listNames.end())
218 return it->second;
219
220 std::string savedName = getTmpVarName();
221 bool canSaveOutside = true;
222
223 std::stringstream declStrm;
224 declStrm << "double " << savedName << "[] = {";
225 for (const auto arg : in) {
226 declStrm << getResult(*arg) << ",";
227 canSaveOutside = canSaveOutside && isScopeIndependent(arg);
228 }
229 declStrm.seekp(-1, declStrm.cur);
230 declStrm << "};\n";
231
232 addToCodeBody(declStrm.str(), canSaveOutside);
233
234 listNames.insert({in.uniqueId().value(), savedName});
235 return savedName;
236}
237
238std::string CodeSquashContext::buildArg(std::span<const double> arr)
239{
240 unsigned int n = arr.size();
241 std::string arrName = getTmpVarName();
242 std::string arrDecl = "double " + arrName + "[" + std::to_string(n) + "] = {";
243 for (unsigned int i = 0; i < n; i++) {
244 arrDecl += " " + std::to_string(arr[i]) + ",";
245 }
246 arrDecl.back() = '}';
247 arrDecl += ";\n";
248 addToCodeBody(arrDecl, true);
249
250 return arrName;
251}
252
254{
255 return !in->isReducerNode() && outputSize(in->namePtr()) == 1;
256}
257
258} // namespace Detail
259} // namespace RooFit
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:79
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
const TNamed * namePtr() const
De-duplicated pointer to this object's name.
Definition RooAbsArg.h:563
virtual void translate(RooFit::Detail::CodeSquashContext &ctx) const
This function defines a translation for each RooAbsReal based object that can be used to express the ...
virtual bool isReducerNode() const
Definition RooAbsArg.h:577
Abstract container object that can hold multiple RooAbsArg objects.
RooFit::UniqueId< RooAbsCollection > const & uniqueId() const
Returns a unique ID that is different for every instantiated RooAbsCollection.
A class to manage loop scopes using the RAII technique.
std::vector< TNamed const * > const & vars() const
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
std::string _tempScope
Stores code that eventually gets injected into main code body.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void endLoop(LoopScope const &scope)
std::unordered_map< const TNamed *, int > _vecObsIndices
A map to keep track of the observable indices if they are non scalar.
int _loopLevel
The current number of for loops the started.
int _tmpVarIdx
Index to get unique names for temporary variables.
std::size_t outputSize(RooFit::Detail::DataKey key) const
Figure out the output size of a node.
std::unordered_map< const TNamed *, std::string > _nodeNames
Map of node names to their result strings.
void addToCodeBody(RooAbsArg const *klass, std::string const &in)
Adds the input string to the squashed code body.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
bool isScopeIndependent(RooAbsArg const *in) const
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
std::string getTmpVarName()
Get a unique variable name to be used in the generated code.
std::string _code
Stores the squashed code body.
void addToGlobalScope(std::string const &str)
Adds the given string to the string block that will be emitted at the top of the squashed function.
std::unordered_map< RooFit::UniqueId< RooAbsCollection >::Value_t, std::string > listNames
A map to keep track of list names as assigned by addResult.
std::string buildArg(RooAbsCollection const &x)
Function to save a RooListProxy as an array in the squashed code.
int _scopePtr
Keeps track of the position to go back and insert code to.
std::string _globalScope
Block of code that is placed before the rest of the function body.
std::unique_ptr< LoopScope > beginLoop(RooAbsArg const *in)
Create a RAII scope for iterating over vector observables.
static const TNamed * known(const char *stringPtr)
If the name is already known, return its TNamed pointer. Otherwise return 0 (don't register the name)...
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
const Int_t n
Definition legend1.C:16
std::string makeValidVarName(std::string const &in)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
constexpr Value_t value() const
Return numerical value of ID.
Definition UniqueId.h:59