Logo ROOT   6.07/09
Reference Guide
DecisionTreeNode.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTreeNode *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Node for the Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Eckhard von Toerne <evt@physik.uni-bonn.de> - U. of Bonn, Germany *
18  * *
19  * Copyright (c) 2009: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 #ifndef ROOT_TMVA_DecisionTreeNode
31 #define ROOT_TMVA_DecisionTreeNode
32 
33 //////////////////////////////////////////////////////////////////////////
34 // //
35 // DecisionTreeNode //
36 // //
37 // Node for the Decision Tree //
38 // //
39 //////////////////////////////////////////////////////////////////////////
40 
41 #ifndef ROOT_TMVA_Node
42 #include "TMVA/Node.h"
43 #endif
44 
45 #ifndef ROOT_TMVA_Version
46 #include "TMVA/Version.h"
47 #endif
48 
49 #include <iostream>
50 #include <vector>
51 #include <map>
52 namespace TMVA {
53 
55  {
56  public:
58  fSampleMax(),
59  fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
60  fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
61  fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
62  fNEvents ( -1 ),
65  fNEvents_unweighted ( 0 ),
68  fNEvents_unboosted ( 0 ),
69  fSeparationIndex (-1 ),
70  fSeparationGain ( -1 )
71  {
72  }
73  std::vector< Float_t > fSampleMin; // the minima for each ivar of the sample on the node during training
74  std::vector< Float_t > fSampleMax; // the maxima for each ivar of the sample on the node during training
75  Double_t fNodeR; // node resubstitution estimate, R(t)
76  Double_t fSubTreeR; // R(T) = Sum(R(t) : t in ~T)
77  Double_t fAlpha; // critical alpha for this node
78  Double_t fG; // minimum alpha in subtree rooted at this node
79  Int_t fNTerminal; // number of terminal nodes in subtree rooted at this node
80  Double_t fNB; // sum of weights of background events from the pruning sample in this node
81  Double_t fNS; // ditto for the signal events
82  Float_t fSumTarget; // sum of weight*target used for the calculatio of the variance (regression)
83  Float_t fSumTarget2; // sum of weight*target^2 used for the calculatio of the variance (regression)
84  Double_t fCC; // debug variable for cost complexity pruning ..
85 
86  Float_t fNSigEvents; // sum of weights of signal event in the node
87  Float_t fNBkgEvents; // sum of weights of backgr event in the node
88  Float_t fNEvents; // number of events in that entered the node (during training)
89  Float_t fNSigEvents_unweighted; // sum of signal event in the node
90  Float_t fNBkgEvents_unweighted; // sum of backgr event in the node
91  Float_t fNEvents_unweighted; // number of events in that entered the node (during training)
92  Float_t fNSigEvents_unboosted; // sum of signal event in the node
93  Float_t fNBkgEvents_unboosted; // sum of backgr event in the node
94  Float_t fNEvents_unboosted; // number of events in that entered the node (during training)
95  Float_t fSeparationIndex; // measure of "purity" (separation between S and B) AT this node
96  Float_t fSeparationGain; // measure of "purity", separation, or information gained BY this nodes selection
97 
98  // copy constructor
100  fSampleMin(),fSampleMax(), // Samplemin and max are reset in copy constructor
101  fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
102  fAlpha(n.fAlpha), fG(n.fG),
103  fNTerminal(n.fNTerminal),
104  fNB(n.fNB), fNS(n.fNS),
105  fSumTarget(0),fSumTarget2(0), // SumTarget reset in copy constructor
106  fCC(0),
107  fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
108  fNEvents ( n.fNEvents ),
109  fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
110  fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
111  fNEvents_unweighted ( n.fNEvents_unweighted ),
112  fSeparationIndex( n.fSeparationIndex ),
113  fSeparationGain ( n.fSeparationGain )
114  { }
115  };
116 
117  class Event;
118  class MsgLogger;
119 
120  class DecisionTreeNode: public Node {
121 
122  public:
123 
124  // constructor of an essentially "empty" node floating in space
125  DecisionTreeNode ();
126  // constructor of a daughter node as a daughter of 'p'
127  DecisionTreeNode (Node* p, char pos);
128 
129  // copy constructor
131 
132  // destructor
133  virtual ~DecisionTreeNode();
134 
135  virtual Node* CreateNode() const { return new DecisionTreeNode(); }
136 
137  inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
138  inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
139  // set fisher coefficients
140  void SetFisherCoeff(Int_t ivar, Double_t coeff);
141  // get fisher coefficients
142  Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
143 
144  // test event if it decends the tree at this node to the right
145  virtual Bool_t GoesRight( const Event & ) const;
146 
147  // test event if it decends the tree at this node to the left
148  virtual Bool_t GoesLeft ( const Event & ) const;
149 
150  // set index of variable used for discrimination at this node
151  void SetSelector( Short_t i) { fSelector = i; }
152  // return index of variable used for discrimination at this node
153  Short_t GetSelector() const { return fSelector; }
154 
155  // set the cut value applied at this node
156  void SetCutValue ( Float_t c ) { fCutValue = c; }
157  // return the cut value applied at this node
158  Float_t GetCutValue ( void ) const { return fCutValue; }
159 
160  // set true: if event variable > cutValue ==> signal , false otherwise
161  void SetCutType( Bool_t t ) { fCutType = t; }
162  // return kTRUE: Cuts select signal, kFALSE: Cuts select bkg
163  Bool_t GetCutType( void ) const { return fCutType; }
164 
165  // set node type: 1 signal node, -1 bkg leave, 0 intermediate Node
166  void SetNodeType( Int_t t ) { fNodeType = t;}
167  // return node type: 1 signal node, -1 bkg leave, 0 intermediate Node
168  Int_t GetNodeType( void ) const { return fNodeType; }
169 
170  //return S/(S+B) (purity) at this node (from training)
171  Float_t GetPurity( void ) const { return fPurity;}
172  //calculate S/(S+B) (purity) at this node (from training)
173  void SetPurity( void );
174 
175  //set the response of the node (for regression)
176  void SetResponse( Float_t r ) { fResponse = r;}
177 
178  //return the response of the node (for regression)
179  Float_t GetResponse( void ) const { return fResponse;}
180 
181  //set the RMS of the response of the node (for regression)
182  void SetRMS( Float_t r ) { fRMS = r;}
183 
184  //return the RMS of the response of the node (for regression)
185  Float_t GetRMS( void ) const { return fRMS;}
186 
187  // set the sum of the signal weights in the node
188  void SetNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents = s; }
189 
190  // set the sum of the backgr weights in the node
191  void SetNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents = b; }
192 
193  // set the number of events that entered the node (during training)
194  void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
195 
196  // set the sum of the unweighted signal events in the node
197  void SetNSigEvents_unweighted( Float_t s ) { fTrainInfo->fNSigEvents_unweighted = s; }
198 
199  // set the sum of the unweighted backgr events in the node
200  void SetNBkgEvents_unweighted( Float_t b ) { fTrainInfo->fNBkgEvents_unweighted = b; }
201 
202  // set the number of unweighted events that entered the node (during training)
203  void SetNEvents_unweighted( Float_t nev ){ fTrainInfo->fNEvents_unweighted =nev ; }
204 
205  // set the sum of the unboosted signal events in the node
206  void SetNSigEvents_unboosted( Float_t s ) { fTrainInfo->fNSigEvents_unboosted = s; }
207 
208  // set the sum of the unboosted backgr events in the node
209  void SetNBkgEvents_unboosted( Float_t b ) { fTrainInfo->fNBkgEvents_unboosted = b; }
210 
211  // set the number of unboosted events that entered the node (during training)
212  void SetNEvents_unboosted( Float_t nev ){ fTrainInfo->fNEvents_unboosted =nev ; }
213 
214  // increment the sum of the signal weights in the node
215  void IncrementNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents += s; }
216 
217  // increment the sum of the backgr weights in the node
218  void IncrementNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents += b; }
219 
220  // increment the number of events that entered the node (during training)
221  void IncrementNEvents( Float_t nev ){ fTrainInfo->fNEvents +=nev ; }
222 
223  // increment the sum of the signal weights in the node
224  void IncrementNSigEvents_unweighted( ) { fTrainInfo->fNSigEvents_unweighted += 1; }
225 
226  // increment the sum of the backgr weights in the node
227  void IncrementNBkgEvents_unweighted( ) { fTrainInfo->fNBkgEvents_unweighted += 1; }
228 
229  // increment the number of events that entered the node (during training)
230  void IncrementNEvents_unweighted( ){ fTrainInfo->fNEvents_unweighted +=1 ; }
231 
232  // return the sum of the signal weights in the node
233  Float_t GetNSigEvents( void ) const { return fTrainInfo->fNSigEvents; }
234 
235  // return the sum of the backgr weights in the node
236  Float_t GetNBkgEvents( void ) const { return fTrainInfo->fNBkgEvents; }
237 
238  // return the number of events that entered the node (during training)
239  Float_t GetNEvents( void ) const { return fTrainInfo->fNEvents; }
240 
241  // return the sum of unweighted signal weights in the node
242  Float_t GetNSigEvents_unweighted( void ) const { return fTrainInfo->fNSigEvents_unweighted; }
243 
244  // return the sum of unweighted backgr weights in the node
245  Float_t GetNBkgEvents_unweighted( void ) const { return fTrainInfo->fNBkgEvents_unweighted; }
246 
247  // return the number of unweighted events that entered the node (during training)
248  Float_t GetNEvents_unweighted( void ) const { return fTrainInfo->fNEvents_unweighted; }
249 
250  // return the sum of unboosted signal weights in the node
251  Float_t GetNSigEvents_unboosted( void ) const { return fTrainInfo->fNSigEvents_unboosted; }
252 
253  // return the sum of unboosted backgr weights in the node
254  Float_t GetNBkgEvents_unboosted( void ) const { return fTrainInfo->fNBkgEvents_unboosted; }
255 
256  // return the number of unboosted events that entered the node (during training)
257  Float_t GetNEvents_unboosted( void ) const { return fTrainInfo->fNEvents_unboosted; }
258 
259 
260  // set the choosen index, measure of "purity" (separation between S and B) AT this node
261  void SetSeparationIndex( Float_t sep ){ fTrainInfo->fSeparationIndex =sep ; }
262  // return the separation index AT this node
263  Float_t GetSeparationIndex( void ) const { return fTrainInfo->fSeparationIndex; }
264 
265  // set the separation, or information gained BY this nodes selection
266  void SetSeparationGain( Float_t sep ){ fTrainInfo->fSeparationGain =sep ; }
267  // return the gain in separation obtained by this nodes selection
268  Float_t GetSeparationGain( void ) const { return fTrainInfo->fSeparationGain; }
269 
270  // printout of the node
271  virtual void Print( std::ostream& os ) const;
272 
273  // recursively print the node and its daughters (--> print the 'tree')
274  virtual void PrintRec( std::ostream& os ) const;
275 
276  virtual void AddAttributesToNode(void* node) const;
277  virtual void AddContentToNode(std::stringstream& s) const;
278 
279  // recursively clear the nodes content (S/N etc, but not the cut criteria)
280  void ClearNodeAndAllDaughters();
281 
282  // get pointers to children, mother in the tree
283 
284  // return pointer to the left/right daughter or parent node
285  inline virtual DecisionTreeNode* GetLeft( ) const { return dynamic_cast<DecisionTreeNode*>(fLeft); }
286  inline virtual DecisionTreeNode* GetRight( ) const { return dynamic_cast<DecisionTreeNode*>(fRight); }
287  inline virtual DecisionTreeNode* GetParent( ) const { return dynamic_cast<DecisionTreeNode*>(fParent); }
288 
289  // set pointer to the left/right daughter and parent node
290  inline virtual void SetLeft (Node* l) { fLeft = dynamic_cast<DecisionTreeNode*>(l);}
291  inline virtual void SetRight (Node* r) { fRight = dynamic_cast<DecisionTreeNode*>(r);}
292  inline virtual void SetParent(Node* p) { fParent = dynamic_cast<DecisionTreeNode*>(p);}
293 
294 
295 
296 
297  // the node resubstitution estimate, R(t), for Cost Complexity pruning
298  inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r; }
299  inline Double_t GetNodeR( ) const { return fTrainInfo->fNodeR; }
300 
301  // the resubstitution estimate, R(T_t), of the tree rooted at this node
302  inline void SetSubTreeR( Double_t r ) { fTrainInfo->fSubTreeR = r; }
303  inline Double_t GetSubTreeR( ) const { return fTrainInfo->fSubTreeR; }
304 
305  // R(t) - R(T_t)
306  // the critical point alpha = -------------
307  // |~T_t| - 1
308  inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
309  inline Double_t GetAlpha( ) const { return fTrainInfo->fAlpha; }
310 
311  // the minimum alpha in the tree rooted at this node
312  inline void SetAlphaMinSubtree( Double_t g ) { fTrainInfo->fG = g; }
313  inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo->fG; }
314 
315  // number of terminal nodes in the subtree rooted here
316  inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n; }
317  inline Int_t GetNTerminal( ) const { return fTrainInfo->fNTerminal; }
318 
319  // number of background/signal events from the pruning validation sample
320  inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
321  inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
322  inline Double_t GetNBValidation( ) const { return fTrainInfo->fNB; }
323  inline Double_t GetNSValidation( ) const { return fTrainInfo->fNS; }
324 
325 
326  inline void SetSumTarget(Float_t t) {fTrainInfo->fSumTarget = t; }
327  inline void SetSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 = t2; }
328 
329  inline void AddToSumTarget(Float_t t) {fTrainInfo->fSumTarget += t; }
330  inline void AddToSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 += t2; }
331 
332  inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
333  inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
334 
335 
336  // reset the pruning validation data
337  void ResetValidationData( );
338 
339  // flag indicates whether this node is terminal
340  inline Bool_t IsTerminal() const { return fIsTerminalNode; }
341  inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
342  void PrintPrune( std::ostream& os ) const ;
343  void PrintRecPrune( std::ostream& os ) const;
344 
345  void SetCC(Double_t cc);
346  Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
347 
348  Float_t GetSampleMin(UInt_t ivar) const;
349  Float_t GetSampleMax(UInt_t ivar) const;
350  void SetSampleMin(UInt_t ivar, Float_t xmin);
351  void SetSampleMax(UInt_t ivar, Float_t xmax);
352 
353  static bool fgIsTraining; // static variable to flag training phase in which we need fTrainInfo
354  static UInt_t fgTmva_Version_Code; // set only when read from weightfile
355 
356  protected:
357 
358  static MsgLogger& Log();
359 
360  std::vector<Double_t> fFisherCoeff; // the fisher coeff (offset at the last element)
361 
362  Float_t fCutValue; // cut value appplied on this node to discriminate bkg against sig
363  Bool_t fCutType; // true: if event variable > cutValue ==> signal , false otherwise
364  Short_t fSelector; // index of variable used in node selection (decision tree)
365 
366  Float_t fResponse; // response value in case of regression
367  Float_t fRMS; // response RMS of the regression node
368  Int_t fNodeType; // Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal
369  Float_t fPurity; // the node purity
370 
371  Bool_t fIsTerminalNode; //! flag to set node as terminal (i.e., without deleting its descendants)
372 
374 
375  private:
376 
377  virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
378  virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
379  virtual void ReadContent(std::stringstream& s);
380 
381  ClassDef(DecisionTreeNode,0); // Node for the Decision Tree
382  };
383 } // namespace TMVA
384 
385 #endif
DTNodeTrainingInfo * fTrainInfo
flag to set node as terminal (i.e., without deleting its descendants)
float xmin
Definition: THbookFile.cxx:93
void SetSelector(Short_t i)
#define TMVA_VERSION_CODE
Definition: Version.h:47
Float_t GetSumTarget() const
Double_t Log(Double_t x)
Definition: TMath.h:526
float Float_t
Definition: RtypesCore.h:53
return c
Float_t GetNBkgEvents_unboosted(void) const
UInt_t GetNFisherCoeff() const
Double_t GetCC() const
std::vector< Float_t > fSampleMax
Double_t GetAlpha() const
virtual DecisionTreeNode * GetRight() const
Float_t GetRMS(void) const
void IncrementNEvents(Float_t nev)
Int_t GetNodeType(void) const
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Float_t GetNEvents_unboosted(void) const
virtual void SetRight(Node *r)
virtual DecisionTreeNode * GetLeft() const
std::vector< Float_t > fSampleMin
virtual DecisionTreeNode * GetParent() const
void SetNSigEvents_unweighted(Float_t s)
void SetResponse(Float_t r)
void SetNBValidation(Double_t b)
void SetNFisherCoeff(Int_t nvars)
#define ClassDef(name, id)
Definition: Rtypes.h:254
Double_t GetNSValidation() const
void IncrementNBkgEvents(Float_t b)
Float_t GetPurity(void) const
Bool_t GetCutType(void) const
void SetSeparationGain(Float_t sep)
void SetNodeR(Double_t r)
virtual Node * CreateNode() const
void SetNBkgEvents(Float_t b)
void SetNSValidation(Double_t s)
void AddToSumTarget(Float_t t)
Double_t GetSubTreeR() const
Double_t GetAlphaMinSubtree() const
void SetNEvents(Float_t nev)
void SetSumTarget2(Float_t t2)
Float_t GetNBkgEvents(void) const
void SetSubTreeR(Double_t r)
TRandom2 r(17)
DTNodeTrainingInfo(const DTNodeTrainingInfo &n)
Int_t GetNTerminal() const
Double_t GetNBValidation() const
virtual void SetLeft(Node *l)
void SetAlpha(Double_t alpha)
Double_t GetFisherCoeff(Int_t ivar) const
void SetCutValue(Float_t c)
unsigned int UInt_t
Definition: RtypesCore.h:42
Float_t GetNEvents(void) const
short Short_t
Definition: RtypesCore.h:35
TLine * l
Definition: textangle.C:4
float xmax
Definition: THbookFile.cxx:93
void SetSumTarget(Float_t t)
Float_t GetNSigEvents_unboosted(void) const
virtual void SetParent(Node *p)
Bool_t IsTerminal() const
void AddToSumTarget2(Float_t t2)
static UInt_t fgTmva_Version_Code
void Print(std::ostream &os, const OptionType &opt)
Double_t GetNodeR() const
double Double_t
Definition: RtypesCore.h:55
void IncrementNSigEvents(Float_t s)
Float_t GetSumTarget2() const
void SetAlphaMinSubtree(Double_t g)
void SetNEvents_unboosted(Float_t nev)
void SetNSigEvents_unboosted(Float_t s)
void SetTerminal(Bool_t s=kTRUE)
void SetNSigEvents(Float_t s)
void SetNBkgEvents_unboosted(Float_t b)
Short_t GetSelector() const
void SetNBkgEvents_unweighted(Float_t b)
Abstract ClassifierFactory template that handles arbitrary types.
you should not use this method at all Int_t Int_t Double_t Double_t Double_t Int_t Double_t Double_t Double_t Double_t b
Definition: TRolke.cxx:630
Float_t GetSeparationGain(void) const
Float_t GetSeparationIndex(void) const
#define NULL
Definition: Rtypes.h:82
Float_t GetNSigEvents(void) const
std::vector< Double_t > fFisherCoeff
Float_t GetNBkgEvents_unweighted(void) const
void SetSeparationIndex(Float_t sep)
Float_t GetNSigEvents_unweighted(void) const
const Bool_t kTRUE
Definition: Rtypes.h:91
const Int_t n
Definition: legend1.C:16
Float_t GetResponse(void) const
Float_t GetNEvents_unweighted(void) const
void SetNEvents_unweighted(Float_t nev)
Float_t GetCutValue(void) const