14namespace Experimental {
 
   46      ss << 
"Failed to open TMVA XML file " 
   48      throw std::runtime_error(ss.str());
 
   51   for (
auto node = xml.
GetChild(mainNode); node; node = xml.
GetNext(node)) {
 
   52      const auto nodeName = std::string(xml.
GetNodeName(node));
 
   54      if (nodeName.compare(
"Variables") == 0) {
 
   55         c.numVariables = std::atoi(xml.
GetAttr(node, 
"NVar"));
 
   56         c.variables = std::vector<std::string>(
c.numVariables);
 
   57         c.expressions = std::vector<std::string>(
c.numVariables);
 
   58         for (
auto thisNode = xml.
GetChild(node); thisNode; thisNode = xml.
GetNext(thisNode)) {
 
   59            const auto iVariable = std::atoi(xml.
GetAttr(thisNode, 
"VarIndex"));
 
   60            c.variables[iVariable] = xml.
GetAttr(thisNode, 
"Title");
 
   61            c.expressions[iVariable] = xml.
GetAttr(thisNode, 
"Expression");
 
   65      else if (nodeName.compare(
"Classes") == 0) {
 
   66         c.numClasses = std::atoi(xml.
GetAttr(node, 
"NClass"));
 
   67         for (
auto thisNode = xml.
GetChild(node); thisNode; thisNode = xml.
GetNext(thisNode)) {
 
   68            c.classes.push_back(xml.
GetAttr(thisNode, 
"Name"));
 
   72      else if (nodeName.compare(
"GeneralInfo") == 0) {
 
   73         std::string analysisType = 
"";
 
   74         for (
auto thisNode = xml.
GetChild(node); thisNode; thisNode = xml.
GetNext(thisNode)) {
 
   75            if (std::string(
"AnalysisType").compare(xml.
GetAttr(thisNode, 
"name")) == 0) {
 
   76               analysisType = xml.
GetAttr(thisNode, 
"value");
 
   79         if (analysisType.compare(
"Classification") == 0) {
 
   81         } 
else if (analysisType.compare(
"Regression") == 0) {
 
   83         } 
else if (analysisType.compare(
"Multiclass") == 0) {
 
   91   if (
c.numVariables != 
c.variables.size() || 
c.numVariables == 0) {
 
   93      ss << 
"Failed to parse input variables from TMVA config " << 
filename << 
".";
 
   94      throw std::runtime_error(ss.str());
 
   96   if (
c.numClasses != 
c.classes.size() || 
c.numClasses == 0) {
 
   98      ss << 
"Failed to parse output classes from TMVA config " << 
filename << 
".";
 
   99      throw std::runtime_error(ss.str());
 
  102      std::stringstream ss;
 
  103      ss << 
"Failed to parse analysis type from TMVA config " << 
filename << 
".";
 
  104      throw std::runtime_error(ss.str());
 
  137      fReader = std::make_unique<Reader>(
"Silent");
 
  139      fValues = std::vector<float>(numVars);
 
  140      for (std::size_t i = 0; i < numVars; i++) {
 
  147   std::vector<float> 
Compute(
const std::vector<float> &
x)
 
  150         throw std::runtime_error(
"Size of input vector is not equal to number of variables.");
 
  153      for (std::size_t i = 0; i < 
x.size(); i++) {
 
  163         return std::vector<float>({
static_cast<float>(
fReader->EvaluateMVA(
name))});
 
  175         throw std::runtime_error(
"RReader has undefined analysis type.");
 
  176         return std::vector<float>();
 
  184      const auto shape = 
x.GetShape();
 
  185      if (shape.size() != 2)
 
  186         throw std::runtime_error(
"Can only compute model outputs for input tensor of rank 2.");
 
  188      const auto numEntries = shape[0];
 
  189      const auto numVars = shape[1];
 
  191         throw std::runtime_error(
"Second dimension of input tensor is not equal to number of variables.");
 
  194      unsigned int numClasses = 1;
 
  202      for (std::size_t i = 0; i < numEntries; i++) {
 
  203         for (std::size_t j = 0; j < numVars; j++) {
 
  218            for (std::size_t k = 0; k < numClasses; k++)
 
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define R__WRITE_LOCKGUARD(mutex)
A replacement for the TMVA::Reader legacy interface.
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
std::vector< float > fValues
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
Internal::AnalysisType fAnalysisType
std::vector< std::string > fExpressions
std::vector< std::string > GetVariableNames()
std::vector< std::string > fVariables
RReader(const std::string &path)
Create TMVA model from XML file.
std::unique_ptr< Reader > fReader
RTensor is a container with contiguous memory and shape information.
RTensor< Value_t, Container_t > Reshape(const Shape_t &shape) const
Reshape tensor.
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
const char * GetNodeName(XMLNodePointer_t xmlnode)
returns name of xmlnode
const char * GetAttr(XMLNodePointer_t xmlnode, const char *name)
returns value of attribute for xmlnode
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped
R__EXTERN TVirtualRWMutex * gCoreMutex
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
AnalysisType
Internal definition of analysis types.
create variable transformations
Container for information extracted from TMVA XML config.
std::vector< std::string > classes
unsigned int numVariables
std::vector< std::string > expressions
AnalysisType analysisType
std::vector< std::string > variables