Logo ROOT   master
Reference Guide
DecisionTree.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  * *
30  **********************************************************************************/
31 
32 #ifndef ROOT_TMVA_DecisionTree
33 #define ROOT_TMVA_DecisionTree
34 
35 //////////////////////////////////////////////////////////////////////////
36 // //
37 // DecisionTree //
38 // //
39 // Implementation of a Decision Tree //
40 // //
41 //////////////////////////////////////////////////////////////////////////
42 
43 #include "TH2.h"
44 
45 #include "TMVA/Types.h"
46 #include "TMVA/DecisionTreeNode.h"
47 #include "TMVA/BinaryTree.h"
48 #include "TMVA/BinarySearchTree.h"
49 #include "TMVA/SeparationBase.h"
51 #include "TMVA/DataSetInfo.h"
52 
53 #ifdef R__USE_IMT
54 #include <ROOT/TThreadExecutor.hxx>
55 #include "TSystem.h"
56 #endif
57 
58 class TRandom3;
59 
60 namespace TMVA {
61 
62  class Event;
63 
64  class DecisionTree : public BinaryTree {
65 
66  private:
67 
68  static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
69 
70  public:
71 
72  typedef std::vector<TMVA::Event*> EventList;
73  typedef std::vector<const TMVA::Event*> EventConstList;
74 
75  // the constructur needed for the "reading" of the decision tree from weight files
76  DecisionTree( void );
77 
78  // the constructur needed for constructing the decision tree via training with events
79  DecisionTree( SeparationBase *sepType, Float_t minSize,
80  Int_t nCuts, DataSetInfo* = NULL,
81  UInt_t cls =0,
82  Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
83  UInt_t nMaxDepth=9999999,
84  Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
85  Int_t treeID = 0);
86 
87  // copy constructor
88  DecisionTree (const DecisionTree &d);
89 
90  virtual ~DecisionTree( void );
91 
92  // Retrieves the address of the root node
93  virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
94  virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
95  virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
96  static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
97  virtual const char* ClassName() const { return "DecisionTree"; }
98 
99  // building of a tree by recursivly splitting the nodes
100 
101  // UInt_t BuildTree( const EventList & eventSample,
102  // DecisionTreeNode *node = NULL);
103  UInt_t BuildTree( const EventConstList & eventSample,
104  DecisionTreeNode *node = NULL);
105  // determine the way how a node is split (which variable, which cut value)
106 
107  Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
108  Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
109  Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
110  void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
111  std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
112 
113  // fill at tree with a given structure already (just see how many signa/bkgr
114  // events end up in each node
115 
116  void FillTree( const EventList & eventSample);
117 
118  // fill the existing the decision tree structure by filling event
119  // in from the top node and see where they happen to end up
120  void FillEvent( const TMVA::Event & event,
121  TMVA::DecisionTreeNode *node );
122 
123  // returns: 1 = Signal (right), -1 = Bkg (left)
124 
125  Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
127 
128  // return the individual relative variable importance
129  std::vector< Double_t > GetVariableImportance();
130 
132 
133  // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
134 
135  void ClearTree();
136 
137  // set pruning method
140 
141  // recursive pruning of the tree, validation sample required for automatic pruning
142  Double_t PruneTree( const EventConstList* validationSample = NULL );
143 
144  // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
147 
148  // apply pruning validation sample to a decision tree
149  void ApplyValidationSample( const EventConstList* validationSample ) const;
150 
151  // return the misclassification rate of a pruned tree
152  Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
153 
154  // pass a single validation event throught a pruned decision tree
155  void CheckEventWithPrunedTree( const TMVA::Event* ) const;
156 
157  // calculate the normalization factor for a pruning validation sample
158  Double_t GetSumWeights( const EventConstList* validationSample ) const;
159 
162 
163  void DescendTree( Node *n = NULL );
164  void SetParentTreeInNodes( Node *n = NULL );
165 
166  // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
167  // is coded as a sequence of left-right moves starting from the root, coded as
168  // 0-1 bit patterns stored in the "long-integer" together with the depth
169  Node* GetNode( ULong_t sequence, UInt_t depth );
170 
171  UInt_t CleanTree(DecisionTreeNode *node=NULL);
172 
173  void PruneNode(TMVA::DecisionTreeNode *node);
174 
175  // prune a node from the tree without deleting its descendants; allows one to
176  // effectively prune a tree many times without making deep copies
178 
180 
181 
183 
184  void SetTreeID(Int_t treeID){fTreeID = treeID;};
185  Int_t GetTreeID(){return fTreeID;};
186 
193  inline void SetNVars(Int_t n){fNvars = n;}
194 
195  private:
196  // utility functions
197 
198  // calculate the Purity out of the number of sig and bkg events collected
199  // from individual samples.
200 
201  // calculates the purity S/(S+B) of a given event sample
202  Double_t SamplePurity(EventList eventSample);
203 
204  UInt_t fNvars; // number of variables used to separate S and B
205  Int_t fNCuts; // number of grid point in variable cut scans
206  Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
207  Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
208  Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
209 
210  SeparationBase *fSepType; // the separation crition
211  RegressionVariance *fRegType; // the separation crition used in Regression
212 
213  Double_t fMinSize; // min number of events in node
214  Double_t fMinNodeSize; // min fraction of training events in node
215  Double_t fMinSepGain; // min number of separation gain to perform node splitting
216 
217  Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
218  Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
219 
220  EPruneMethod fPruneMethod; // method used for prunig
221  Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
222 
223  Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
224 
225  Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
226  Int_t fUseNvars; // the number of variables used in randomised trees;
227  Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
228 
229  TRandom3 *fMyTrandom; // random number generator for randomised trees
230 
231  std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
232 
233  UInt_t fMaxDepth; // max depth
234  UInt_t fSigClass; // class which is treated as signal when building the tree
235  static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
236  Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
237 
238  Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
239 
241 
242  ClassDef(DecisionTree,0); // implementation of a Decision Tree
243  };
244 
245 } // namespace TMVA
246 
247 #endif
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Definition: DecisionTree.h:139
DataSetInfo * fDataSetInfo
Definition: DecisionTree.h:240
virtual BinaryTree * CreateTree() const
Definition: DecisionTree.h:95
Random number generator class based on M.
Definition: TRandom3.h:27
unsigned long ULong_t
Definition: CPyCppyy.h:51
#define TMVA_VERSION_CODE
Definition: Version.h:47
auto * m
Definition: textangle.C:8
float Float_t
Definition: RtypesCore.h:55
unsigned int UInt_t
Definition: CPyCppyy.h:44
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:161
UInt_t GetNNodes() const
Definition: BinaryTree.h:86
EPruneMethod fPruneMethod
Definition: DecisionTree.h:220
EAnalysisType
Definition: Types.h:127
Types::EAnalysisType GetAnalysisType(void)
Definition: DecisionTree.h:189
Calculate the "SeparationGain" for Regression analysis separation criteria used in various training a...
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
bool Bool_t
Definition: RtypesCore.h:61
void SetUseExclusiveVars(Bool_t t=kTRUE)
Definition: DecisionTree.h:192
Double_t fNodePurityLimit
Definition: DecisionTree.h:223
virtual ~DecisionTree(void)
destructor
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:93
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event through a pruned decision tree on the way down the tree...
void SetNodePurityLimit(Double_t p)
Definition: DecisionTree.h:160
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetAnalysisType(Types::EAnalysisType t)
Definition: DecisionTree.h:188
std::vector< const TMVA::Event * > EventConstList
Definition: DecisionTree.h:73
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
#define ClassDef(name, id)
Definition: Rtypes.h:322
static const Int_t fgRandomSeed
Definition: DecisionTree.h:68
void FillTree(const EventList &eventSample)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
std::vector< Double_t > fVariableImportance
Definition: DecisionTree.h:231
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
Class that contains all the data information.
Definition: DataSetInfo.h:60
void SetTreeID(Int_t treeID)
Definition: DecisionTree.h:184
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t fPruneStrength
Definition: DecisionTree.h:218
Bool_t DoRegression() const
Definition: DecisionTree.h:187
Double_t fMinLinCorrForFisher
Definition: DecisionTree.h:207
void SetNVars(Int_t n)
Definition: DecisionTree.h:193
void SetMinLinCorrForFisher(Double_t min)
Definition: DecisionTree.h:191
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i.e.
virtual DecisionTreeNode * CreateNode(UInt_t) const
Definition: DecisionTree.h:94
Int_t GetNNodesBeforePruning()
Definition: DecisionTree.h:179
void SetPruneStrength(Double_t p)
Definition: DecisionTree.h:145
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
Implementation of a Decision Tree.
Definition: DecisionTree.h:64
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
train a node by finding the single optimal cut for a single variable that best separates signal and b...
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time...
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in...
std::vector< TMVA::Event * > EventList
Definition: DecisionTree.h:72
void SetUseFisherCuts(Bool_t t=kTRUE)
Definition: DecisionTree.h:190
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
const Bool_t kFALSE
Definition: RtypesCore.h:90
#define d(i)
Definition: RSha256.hxx:102
TRandom3 * fMyTrandom
Definition: DecisionTree.h:229
double Double_t
Definition: RtypesCore.h:57
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
static const Int_t fgDebugLevel
Definition: DecisionTree.h:235
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree ...
Types::EAnalysisType fAnalysisType
Definition: DecisionTree.h:238
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
RegressionVariance * fRegType
Definition: DecisionTree.h:211
SeparationBase * fSepType
Definition: DecisionTree.h:210
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
create variable transformations
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
DecisionTree(void)
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
Double_t GetPruneStrength() const
Definition: DecisionTree.h:146
const Bool_t kTRUE
Definition: RtypesCore.h:89
virtual const char * ClassName() const
Definition: DecisionTree.h:97
const Int_t n
Definition: legend1.C:16
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
Definition: DecisionTree.h:107
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node