Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RReader.hxx
Go to the documentation of this file.
1#ifndef TMVA_RREADER
2#define TMVA_RREADER
3
4#include "TString.h"
5#include "TXMLEngine.h"
6
7#include "TMVA/RTensor.hxx"
8#include "TMVA/Reader.h"
9
10#include <memory> // std::unique_ptr
11#include <sstream> // std::stringstream
12
13namespace TMVA {
14namespace Experimental {
15
16namespace Internal {
17
18/// Internal definition of analysis types
20
21/// Container for information extracted from TMVA XML config
22struct XMLConfig {
23 unsigned int numVariables;
24 std::vector<std::string> variables;
25 std::vector<std::string> variable_expressions;
26 unsigned int numSpectators;
27 std::vector<std::string> spectators;
28 std::vector<std::string> spectator_expressions;
29 unsigned int numClasses;
30 std::vector<std::string> classes;
33 : numVariables(0), variables(std::vector<std::string>(0)),
34 numSpectators(0), spectators(std::vector<std::string>(0)),
35 numClasses(0), classes(std::vector<std::string>(0)),
37 {
38 }
39};
40
41/// Parse TMVA XML config
42inline XMLConfig ParseXMLConfig(const std::string &filename)
43{
45
46 // Parse XML file and find root node
47 TXMLEngine xml;
48 auto xmldoc = xml.ParseFile(filename.c_str());
49 if (!xmldoc) {
50 std::stringstream ss;
51 ss << "Failed to open TMVA XML file "
52 << filename << ".";
53 throw std::runtime_error(ss.str());
54 }
55 auto mainNode = xml.DocGetRootElement(xmldoc);
56 for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
57 const auto nodeName = std::string(xml.GetNodeName(node));
58 // Read out input variables
59 if (nodeName.compare("Variables") == 0) {
60 c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
61 c.variables = std::vector<std::string>(c.numVariables);
62 c.variable_expressions = std::vector<std::string>(c.numVariables);
63 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
64 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
65 c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
66 c.variable_expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
67 }
68 }
69 // Read out input spectators
70 else if (nodeName.compare("Spectators") == 0) {
71 c.numSpectators = std::atoi(xml.GetAttr(node, "NSpec"));
72 c.spectators = std::vector<std::string>(c.numSpectators);
73 c.spectator_expressions = std::vector<std::string>(c.numSpectators);
74 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
75 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "SpecIndex"));
76 c.spectators[iVariable] = xml.GetAttr(thisNode, "Title");
77 c.spectator_expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
78 }
79 }
80 // Read out output classes
81 else if (nodeName.compare("Classes") == 0) {
82 c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
83 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
84 c.classes.push_back(xml.GetAttr(thisNode, "Name"));
85 }
86 }
87 // Read out analysis type
88 else if (nodeName.compare("GeneralInfo") == 0) {
89 std::string analysisType = "";
90 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
91 if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
92 analysisType = xml.GetAttr(thisNode, "value");
93 }
94 }
95 if (analysisType.compare("Classification") == 0) {
97 } else if (analysisType.compare("Regression") == 0) {
99 } else if (analysisType.compare("Multiclass") == 0) {
101 }
102 }
103 }
104 xml.FreeDoc(xmldoc);
105
106 // Error-handling
107 if (c.numVariables != c.variables.size() || c.numVariables == 0) {
108 std::stringstream ss;
109 ss << "Failed to parse input variables from TMVA config " << filename << ".";
110 throw std::runtime_error(ss.str());
111 }
112 if (c.numSpectators != c.spectators.size()) {
113 std::stringstream ss;
114 ss << "Failed to parse input spectators from TMVA config " << filename << ".";
115 throw std::runtime_error(ss.str());
116 }
117 if (c.numClasses != c.classes.size() || c.numClasses == 0) {
118 std::stringstream ss;
119 ss << "Failed to parse output classes from TMVA config " << filename << ".";
120 throw std::runtime_error(ss.str());
121 }
122 if (c.analysisType == Internal::AnalysisType::Undefined) {
123 std::stringstream ss;
124 ss << "Failed to parse analysis type from TMVA config " << filename << ".";
125 throw std::runtime_error(ss.str());
126 }
127
128 return c;
129}
130
131} // namespace Internal
132
133/// A replacement for the TMVA::Reader legacy interface.
134/// Performs inference for TMVA models stored as XML files.
135/// For neural network inference consider using [SOFIE](https://github.com/root-project/root/blob/master/tmva/sofie/README.md) instead.
136class RReader {
137private:
138 std::unique_ptr<Reader> fReader;
139 std::vector<float> fVariableValues;
140 std::vector<std::string> fVariables;
141 std::vector<std::string> fVariableExpressions;
142 std::vector<float> fSpectatorValues;
143 std::vector<std::string> fSpectators;
144 std::vector<std::string> fSpectatorExpressions;
145 unsigned int fNumClasses;
146 const char *name = "RReader";
148
149public:
150 /// Create TMVA model from XML file
151 RReader(const std::string &path)
152 {
153 // Load config
154 auto c = Internal::ParseXMLConfig(path);
155 fVariables = c.variables;
156 fVariableExpressions = c.variable_expressions;
157 fSpectators = c.spectators;
158 fSpectatorExpressions = c.spectator_expressions;
159 fAnalysisType = c.analysisType;
160 fNumClasses = c.numClasses;
161
162 // Setup reader
163 fReader = std::make_unique<Reader>("Silent");
164 const auto numVars = fVariables.size();
165 fVariableValues = std::vector<float>(numVars);
166 for (std::size_t i = 0; i < numVars; i++) {
168 }
169 const auto numSpecs = fSpectators.size();
170 fSpectatorValues = std::vector<float>(numSpecs);
171 for (std::size_t i = 0; i < numSpecs; i++) {
173 }
174 fReader->BookMVA(name, path.c_str());
175 }
176
177 /// Compute model prediction on vector
178 std::vector<float> Compute(const std::vector<float> &x)
179 {
180 if (x.size() != (fVariables.size()+fSpectators.size()))
181 throw std::runtime_error("Size of input vector is not equal to number of variables.");
182
183 // Copy over inputs to memory used by TMVA reader
184 const auto nVars = fVariables.size();
185 for (std::size_t i = 0; i != nVars ; ++i) {
186 fVariableValues[i] = x[i];
187 }
188 for (std::size_t i = 0; i != fSpectators.size(); ++i) {
189 fSpectatorValues[i] = x[nVars+i];
190 }
191
192 // Take lock to protect model evaluation
194
195 // Evaluate TMVA model
196 // Classification
198 return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
199 }
200 // Regression
202 return fReader->EvaluateRegression(name);
203 }
204 // Multiclass
206 return fReader->EvaluateMulticlass(name);
207 }
208 // Throw error
209 else {
210 throw std::runtime_error("RReader has undefined analysis type.");
211 return std::vector<float>();
212 }
213 }
214
215 /// Compute model prediction on input RTensor
217 {
218 // Error-handling for input tensor
219 const auto shape = x.GetShape();
220 if (shape.size() != 2)
221 throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
222
223 const auto numEntries = shape[0];
224 const auto numVars = shape[1];
225 if (numVars != (fVariables.size()+fSpectators.size()))
226 throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
227
228 // Define shape of output tensor based on analysis type
229 unsigned int numClasses = 1;
231 numClasses = fNumClasses;
232 RTensor<float> y({numEntries * numClasses});
234 y = y.Reshape({numEntries, numClasses});
235
236 // Fill output tensor
237 const auto nVars = fVariables.size(); // number of non-spectator variables
238 for (std::size_t i = 0; i < numEntries; i++) {
239 for (std::size_t j = 0; j < nVars; j++) {
240 fVariableValues[j] = x(i, j);
241 }
242 for (std::size_t j = 0; j < fSpectators.size(); ++j) {
243 fSpectatorValues[j] = x(i, nVars+j);
244 }
246 // Classification
248 y(i) = fReader->EvaluateMVA(name);
249 }
250 // Regression
252 y(i) = fReader->EvaluateRegression(name)[0];
253 }
254 // Multiclass
256 const auto p = fReader->EvaluateMulticlass(name);
257 for (std::size_t k = 0; k < numClasses; k++)
258 y(i, k) = p[k];
259 }
260 }
261
262 return y;
263 }
264
265 std::vector<std::string> GetVariableNames() { return fVariables; }
266 std::vector<std::string> GetSpectatorNames() { return fSpectators; }
267};
268
269} // namespace Experimental
270} // namespace TMVA
271
272#endif // TMVA_RREADER
#define c(i)
Definition RSha256.hxx:101
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define R__WRITE_LOCKGUARD(mutex)
A replacement for the TMVA::Reader legacy interface.
Definition RReader.hxx:136
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
Definition RReader.hxx:216
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
Definition RReader.hxx:178
std::vector< std::string > GetSpectatorNames()
Definition RReader.hxx:266
Internal::AnalysisType fAnalysisType
Definition RReader.hxx:147
std::vector< float > fVariableValues
Definition RReader.hxx:139
std::vector< float > fSpectatorValues
Definition RReader.hxx:142
std::vector< std::string > fVariableExpressions
Definition RReader.hxx:141
std::vector< std::string > GetVariableNames()
Definition RReader.hxx:265
std::vector< std::string > fSpectatorExpressions
Definition RReader.hxx:144
std::vector< std::string > fSpectators
Definition RReader.hxx:143
std::vector< std::string > fVariables
Definition RReader.hxx:140
RReader(const std::string &path)
Create TMVA model from XML file.
Definition RReader.hxx:151
std::unique_ptr< Reader > fReader
Definition RReader.hxx:138
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
RTensor< Value_t, Container_t > Reshape(const Shape_t &shape) const
Reshape tensor.
Definition RTensor.hxx:480
const Shape_t & GetShape() const
Definition RTensor.hxx:242
Basic string class.
Definition TString.h:139
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
const char * GetNodeName(XMLNodePointer_t xmlnode)
returns name of xmlnode
const char * GetAttr(XMLNodePointer_t xmlnode, const char *name)
returns value of attribute for xmlnode
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
R__EXTERN TVirtualRWMutex * gCoreMutex
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
Definition RReader.hxx:42
AnalysisType
Internal definition of analysis types.
Definition RReader.hxx:19
create variable transformations
Container for information extracted from TMVA XML config.
Definition RReader.hxx:22
std::vector< std::string > classes
Definition RReader.hxx:30
std::vector< std::string > spectators
Definition RReader.hxx:27
std::vector< std::string > spectator_expressions
Definition RReader.hxx:28
std::vector< std::string > variable_expressions
Definition RReader.hxx:25
std::vector< std::string > variables
Definition RReader.hxx:24