Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
DecisionTree.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : DecisionTree *
8 * *
9 * *
10 * Description: *
11 * Implementation of a Decision Tree *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * *
20 * Copyright (c) 2005-2011: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * U. of Bonn, Germany *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://mva.sourceforge.net/license.txt) *
29 * *
30 **********************************************************************************/
31
32#ifndef ROOT_TMVA_DecisionTree
33#define ROOT_TMVA_DecisionTree
34
35//////////////////////////////////////////////////////////////////////////
36// //
37// DecisionTree //
38// //
39// Implementation of a Decision Tree //
40// //
41//////////////////////////////////////////////////////////////////////////
42
43#include "TH2.h"
44#include <vector>
45
46#include "TMVA/Types.h"
48#include "TMVA/BinaryTree.h"
50#include "TMVA/SeparationBase.h"
52#include "TMVA/DataSetInfo.h"
53
54#ifdef R__USE_IMT
56#include "TSystem.h"
57#endif
58
59class TRandom3;
60
61namespace TMVA {
62
63 class Event;
64
65 class DecisionTree : public BinaryTree {
66
67 private:
68
69 static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
70
71 public:
72
73 typedef std::vector<TMVA::Event*> EventList;
74 typedef std::vector<const TMVA::Event*> EventConstList;
75
76 // the constructor needed for the "reading" of the decision tree from weight files
77 DecisionTree( void );
78
79 // the constructor needed for constructing the decision tree via training with events
80 DecisionTree( SeparationBase *sepType, Float_t minSize,
81 Int_t nCuts, DataSetInfo* = nullptr,
82 UInt_t cls =0,
83 Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
84 UInt_t nMaxDepth=9999999,
85 Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
86 Int_t treeID = 0);
87
88 // copy constructor
90
91 virtual ~DecisionTree( void );
92
93 // Retrieves the address of the root node
94 virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
95 virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
96 virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
97 static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
98 virtual const char* ClassName() const { return "DecisionTree"; }
99
100 // building of a tree by recursively splitting the nodes
101
102 // UInt_t BuildTree( const EventList & eventSample,
103 // DecisionTreeNode *node = nullptr);
104 UInt_t BuildTree( const EventConstList & eventSample,
105 DecisionTreeNode *node = nullptr);
106 // determine the way how a node is split (which variable, which cut value)
107
108 Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
109 Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
110 Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
111 void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
112 std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
113
114 // fill at tree with a given structure already (just see how many signa/bkgr
115 // events end up in each node
116
117 void FillTree( const EventList & eventSample);
118
119 // fill the existing the decision tree structure by filling event
120 // in from the top node and see where they happen to end up
121 void FillEvent( const TMVA::Event & event,
123
124 // returns: 1 = Signal (right), -1 = Bkg (left)
125
126 Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
128
129 // return the individual relative variable importance
130 std::vector< Double_t > GetVariableImportance();
131
133
134 // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
135
136 void ClearTree();
137
138 // set pruning method
141
142 // recursive pruning of the tree, validation sample required for automatic pruning
143 Double_t PruneTree( const EventConstList* validationSample = nullptr );
144
145 // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
148
149 // apply pruning validation sample to a decision tree
150 void ApplyValidationSample( const EventConstList* validationSample ) const;
151
152 // return the misclassification rate of a pruned tree
153 Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = nullptr, Int_t mode = 0 ) const;
154
155 // pass a single validation event through a pruned decision tree
156 void CheckEventWithPrunedTree( const TMVA::Event* ) const;
157
158 // calculate the normalization factor for a pruning validation sample
159 Double_t GetSumWeights( const EventConstList* validationSample ) const;
160
163
164 void DescendTree( Node *n = nullptr );
165 void SetParentTreeInNodes( Node *n = nullptr );
166
167 // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
168 // is coded as a sequence of left-right moves starting from the root, coded as
169 // 0-1 bit patterns stored in the "long-integer" together with the depth
170 Node* GetNode( ULong_t sequence, UInt_t depth );
171
172 UInt_t CleanTree(DecisionTreeNode *node = nullptr);
173
175
176 // prune a node from the tree without deleting its descendants; allows one to
177 // effectively prune a tree many times without making deep copies
179
181
182
183 UInt_t CountLeafNodes(TMVA::Node *n = nullptr);
184
185 void SetTreeID(Int_t treeID){fTreeID = treeID;};
187
194 inline void SetNVars(Int_t n){fNvars = n;}
195
196 private:
197 // utility functions
198
199 // calculate the Purity out of the number of sig and bkg events collected
200 // from individual samples.
201
202 // calculates the purity S/(S+B) of a given event sample
203 Double_t SamplePurity(EventList eventSample);
204
205 UInt_t fNvars; ///< number of variables used to separate S and B
206 Int_t fNCuts; ///< number of grid point in variable cut scans
207 Bool_t fUseFisherCuts; ///< use multivariate splits using the Fisher criterium
208 Double_t fMinLinCorrForFisher; ///< the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
209 Bool_t fUseExclusiveVars; ///< individual variables already used in fisher criterium are not anymore analysed individually for node splitting
210
211 SeparationBase *fSepType; ///< the separation criteria
212 RegressionVariance *fRegType; ///< the separation criteria used in Regression
213
214 Double_t fMinSize; ///< min number of events in node
215 Double_t fMinNodeSize; ///< min fraction of training events in node
216 Double_t fMinSepGain; ///< min number of separation gain to perform node splitting
217
218 Bool_t fUseSearchTree; ///< cut scan done with binary trees or simple event loop.
219 Double_t fPruneStrength; ///< a parameter to set the "amount" of pruning..needs to be adjusted
220
221 EPruneMethod fPruneMethod; ///< method used for pruning
222 Int_t fNNodesBeforePruning; ///< remember this one (in case of pruning, it allows to monitor the before/after
223
224 Double_t fNodePurityLimit; ///< purity limit to decide whether a node is signal
225
226 Bool_t fRandomisedTree; ///< choose at each node splitting a random set of variables
227 Int_t fUseNvars; ///< the number of variables used in randomised trees;
228 Bool_t fUsePoissonNvars; ///< use "fUseNvars" not as fixed number but as mean of a poisson distr. in each split
229
230 TRandom3 *fMyTrandom; ///< random number generator for randomised trees
231
232 std::vector< Double_t > fVariableImportance; ///< the relative importance of the different variables
233
234 UInt_t fMaxDepth; ///< max depth
235 UInt_t fSigClass; ///< class which is treated as signal when building the tree
236 static const Int_t fgDebugLevel = 0; ///< debug level determining some printout/control plots etc.
237 Int_t fTreeID; ///< just an ID number given to the tree.. makes debugging easier as tree knows who he is.
238
239 Types::EAnalysisType fAnalysisType; ///< kClassification(=0=false) or kRegression(=1=true)
240
242
243 ClassDef(DecisionTree,0); // implementation of a Decision Tree
244 };
245
246} // namespace TMVA
247
248#endif
#define d(i)
Definition RSha256.hxx:102
#define e(i)
Definition RSha256.hxx:103
bool Bool_t
Definition RtypesCore.h:63
unsigned long ULong_t
Definition RtypesCore.h:55
unsigned int UInt_t
Definition RtypesCore.h:46
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
#define ClassDef(name, id)
Definition Rtypes.h:337
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char mode
#define TMVA_VERSION_CODE
Definition Version.h:47
Base class for BinarySearch and Decision Trees.
Definition BinaryTree.h:62
Node * fRoot
the root node of the tree the tree only has it's root node, the "daughters" are taken care of by the ...
Definition BinaryTree.h:110
UInt_t GetNNodes() const
Definition BinaryTree.h:86
Class that contains all the data information.
Definition DataSetInfo.h:62
Implementation of a Decision Tree.
Bool_t DoRegression() const
Int_t fNNodesBeforePruning
remember this one (in case of pruning, it allows to monitor the before/after
void SetAnalysisType(Types::EAnalysisType t)
Double_t fMinSize
min number of events in node
void SetUseExclusiveVars(Bool_t t=kTRUE)
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=nullptr)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
Double_t GetNodePurityLimit() const
void FillTree(const EventList &eventSample)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t fMinNodeSize
min fraction of training events in node
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node
TRandom3 * fMyTrandom
random number generator for randomised trees
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Int_t fTreeID
just an ID number given to the tree.. makes debugging easier as tree knows who he is.
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
train a node by finding the single optimal cut for a single variable that best separates signal and b...
EPruneMethod fPruneMethod
method used for pruning
Bool_t fUseSearchTree
cut scan done with binary trees or simple event loop.
std::vector< TMVA::Event * > EventList
virtual DecisionTreeNode * GetRoot() const
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in... (used in gradient boosting)
static const Int_t fgDebugLevel
debug level determining some printout/control plots etc.
SeparationBase * fSepType
the separation criteria
void SetUseFisherCuts(Bool_t t=kTRUE)
UInt_t fMaxDepth
max depth
virtual const char * ClassName() const
void SetNodePurityLimit(Double_t p)
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
Int_t fUseNvars
the number of variables used in randomised trees;
void SetParentTreeInNodes(Node *n=nullptr)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time.
void DescendTree(Node *n=nullptr)
descend a tree to find all its leaf nodes
virtual DecisionTreeNode * CreateNode(UInt_t) const
virtual BinaryTree * CreateTree() const
Double_t fPruneStrength
a parameter to set the "amount" of pruning..needs to be adjusted
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=262657)
re-create a new tree (decision tree or search tree) from XML
UInt_t fSigClass
class which is treated as signal when building the tree
void SetPruneStrength(Double_t p)
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
std::vector< const TMVA::Event * > EventConstList
Bool_t fUseFisherCuts
use multivariate splits using the Fisher criterium
Double_t fNodePurityLimit
purity limit to decide whether a node is signal
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
Bool_t fUseExclusiveVars
individual variables already used in fisher criterium are not anymore analysed individually for node ...
static const Int_t fgRandomSeed
Int_t fNCuts
number of grid point in variable cut scans
void SetTreeID(Int_t treeID)
UInt_t CleanTree(DecisionTreeNode *node=nullptr)
remove those last splits that result in two leaf nodes that are both of the type (i....
Double_t fMinLinCorrForFisher
the minimum linear correlation between two variables demanded for use in fisher criterium in node spl...
UInt_t fNvars
number of variables used to separate S and B
Bool_t fRandomisedTree
choose at each node splitting a random set of variables
virtual ~DecisionTree(void)
destructor
Types::EAnalysisType fAnalysisType
kClassification(=0=false) or kRegression(=1=true)
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetNVars(Int_t n)
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event through a pruned decision tree on the way down the tree,...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
Double_t fMinSepGain
min number of separation gain to perform node splitting
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=nullptr, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
std::vector< Double_t > fVariableImportance
the relative importance of the different variables
Double_t PruneTree(const EventConstList *validationSample=nullptr)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Int_t GetNNodesBeforePruning()
DataSetInfo * fDataSetInfo
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
UInt_t CountLeafNodes(TMVA::Node *n=nullptr)
return the number of terminal nodes in the sub-tree below Node n
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
Types::EAnalysisType GetAnalysisType(void)
Bool_t fUsePoissonNvars
use "fUseNvars" not as fixed number but as mean of a poisson distr. in each split
void SetMinLinCorrForFisher(Double_t min)
RegressionVariance * fRegType
the separation criteria used in Regression
DecisionTree(void)
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
Double_t GetPruneStrength() const
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
Calculate the "SeparationGain" for Regression analysis separation criteria used in various training a...
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
@ kRegression
Definition Types.h:128
Random number generator class based on M.
Definition TRandom3.h:27
const Int_t n
Definition legend1.C:16
create variable transformations
TMarker m
Definition textangle.C:8