Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RuleFit.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : RuleFit *
8 * *
9 * *
10 * Description: *
11 * A class implementing various fits of rule ensembles *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
16 * *
17 * Copyright (c) 2005: *
18 * CERN, Switzerland *
19 * Iowa State U. *
20 * MPI-K Heidelberg, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
24 * (see tmva/doc/LICENSE) *
25 **********************************************************************************/
26
27#ifndef ROOT_TMVA_RuleFit
28#define ROOT_TMVA_RuleFit
29
30#include "TMVA/DecisionTree.h"
31#include "TMVA/RuleEnsemble.h"
32#include "TMVA/RuleFitParams.h"
33#include "TMVA/Event.h"
34
35#include <algorithm>
36#include <random>
37#include <vector>
38
39namespace TMVA {
40
41
42 class MethodBase;
43 class MethodRuleFit;
44 class MsgLogger;
45
46 class RuleFit {
47
48 public:
49
50 // main constructor
51 RuleFit( const TMVA::MethodBase *rfbase );
52
53 // empty constructor
54 RuleFit( void );
55
56 virtual ~RuleFit( void );
57
58 void InitNEveEff();
59 void InitPtrs( const TMVA::MethodBase *rfbase );
60 void Initialize( const TMVA::MethodBase *rfbase );
61
62 void SetMsgType( EMsgType t );
63
64 void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
65
67 {
68 std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
69 }
70
71 void SetMethodBase( const MethodBase *rfbase );
72
73 // make the forest of trees for rule generation
74 void MakeForest();
75
76 // build a tree
77 void BuildTree( TMVA::DecisionTree *dt );
78
79 // save event weights
80 void SaveEventWeights();
81
82 // restore saved event weights
84
85 // boost events based on the given tree
86 void Boost( TMVA::DecisionTree *dt );
87
88 // calculate and print some statistics on the given forest
89 void ForestStatistics();
90
91 // calculate the discriminating variable for the given event
92 Double_t EvalEvent( const Event& e );
93
94 // calculate sum of
95 Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
96
97 // do the fitting of the coefficients
98 void FitCoefficients();
99
100 // calculate variable and rule importance from a set of events
101 void CalcImportance();
102
103 // set usage of linear term
105 // set usage of rules
107 // set usage of linear term
109 // set minimum importance allowed
111 // set minimum rule distance - see RuleEnsemble
113 // set path related parameters
117 // make visualization histograms
121 void MakeVisHists();
122 void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
123 void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
124 void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
125 void FillLin(TH2F* h2,Int_t vind);
126 void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
127 void NormVisHists(std::vector<TH2F *> & hlist);
128 void MakeDebugHists();
129 Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
130 // accessors
132 Double_t GetNEveEff() const { return fNEveEffTrain; } // reweighted number of events = sum(wi)
133 const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
134 Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
135
136 // const Event* GetTrainingEvent(UInt_t i, UInt_t isub) const { return &(fTrainingEvents[fSubsampleEvents[isub]])[i]; }
137
138 const std::vector< const TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
139 // const std::vector< Int_t > & GetSubsampleEvents() const { return fSubsampleEvents; }
140
141 // void GetSubsampleEvents(Int_t sub, UInt_t & ibeg, UInt_t & iend) const;
142 void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
143 //
144 const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
145 const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
150 const MethodBase * GetMethodBase() const { return fMethodBase; }
151
152 private:
153
154 // copy constructor
155 RuleFit( const RuleFit & other );
156
157 // copy method
158 void Copy( const RuleFit & other );
159
160 std::vector<const TMVA::Event *> fTrainingEvents; ///< all training events
161 std::vector<const TMVA::Event *> fTrainingEventsRndm; ///< idem, but randomly shuffled
162 std::vector<Double_t> fEventWeights; ///< original weights of the events - follows fTrainingEvents
163 UInt_t fNTreeSample; ///< number of events in sub sample = frac*neve
164
165 Double_t fNEveEffTrain; ///< reweighted number of events = sum(wi)
166 std::vector< const TMVA::DecisionTree *> fForest; ///< the input forest of decision trees
167 RuleEnsemble fRuleEnsemble; ///< the ensemble of rules
168 RuleFitParams fRuleFitParams; ///< fit rule parameters
169 const MethodRuleFit *fMethodRuleFit; ///< pointer the method which initialized this RuleFit instance
170 const MethodBase *fMethodBase; ///< pointer the method base which initialized this RuleFit instance
171 Bool_t fVisHistsUseImp; ///< if true, use importance as weight; else coef in vis hists
172
173 mutable MsgLogger* fLogger; ///<! message logger
174 MsgLogger& Log() const { return *fLogger; }
175
176 static const Int_t randSEED = 0; // set to 1 for debugging purposes or to zero for random seeds
177 std::default_random_engine fRNGEngine;
178
179 ClassDef(RuleFit,0); // Calculations for Friedman's RuleFit method
180 };
181}
182
183#endif
#define d(i)
Definition RSha256.hxx:102
#define f(i)
Definition RSha256.hxx:104
#define e(i)
Definition RSha256.hxx:103
bool Bool_t
Definition RtypesCore.h:63
unsigned int UInt_t
Definition RtypesCore.h:46
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
#define ClassDef(name, id)
Definition Rtypes.h:342
2-D histogram with a float per channel (see TH1 documentation)
Definition TH2.h:308
Implementation of a Decision Tree.
Virtual base Class for all MVA method.
Definition MethodBase.h:111
J Friedman's RuleFit method.
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
void SetRuleMinDist(Double_t d)
void SetImportanceCut(Double_t minimp=0)
A class doing the actual fitting of a linear model using rules as base functions.
void SetGDPathStep(Double_t s)
void SetGDTau(Double_t t)
void SetGDNPathSteps(Int_t np)
A class implementing various fits of rule ensembles.
Definition RuleFit.h:46
void GetRndmSampleEvents(std::vector< const TMVA::Event * > &evevec, UInt_t nevents)
draw a random subsample of the training events without replacement
Definition RuleFit.cxx:456
Double_t EvalEvent(const Event &e)
evaluate single event
Definition RuleFit.cxx:421
UInt_t fNTreeSample
number of events in sub sample = frac*neve
Definition RuleFit.h:163
void SetMethodBase(const MethodBase *rfbase)
set MethodBase
Definition RuleFit.cxx:150
Double_t GetNEveEff() const
Definition RuleFit.h:132
void InitPtrs(const TMVA::MethodBase *rfbase)
initialize pointers
Definition RuleFit.cxx:109
void Boost(TMVA::DecisionTree *dt)
Boost the events.
Definition RuleFit.cxx:328
RuleEnsemble * GetRuleEnsemblePtr()
Definition RuleFit.h:146
Bool_t fVisHistsUseImp
if true, use importance as weight; else coef in vis hists
Definition RuleFit.h:171
const RuleFitParams & GetRuleFitParams() const
Definition RuleFit.h:147
void ForestStatistics()
summary of statistics of all trees
Definition RuleFit.cxx:375
static const Int_t randSEED
Definition RuleFit.h:176
void CalcImportance()
calculates the importance of each rule
Definition RuleFit.cxx:407
RuleFitParams * GetRuleFitParamsPtr()
Definition RuleFit.h:148
void SetMsgType(EMsgType t)
set the current message type to that of mlog for this class and all other subtools
Definition RuleFit.cxx:190
void Initialize(const TMVA::MethodBase *rfbase)
initialize the parameters of the RuleFit method and make rules
Definition RuleFit.cxx:119
std::vector< const TMVA::Event * > fTrainingEventsRndm
idem, but randomly shuffled
Definition RuleFit.h:161
virtual ~RuleFit(void)
destructor
Definition RuleFit.cxx:89
void FillVisHistCorr(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all correlation plots
Definition RuleFit.cxx:704
std::default_random_engine fRNGEngine
Definition RuleFit.h:177
void InitNEveEff()
init effective number of events (using event weights)
Definition RuleFit.cxx:97
UInt_t GetNTreeSample() const
Definition RuleFit.h:131
std::vector< const TMVA::DecisionTree * > fForest
the input forest of decision trees
Definition RuleFit.h:166
void SetGDPathStep(Double_t s=0.01)
Definition RuleFit.h:115
MsgLogger & Log() const
Definition RuleFit.h:174
const MethodBase * fMethodBase
pointer the method base which initialized this RuleFit instance
Definition RuleFit.h:170
Double_t GetTrainingEventWeight(UInt_t i) const
Definition RuleFit.h:134
std::vector< const TMVA::Event * > fTrainingEvents
all training events
Definition RuleFit.h:160
void SaveEventWeights()
save event weights - must be done before making the forest
Definition RuleFit.cxx:298
void FillCut(TH2F *h2, const TMVA::Rule *rule, Int_t vind)
Fill cut.
Definition RuleFit.cxx:522
void FillLin(TH2F *h2, Int_t vind)
fill lin
Definition RuleFit.cxx:573
Bool_t GetCorrVars(TString &title, TString &var1, TString &var2)
get first and second variables from title
Definition RuleFit.cxx:743
void UseCoefficientsVisHists()
Definition RuleFit.h:120
const Event * GetTrainingEvent(UInt_t i) const
Definition RuleFit.h:133
void MakeForest()
make a forest of decisiontrees
Definition RuleFit.cxx:221
const std::vector< const TMVA::DecisionTree * > & GetForest() const
Definition RuleFit.h:144
void FitCoefficients()
Fit the coefficients for the rule ensemble.
Definition RuleFit.cxx:398
void SetRuleMinDist(Double_t d)
Definition RuleFit.h:112
void SetModelRules()
Definition RuleFit.h:106
RuleFit(const RuleFit &other)
const MethodRuleFit * fMethodRuleFit
pointer the method which initialized this RuleFit instance
Definition RuleFit.h:169
const MethodBase * GetMethodBase() const
Definition RuleFit.h:150
void FillCorr(TH2F *h2, const TMVA::Rule *rule, Int_t v1, Int_t v2)
fill rule correlation between vx and vy, weighted with either the importance or the coefficient
Definition RuleFit.cxx:597
void NormVisHists(std::vector< TH2F * > &hlist)
normalize rule importance hists
Definition RuleFit.cxx:475
void SetGDNPathSteps(Int_t n=100)
Definition RuleFit.h:116
void RestoreEventWeights()
save event weights - must be done before making the forest
Definition RuleFit.cxx:310
RuleFitParams fRuleFitParams
fit rule parameters
Definition RuleFit.h:168
void SetVisHistsUseImp(Bool_t f)
Definition RuleFit.h:118
void SetModelFull()
Definition RuleFit.h:108
void MakeVisHists()
this will create histograms visualizing the rule ensemble
Definition RuleFit.cxx:766
void FillVisHistCut(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all variables
Definition RuleFit.cxx:673
std::vector< Double_t > fEventWeights
original weights of the events - follows fTrainingEvents
Definition RuleFit.h:162
void BuildTree(TMVA::DecisionTree *dt)
build the decision tree using fNTreeSample events from fTrainingEventsRndm
Definition RuleFit.cxx:200
const std::vector< const TMVA::Event * > & GetTrainingEvents() const
Definition RuleFit.h:138
void SetGDTau(Double_t t=0.0)
Definition RuleFit.h:114
const MethodRuleFit * GetMethodRuleFit() const
Definition RuleFit.h:149
void UseImportanceVisHists()
Definition RuleFit.h:119
void SetTrainingEvents(const std::vector< const TMVA::Event * > &el)
set the training events randomly
Definition RuleFit.cxx:429
void ReshuffleEvents()
Definition RuleFit.h:66
Double_t fNEveEffTrain
reweighted number of events = sum(wi)
Definition RuleFit.h:165
void SetModelLinear()
Definition RuleFit.h:104
void Copy(const RuleFit &other)
copy method
Definition RuleFit.cxx:159
RuleEnsemble fRuleEnsemble
the ensemble of rules
Definition RuleFit.h:167
const RuleEnsemble & GetRuleEnsemble() const
Definition RuleFit.h:145
void SetImportanceCut(Double_t minimp=0)
Definition RuleFit.h:110
Double_t CalcWeightSum(const std::vector< const TMVA::Event * > *events, UInt_t neve=0)
calculate the sum of weights
Definition RuleFit.cxx:175
RuleFit(void)
default constructor
Definition RuleFit.cxx:75
MsgLogger * fLogger
! message logger
Definition RuleFit.h:173
void MakeDebugHists()
this will create a histograms intended rather for debugging or for the curious user
Definition RuleFit.cxx:926
Implementation of a rule.
Definition Rule.h:50
Basic string class.
Definition TString.h:139
const Int_t n
Definition legend1.C:16
create variable transformations