35template <
class Value_t>
36void softmaxTransformInplace(Value_t *out,
int nOut)
42 for (
int i = 1; i < nOut; ++i) {
45 for (
int i = 0; i < nOut; ++i) {
50 for (
int i = 0; i < nOut; ++i) {
51 out[i] /=
static_cast<float>(norm);
57inline bool isInteger(
const std::string &s)
59 if (s.empty() || ((!isdigit(s[0])) && (s[0] !=
'-') && (s[0] !=
'+')))
63 strtol(s.c_str(), &
p, 10);
68template <
class NumericType>
69struct NumericAfterSubstrOutput {
70 explicit NumericAfterSubstrOutput()
82template <
class NumericType>
83inline NumericAfterSubstrOutput<NumericType> numericAfterSubstr(std::string
const &str, std::string
const &substr)
86 NumericAfterSubstrOutput<NumericType>
output;
89 std::size_t found = str.find(substr);
90 if (found != std::string::npos) {
92 std::stringstream ss(str.substr(found + substr.size(), str.size() - found + substr.size()));
112 const std::size_t rows =
x.GetShape()[0];
113 const std::size_t cols =
x.GetShape()[1];
115 std::vector<Value_t> xRow(cols);
116 std::vector<Value_t> yRow(nOut);
117 for (std::size_t iRow = 0; iRow < rows; ++iRow) {
118 for (std::size_t iCol = 0; iCol < cols; ++iCol) {
119 xRow[iCol] =
x({iRow, iCol});
122 for (std::size_t iOut = 0; iOut < nOut; ++iOut) {
123 y({iRow, iOut}) = yRow[iOut];
131 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
133 throw std::runtime_error(
134 "Error in RBDT::softmax : binary classification models don't support softmax evaluation. Plase set "
135 "the number of classes in the RBDT-creating function if this is a multiclassification model.");
138 for (std::size_t i = 0; i < nOut; ++i) {
139 out[i] = fBaseScore + fBaseResponses[i];
143 for (
int index : fRootIndices) {
145 int r = fRightIndices[
index];
146 int l = fLeftIndices[
index];
149 out[fTreeNumbers[iRootIndex] % nOut] += fResponses[-
index];
153 softmaxTransformInplace(out, nOut);
158 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
162 out[0] = EvaluateBinary(array);
164 out[0] = 1.0 / (1.0 + std::exp(-out[0]));
171 Value_t out = fBaseScore + fBaseResponses[0];
173 for (std::vector<int>::const_iterator indexIter = fRootIndices.begin(); indexIter != fRootIndices.end();
175 int index = *indexIter;
177 int r = fRightIndices[
index];
178 int l = fLeftIndices[
index];
181 out += fResponses[-
index];
194 for (
int &idx : indices) {
195 auto foundNode = nodeIndices.find(idx);
196 if (foundNode != nodeIndices.end()) {
197 idx = foundNode->second;
200 auto foundLeaf = leafIndices.find(idx);
201 if (foundLeaf != leafIndices.end()) {
202 idx = -foundLeaf->second;
205 std::stringstream errMsg;
206 errMsg <<
"RBDT: something is wrong in the node structure - node with index " << idx <<
" doesn't exist";
207 throw std::runtime_error(errMsg.str());
218 if (nPreviousNodes !=
static_cast<int>(ff.
fCutValues.size())) {
222 int treeNumbers = ff.
fRootIndices.size() + treesSkipped;
235 std::vector<std::string> &features,
int nClasses,
236 bool logistic,
Value_t baseScore)
238 const std::string info =
"constructing RBDT from " + txtpath +
": ";
241 throw std::runtime_error(info +
"file does not exists");
244 std::ifstream file(txtpath.c_str());
245 return LoadText(file, features, nClasses, logistic, baseScore);
249 int nClasses,
bool logistic,
Value_t baseScore)
251 const std::string info =
"constructing RBDT from istream: ";
258 int treesSkipped = 0;
261 std::unordered_map<std::string, int> varIndices;
262 bool fixFeatures =
false;
264 if (!features.empty()) {
266 nVariables = features.size();
267 for (
int i = 0; i < nVariables; ++i) {
268 varIndices[features[i]] = i;
277 int nPreviousNodes = 0;
278 int nPreviousLeaves = 0;
280 while (std::getline(file,
line)) {
281 std::size_t foundBegin =
line.find(
"[");
282 std::size_t foundEnd =
line.find(
"]");
283 if (foundBegin != std::string::npos) {
284 std::string subline =
line.substr(foundBegin + 1, foundEnd - foundBegin - 1);
285 if (util::isInteger(subline) && !ff.
fResponses.empty()) {
286 terminateTree(ff, nPreviousNodes, nPreviousLeaves, nodeIndices, leafIndices, treesSkipped);
287 }
else if (!util::isInteger(subline)) {
288 std::stringstream ss(
line);
293 std::vector<std::string> splitstring =
ROOT::Split(subline,
"<");
294 std::string
const &varName = splitstring[0];
297 std::stringstream ss1(splitstring[1]);
300 if (!varIndices.count(varName)) {
302 throw std::runtime_error(info +
"feature " + varName +
" not in list of features");
304 varIndices[varName] = nVariables;
305 features.push_back(varName);
310 util::NumericAfterSubstrOutput<int>
output = util::numericAfterSubstr<int>(
line,
"yes=");
314 throw std::runtime_error(info +
"problem while parsing the text dump");
316 output = util::numericAfterSubstr<int>(
output.rest,
"no=");
320 throw std::runtime_error(info +
"problem while parsing the text dump");
327 std::size_t nNodeIndices = nodeIndices.size();
328 nodeIndices[
index] = nNodeIndices + nPreviousNodes;
332 util::NumericAfterSubstrOutput<Value_t>
output = util::numericAfterSubstr<Value_t>(
line,
"leaf=");
334 std::stringstream ss(
line);
340 std::size_t nLeafIndices = leafIndices.size();
341 leafIndices[
index] = nLeafIndices + nPreviousLeaves;
345 terminateTree(ff, nPreviousNodes, nPreviousLeaves, nodeIndices, leafIndices, treesSkipped);
347 if (nClasses > 2 && (ff.
fRootIndices.size() + treesSkipped) % nClasses != 0) {
348 std::stringstream ss;
349 ss <<
"Error in RBDT construction : Forest has " << ff.
fRootIndices.size()
350 <<
" trees, which is not compatible with " << nClasses <<
"classes!";
351 throw std::runtime_error(ss.str());
360 if (!file || file->IsZombie()) {
361 throw std::runtime_error(
"Failed to open input file " +
filename);
365 throw std::runtime_error(
"No RBDT with name " + key);
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t wmax
R__EXTERN TSystem * gSystem
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
std::vector< Value_t > fCutValues
static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves, IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped)
RBDT()=default
IO constructor (both for ROOT IO and LoadText()).
static void correctIndices(std::span< int > indices, IndexMap const &nodeIndices, IndexMap const &leafIndices)
RBDT uses a more efficient representation of the BDT in flat arrays.
std::vector< int > fRightIndices
std::unordered_map< int, int > IndexMap
Map from XGBoost to RBDT indices.
void Softmax(const Value_t *array, Value_t *out) const
std::vector< int > fTreeNumbers
Value_t EvaluateBinary(const Value_t *array) const
std::vector< Value_t > fResponses
std::vector< Value_t > fBaseResponses
Vector Compute(const Vector &x) const
Compute model prediction on a single event.
std::vector< unsigned int > fCutIndices
void ComputeImpl(const Value_t *array, Value_t *out) const
static RBDT LoadText(std::string const &txtpath, std::vector< std::string > &features, int nClasses, bool logistic, Value_t baseScore)
std::vector< int > fRootIndices
std::vector< int > fLeftIndices
RTensor is a container with contiguous memory and shape information.
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.