Logo ROOT   6.08/07
Reference Guide
DecisionTree.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  * *
30  **********************************************************************************/
31 
32 #ifndef ROOT_TMVA_DecisionTree
33 #define ROOT_TMVA_DecisionTree
34 
35 //////////////////////////////////////////////////////////////////////////
36 // //
37 // DecisionTree //
38 // //
39 // Implementation of a Decision Tree //
40 // //
41 //////////////////////////////////////////////////////////////////////////
42 
43 #ifndef ROOT_TH2
44 #include "TH2.h"
45 #endif
46 
47 #ifndef ROOT_TMVA_Types
48 #include "TMVA/Types.h"
49 #endif
50 #ifndef ROOT_TMVA_DecisionTreeNode
51 #include "TMVA/DecisionTreeNode.h"
52 #endif
53 #ifndef ROOT_TMVA_BinaryTree
54 #include "TMVA/BinaryTree.h"
55 #endif
56 #ifndef ROOT_TMVA_BinarySearchTree
57 #include "TMVA/BinarySearchTree.h"
58 #endif
59 #ifndef ROOT_TMVA_SeparationBase
60 #include "TMVA/SeparationBase.h"
61 #endif
62 #ifndef ROOT_TMVA_RegressionVariance
64 #endif
65 #include "TMVA/DataSetInfo.h"
66 
67 class TRandom3;
68 
69 namespace TMVA {
70 
71  class Event;
72 
73  class DecisionTree : public BinaryTree {
74 
75  private:
76 
77  static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
78 
79  public:
80 
81  typedef std::vector<TMVA::Event*> EventList;
82  typedef std::vector<const TMVA::Event*> EventConstList;
83 
84  // the constructur needed for the "reading" of the decision tree from weight files
85  DecisionTree( void );
86 
87  // the constructur needed for constructing the decision tree via training with events
88  DecisionTree( SeparationBase *sepType, Float_t minSize,
89  Int_t nCuts, DataSetInfo* = NULL,
90  UInt_t cls =0,
91  Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
92  UInt_t nMaxDepth=9999999,
93  Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
94  Int_t treeID = 0);
95 
96  // copy constructor
97  DecisionTree (const DecisionTree &d);
98 
99  virtual ~DecisionTree( void );
100 
101  // Retrieves the address of the root node
102  virtual DecisionTreeNode* GetRoot() const { return dynamic_cast<TMVA::DecisionTreeNode*>(fRoot); }
103  virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
104  virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
105  static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
106  virtual const char* ClassName() const { return "DecisionTree"; }
107 
108  // building of a tree by recursivly splitting the nodes
109 
110  // UInt_t BuildTree( const EventList & eventSample,
111  // DecisionTreeNode *node = NULL);
112  UInt_t BuildTree( const EventConstList & eventSample,
113  DecisionTreeNode *node = NULL);
114  // determine the way how a node is split (which variable, which cut value)
115 
116  Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
117  Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
118  Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
119  void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
120  std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
121 
122  // fill at tree with a given structure already (just see how many signa/bkgr
123  // events end up in each node
124 
125  void FillTree( const EventList & eventSample);
126 
127  // fill the existing the decision tree structure by filling event
128  // in from the top node and see where they happen to end up
129  void FillEvent( const TMVA::Event & event,
130  TMVA::DecisionTreeNode *node );
131 
132  // returns: 1 = Signal (right), -1 = Bkg (left)
133 
134  Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
136 
137  // return the individual relative variable importance
138  std::vector< Double_t > GetVariableImportance();
139 
141 
142  // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
143 
144  void ClearTree();
145 
146  // set pruning method
149 
150  // recursive pruning of the tree, validation sample required for automatic pruning
151  Double_t PruneTree( const EventConstList* validationSample = NULL );
152 
153  // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
156 
157  // apply pruning validation sample to a decision tree
158  void ApplyValidationSample( const EventConstList* validationSample ) const;
159 
160  // return the misclassification rate of a pruned tree
161  Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
162 
163  // pass a single validation event throught a pruned decision tree
164  void CheckEventWithPrunedTree( const TMVA::Event* ) const;
165 
166  // calculate the normalization factor for a pruning validation sample
167  Double_t GetSumWeights( const EventConstList* validationSample ) const;
168 
171 
172  void DescendTree( Node *n = NULL );
173  void SetParentTreeInNodes( Node *n = NULL );
174 
175  // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
176  // is coded as a sequence of left-right moves starting from the root, coded as
177  // 0-1 bit patterns stored in the "long-integer" together with the depth
178  Node* GetNode( ULong_t sequence, UInt_t depth );
179 
181 
182  void PruneNode(TMVA::DecisionTreeNode *node);
183 
184  // prune a node from the tree without deleting its descendants; allows one to
185  // effectively prune a tree many times without making deep copies
187 
189 
190 
192 
193  void SetTreeID(Int_t treeID){fTreeID = treeID;};
194  Int_t GetTreeID(){return fTreeID;};
195 
202  inline void SetNVars(Int_t n){fNvars = n;}
203 
204 
205  private:
206  // utility functions
207 
208  // calculate the Purity out of the number of sig and bkg events collected
209  // from individual samples.
210 
211  // calculates the purity S/(S+B) of a given event sample
212  Double_t SamplePurity(EventList eventSample);
213 
214  UInt_t fNvars; // number of variables used to separate S and B
215  Int_t fNCuts; // number of grid point in variable cut scans
216  Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
217  Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
218  Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
219 
220  SeparationBase *fSepType; // the separation crition
221  RegressionVariance *fRegType; // the separation crition used in Regression
222 
223  Double_t fMinSize; // min number of events in node
224  Double_t fMinNodeSize; // min fraction of training events in node
225  Double_t fMinSepGain; // min number of separation gain to perform node splitting
226 
227  Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
228  Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
229 
230  EPruneMethod fPruneMethod; // method used for prunig
231  Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
232 
233  Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
234 
235  Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
236  Int_t fUseNvars; // the number of variables used in randomised trees;
237  Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
238 
239  TRandom3 *fMyTrandom; // random number generator for randomised trees
240 
241  std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
242 
243  UInt_t fMaxDepth; // max depth
244  UInt_t fSigClass; // class which is treated as signal when building the tree
245  static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
246  Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
247 
248  Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
249 
251 
252 
253  ClassDef(DecisionTree,0); // implementation of a Decision Tree
254  };
255 
256 } // namespace TMVA
257 
258 #endif
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Definition: DecisionTree.h:148
DataSetInfo * fDataSetInfo
Definition: DecisionTree.h:250
virtual BinaryTree * CreateTree() const
Definition: DecisionTree.h:104
Random number generator class based on M.
Definition: TRandom3.h:29
#define TMVA_VERSION_CODE
Definition: Version.h:47
float Float_t
Definition: RtypesCore.h:53
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:170
UInt_t GetNNodes() const
Definition: BinaryTree.h:92
EPruneMethod fPruneMethod
Definition: DecisionTree.h:230
EAnalysisType
Definition: Types.h:129
Types::EAnalysisType GetAnalysisType(void)
Definition: DecisionTree.h:198
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
void SetUseExclusiveVars(Bool_t t=kTRUE)
Definition: DecisionTree.h:201
const Bool_t kFALSE
Definition: Rtypes.h:92
Double_t fNodePurityLimit
Definition: DecisionTree.h:233
virtual ~DecisionTree(void)
destructor
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event throught a pruned decision tree on the way down the tree...
void SetNodePurityLimit(Double_t p)
Definition: DecisionTree.h:169
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetAnalysisType(Types::EAnalysisType t)
Definition: DecisionTree.h:197
std::vector< const TMVA::Event * > EventConstList
Definition: DecisionTree.h:82
#define ClassDef(name, id)
Definition: Rtypes.h:254
static const Int_t fgRandomSeed
Definition: DecisionTree.h:77
void FillTree(const EventList &eventSample)
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
std::vector< Double_t > fVariableImportance
Definition: DecisionTree.h:241
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
void SetTreeID(Int_t treeID)
Definition: DecisionTree.h:193
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t fPruneStrength
Definition: DecisionTree.h:228
Bool_t DoRegression() const
Definition: DecisionTree.h:196
Double_t fMinLinCorrForFisher
Definition: DecisionTree.h:217
void SetNVars(Int_t n)
Definition: DecisionTree.h:202
void SetMinLinCorrForFisher(Double_t min)
Definition: DecisionTree.h:200
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i.e.
virtual DecisionTreeNode * CreateNode(UInt_t) const
Definition: DecisionTree.h:103
Int_t GetNNodesBeforePruning()
Definition: DecisionTree.h:188
void SetPruneStrength(Double_t p)
Definition: DecisionTree.h:154
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporaily (without actually deleting its decendants which allows testing the pruned tre...
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in...
std::vector< TMVA::Event * > EventList
Definition: DecisionTree.h:81
void SetUseFisherCuts(Bool_t t=kTRUE)
Definition: DecisionTree.h:199
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
TRandom3 * fMyTrandom
Definition: DecisionTree.h:239
double Double_t
Definition: RtypesCore.h:55
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
static const Int_t fgDebugLevel
Definition: DecisionTree.h:245
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree ...
Types::EAnalysisType fAnalysisType
Definition: DecisionTree.h:248
unsigned long ULong_t
Definition: RtypesCore.h:51
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
RegressionVariance * fRegType
Definition: DecisionTree.h:221
SeparationBase * fSepType
Definition: DecisionTree.h:220
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining serveral different pruning ...
Abstract ClassifierFactory template that handles arbitrary types.
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
#define NULL
Definition: Rtypes.h:82
const Bool_t kTRUE
Definition: Rtypes.h:91
Double_t GetPruneStrength() const
Definition: DecisionTree.h:155
virtual const char * ClassName() const
Definition: DecisionTree.h:106
const Int_t n
Definition: legend1.C:16
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
Definition: DecisionTree.h:116
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node