14namespace Experimental {
51 ss <<
"Failed to open TMVA XML file "
53 throw std::runtime_error(
ss.str());
56 for (
auto node =
xml.GetChild(
mainNode); node; node =
xml.GetNext(node)) {
57 const auto nodeName = std::string(
xml.GetNodeName(node));
59 if (
nodeName.compare(
"Variables") == 0) {
60 c.numVariables = std::atoi(
xml.GetAttr(node,
"NVar"));
61 c.variables = std::vector<std::string>(
c.numVariables);
62 c.variable_expressions = std::vector<std::string>(
c.numVariables);
70 else if (
nodeName.compare(
"Spectators") == 0) {
71 c.numSpectators = std::atoi(
xml.GetAttr(node,
"NSpec"));
72 c.spectators = std::vector<std::string>(
c.numSpectators);
73 c.spectator_expressions = std::vector<std::string>(
c.numSpectators);
81 else if (
nodeName.compare(
"Classes") == 0) {
82 c.numClasses = std::atoi(
xml.GetAttr(node,
"NClass"));
88 else if (
nodeName.compare(
"GeneralInfo") == 0) {
89 std::string analysisType =
"";
91 if (std::string(
"AnalysisType").compare(
xml.GetAttr(
thisNode,
"name")) == 0) {
95 if (analysisType.compare(
"Classification") == 0) {
97 }
else if (analysisType.compare(
"Regression") == 0) {
99 }
else if (analysisType.compare(
"Multiclass") == 0) {
107 if (
c.numVariables !=
c.variables.size() ||
c.numVariables == 0) {
108 std::stringstream
ss;
109 ss <<
"Failed to parse input variables from TMVA config " <<
filename <<
".";
110 throw std::runtime_error(
ss.str());
112 if (
c.numSpectators !=
c.spectators.size()) {
113 std::stringstream
ss;
114 ss <<
"Failed to parse input spectators from TMVA config " <<
filename <<
".";
115 throw std::runtime_error(
ss.str());
117 if (
c.numClasses !=
c.classes.size() ||
c.numClasses == 0) {
118 std::stringstream
ss;
119 ss <<
"Failed to parse output classes from TMVA config " <<
filename <<
".";
120 throw std::runtime_error(
ss.str());
123 std::stringstream
ss;
124 ss <<
"Failed to parse analysis type from TMVA config " <<
filename <<
".";
125 throw std::runtime_error(
ss.str());
163 fReader = std::make_unique<Reader>(
"Silent");
166 for (std::size_t i = 0; i <
numVars; i++) {
171 for (std::size_t i = 0; i <
numSpecs; i++) {
178 std::vector<float>
Compute(
const std::vector<float> &
x)
184 throw std::runtime_error(
"Size of input vector is not equal to number of variables.");
188 for (std::size_t i = 0; i !=
nVars ; ++i) {
191 for (std::size_t i = 0; i !=
fSpectators.size(); ++i) {
198 return std::vector<float>({
static_cast<float>(
fReader->EvaluateMVA(
name))});
210 throw std::runtime_error(
"RReader has undefined analysis type.");
211 return std::vector<float>();
220 if (shape.size() != 2)
221 throw std::runtime_error(
"Can only compute model outputs for input tensor of rank 2.");
223 const auto numEntries = shape[0];
226 throw std::runtime_error(
"Second dimension of input tensor is not equal to number of variables.");
229 unsigned int numClasses = 1;
234 y =
y.Reshape({numEntries, numClasses});
239 for (std::size_t i = 0; i < numEntries; i++) {
240 for (std::size_t
j = 0;
j <
nVars;
j++) {
257 for (std::size_t k = 0; k < numClasses; k++)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define R__WRITE_LOCKGUARD(mutex)
A replacement for the TMVA::Reader legacy interface.
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
std::vector< std::string > GetSpectatorNames()
Internal::AnalysisType fAnalysisType
std::vector< float > fVariableValues
std::vector< float > fSpectatorValues
std::vector< std::string > fVariableExpressions
std::vector< std::string > GetVariableNames()
std::vector< std::string > fSpectatorExpressions
std::vector< std::string > fSpectators
std::vector< std::string > fVariables
RReader(const std::string &path)
Create TMVA model from XML file.
std::unique_ptr< Reader > fReader
const Shape_t & GetShape() const
R__EXTERN TVirtualRWMutex * gCoreMutex
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
AnalysisType
Internal definition of analysis types.
create variable transformations
Container for information extracted from TMVA XML config.
std::vector< std::string > classes
unsigned int numVariables
AnalysisType analysisType
std::vector< std::string > spectators
std::vector< std::string > spectator_expressions
unsigned int numSpectators
std::vector< std::string > variable_expressions
std::vector< std::string > variables