Logo ROOT  
Reference Guide
DecisionTreeNode.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : DecisionTreeNode *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Node for the Decision Tree *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17 * Eckhard von Toerne <evt@physik.uni-bonn.de> - U. of Bonn, Germany *
18 * *
19 * Copyright (c) 2009: *
20 * CERN, Switzerland *
21 * U. of Victoria, Canada *
22 * MPI-K Heidelberg, Germany *
23 * U. of Bonn, Germany *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 **********************************************************************************/
29
30#ifndef ROOT_TMVA_DecisionTreeNode
31#define ROOT_TMVA_DecisionTreeNode
32
33//////////////////////////////////////////////////////////////////////////
34// //
35// DecisionTreeNode //
36// //
37// Node for the Decision Tree //
38// //
39//////////////////////////////////////////////////////////////////////////
40
41#include "TMVA/Node.h"
42
43#include "TMVA/Version.h"
44
45#include <iostream>
46#include <vector>
47#include <map>
48namespace TMVA {
49
51 {
52 public:
54 fSampleMax(),
55 fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
56 fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
57 fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
58 fNEvents ( -1 ),
65 fSeparationIndex (-1 ),
66 fSeparationGain ( -1 )
67 {
68 }
69 std::vector< Float_t > fSampleMin; // the minima for each ivar of the sample on the node during training
70 std::vector< Float_t > fSampleMax; // the maxima for each ivar of the sample on the node during training
71 Double_t fNodeR; // node resubstitution estimate, R(t)
72 Double_t fSubTreeR; // R(T) = Sum(R(t) : t in ~T)
73 Double_t fAlpha; // critical alpha for this node
74 Double_t fG; // minimum alpha in subtree rooted at this node
75 Int_t fNTerminal; // number of terminal nodes in subtree rooted at this node
76 Double_t fNB; // sum of weights of background events from the pruning sample in this node
77 Double_t fNS; // ditto for the signal events
78 Float_t fSumTarget; // sum of weight*target used for the calculatio of the variance (regression)
79 Float_t fSumTarget2; // sum of weight*target^2 used for the calculatio of the variance (regression)
80 Double_t fCC; // debug variable for cost complexity pruning ..
81
82 Float_t fNSigEvents; // sum of weights of signal event in the node
83 Float_t fNBkgEvents; // sum of weights of backgr event in the node
84 Float_t fNEvents; // number of events in that entered the node (during training)
85 Float_t fNSigEvents_unweighted; // sum of signal event in the node
86 Float_t fNBkgEvents_unweighted; // sum of backgr event in the node
87 Float_t fNEvents_unweighted; // number of events in that entered the node (during training)
88 Float_t fNSigEvents_unboosted; // sum of signal event in the node
89 Float_t fNBkgEvents_unboosted; // sum of backgr event in the node
90 Float_t fNEvents_unboosted; // number of events in that entered the node (during training)
91 Float_t fSeparationIndex; // measure of "purity" (separation between S and B) AT this node
92 Float_t fSeparationGain; // measure of "purity", separation, or information gained BY this nodes selection
93
94 // copy constructor
96 fSampleMin(),fSampleMax(), // Samplemin and max are reset in copy constructor
98 fAlpha(n.fAlpha), fG(n.fG),
100 fNB(n.fNB), fNS(n.fNS),
101 fSumTarget(0),fSumTarget2(0), // SumTarget reset in copy constructor
102 fCC(0),
104 fNEvents ( n.fNEvents ),
110 { }
111 };
112
113 class Event;
114 class MsgLogger;
115
116 class DecisionTreeNode: public Node {
117
118 public:
119
120 // constructor of an essentially "empty" node floating in space
122 // constructor of a daughter node as a daughter of 'p'
123 DecisionTreeNode (Node* p, char pos);
124
125 // copy constructor
126 DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = NULL);
127
128 // destructor
129 virtual ~DecisionTreeNode();
130
131 virtual Node* CreateNode() const { return new DecisionTreeNode(); }
132
133 inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
134 inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
135 // set fisher coefficients
136 void SetFisherCoeff(Int_t ivar, Double_t coeff);
137 // get fisher coefficients
138 Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
139
140 // test event if it decends the tree at this node to the right
141 virtual Bool_t GoesRight( const Event & ) const;
142
143 // test event if it decends the tree at this node to the left
144 virtual Bool_t GoesLeft ( const Event & ) const;
145
146 // set index of variable used for discrimination at this node
147 void SetSelector( Short_t i) { fSelector = i; }
148 // return index of variable used for discrimination at this node
149 Short_t GetSelector() const { return fSelector; }
150
151 // set the cut value applied at this node
153 // return the cut value applied at this node
154 Float_t GetCutValue ( void ) const { return fCutValue; }
155
156 // set true: if event variable > cutValue ==> signal , false otherwise
157 void SetCutType( Bool_t t ) { fCutType = t; }
158 // return kTRUE: Cuts select signal, kFALSE: Cuts select bkg
159 Bool_t GetCutType( void ) const { return fCutType; }
160
161 // set node type: 1 signal node, -1 bkg leave, 0 intermediate Node
162 void SetNodeType( Int_t t ) { fNodeType = t;}
163 // return node type: 1 signal node, -1 bkg leave, 0 intermediate Node
164 Int_t GetNodeType( void ) const { return fNodeType; }
165
166 //return S/(S+B) (purity) at this node (from training)
167 Float_t GetPurity( void ) const { return fPurity;}
168 //calculate S/(S+B) (purity) at this node (from training)
169 void SetPurity( void );
170
171 //set the response of the node (for regression)
173
174 //return the response of the node (for regression)
175 Float_t GetResponse( void ) const { return fResponse;}
176
177 //set the RMS of the response of the node (for regression)
178 void SetRMS( Float_t r ) { fRMS = r;}
179
180 //return the RMS of the response of the node (for regression)
181 Float_t GetRMS( void ) const { return fRMS;}
182
183 // set the sum of the signal weights in the node
185
186 // set the sum of the backgr weights in the node
188
189 // set the number of events that entered the node (during training)
190 void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
191
192 // set the sum of the unweighted signal events in the node
194
195 // set the sum of the unweighted backgr events in the node
197
198 // set the number of unweighted events that entered the node (during training)
200
201 // set the sum of the unboosted signal events in the node
203
204 // set the sum of the unboosted backgr events in the node
206
207 // set the number of unboosted events that entered the node (during training)
209
210 // increment the sum of the signal weights in the node
212
213 // increment the sum of the backgr weights in the node
215
216 // increment the number of events that entered the node (during training)
218
219 // increment the sum of the signal weights in the node
221
222 // increment the sum of the backgr weights in the node
224
225 // increment the number of events that entered the node (during training)
227
228 // return the sum of the signal weights in the node
229 Float_t GetNSigEvents( void ) const { return fTrainInfo->fNSigEvents; }
230
231 // return the sum of the backgr weights in the node
232 Float_t GetNBkgEvents( void ) const { return fTrainInfo->fNBkgEvents; }
233
234 // return the number of events that entered the node (during training)
235 Float_t GetNEvents( void ) const { return fTrainInfo->fNEvents; }
236
237 // return the sum of unweighted signal weights in the node
239
240 // return the sum of unweighted backgr weights in the node
242
243 // return the number of unweighted events that entered the node (during training)
245
246 // return the sum of unboosted signal weights in the node
248
249 // return the sum of unboosted backgr weights in the node
251
252 // return the number of unboosted events that entered the node (during training)
254
255
256 // set the choosen index, measure of "purity" (separation between S and B) AT this node
258 // return the separation index AT this node
260
261 // set the separation, or information gained BY this nodes selection
263 // return the gain in separation obtained by this nodes selection
265
266 // printout of the node
267 virtual void Print( std::ostream& os ) const;
268
269 // recursively print the node and its daughters (--> print the 'tree')
270 virtual void PrintRec( std::ostream& os ) const;
271
272 virtual void AddAttributesToNode(void* node) const;
273 virtual void AddContentToNode(std::stringstream& s) const;
274
275 // recursively clear the nodes content (S/N etc, but not the cut criteria)
277
278 // get pointers to children, mother in the tree
279
280 // return pointer to the left/right daughter or parent node
281 inline virtual DecisionTreeNode* GetLeft( ) const { return static_cast<DecisionTreeNode*>(fLeft); }
282 inline virtual DecisionTreeNode* GetRight( ) const { return static_cast<DecisionTreeNode*>(fRight); }
283 inline virtual DecisionTreeNode* GetParent( ) const { return static_cast<DecisionTreeNode*>(fParent); }
284
285 // set pointer to the left/right daughter and parent node
286 inline virtual void SetLeft (Node* l) { fLeft = l;}
287 inline virtual void SetRight (Node* r) { fRight = r;}
288 inline virtual void SetParent(Node* p) { fParent = p;}
289
290
291
292
293 // the node resubstitution estimate, R(t), for Cost Complexity pruning
294 inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r; }
295 inline Double_t GetNodeR( ) const { return fTrainInfo->fNodeR; }
296
297 // the resubstitution estimate, R(T_t), of the tree rooted at this node
299 inline Double_t GetSubTreeR( ) const { return fTrainInfo->fSubTreeR; }
300
301 // R(t) - R(T_t)
302 // the critical point alpha = -------------
303 // |~T_t| - 1
304 inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
305 inline Double_t GetAlpha( ) const { return fTrainInfo->fAlpha; }
306
307 // the minimum alpha in the tree rooted at this node
309 inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo->fG; }
310
311 // number of terminal nodes in the subtree rooted here
312 inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n; }
313 inline Int_t GetNTerminal( ) const { return fTrainInfo->fNTerminal; }
314
315 // number of background/signal events from the pruning validation sample
316 inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
317 inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
318 inline Double_t GetNBValidation( ) const { return fTrainInfo->fNB; }
319 inline Double_t GetNSValidation( ) const { return fTrainInfo->fNS; }
320
321
324
327
328 inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
329 inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
330
331
332 // reset the pruning validation data
333 void ResetValidationData( );
334
335 // flag indicates whether this node is terminal
336 inline Bool_t IsTerminal() const { return fIsTerminalNode; }
337 inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
338 void PrintPrune( std::ostream& os ) const ;
339 void PrintRecPrune( std::ostream& os ) const;
340
341 void SetCC(Double_t cc);
342 Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
343
344 Float_t GetSampleMin(UInt_t ivar) const;
345 Float_t GetSampleMax(UInt_t ivar) const;
346 void SetSampleMin(UInt_t ivar, Float_t xmin);
347 void SetSampleMax(UInt_t ivar, Float_t xmax);
348
349 static bool fgIsTraining; // static variable to flag training phase in which we need fTrainInfo
350 static UInt_t fgTmva_Version_Code; // set only when read from weightfile
351
352 virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
353 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
354 virtual void ReadContent(std::stringstream& s);
355
356 protected:
357
358 static MsgLogger& Log();
359
360 std::vector<Double_t> fFisherCoeff; // the fisher coeff (offset at the last element)
361
362 Float_t fCutValue; // cut value appplied on this node to discriminate bkg against sig
363 Bool_t fCutType; // true: if event variable > cutValue ==> signal , false otherwise
364 Short_t fSelector; // index of variable used in node selection (decision tree)
365
366 Float_t fResponse; // response value in case of regression
367 Float_t fRMS; // response RMS of the regression node
368 Int_t fNodeType; // Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal
369 Float_t fPurity; // the node purity
370
371 Bool_t fIsTerminalNode; //! flag to set node as terminal (i.e., without deleting its descendants)
372
374
375 private:
376
377 ClassDef(DecisionTreeNode,0); // Node for the Decision Tree
378 };
379} // namespace TMVA
380
381#endif
ROOT::R::TRInterface & r
Definition: Object.C:4
#define b(i)
Definition: RSha256.hxx:100
#define c(i)
Definition: RSha256.hxx:101
#define g(i)
Definition: RSha256.hxx:105
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
short Short_t
Definition: RtypesCore.h:35
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassDef(name, id)
Definition: Rtypes.h:326
float xmin
Definition: THbookFile.cxx:93
float xmax
Definition: THbookFile.cxx:93
#define TMVA_VERSION_CODE
Definition: Version.h:47
std::vector< Float_t > fSampleMax
DTNodeTrainingInfo(const DTNodeTrainingInfo &n)
std::vector< Float_t > fSampleMin
virtual void AddContentToNode(std::stringstream &s) const
adding attributes to tree node (well, was used in BinarySearchTree, and somehow I guess someone progr...
void SetNEvents_unweighted(Float_t nev)
Float_t GetNBkgEvents_unboosted(void) const
DTNodeTrainingInfo * fTrainInfo
flag to set node as terminal (i.e., without deleting its descendants)
virtual ~DecisionTreeNode()
destructor
Float_t GetNSigEvents_unweighted(void) const
Float_t GetNBkgEvents_unweighted(void) const
Double_t GetSubTreeR() const
Float_t GetSeparationIndex(void) const
void SetSeparationGain(Float_t sep)
void SetNBkgEvents(Float_t b)
Float_t GetNSigEvents_unboosted(void) const
Double_t GetNSValidation() const
void PrintPrune(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
Float_t GetSumTarget() const
void PrintRecPrune(std::ostream &os) const
recursive printout of the node and its daughters
void SetFisherCoeff(Int_t ivar, Double_t coeff)
set fisher coefficients
void SetNSigEvents_unboosted(Float_t s)
void SetSumTarget2(Float_t t2)
void SetAlphaMinSubtree(Double_t g)
static UInt_t fgTmva_Version_Code
void IncrementNBkgEvents(Float_t b)
void SetNEvents_unboosted(Float_t nev)
Float_t GetNSigEvents(void) const
virtual void SetLeft(Node *l)
Double_t GetAlphaMinSubtree() const
void SetTerminal(Bool_t s=kTRUE)
Float_t GetNEvents_unweighted(void) const
void SetResponse(Float_t r)
UInt_t GetNFisherCoeff() const
void SetSampleMax(UInt_t ivar, Float_t xmax)
set the maximum of variable ivar from the training sample that pass/end up in this node
void ClearNodeAndAllDaughters()
clear the nodes (their S/N, Nevents etc), just keep the structure of the tree
virtual Bool_t GoesLeft(const Event &) const
test event if it descends the tree at this node to the left
virtual void ReadContent(std::stringstream &s)
reading attributes from tree node (well, was used in BinarySearchTree, and somehow I guess someone pr...
void SetNBValidation(Double_t b)
Float_t GetRMS(void) const
void IncrementNEvents(Float_t nev)
void SetPurity(void)
return the S/(S+B) (purity) for the node REM: even if nodes with purity 0.01 are very PURE background...
void SetSubTreeR(Double_t r)
void AddToSumTarget2(Float_t t2)
virtual void Print(std::ostream &os) const
print the node
virtual DecisionTreeNode * GetLeft() const
Double_t GetNodeR() const
Float_t GetSumTarget2() const
virtual Bool_t GoesRight(const Event &) const
test event if it descends the tree at this node to the right
DecisionTreeNode()
constructor of an essentially "empty" node floating in space
void SetNFisherCoeff(Int_t nvars)
virtual void AddAttributesToNode(void *node) const
add attribute to xml
Short_t GetSelector() const
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
void SetNSigEvents(Float_t s)
Float_t GetResponse(void) const
Float_t GetCutValue(void) const
Int_t GetNodeType(void) const
Double_t GetAlpha() const
Bool_t GetCutType(void) const
static MsgLogger & Log()
void ResetValidationData()
temporary stored node values (number of events, etc.) that originate not from the training but from t...
virtual void PrintRec(std::ostream &os) const
recursively print the node and its daughters (--> print the 'tree')
void SetNSigEvents_unweighted(Float_t s)
Float_t GetNEvents(void) const
Double_t GetCC() const
virtual Node * CreateNode() const
Double_t GetNBValidation() const
void SetAlpha(Double_t alpha)
void SetSeparationIndex(Float_t sep)
virtual void SetRight(Node *r)
virtual Bool_t ReadDataRecord(std::istream &is, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
Read the data block.
void SetSumTarget(Float_t t)
virtual void SetParent(Node *p)
void SetNodeR(Double_t r)
void SetNBkgEvents_unboosted(Float_t b)
Float_t GetPurity(void) const
Float_t GetNEvents_unboosted(void) const
void IncrementNSigEvents(Float_t s)
Float_t GetSeparationGain(void) const
Float_t GetSampleMax(UInt_t ivar) const
return the maximum of variable ivar from the training sample that pass/end up in this node
void SetCutValue(Float_t c)
Float_t GetNBkgEvents(void) const
Float_t GetSampleMin(UInt_t ivar) const
return the minimum of variable ivar from the training sample that pass/end up in this node
void SetSampleMin(UInt_t ivar, Float_t xmin)
set the minimum of variable ivar from the training sample that pass/end up in this node
void SetSelector(Short_t i)
std::vector< Double_t > fFisherCoeff
virtual DecisionTreeNode * GetParent() const
Double_t GetFisherCoeff(Int_t ivar) const
void SetNBkgEvents_unweighted(Float_t b)
void SetNSValidation(Double_t s)
void AddToSumTarget(Float_t t)
void SetNEvents(Float_t nev)
virtual DecisionTreeNode * GetRight() const
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
Node * fLeft
Definition: Node.h:137
Node * fParent
Definition: Node.h:136
Node * fRight
Definition: Node.h:138
const Int_t n
Definition: legend1.C:16
static constexpr double s
create variable transformations
auto * l
Definition: textangle.C:4