Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RReader.hxx
Go to the documentation of this file.
1#ifndef TMVA_RREADER
2#define TMVA_RREADER
3
4#include "TString.h"
5#include "TXMLEngine.h"
6
7#include "TMVA/RTensor.hxx"
8#include "TMVA/Reader.h"
9
10#include <memory> // std::unique_ptr
11#include <sstream> // std::stringstream
12
13namespace TMVA {
14namespace Experimental {
15
16namespace Internal {
17
18/// Internal definition of analysis types
20
21/// Container for information extracted from TMVA XML config
22struct XMLConfig {
23 unsigned int numVariables;
24 std::vector<std::string> variables;
25 std::vector<std::string> expressions;
26 unsigned int numClasses;
27 std::vector<std::string> classes;
30 : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
32 {
33 }
34};
35
36/// Parse TMVA XML config
37inline XMLConfig ParseXMLConfig(const std::string &filename)
38{
40
41 // Parse XML file and find root node
42 TXMLEngine xml;
43 auto xmldoc = xml.ParseFile(filename.c_str());
44 if (!xmldoc) {
45 std::stringstream ss;
46 ss << "Failed to open TMVA XML file "
47 << filename << ".";
48 throw std::runtime_error(ss.str());
49 }
50 auto mainNode = xml.DocGetRootElement(xmldoc);
51 for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
52 const auto nodeName = std::string(xml.GetNodeName(node));
53 // Read out input variables
54 if (nodeName.compare("Variables") == 0) {
55 c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
56 c.variables = std::vector<std::string>(c.numVariables);
57 c.expressions = std::vector<std::string>(c.numVariables);
58 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
59 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
60 c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
61 c.expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
62 }
63 }
64 // Read out output classes
65 else if (nodeName.compare("Classes") == 0) {
66 c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
67 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
68 c.classes.push_back(xml.GetAttr(thisNode, "Name"));
69 }
70 }
71 // Read out analysis type
72 else if (nodeName.compare("GeneralInfo") == 0) {
73 std::string analysisType = "";
74 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
75 if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
76 analysisType = xml.GetAttr(thisNode, "value");
77 }
78 }
79 if (analysisType.compare("Classification") == 0) {
81 } else if (analysisType.compare("Regression") == 0) {
83 } else if (analysisType.compare("Multiclass") == 0) {
85 }
86 }
87 }
88 xml.FreeDoc(xmldoc);
89
90 // Error-handling
91 if (c.numVariables != c.variables.size() || c.numVariables == 0) {
92 std::stringstream ss;
93 ss << "Failed to parse input variables from TMVA config " << filename << ".";
94 throw std::runtime_error(ss.str());
95 }
96 if (c.numClasses != c.classes.size() || c.numClasses == 0) {
97 std::stringstream ss;
98 ss << "Failed to parse output classes from TMVA config " << filename << ".";
99 throw std::runtime_error(ss.str());
100 }
101 if (c.analysisType == Internal::AnalysisType::Undefined) {
102 std::stringstream ss;
103 ss << "Failed to parse analysis type from TMVA config " << filename << ".";
104 throw std::runtime_error(ss.str());
105 }
106
107 return c;
108}
109
110} // namespace Internal
111
112/// A replacement for the TMVA::Reader legacy interface.
113/// Performs inference for TMVA models stored as XML files.
114/// For neural network inference consider using [SOFIE](https://github.com/root-project/root/blob/master/tmva/sofie/README.md) instead.
115class RReader {
116private:
117 std::unique_ptr<Reader> fReader;
118 std::vector<float> fValues;
119 std::vector<std::string> fVariables;
120 std::vector<std::string> fExpressions;
121 unsigned int fNumClasses;
122 const char *name = "RReader";
124
125public:
126 /// Create TMVA model from XML file
127 RReader(const std::string &path)
128 {
129 // Load config
130 auto c = Internal::ParseXMLConfig(path);
131 fVariables = c.variables;
132 fExpressions = c.expressions;
133 fAnalysisType = c.analysisType;
134 fNumClasses = c.numClasses;
135
136 // Setup reader
137 fReader = std::make_unique<Reader>("Silent");
138 const auto numVars = fVariables.size();
139 fValues = std::vector<float>(numVars);
140 for (std::size_t i = 0; i < numVars; i++) {
141 fReader->AddVariable(TString(fExpressions[i]), &fValues[i]);
142 }
143 fReader->BookMVA(name, path.c_str());
144 }
145
146 /// Compute model prediction on vector
147 std::vector<float> Compute(const std::vector<float> &x)
148 {
149 if (x.size() != fVariables.size())
150 throw std::runtime_error("Size of input vector is not equal to number of variables.");
151
152 // Copy over inputs to memory used by TMVA reader
153 for (std::size_t i = 0; i < x.size(); i++) {
154 fValues[i] = x[i];
155 }
156
157 // Take lock to protect model evaluation
159
160 // Evaluate TMVA model
161 // Classification
163 return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
164 }
165 // Regression
167 return fReader->EvaluateRegression(name);
168 }
169 // Multiclass
171 return fReader->EvaluateMulticlass(name);
172 }
173 // Throw error
174 else {
175 throw std::runtime_error("RReader has undefined analysis type.");
176 return std::vector<float>();
177 }
178 }
179
180 /// Compute model prediction on input RTensor
182 {
183 // Error-handling for input tensor
184 const auto shape = x.GetShape();
185 if (shape.size() != 2)
186 throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
187
188 const auto numEntries = shape[0];
189 const auto numVars = shape[1];
190 if (numVars != fVariables.size())
191 throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
192
193 // Define shape of output tensor based on analysis type
194 unsigned int numClasses = 1;
196 numClasses = fNumClasses;
197 RTensor<float> y({numEntries * numClasses});
199 y = y.Reshape({numEntries, numClasses});
200
201 // Fill output tensor
202 for (std::size_t i = 0; i < numEntries; i++) {
203 for (std::size_t j = 0; j < numVars; j++) {
204 fValues[j] = x(i, j);
205 }
207 // Classification
209 y(i) = fReader->EvaluateMVA(name);
210 }
211 // Regression
213 y(i) = fReader->EvaluateRegression(name)[0];
214 }
215 // Multiclass
217 const auto p = fReader->EvaluateMulticlass(name);
218 for (std::size_t k = 0; k < numClasses; k++)
219 y(i, k) = p[k];
220 }
221 }
222
223 return y;
224 }
225
226 std::vector<std::string> GetVariableNames() { return fVariables; }
227};
228
229} // namespace Experimental
230} // namespace TMVA
231
232#endif // TMVA_RREADER
#define c(i)
Definition RSha256.hxx:101
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define R__WRITE_LOCKGUARD(mutex)
A replacement for the TMVA::Reader legacy interface.
Definition RReader.hxx:115
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
Definition RReader.hxx:181
std::vector< float > fValues
Definition RReader.hxx:118
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
Definition RReader.hxx:147
Internal::AnalysisType fAnalysisType
Definition RReader.hxx:123
std::vector< std::string > fExpressions
Definition RReader.hxx:120
std::vector< std::string > GetVariableNames()
Definition RReader.hxx:226
std::vector< std::string > fVariables
Definition RReader.hxx:119
RReader(const std::string &path)
Create TMVA model from XML file.
Definition RReader.hxx:127
std::unique_ptr< Reader > fReader
Definition RReader.hxx:117
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
RTensor< Value_t, Container_t > Reshape(const Shape_t &shape) const
Reshape tensor.
Definition RTensor.hxx:481
Basic string class.
Definition TString.h:139
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
const char * GetNodeName(XMLNodePointer_t xmlnode)
returns name of xmlnode
const char * GetAttr(XMLNodePointer_t xmlnode, const char *name)
returns value of attribute for xmlnode
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
R__EXTERN TVirtualRWMutex * gCoreMutex
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
Definition RReader.hxx:37
AnalysisType
Internal definition of analysis types.
Definition RReader.hxx:19
create variable transformations
Container for information extracted from TMVA XML config.
Definition RReader.hxx:22
std::vector< std::string > classes
Definition RReader.hxx:27
std::vector< std::string > expressions
Definition RReader.hxx:25
std::vector< std::string > variables
Definition RReader.hxx:24