Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
DecisionTree.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : DecisionTree *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation of a Decision Tree *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * *
20 * Copyright (c) 2005-2011: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * U. of Bonn, Germany *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://mva.sourceforge.net/license.txt) *
29 * *
30 **********************************************************************************/
31
32#ifndef ROOT_TMVA_DecisionTree
33#define ROOT_TMVA_DecisionTree
34
35//////////////////////////////////////////////////////////////////////////
36// //
37// DecisionTree //
38// //
39// Implementation of a Decision Tree //
40// //
41//////////////////////////////////////////////////////////////////////////
42
43#include "TH2.h"
44#include <vector>
45
46#include "TMVA/Types.h"
48#include "TMVA/BinaryTree.h"
50#include "TMVA/SeparationBase.h"
52#include "TMVA/DataSetInfo.h"
53
54#ifdef R__USE_IMT
56#include "TSystem.h"
57#endif
58
59class TRandom3;
60
61namespace TMVA {
62
63 class Event;
64
65 class DecisionTree : public BinaryTree {
66
67 private:
68
69 static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
70
71 public:
72
73 typedef std::vector<TMVA::Event*> EventList;
74 typedef std::vector<const TMVA::Event*> EventConstList;
75
76 // the constructur needed for the "reading" of the decision tree from weight files
77 DecisionTree( void );
78
79 // the constructur needed for constructing the decision tree via training with events
80 DecisionTree( SeparationBase *sepType, Float_t minSize,
81 Int_t nCuts, DataSetInfo* = NULL,
82 UInt_t cls =0,
83 Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
84 UInt_t nMaxDepth=9999999,
85 Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
86 Int_t treeID = 0);
87
88 // copy constructor
90
91 virtual ~DecisionTree( void );
92
93 // Retrieves the address of the root node
94 virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
95 virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
96 virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
97 static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
98 virtual const char* ClassName() const { return "DecisionTree"; }
99
100 // building of a tree by recursivly splitting the nodes
101
102 // UInt_t BuildTree( const EventList & eventSample,
103 // DecisionTreeNode *node = NULL);
104 UInt_t BuildTree( const EventConstList & eventSample,
105 DecisionTreeNode *node = NULL);
106 // determine the way how a node is split (which variable, which cut value)
107
108 Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
109 Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
110 Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
111 void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
112 std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
113
114 // fill at tree with a given structure already (just see how many signa/bkgr
115 // events end up in each node
116
117 void FillTree( const EventList & eventSample);
118
119 // fill the existing the decision tree structure by filling event
120 // in from the top node and see where they happen to end up
121 void FillEvent( const TMVA::Event & event,
123
124 // returns: 1 = Signal (right), -1 = Bkg (left)
125
126 Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
128
129 // return the individual relative variable importance
130 std::vector< Double_t > GetVariableImportance();
131
133
134 // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
135
136 void ClearTree();
137
138 // set pruning method
141
142 // recursive pruning of the tree, validation sample required for automatic pruning
143 Double_t PruneTree( const EventConstList* validationSample = NULL );
144
145 // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
148
149 // apply pruning validation sample to a decision tree
150 void ApplyValidationSample( const EventConstList* validationSample ) const;
151
152 // return the misclassification rate of a pruned tree
153 Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
154
155 // pass a single validation event throught a pruned decision tree
156 void CheckEventWithPrunedTree( const TMVA::Event* ) const;
157
158 // calculate the normalization factor for a pruning validation sample
159 Double_t GetSumWeights( const EventConstList* validationSample ) const;
160
163
164 void DescendTree( Node *n = NULL );
165 void SetParentTreeInNodes( Node *n = NULL );
166
167 // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
168 // is coded as a sequence of left-right moves starting from the root, coded as
169 // 0-1 bit patterns stored in the "long-integer" together with the depth
170 Node* GetNode( ULong_t sequence, UInt_t depth );
171
173
175
176 // prune a node from the tree without deleting its descendants; allows one to
177 // effectively prune a tree many times without making deep copies
179
181
182
184
185 void SetTreeID(Int_t treeID){fTreeID = treeID;};
187
194 inline void SetNVars(Int_t n){fNvars = n;}
195
196 private:
197 // utility functions
198
199 // calculate the Purity out of the number of sig and bkg events collected
200 // from individual samples.
201
202 // calculates the purity S/(S+B) of a given event sample
203 Double_t SamplePurity(EventList eventSample);
204
205 UInt_t fNvars; // number of variables used to separate S and B
206 Int_t fNCuts; // number of grid point in variable cut scans
207 Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
208 Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
209 Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
210
211 SeparationBase *fSepType; // the separation crition
212 RegressionVariance *fRegType; // the separation crition used in Regression
213
214 Double_t fMinSize; // min number of events in node
215 Double_t fMinNodeSize; // min fraction of training events in node
216 Double_t fMinSepGain; // min number of separation gain to perform node splitting
217
218 Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
219 Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
220
221 EPruneMethod fPruneMethod; // method used for prunig
222 Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
223
224 Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
225
226 Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
227 Int_t fUseNvars; // the number of variables used in randomised trees;
228 Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
229
230 TRandom3 *fMyTrandom; // random number generator for randomised trees
231
232 std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
233
234 UInt_t fMaxDepth; // max depth
235 UInt_t fSigClass; // class which is treated as signal when building the tree
236 static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
237 Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
238
239 Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
240
242
243 ClassDef(DecisionTree,0); // implementation of a Decision Tree
244 };
245
246} // namespace TMVA
247
248#endif
#define d(i)
Definition RSha256.hxx:102
#define e(i)
Definition RSha256.hxx:103
unsigned int UInt_t
Definition RtypesCore.h:46
const Bool_t kFALSE
Definition RtypesCore.h:101
unsigned long ULong_t
Definition RtypesCore.h:55
bool Bool_t
Definition RtypesCore.h:63
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassDef(name, id)
Definition Rtypes.h:325
#define TMVA_VERSION_CODE
Definition Version.h:47
Base class for BinarySearch and Decision Trees.
Definition BinaryTree.h:62
UInt_t GetNNodes() const
Definition BinaryTree.h:86
Class that contains all the data information.
Definition DataSetInfo.h:62
Implementation of a Decision Tree.
Bool_t DoRegression() const
void SetAnalysisType(Types::EAnalysisType t)
void SetUseExclusiveVars(Bool_t t=kTRUE)
Double_t GetNodePurityLimit() const
void FillTree(const EventList &eventSample)
fill the existing the decision tree structure by filling event in from the top node and see where the...
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node
TRandom3 * fMyTrandom
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
train a node by finding the single optimal cut for a single variable that best separates signal and b...
EPruneMethod fPruneMethod
std::vector< TMVA::Event * > EventList
virtual DecisionTreeNode * GetRoot() const
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in... (used in gradient boosting)
static const Int_t fgDebugLevel
SeparationBase * fSepType
void SetUseFisherCuts(Bool_t t=kTRUE)
virtual const char * ClassName() const
void SetNodePurityLimit(Double_t p)
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
virtual DecisionTreeNode * CreateNode(UInt_t) const
virtual BinaryTree * CreateTree() const
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i....
void SetPruneStrength(Double_t p)
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
std::vector< const TMVA::Event * > EventConstList
Double_t fNodePurityLimit
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
static const Int_t fgRandomSeed
void SetTreeID(Int_t treeID)
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
Double_t fMinLinCorrForFisher
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
virtual ~DecisionTree(void)
destructor
Types::EAnalysisType fAnalysisType
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetNVars(Int_t n)
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event through a pruned decision tree on the way down the tree,...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
std::vector< Double_t > fVariableImportance
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
Int_t GetNNodesBeforePruning()
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
DataSetInfo * fDataSetInfo
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time.
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
Types::EAnalysisType GetAnalysisType(void)
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
void SetMinLinCorrForFisher(Double_t min)
RegressionVariance * fRegType
DecisionTree(void)
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
Double_t GetPruneStrength() const
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
Calculate the "SeparationGain" for Regression analysis separation criteria used in various training a...
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
@ kRegression
Definition Types.h:128
Random number generator class based on M.
Definition TRandom3.h:27
const Int_t n
Definition legend1.C:16
create variable transformations
auto * m
Definition textangle.C:8