Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RReader.hxx
Go to the documentation of this file.
1#ifndef TMVA_RREADER
2#define TMVA_RREADER
3
4#include "TString.h"
5#include "TXMLEngine.h"
7
8#include "TMVA/RTensor.hxx"
9#include "TMVA/Reader.h"
10
11#include <memory> // std::unique_ptr
12#include <sstream> // std::stringstream
13
14namespace TMVA {
15namespace Experimental {
16
17namespace Internal {
18
19/// Internal definition of analysis types
21
22/// Container for information extracted from TMVA XML config
23struct XMLConfig {
24 unsigned int numVariables;
25 std::vector<std::string> variables;
26 std::vector<std::string> expressions;
27 unsigned int numClasses;
28 std::vector<std::string> classes;
31 : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
33 {
34 }
35};
36
37/// Parse TMVA XML config
38inline XMLConfig ParseXMLConfig(const std::string &filename)
39{
41
42 // Parse XML file and find root node
43 TXMLEngine xml;
44 auto xmldoc = xml.ParseFile(filename.c_str());
45 if (xmldoc == 0) {
46 std::stringstream ss;
47 ss << "Failed to open TMVA XML file "
48 << filename << ".";
49 throw std::runtime_error(ss.str());
50 }
51 auto mainNode = xml.DocGetRootElement(xmldoc);
52 for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
53 const auto nodeName = std::string(xml.GetNodeName(node));
54 // Read out input variables
55 if (nodeName.compare("Variables") == 0) {
56 c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
57 c.variables = std::vector<std::string>(c.numVariables);
58 c.expressions = std::vector<std::string>(c.numVariables);
59 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
60 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
61 c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
62 c.expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
63 }
64 }
65 // Read out output classes
66 else if (nodeName.compare("Classes") == 0) {
67 c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
68 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
69 c.classes.push_back(xml.GetAttr(thisNode, "Name"));
70 }
71 }
72 // Read out analysis type
73 else if (nodeName.compare("GeneralInfo") == 0) {
74 std::string analysisType = "";
75 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
76 if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
77 analysisType = xml.GetAttr(thisNode, "value");
78 }
79 }
80 if (analysisType.compare("Classification") == 0) {
82 } else if (analysisType.compare("Regression") == 0) {
84 } else if (analysisType.compare("Multiclass") == 0) {
86 }
87 }
88 }
89 xml.FreeDoc(xmldoc);
90
91 // Error-handling
92 if (c.numVariables != c.variables.size() || c.numVariables == 0) {
93 std::stringstream ss;
94 ss << "Failed to parse input variables from TMVA config " << filename << ".";
95 throw std::runtime_error(ss.str());
96 }
97 if (c.numClasses != c.classes.size() || c.numClasses == 0) {
98 std::stringstream ss;
99 ss << "Failed to parse output classes from TMVA config " << filename << ".";
100 throw std::runtime_error(ss.str());
101 }
102 if (c.analysisType == Internal::AnalysisType::Undefined) {
103 std::stringstream ss;
104 ss << "Failed to parse analysis type from TMVA config " << filename << ".";
105 throw std::runtime_error(ss.str());
106 }
107
108 return c;
109}
110
111} // namespace Internal
112
113/// TMVA::Reader legacy interface
114class RReader {
115private:
116 std::unique_ptr<Reader> fReader;
117 std::vector<float> fValues;
118 std::vector<std::string> fVariables;
119 std::vector<std::string> fExpressions;
120 unsigned int fNumClasses;
121 const char *name = "RReader";
123
124public:
125 /// Create TMVA model from XML file
126 RReader(const std::string &path)
127 {
128 // Load config
129 auto c = Internal::ParseXMLConfig(path);
130 fVariables = c.variables;
131 fExpressions = c.expressions;
132 fAnalysisType = c.analysisType;
133 fNumClasses = c.numClasses;
134
135 // Setup reader
136 fReader = std::make_unique<Reader>("Silent");
137 const auto numVars = fVariables.size();
138 fValues = std::vector<float>(numVars);
139 for (std::size_t i = 0; i < numVars; i++) {
140 fReader->AddVariable(TString(fExpressions[i]), &fValues[i]);
141 }
142 fReader->BookMVA(name, path.c_str());
143 }
144
145 /// Compute model prediction on vector
146 std::vector<float> Compute(const std::vector<float> &x)
147 {
148 if (x.size() != fVariables.size())
149 throw std::runtime_error("Size of input vector is not equal to number of variables.");
150
151 // Copy over inputs to memory used by TMVA reader
152 for (std::size_t i = 0; i < x.size(); i++) {
153 fValues[i] = x[i];
154 }
155
156 // Take lock to protect model evaluation
158
159 // Evaluate TMVA model
160 // Classification
162 return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
163 }
164 // Regression
166 return fReader->EvaluateRegression(name);
167 }
168 // Multiclass
170 return fReader->EvaluateMulticlass(name);
171 }
172 // Throw error
173 else {
174 throw std::runtime_error("RReader has undefined analysis type.");
175 return std::vector<float>();
176 }
177 }
178
179 /// Compute model prediction on input RTensor
181 {
182 // Error-handling for input tensor
183 const auto shape = x.GetShape();
184 if (shape.size() != 2)
185 throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
186
187 const auto numEntries = shape[0];
188 const auto numVars = shape[1];
189 if (numVars != fVariables.size())
190 throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
191
192 // Define shape of output tensor based on analysis type
193 unsigned int numClasses = 1;
195 numClasses = fNumClasses;
196 RTensor<float> y({numEntries * numClasses});
198 y = y.Reshape({numEntries, numClasses});
199
200 // Fill output tensor
201 for (std::size_t i = 0; i < numEntries; i++) {
202 for (std::size_t j = 0; j < numVars; j++) {
203 fValues[j] = x(i, j);
204 }
206 // Classification
208 y(i) = fReader->EvaluateMVA(name);
209 }
210 // Regression
212 y(i) = fReader->EvaluateRegression(name)[0];
213 }
214 // Multiclass
216 const auto p = fReader->EvaluateMulticlass(name);
217 for (std::size_t k = 0; k < numClasses; k++)
218 y(i, k) = p[k];
219 }
220 }
221
222 return y;
223 }
224
225 std::vector<std::string> GetVariableNames() { return fVariables; }
226};
227
228} // namespace Experimental
229} // namespace TMVA
230
231#endif // TMVA_RREADER
#define c(i)
Definition RSha256.hxx:101
#define R__WRITE_LOCKGUARD(mutex)
TMVA::Reader legacy interface.
Definition RReader.hxx:114
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
Definition RReader.hxx:180
std::vector< float > fValues
Definition RReader.hxx:117
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
Definition RReader.hxx:146
Internal::AnalysisType fAnalysisType
Definition RReader.hxx:122
std::vector< std::string > fExpressions
Definition RReader.hxx:119
std::vector< std::string > GetVariableNames()
Definition RReader.hxx:225
std::vector< std::string > fVariables
Definition RReader.hxx:118
RReader(const std::string &path)
Create TMVA model from XML file.
Definition RReader.hxx:126
std::unique_ptr< Reader > fReader
Definition RReader.hxx:116
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
RTensor< Value_t, Container_t > Reshape(const Shape_t &shape) const
Reshape tensor.
Definition RTensor.hxx:477
Basic string class.
Definition TString.h:136
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
const char * GetNodeName(XMLNodePointer_t xmlnode)
returns name of xmlnode
const char * GetAttr(XMLNodePointer_t xmlnode, const char *name)
returns value of attribute for xmlnode
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
R__EXTERN TVirtualRWMutex * gCoreMutex
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
Definition RReader.hxx:38
AnalysisType
Internal definition of analysis types.
Definition RReader.hxx:20
create variable transformations
Container for information extracted from TMVA XML config.
Definition RReader.hxx:23
std::vector< std::string > classes
Definition RReader.hxx:28
std::vector< std::string > expressions
Definition RReader.hxx:26
std::vector< std::string > variables
Definition RReader.hxx:25