Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBDT.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Web : http://tmva.sourceforge.net *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Stefan Wunsch (stefan.wunsch@cern.ch) *
10 * *
11 * Copyright (c) 2019: *
12 * CERN, Switzerland *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (http://tmva.sourceforge.net/LICENSE) *
17 **********************************************************************************/
18
19#ifndef TMVA_RBDT
20#define TMVA_RBDT
21
22#include "TMVA/RTensor.hxx"
24#include "TFile.h"
25
26#include <vector>
27#include <string>
28#include <sstream> // std::stringstream
29
30namespace TMVA {
31namespace Experimental {
32
33/// Fast boosted decision tree inference
34template <typename Backend = BranchlessJittedForest<float>>
35class RBDT {
36public:
37 using Value_t = typename Backend::Value_t;
38 using Backend_t = Backend;
39
40private:
43 std::vector<Backend_t> fBackends;
44
45public:
46 /// Construct backends from model in ROOT file
47 RBDT(const std::string &key, const std::string &filename)
48 {
49 // Get number of output nodes of the forest
50 auto file = TFile::Open(filename.c_str(), "READ");
51 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_outputs");
52 fNumOutputs = numOutputs->at(0);
53 delete numOutputs;
54
55 // Get objective and decide whether to normalize output nodes for example in the multiclass case
56 auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + "/objective");
57 if (objective->compare("softmax") == 0)
58 fNormalizeOutputs = true;
59 else
60 fNormalizeOutputs = false;
61 delete objective;
62 file->Close();
63
64 // Initialize backends
65 fBackends = std::vector<Backend_t>(fNumOutputs);
66 for (int i = 0; i < fNumOutputs; i++)
67 fBackends[i].Load(key, filename, i);
68 }
69
70 /// Compute model prediction on a single event
71 ///
72 /// The method is intended to be used with std::vectors-like containers,
73 /// for example RVecs.
74 template <typename Vector>
75 Vector Compute(const Vector &x)
76 {
77 Vector y;
78 y.resize(fNumOutputs);
79 for (int i = 0; i < fNumOutputs; i++)
80 fBackends[i].Inference(&x[0], 1, true, &y[i]);
82 Value_t s = 0.0;
83 for (int i = 0; i < fNumOutputs; i++)
84 s += y[i];
85 for (int i = 0; i < fNumOutputs; i++)
86 y[i] /= s;
87 }
88 return y;
89 }
90
91 /// Compute model prediction on a single event
92 std::vector<Value_t> Compute(const std::vector<Value_t> &x) { return this->Compute<std::vector<Value_t>>(x); }
93
94 /// Compute model prediction on input RTensor
96 {
97 const auto rows = x.GetShape()[0];
98 RTensor<Value_t> y({rows, static_cast<std::size_t>(fNumOutputs)}, MemoryLayout::ColumnMajor);
99 const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ? false : true;
100 for (int i = 0; i < fNumOutputs; i++)
101 fBackends[i].Inference(x.GetData(), rows, layout, &y(0, i));
102 if (fNormalizeOutputs) {
103 Value_t s;
104 for (int i = 0; i < static_cast<int>(rows); i++) {
105 s = 0.0;
106 for (int j = 0; j < fNumOutputs; j++)
107 s += y(i, j);
108 for (int j = 0; j < fNumOutputs; j++)
109 y(i, j) /= s;
110 }
111 }
112 return y;
113 }
114};
115
118
119} // namespace Experimental
120} // namespace TMVA
121
122#endif // TMVA_RBDT
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
Fast boosted decision tree inference.
Definition RBDT.hxx:35
RTensor< Value_t > Compute(const RTensor< Value_t > &x)
Compute model prediction on input RTensor.
Definition RBDT.hxx:95
std::vector< Value_t > Compute(const std::vector< Value_t > &x)
Compute model prediction on a single event.
Definition RBDT.hxx:92
typename Backend::Value_t Value_t
Definition RBDT.hxx:37
RBDT(const std::string &key, const std::string &filename)
Construct backends from model in ROOT file.
Definition RBDT.hxx:47
Vector Compute(const Vector &x)
Compute model prediction on a single event.
Definition RBDT.hxx:75
std::vector< Backend_t > fBackends
Definition RBDT.hxx:43
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
MemoryLayout GetMemoryLayout() const
Definition RTensor.hxx:248
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
create variable transformations
Definition file.py:1