Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CCTreeWrapper.h
Go to the documentation of this file.
1
2/**********************************************************************************
3 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
4 * Package: TMVA *
5 * Class : CCTreeWrapper *
6 * *
7 * *
8 * Description: a light wrapper of a decision tree, used to perform cost *
9 * complexity pruning "in-place" Cost Complexity Pruning *
10 * *
11 * Author: Doug Schouten (dschoute@sfu.ca) *
12 * *
13 * *
14 * Copyright (c) 2007: *
15 * CERN, Switzerland *
16 * MPI-K Heidelberg, Germany *
17 * U. of Texas at Austin, USA *
18 * *
19 * Redistribution and use in source and binary forms, with or without *
20 * modification, are permitted according to the terms listed in LICENSE *
21 * (see tmva/doc/LICENSE) *
22 **********************************************************************************/
23
24#ifndef ROOT_TMVA_CCTreeWrapper
25#define ROOT_TMVA_CCTreeWrapper
26
27#include "TMVA/Event.h"
28#include "TMVA/SeparationBase.h"
29#include "TMVA/DecisionTree.h"
30#include "TMVA/DataSet.h"
31#include "TMVA/Version.h"
32#include <vector>
33#include <string>
34#include <sstream>
35
36namespace TMVA {
37
39
40 public:
41
42 typedef std::vector<Event*> EventList;
43
44 /////////////////////////////////////////////////////////////
45 // CCTreeNode - a light wrapper of a decision tree node //
46 // //
47 /////////////////////////////////////////////////////////////
48
49 class CCTreeNode : virtual public Node {
50
51 public:
52
53 CCTreeNode( DecisionTreeNode* n = nullptr );
54 virtual ~CCTreeNode( );
55
56 virtual Node* CreateNode() const { return new CCTreeNode(); }
57
58 // set |~T_t|, the number of terminal descendants of node t
59 inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
60
61 // return |~T_t|
62 inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
63
64 // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
66
67 // return R(t) for node t
69
70 // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
71 // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
72 inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
73
74 // return R(T_t) for node t
76
77 // set the critical point of alpha
78 // R(t) - R(T_t)
79 // alpha_c < ------------- := g(t)
80 // |~T_t| - 1
81 // which is the value of alpha such that the branch rooted at node t is pruned
82 inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
83
84 // get the critical alpha value for this node
85 inline Double_t GetAlphaC( ) const { return fAlphaC; }
86
87 // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
88 inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
89
90 // get the minimum critical alpha value
91 inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
92
93 // get the pointer to the wrapped DT node
94 inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
95
96 // get pointers to children, mother in the CC tree
97 inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
98 inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
99 inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
100
101 // printout of the node (can be read in with ReadDataRecord)
102 virtual void Print( std::ostream& os ) const;
103
104 // recursive printout of the node and its daughters
105 virtual void PrintRec ( std::ostream& os ) const;
106
107 virtual void AddAttributesToNode(void* node) const;
108 virtual void AddContentToNode(std::stringstream& s) const;
109
110
111 // test event if it descends the tree at this node to the right
112 inline virtual Bool_t GoesRight( const Event& e ) const { return GetDTNode() ?
113 GetDTNode()->GoesRight(e) : false; }
114
115 // test event if it descends the tree at this node to the left
116 inline virtual Bool_t GoesLeft ( const Event& e ) const { return GetDTNode() ?
117 GetDTNode()->GoesLeft(e) : false; }
118 // initialize a node from a data record
119 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
120 virtual void ReadContent(std::stringstream& s);
121 virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
122
123 private:
124
125 Int_t fNLeafDaughters; //! number of terminal descendants
126 Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
127 Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
128 Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
129 Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
130 DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
131 };
132
133 CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
135
136 // return the decision tree output for an event
137 Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
138 // return the misclassification rate of a pruned tree for a validation event sample
139 Double_t TestTreeQuality( const EventList* validationSample );
140 Double_t TestTreeQuality( const DataSet* validationSample );
141
142 // remove the branch rooted at node t
143 void PruneNode( CCTreeNode* t );
144 // initialize the node t and all its descendants
145 void InitTree( CCTreeNode* t );
146
147 // return the root node for this tree
148 CCTreeNode* GetRoot() { return fRoot; }
149 private:
150 SeparationBase* fQualityIndex; ///<! pointer to the used quality index calculator
151 DecisionTree* fDTParent; ///<! pointer to underlying DecisionTree
152 CCTreeNode* fRoot; ///<! the root node of the (wrapped) decision Tree
153 };
154
155}
156
157#endif
158
159
160
#define R(a, b, c, d, e, f, g, h, i)
Definition RSha256.hxx:110
#define e(i)
Definition RSha256.hxx:103
#define N
#define TMVA_VERSION_CODE
Definition Version.h:47
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=262657)
Double_t GetNodeResubstitutionEstimate() const
virtual void ReadContent(std::stringstream &s)
Double_t fNodeResubstitutionEstimate
number of terminal descendants
virtual void AddAttributesToNode(void *node) const
void SetMinAlphaC(Double_t alpha)
DecisionTreeNode * GetDTNode() const
void SetResubstitutionEstimate(Double_t R)
Double_t fAlphaC
R(T_t) = sum[t' in ~T_t]{ R(t) }.
virtual Bool_t GoesRight(const Event &e) const
Double_t GetResubstitutionEstimate() const
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=262657)
initialize a node from a data record
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
virtual Bool_t GoesLeft(const Event &e) const
DecisionTreeNode * fDTNode
G(t), minimum critical point of t and its descendants.
virtual void AddContentToNode(std::stringstream &s) const
void SetAlphaC(Double_t alpha)
void SetNodeResubstitutionEstimate(Double_t R)
virtual ~CCTreeNode()
destructor of a CCTreeNode
virtual Node * CreateNode() const
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
SeparationBase * fQualityIndex
! pointer to the used quality index calculator
std::vector< Event * > EventList
DecisionTree * fDTParent
! pointer to underlying DecisionTree
CCTreeNode * fRoot
! the root node of the (wrapped) decision Tree
CCTreeNode * GetRoot()
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
Class that contains all the data information.
Definition DataSet.h:58
virtual Bool_t GoesLeft(const Event &) const
test event if it descends the tree at this node to the left
virtual Bool_t GoesRight(const Event &) const
test event if it descends the tree at this node to the right
Implementation of a Decision Tree.
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
virtual Node * GetLeft() const
Definition Node.h:89
virtual Node * GetParent() const
Definition Node.h:91
virtual Node * GetRight() const
Definition Node.h:90
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
const Int_t n
Definition legend1.C:16
create variable transformations