Logo ROOT   master
Reference Guide
RReader.hxx
Go to the documentation of this file.
1 #ifndef TMVA_RREADER
2 #define TMVA_RREADER
3 
4 #include "TString.h"
5 #include "TXMLEngine.h"
6 #include "ROOT/RMakeUnique.hxx"
7 
8 #include "TMVA/RTensor.hxx"
9 #include "TMVA/Reader.h"
10 
11 #include <memory> // std::unique_ptr
12 #include <sstream> // std::stringstream
13 
14 namespace TMVA {
15 namespace Experimental {
16 
17 namespace Internal {
18 
19 /// Internal definition of analysis types
21 
22 /// Container for information extracted from TMVA XML config
23 struct XMLConfig {
24  unsigned int numVariables;
25  std::vector<std::string> variables;
26  std::vector<std::string> expressions;
27  unsigned int numClasses;
28  std::vector<std::string> classes;
31  : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
33  {
34  }
35 };
36 
37 /// Parse TMVA XML config
38 inline XMLConfig ParseXMLConfig(const std::string &filename)
39 {
40  XMLConfig c;
41 
42  // Parse XML file and find root node
43  TXMLEngine xml;
44  auto xmldoc = xml.ParseFile(filename.c_str());
45  if (xmldoc == 0) {
46  std::stringstream ss;
47  ss << "Failed to open TMVA XML file "
48  << filename << ".";
49  throw std::runtime_error(ss.str());
50  }
51  auto mainNode = xml.DocGetRootElement(xmldoc);
52  for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
53  const auto nodeName = std::string(xml.GetNodeName(node));
54  // Read out input variables
55  if (nodeName.compare("Variables") == 0) {
56  c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
57  c.variables = std::vector<std::string>(c.numVariables);
58  c.expressions = std::vector<std::string>(c.numVariables);
59  for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
60  const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
61  c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
62  c.expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
63  }
64  }
65  // Read out output classes
66  else if (nodeName.compare("Classes") == 0) {
67  c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
68  for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
69  c.classes.push_back(xml.GetAttr(thisNode, "Name"));
70  }
71  }
72  // Read out analysis type
73  else if (nodeName.compare("GeneralInfo") == 0) {
74  std::string analysisType = "";
75  for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
76  if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
77  analysisType = xml.GetAttr(thisNode, "value");
78  }
79  }
80  if (analysisType.compare("Classification") == 0) {
82  } else if (analysisType.compare("Regression") == 0) {
84  } else if (analysisType.compare("Multiclass") == 0) {
86  }
87  }
88  }
89  xml.FreeDoc(xmldoc);
90 
91  // Error-handling
92  if (c.numVariables != c.variables.size() || c.numVariables == 0) {
93  std::stringstream ss;
94  ss << "Failed to parse input variables from TMVA config " << filename << ".";
95  throw std::runtime_error(ss.str());
96  }
97  if (c.numClasses != c.classes.size() || c.numClasses == 0) {
98  std::stringstream ss;
99  ss << "Failed to parse output classes from TMVA config " << filename << ".";
100  throw std::runtime_error(ss.str());
101  }
102  if (c.analysisType == Internal::AnalysisType::Undefined) {
103  std::stringstream ss;
104  ss << "Failed to parse analysis type from TMVA config " << filename << ".";
105  throw std::runtime_error(ss.str());
106  }
107 
108  return c;
109 }
110 
111 } // namespace Internal
112 
113 /// TMVA::Reader legacy interface
114 class RReader {
115 private:
116  std::unique_ptr<Reader> fReader;
117  std::vector<float> fValues;
118  std::vector<std::string> fVariables;
119  std::vector<std::string> fExpressions;
120  unsigned int fNumClasses;
121  const char *name = "RReader";
123 
124 public:
125  /// Create TMVA model from XML file
126  RReader(const std::string &path)
127  {
128  // Load config
129  auto c = Internal::ParseXMLConfig(path);
130  fVariables = c.variables;
131  fExpressions = c.expressions;
132  fAnalysisType = c.analysisType;
133  fNumClasses = c.numClasses;
134 
135  // Setup reader
136  fReader = std::make_unique<Reader>("Silent");
137  const auto numVars = fVariables.size();
138  fValues = std::vector<float>(numVars);
139  for (std::size_t i = 0; i < numVars; i++) {
140  fReader->AddVariable(TString(fExpressions[i]), &fValues[i]);
141  }
142  fReader->BookMVA(name, path.c_str());
143  }
144 
145  /// Compute model prediction on vector
146  std::vector<float> Compute(const std::vector<float> &x)
147  {
148  if (x.size() != fVariables.size())
149  throw std::runtime_error("Size of input vector is not equal to number of variables.");
150 
151  // Copy over inputs to memory used by TMVA reader
152  for (std::size_t i = 0; i < x.size(); i++) {
153  fValues[i] = x[i];
154  }
155 
156  // Take lock to protect model evaluation
158 
159  // Evaluate TMVA model
160  // Classification
162  return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
163  }
164  // Regression
166  return fReader->EvaluateRegression(name);
167  }
168  // Multiclass
170  return fReader->EvaluateMulticlass(name);
171  }
172  // Throw error
173  else {
174  throw std::runtime_error("RReader has undefined analysis type.");
175  return std::vector<float>();
176  }
177  }
178 
179  /// Compute model prediction on input RTensor
181  {
182  // Error-handling for input tensor
183  const auto shape = x.GetShape();
184  if (shape.size() != 2)
185  throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
186 
187  const auto numEntries = shape[0];
188  const auto numVars = shape[1];
189  if (numVars != fVariables.size())
190  throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
191 
192  // Define shape of output tensor based on analysis type
193  unsigned int numClasses = 1;
195  numClasses = fNumClasses;
196  RTensor<float> y({numEntries * numClasses});
198  y = y.Reshape({numEntries, numClasses});
199 
200  // Fill output tensor
201  for (std::size_t i = 0; i < numEntries; i++) {
202  for (std::size_t j = 0; j < numVars; j++) {
203  fValues[j] = x(i, j);
204  }
206  // Classification
208  y(i) = fReader->EvaluateMVA(name);
209  }
210  // Regression
212  y(i) = fReader->EvaluateRegression(name)[0];
213  }
214  // Multiclass
216  const auto p = fReader->EvaluateMulticlass(name);
217  for (std::size_t k = 0; k < numClasses; k++)
218  y(i, k) = p[k];
219  }
220  }
221 
222  return y;
223  }
224 
225  std::vector<std::string> GetVariableNames() { return fVariables; }
226 };
227 
228 } // namespace Experimental
229 } // namespace TMVA
230 
231 #endif // TMVA_RREADER
std::vector< std::string > classes
Definition: RReader.hxx:28
TMVA::Reader legacy interface.
Definition: RReader.hxx:114
Internal::AnalysisType fAnalysisType
Definition: RReader.hxx:122
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped ...
Basic string class.
Definition: TString.h:131
STL namespace.
std::vector< std::string > GetVariableNames()
Definition: RReader.hxx:225
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
Definition: RReader.hxx:146
Double_t x[n]
Definition: legend1.C:17
std::vector< std::string > fVariables
Definition: RReader.hxx:118
#define R__WRITE_LOCKGUARD(mutex)
R__EXTERN TVirtualRWMutex * gCoreMutex
AnalysisType
Internal definition of analysis types.
Definition: RReader.hxx:20
Container for information extracted from TMVA XML config.
Definition: RReader.hxx:23
const char * GetNodeName(XMLNodePointer_t xmlnode)
returns name of xmlnode
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
Definition: RReader.hxx:38
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
Definition: RReader.hxx:180
std::unique_ptr< Reader > fReader
Definition: RReader.hxx:116
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
const char * GetAttr(XMLNodePointer_t xmlnode, const char *name)
returns value of attribute for xmlnode
Definition: TXMLEngine.cxx:549
Double_t y[n]
Definition: legend1.C:17
std::vector< std::string > expressions
Definition: RReader.hxx:26
RReader(const std::string &path)
Create TMVA model from XML file.
Definition: RReader.hxx:126
RTensor is a container with contiguous memory and shape information.
Definition: RTensor.hxx:162
create variable transformations
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
std::vector< std::string > variables
Definition: RReader.hxx:25
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
std::vector< std::string > fExpressions
Definition: RReader.hxx:119
#define c(i)
Definition: RSha256.hxx:101
std::vector< float > fValues
Definition: RReader.hxx:117