Logo ROOT  
Reference Guide
MethodRXGB.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodRXGB *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * R eXtreme Gradient Boosting *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <iomanip>
21 
22 #include "TMath.h"
23 #include "Riostream.h"
24 #include "TMatrix.h"
25 #include "TMatrixD.h"
26 #include "TVectorD.h"
27 
29 #include "TMVA/MethodRXGB.h"
30 #include "TMVA/Tools.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Ranking.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/PDF.h"
35 #include "TMVA/ClassifierFactory.h"
36 
37 #include "TMVA/Results.h"
38 #include "TMVA/Timer.h"
39 
40 using namespace TMVA;
41 
42 REGISTER_METHOD(RXGB)
43 
45 
46 //creating an Instance
48 
49 //_______________________________________________________________________
51  const TString &methodTitle,
52  DataSetInfo &dsi,
53  const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
54  fNRounds(10),
55  fEta(0.3),
56  fMaxDepth(6),
57  predict("predict", "xgboost"),
58  xgbtrain("xgboost"),
59  xgbdmatrix("xgb.DMatrix"),
60  xgbsave("xgb.save"),
61  xgbload("xgb.load"),
62  asfactor("as.factor"),
63  asmatrix("as.matrix"),
64  fModel(NULL)
65 {
66  // standard constructor for the RXGB
67 
68 }
69 
70 //_______________________________________________________________________
71 MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
72  : RMethodBase(Types::kRXGB, theData, theWeightFile),
73  fNRounds(10),
74  fEta(0.3),
75  fMaxDepth(6),
76  predict("predict", "xgboost"),
77  xgbtrain("xgboost"),
78  xgbdmatrix("xgb.DMatrix"),
79  xgbsave("xgb.save"),
80  xgbload("xgb.load"),
81  asfactor("as.factor"),
82  asmatrix("as.matrix"),
83  fModel(NULL)
84 {
85 
86 }
87 
88 
89 //_______________________________________________________________________
91 {
92  if (fModel) delete fModel;
93 }
94 
95 //_______________________________________________________________________
97 {
98  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
99  return kFALSE;
100 }
101 
102 
103 //_______________________________________________________________________
105 {
106 
107  if (!IsModuleLoaded) {
108  Error("Init", "R's package xgboost can not be loaded.");
109  Log() << kFATAL << " R's package xgboost can not be loaded."
110  << Endl;
111  return;
112  }
113  //factors creations
114  //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
115  UInt_t size = fFactorTrain.size();
116  fFactorNumeric.resize(size);
117 
118  for (UInt_t i = 0; i < size; i++) {
119  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
120  else fFactorNumeric[i] = 0;
121  }
122 
123 
124 
125 }
126 
128 {
129  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
131  ROOT::R::TRDataFrame params;
132  params["eta"] = fEta;
133  params["max.depth"] = fMaxDepth;
134 
135  SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
136  ROOT::R::Label["label"] = fFactorNumeric,
137  ROOT::R::Label["weight"] = fWeightTrain,
138  ROOT::R::Label["nrounds"] = fNRounds,
139  ROOT::R::Label["params"] = params);
140 
141  fModel = new ROOT::R::TRObject(Model);
142  if (IsModelPersistence())
143  {
144  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
145  Log() << Endl;
146  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
147  Log() << Endl;
148  xgbsave(Model, path);
149  }
150 }
151 
152 //_______________________________________________________________________
154 {
155  DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
156  DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157  DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
158 }
159 
160 //_______________________________________________________________________
162 {
163 }
164 
165 //_______________________________________________________________________
167 {
168  Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
170 }
171 
172 
173 //_______________________________________________________________________
175 {
176  NoErrorCalc(errLower, errUpper);
177  Double_t mvaValue;
178  const TMVA::Event *ev = GetEvent();
179  const UInt_t nvar = DataInfo().GetNVariables();
180  ROOT::R::TRDataFrame fDfEvent;
181  for (UInt_t i = 0; i < nvar; i++) {
182  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
183  }
184  //if using persistence model
186 
187  mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
188  return mvaValue;
189 }
190 
191 ////////////////////////////////////////////////////////////////////////////////
192 /// get all the MVA values for the events of the current Data type
193 std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
194 {
195  Long64_t nEvents = Data()->GetNEvents();
196  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197  if (firstEvt < 0) firstEvt = 0;
198 
199  nEvents = lastEvt-firstEvt;
200 
201  UInt_t nvars = Data()->GetNVariables();
202 
203  // use timer
204  Timer timer( nEvents, GetName(), kTRUE );
205  if (logProgress)
206  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
207  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
208 
209 
210  // fill R DATA FRAME with events data
211  std::vector<std::vector<Float_t> > inputData(nvars);
212  for (UInt_t i = 0; i < nvars; i++) {
213  inputData[i] = std::vector<Float_t>(nEvents);
214  }
215 
216  for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217  Data()->SetCurrentEvent(ievt);
218  const TMVA::Event *e = Data()->GetEvent();
219  assert(nvars == e->GetNVariables());
220  for (UInt_t i = 0; i < nvars; i++) {
221  inputData[i][ievt] = e->GetValue(i);
222  }
223  // if (ievt%100 == 0)
224  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
225  }
226 
227  ROOT::R::TRDataFrame evtData;
228  for (UInt_t i = 0; i < nvars; i++) {
229  evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
230  }
231  //if using persistence model
233 
234  std::vector<Double_t> mvaValues(nEvents);
235  ROOT::R::TRObject pred = predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(evtData)));
236  mvaValues = pred.As<std::vector<Double_t>>();
237 
238  if (logProgress) {
239  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
240  << timer.GetElapsedTime() << " " << Endl;
241  }
242 
243  return mvaValues;
244 
245 }
246 //_______________________________________________________________________
248 {
249 // get help message text
250 //
251 // typical length of text line:
252 // "|--------------------------------------------------------------|"
253  Log() << Endl;
254  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
255  Log() << Endl;
256  Log() << "Decision Trees and Rule-Based Models " << Endl;
257  Log() << Endl;
258  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
259  Log() << Endl;
260  Log() << Endl;
261  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
262  Log() << Endl;
263  Log() << "<None>" << Endl;
264 }
265 
266 //_______________________________________________________________________
268 {
270  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
271  Log() << Endl;
272  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
273  Log() << Endl;
274 
275  SEXP Model = xgbload(path);
276  fModel = new ROOT::R::TRObject(Model);
277 
278 }
279 
280 //_______________________________________________________________________
281 void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
282 {
283 }
TMVA::MethodBase::TestClassification
virtual void TestClassification()
initialization
Definition: MethodBase.cxx:1111
TMVA::MethodRXGB::predict
ROOT::R::TRFunctionImport predict
Definition: MethodRXGB.h:104
TMVA::DataSet::GetNVariables
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
TMVA::Configurable::Log
MsgLogger & Log() const
Definition: Configurable.h:164
TMVA::DataSet::GetCurrentType
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:194
e
#define e(i)
Definition: RSha256.hxx:121
TVectorD.h
TMVA::MethodBase::Data
DataSet * Data() const
Definition: MethodBase.h:408
TMVA::MethodBase::IsModelPersistence
Bool_t IsModelPersistence() const
Definition: MethodBase.h:382
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
Form
char * Form(const char *fmt,...)
TMVA::MethodRXGB::ProcessOptions
void ProcessOptions()
Definition: MethodRXGB.cxx:161
Long64_t
long long Long64_t
Definition: RtypesCore.h:73
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:749
TObject::Error
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:890
TMVA::Event::GetValues
std::vector< Float_t > & GetValues()
return value vector
Definition: Event.h:94
Ranking.h
TMVA::DataSetInfo::GetNVariables
UInt_t GetNVariables() const
Definition: DataSetInfo.h:127
TMVA::MethodRXGB::asmatrix
ROOT::R::TRFunctionImport asmatrix
Definition: MethodRXGB.h:110
VariableTransformBase.h
TMVA::MethodRXGB::GetHelpMessage
void GetHelpMessage() const
Definition: MethodRXGB.cxx:247
TMVA::MethodRXGB::MethodRXGB
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodRXGB.cxx:50
TString
Definition: TString.h:136
Bool_t
bool Bool_t
Definition: RtypesCore.h:63
ROOT::R::TRObject
This is a class to get ROOT's objects from R's objects.
Definition: TRObject.h:83
ROOT::R::TRInterface::Require
Bool_t Require(TString pkg)
Method to load an R's package.
Definition: TRInterface.cxx:198
REGISTER_METHOD
#define REGISTER_METHOD(CLASS)
for example
Definition: ClassifierFactory.h:124
bool
TMatrix.h
TMVA::MethodRXGB::fFactorNumeric
std::vector< UInt_t > fFactorNumeric
Definition: MethodRXGB.h:101
PDF.h
TMVA::MethodBase::DataInfo
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
TMVA::MethodRXGB::DeclareOptions
void DeclareOptions()
Definition: MethodRXGB.cxx:153
TMVA::MethodRXGB::~MethodRXGB
~MethodRXGB(void)
Definition: MethodRXGB.cxx:90
TMVA::MethodRXGB::GetMvaValues
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodRXGB.cxx:193
TMVA::MethodRXGB::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodRXGB.cxx:96
TMVA::DataSetInfo
Definition: DataSetInfo.h:62
TMVA::Timer::GetElapsedTime
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
TMVA::MethodBase::GetMethodName
const TString & GetMethodName() const
Definition: MethodBase.h:330
Timer.h
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:150
TMVA::MethodRXGB::fModel
ROOT::R::TRObject * fModel
Definition: MethodRXGB.h:111
TMVA::DataSet::GetEvent
const Event * GetEvent() const
Definition: DataSet.cxx:202
TMVA::DataSet::GetNEvents
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TMVA::MethodRXGB::fMaxDepth
UInt_t fMaxDepth
Definition: MethodRXGB.h:98
TMVA::MethodRXGB::MakeClass
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Definition: MethodRXGB.cxx:281
TMVA::Types::kClassification
@ kClassification
Definition: Types.h:151
TMVA::RMethodBase
Definition: RMethodBase.h:48
TMVA::MethodBase::NoErrorCalc
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
TMVA::RMethodBase::fDfTrain
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:91
TMVA::MethodRXGB::GetMvaValue
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodRXGB.cxx:174
TMVA::MethodBase::ReadStateFromFile
void ReadStateFromFile()
Function to write options and weights to file.
Definition: MethodBase.cxx:1412
TMVA::MethodBase::GetWeightFileDir
const TString & GetWeightFileDir() const
Definition: MethodBase.h:490
TMVA::MethodRXGB::ReadModelFromFile
void ReadModelFromFile()
Definition: MethodRXGB.cxx:267
TMVA::MethodRXGB::IsModuleLoaded
static Bool_t IsModuleLoaded
Definition: MethodRXGB.h:99
TMVA::Types
Definition: Types.h:96
Types.h
TMVA::MethodRXGB::fNRounds
UInt_t fNRounds
Definition: MethodRXGB.h:96
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:182
Config.h
unsigned int
TMVA::Timer
Definition: Timer.h:80
TMVA::Tools::Color
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
TMVA::Types::kTraining
@ kTraining
Definition: Types.h:167
TMVA::MethodRXGB::xgbsave
ROOT::R::TRFunctionImport xgbsave
Definition: MethodRXGB.h:107
TMVA::MethodRXGB::fEta
Double_t fEta
Definition: MethodRXGB.h:97
TMVA::RMethodBase::fWeightTrain
TVectorD fWeightTrain
Definition: RMethodBase.h:93
TMVA::DataSet::SetCurrentEvent
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:111
Double_t
double Double_t
Definition: RtypesCore.h:59
ROOT::R::TRObject::As
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition: TRObject.h:171
TMVA::MethodBase::GetName
const char * GetName() const
Definition: MethodBase.h:333
TMVA::Event
Definition: Event.h:51
TMVA::MethodBase::GetEvent
const Event * GetEvent() const
Definition: MethodBase.h:749
TMVA::MethodRXGB::xgbdmatrix
ROOT::R::TRFunctionImport xgbdmatrix
Definition: MethodRXGB.h:106
ROOT::R::TRInterface::Instance
static TRInterface & Instance()
static method to get an TRInterface instance reference
Definition: TRInterface.cxx:185
ROOT::R::Label
const Rcpp::internal::NamedPlaceHolder & Label
TMVA::MethodRXGB::TestClassification
virtual void TestClassification()
initialization
Definition: MethodRXGB.cxx:166
TMVA::MethodRXGB::xgbtrain
ROOT::R::TRFunctionImport xgbtrain
Definition: MethodRXGB.h:105
TMVA::MethodRXGB
Definition: MethodRXGB.h:41
Tools.h
ClassifierFactory.h
type
int type
Definition: TGX11.cxx:121
TMatrixD.h
Results.h
Riostream.h
TMVA::gTools
Tools & gTools()
TMVA::Configurable::DeclareOptionRef
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
TMVA::RMethodBase::fFactorTrain
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:95
TMVA::MethodRXGB::Init
void Init()
Definition: MethodRXGB.cxx:104
ROOT::R::TRDataFrame
Definition: TRDataFrame.h:189
TMVA::MethodRXGB::Train
void Train()
Definition: MethodRXGB.cxx:127
MethodRXGB.h
TMath.h
TMVA::DataSetInfo::GetListOfVariables
std::vector< TString > GetListOfVariables() const
returns list of variables
Definition: DataSetInfo.cxx:393
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int