Logo ROOT  
Reference Guide
DecisionTree.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : DecisionTree *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation of a Decision Tree *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * *
20 * Copyright (c) 2005-2011: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * U. of Bonn, Germany *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://mva.sourceforge.net/license.txt) *
29 * *
30 **********************************************************************************/
31
32#ifndef ROOT_TMVA_DecisionTree
33#define ROOT_TMVA_DecisionTree
34
35//////////////////////////////////////////////////////////////////////////
36// //
37// DecisionTree //
38// //
39// Implementation of a Decision Tree //
40// //
41//////////////////////////////////////////////////////////////////////////
42
43#include "TH2.h"
44
45#include "TMVA/Types.h"
47#include "TMVA/BinaryTree.h"
49#include "TMVA/SeparationBase.h"
51#include "TMVA/DataSetInfo.h"
52
53#ifdef R__USE_IMT
55#include "TSystem.h"
56#endif
57
58class TRandom3;
59
60namespace TMVA {
61
62 class Event;
63
64 class DecisionTree : public BinaryTree {
65
66 private:
67
68 static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
69
70 public:
71
72 typedef std::vector<TMVA::Event*> EventList;
73 typedef std::vector<const TMVA::Event*> EventConstList;
74
75 // the constructur needed for the "reading" of the decision tree from weight files
76 DecisionTree( void );
77
78 // the constructur needed for constructing the decision tree via training with events
79 DecisionTree( SeparationBase *sepType, Float_t minSize,
80 Int_t nCuts, DataSetInfo* = NULL,
81 UInt_t cls =0,
82 Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
83 UInt_t nMaxDepth=9999999,
84 Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
85 Int_t treeID = 0);
86
87 // copy constructor
89
90 virtual ~DecisionTree( void );
91
92 // Retrieves the address of the root node
93 virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
94 virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
95 virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
96 static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
97 virtual const char* ClassName() const { return "DecisionTree"; }
98
99 // building of a tree by recursivly splitting the nodes
100
101 // UInt_t BuildTree( const EventList & eventSample,
102 // DecisionTreeNode *node = NULL);
103 UInt_t BuildTree( const EventConstList & eventSample,
104 DecisionTreeNode *node = NULL);
105 // determine the way how a node is split (which variable, which cut value)
106
107 Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
108 Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
109 Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
110 void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
111 std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
112
113 // fill at tree with a given structure already (just see how many signa/bkgr
114 // events end up in each node
115
116 void FillTree( const EventList & eventSample);
117
118 // fill the existing the decision tree structure by filling event
119 // in from the top node and see where they happen to end up
120 void FillEvent( const TMVA::Event & event,
122
123 // returns: 1 = Signal (right), -1 = Bkg (left)
124
125 Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
127
128 // return the individual relative variable importance
129 std::vector< Double_t > GetVariableImportance();
130
132
133 // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
134
135 void ClearTree();
136
137 // set pruning method
140
141 // recursive pruning of the tree, validation sample required for automatic pruning
142 Double_t PruneTree( const EventConstList* validationSample = NULL );
143
144 // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
147
148 // apply pruning validation sample to a decision tree
149 void ApplyValidationSample( const EventConstList* validationSample ) const;
150
151 // return the misclassification rate of a pruned tree
152 Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
153
154 // pass a single validation event throught a pruned decision tree
155 void CheckEventWithPrunedTree( const TMVA::Event* ) const;
156
157 // calculate the normalization factor for a pruning validation sample
158 Double_t GetSumWeights( const EventConstList* validationSample ) const;
159
162
163 void DescendTree( Node *n = NULL );
164 void SetParentTreeInNodes( Node *n = NULL );
165
166 // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
167 // is coded as a sequence of left-right moves starting from the root, coded as
168 // 0-1 bit patterns stored in the "long-integer" together with the depth
169 Node* GetNode( ULong_t sequence, UInt_t depth );
170
172
174
175 // prune a node from the tree without deleting its descendants; allows one to
176 // effectively prune a tree many times without making deep copies
178
180
181
183
184 void SetTreeID(Int_t treeID){fTreeID = treeID;};
186
193 inline void SetNVars(Int_t n){fNvars = n;}
194
195 private:
196 // utility functions
197
198 // calculate the Purity out of the number of sig and bkg events collected
199 // from individual samples.
200
201 // calculates the purity S/(S+B) of a given event sample
202 Double_t SamplePurity(EventList eventSample);
203
204 UInt_t fNvars; // number of variables used to separate S and B
205 Int_t fNCuts; // number of grid point in variable cut scans
206 Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
207 Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
208 Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
209
210 SeparationBase *fSepType; // the separation crition
211 RegressionVariance *fRegType; // the separation crition used in Regression
212
213 Double_t fMinSize; // min number of events in node
214 Double_t fMinNodeSize; // min fraction of training events in node
215 Double_t fMinSepGain; // min number of separation gain to perform node splitting
216
217 Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
218 Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
219
220 EPruneMethod fPruneMethod; // method used for prunig
221 Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
222
223 Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
224
225 Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
226 Int_t fUseNvars; // the number of variables used in randomised trees;
227 Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
228
229 TRandom3 *fMyTrandom; // random number generator for randomised trees
230
231 std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
232
233 UInt_t fMaxDepth; // max depth
234 UInt_t fSigClass; // class which is treated as signal when building the tree
235 static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
236 Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
237
238 Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
239
241
242 ClassDef(DecisionTree,0); // implementation of a Decision Tree
243 };
244
245} // namespace TMVA
246
247#endif
#define d(i)
Definition: RSha256.hxx:102
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
unsigned long ULong_t
Definition: RtypesCore.h:51
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassDef(name, id)
Definition: Rtypes.h:326
#define TMVA_VERSION_CODE
Definition: Version.h:47
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
UInt_t GetNNodes() const
Definition: BinaryTree.h:86
Class that contains all the data information.
Definition: DataSetInfo.h:60
Implementation of a Decision Tree.
Definition: DecisionTree.h:64
Bool_t DoRegression() const
Definition: DecisionTree.h:187
void SetAnalysisType(Types::EAnalysisType t)
Definition: DecisionTree.h:188
void SetUseExclusiveVars(Bool_t t=kTRUE)
Definition: DecisionTree.h:192
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:161
void FillTree(const EventList &eventSample)
fill the existing the decision tree structure by filling event in from the top node and see where the...
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node
TRandom3 * fMyTrandom
Definition: DecisionTree.h:229
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Definition: DecisionTree.h:139
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
train a node by finding the single optimal cut for a single variable that best separates signal and b...
EPruneMethod fPruneMethod
Definition: DecisionTree.h:220
std::vector< TMVA::Event * > EventList
Definition: DecisionTree.h:72
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:93
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in... (used in gradient boosting)
static const Int_t fgDebugLevel
Definition: DecisionTree.h:235
SeparationBase * fSepType
Definition: DecisionTree.h:210
void SetUseFisherCuts(Bool_t t=kTRUE)
Definition: DecisionTree.h:190
virtual const char * ClassName() const
Definition: DecisionTree.h:97
void SetNodePurityLimit(Double_t p)
Definition: DecisionTree.h:160
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
virtual DecisionTreeNode * CreateNode(UInt_t) const
Definition: DecisionTree.h:94
virtual BinaryTree * CreateTree() const
Definition: DecisionTree.h:95
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i....
Double_t fPruneStrength
Definition: DecisionTree.h:218
void SetPruneStrength(Double_t p)
Definition: DecisionTree.h:145
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
Definition: DecisionTree.h:107
std::vector< const TMVA::Event * > EventConstList
Definition: DecisionTree.h:73
Double_t fNodePurityLimit
Definition: DecisionTree.h:223
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
static const Int_t fgRandomSeed
Definition: DecisionTree.h:68
void SetTreeID(Int_t treeID)
Definition: DecisionTree.h:184
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
Double_t fMinLinCorrForFisher
Definition: DecisionTree.h:207
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
virtual ~DecisionTree(void)
destructor
Types::EAnalysisType fAnalysisType
Definition: DecisionTree.h:238
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetNVars(Int_t n)
Definition: DecisionTree.h:193
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event through a pruned decision tree on the way down the tree,...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
std::vector< Double_t > fVariableImportance
Definition: DecisionTree.h:231
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
Int_t GetNNodesBeforePruning()
Definition: DecisionTree.h:179
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
DataSetInfo * fDataSetInfo
Definition: DecisionTree.h:240
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time.
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
Types::EAnalysisType GetAnalysisType(void)
Definition: DecisionTree.h:189
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
void SetMinLinCorrForFisher(Double_t min)
Definition: DecisionTree.h:191
RegressionVariance * fRegType
Definition: DecisionTree.h:211
DecisionTree(void)
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
Double_t GetPruneStrength() const
Definition: DecisionTree.h:146
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
Calculate the "SeparationGain" for Regression analysis separation criteria used in various training a...
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
EAnalysisType
Definition: Types.h:127
@ kRegression
Definition: Types.h:129
Random number generator class based on M.
Definition: TRandom3.h:27
const Int_t n
Definition: legend1.C:16
create variable transformations
auto * m
Definition: textangle.C:8