Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBDT.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Stefan Wunsch (stefan.wunsch@cern.ch) *
10 * Jonas Rembser (jonas.rembser@cern.ch) *
11 * *
12 * Copyright (c) 2024: *
13 * CERN, Switzerland *
14 * *
15 * Redistribution and use in source and binary forms, with or without *
16 * modification, are permitted according to the terms listed in LICENSE *
17 * (see tmva/doc/LICENSE) *
18 **********************************************************************************/
19
20#ifndef TMVA_RBDT
21#define TMVA_RBDT
22
23#include <Rtypes.h>
24#include <ROOT/RSpan.hxx>
25#include <TMVA/RTensor.hxx>
26
27#include <array>
28#include <istream>
29#include <string>
30#include <unordered_map>
31#include <vector>
32
33namespace TMVA {
34
35namespace Experimental {
36
37class RBDT final {
38public:
39 typedef float Value_t;
40
41 /// IO constructor (both for ROOT IO and LoadText()).
42 RBDT() = default;
43
44 /// Construct backends from model in ROOT file.
45 RBDT(const std::string &key, const std::string &filename);
46
47 /// Compute model prediction on a single event.
48 ///
49 /// The method is intended to be used with std::vectors-like containers,
50 /// for example RVecs.
51 template <typename Vector>
52 Vector Compute(const Vector &x) const
53 {
54 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
55 Vector y(nOut);
56 ComputeImpl(x.data(), y.data());
57 return y;
58 }
59
60 /// Compute model prediction on a single event.
61 inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) const { return Compute<std::vector<Value_t>>(x); }
62
64
65 static RBDT LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses, bool logistic,
66 Value_t baseScore);
67
68private:
69 /// Map from XGBoost to RBDT indices.
70 using IndexMap = std::unordered_map<int, int>;
71
72 void Softmax(const Value_t *array, Value_t *out) const;
73 void ComputeImpl(const Value_t *array, Value_t *out) const;
74 Value_t EvaluateBinary(const Value_t *array) const;
75 static void correctIndices(std::span<int> indices, IndexMap const &nodeIndices, IndexMap const &leafIndices);
76 static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves,
77 IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped);
78 static RBDT
79 LoadText(std::istream &is, std::vector<std::string> &features, int nClasses, bool logistic, Value_t baseScore);
80
81 std::vector<int> fRootIndices;
82 std::vector<unsigned int> fCutIndices;
83 std::vector<Value_t> fCutValues;
84 std::vector<int> fLeftIndices;
85 std::vector<int> fRightIndices;
86 std::vector<Value_t> fResponses;
87 std::vector<int> fTreeNumbers;
88 std::vector<Value_t> fBaseResponses;
90 bool fLogistic = false;
91
93};
94
95} // namespace Experimental
96
97} // namespace TMVA
98
99#endif // TMVA_RBDT
#define ClassDefNV(name, id)
Definition Rtypes.h:345
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
std::vector< Value_t > fCutValues
Definition RBDT.hxx:83
static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves, IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped)
Definition RBDT.cxx:212
RBDT()=default
IO constructor (both for ROOT IO and LoadText()).
static void correctIndices(std::span< int > indices, IndexMap const &nodeIndices, IndexMap const &leafIndices)
RBDT uses a more efficient representation of the BDT in flat arrays.
Definition RBDT.cxx:191
std::vector< int > fRightIndices
Definition RBDT.hxx:85
std::unordered_map< int, int > IndexMap
Map from XGBoost to RBDT indices.
Definition RBDT.hxx:70
void Softmax(const Value_t *array, Value_t *out) const
Definition RBDT.cxx:129
std::vector< int > fTreeNumbers
Definition RBDT.hxx:87
Value_t EvaluateBinary(const Value_t *array) const
Definition RBDT.cxx:169
std::vector< Value_t > fResponses
Definition RBDT.hxx:86
std::vector< Value_t > fBaseResponses
Definition RBDT.hxx:88
std::vector< Value_t > Compute(std::vector< Value_t > const &x) const
Compute model prediction on a single event.
Definition RBDT.hxx:61
Vector Compute(const Vector &x) const
Compute model prediction on a single event.
Definition RBDT.hxx:52
std::vector< unsigned int > fCutIndices
Definition RBDT.hxx:82
void ComputeImpl(const Value_t *array, Value_t *out) const
Definition RBDT.cxx:156
static RBDT LoadText(std::string const &txtpath, std::vector< std::string > &features, int nClasses, bool logistic, Value_t baseScore)
Definition RBDT.cxx:234
std::vector< int > fRootIndices
Definition RBDT.hxx:81
std::vector< int > fLeftIndices
Definition RBDT.hxx:84
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
create variable transformations