Logo ROOT  
Reference Guide
CCTreeWrapper.h
Go to the documentation of this file.
1 
2 /**********************************************************************************
3  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
4  * Package: TMVA *
5  * Class : CCTreeWrapper *
6  * Web : http://tmva.sourceforge.net *
7  * *
8  * Description: a light wrapper of a decision tree, used to perform cost *
9  * complexity pruning "in-place" Cost Complexity Pruning *
10  * *
11  * Author: Doug Schouten (dschoute@sfu.ca) *
12  * *
13  * *
14  * Copyright (c) 2007: *
15  * CERN, Switzerland *
16  * MPI-K Heidelberg, Germany *
17  * U. of Texas at Austin, USA *
18  * *
19  * Redistribution and use in source and binary forms, with or without *
20  * modification, are permitted according to the terms listed in LICENSE *
21  * (http://tmva.sourceforge.net/LICENSE) *
22  **********************************************************************************/
23 
24 #ifndef ROOT_TMVA_CCTreeWrapper
25 #define ROOT_TMVA_CCTreeWrapper
26 
27 #include "TMVA/Event.h"
28 #include "TMVA/SeparationBase.h"
29 #include "TMVA/DecisionTree.h"
30 #include "TMVA/DataSet.h"
31 #include "TMVA/Version.h"
32 #include <vector>
33 #include <string>
34 #include <sstream>
35 
36 namespace TMVA {
37 
38  class CCTreeWrapper {
39 
40  public:
41 
42  typedef std::vector<Event*> EventList;
43 
44  /////////////////////////////////////////////////////////////
45  // CCTreeNode - a light wrapper of a decision tree node //
46  // //
47  /////////////////////////////////////////////////////////////
48 
49  class CCTreeNode : virtual public Node {
50 
51  public:
52 
53  CCTreeNode( DecisionTreeNode* n = NULL );
54  virtual ~CCTreeNode( );
55 
56  virtual Node* CreateNode() const { return new CCTreeNode(); }
57 
58  // set |~T_t|, the number of terminal descendants of node t
59  inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
60 
61  // return |~T_t|
62  inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
63 
64  // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
66 
67  // return R(t) for node t
69 
70  // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
71  // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
72  inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
73 
74  // return R(T_t) for node t
76 
77  // set the critical point of alpha
78  // R(t) - R(T_t)
79  // alpha_c < ------------- := g(t)
80  // |~T_t| - 1
81  // which is the value of alpha such that the branch rooted at node t is pruned
82  inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
83 
84  // get the critical alpha value for this node
85  inline Double_t GetAlphaC( ) const { return fAlphaC; }
86 
87  // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
88  inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
89 
90  // get the minimum critical alpha value
91  inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
92 
93  // get the pointer to the wrapped DT node
94  inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
95 
96  // get pointers to children, mother in the CC tree
97  inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
98  inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
99  inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
100 
101  // printout of the node (can be read in with ReadDataRecord)
102  virtual void Print( std::ostream& os ) const;
103 
104  // recursive printout of the node and its daughters
105  virtual void PrintRec ( std::ostream& os ) const;
106 
107  virtual void AddAttributesToNode(void* node) const;
108  virtual void AddContentToNode(std::stringstream& s) const;
109 
110 
111  // test event if it decends the tree at this node to the right
112  inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ?
113  GetDTNode()->GoesRight(e) : false); }
114 
115  // test event if it decends the tree at this node to the left
116  inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ?
117  GetDTNode()->GoesLeft(e) : false); }
118  // initialize a node from a data record
119  virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
120  virtual void ReadContent(std::stringstream& s);
121  virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
122 
123  private:
124 
125  Int_t fNLeafDaughters; //! number of terminal descendants
126  Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
127  Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
128  Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
129  Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
130  DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
131  };
132 
133  CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
134  ~CCTreeWrapper( );
135 
136  // return the decision tree output for an event
137  Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
138  // return the misclassification rate of a pruned tree for a validation event sample
139  Double_t TestTreeQuality( const EventList* validationSample );
140  Double_t TestTreeQuality( const DataSet* validationSample );
141 
142  // remove the branch rooted at node t
143  void PruneNode( CCTreeNode* t );
144  // initialize the node t and all its descendants
145  void InitTree( CCTreeNode* t );
146 
147  // return the root node for this tree
148  CCTreeNode* GetRoot() { return fRoot; }
149  private:
150  SeparationBase* fQualityIndex; //! pointer to the used quality index calculator
151  DecisionTree* fDTParent; //! pointer to underlying DecisionTree
152  CCTreeNode* fRoot; //! the root node of the (wrapped) decision Tree
153  };
154 
155 }
156 
157 #endif
158 
159 
160 
TMVA::CCTreeWrapper::CCTreeNode::GetMinAlphaC
Double_t GetMinAlphaC() const
Definition: CCTreeWrapper.h:91
n
const Int_t n
Definition: legend1.C:16
TMVA::CCTreeWrapper::CCTreeNode::fMinAlphaC
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
Definition: CCTreeWrapper.h:129
e
#define e(i)
Definition: RSha256.hxx:103
TMVA::CCTreeWrapper::CCTreeNode::fNLeafDaughters
Int_t fNLeafDaughters
Definition: CCTreeWrapper.h:125
TMVA::CCTreeWrapper
Definition: CCTreeWrapper.h:38
TMVA::CCTreeWrapper::CCTreeNode::GetNodeResubstitutionEstimate
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:68
TMVA::CCTreeWrapper::CCTreeNode::AddAttributesToNode
virtual void AddAttributesToNode(void *node) const
Definition: CCTreeWrapper.cxx:252
TMVA::CCTreeWrapper::CCTreeNode::fAlphaC
Double_t fAlphaC
R(T_t) = sum[t' in ~T_t]{ R(t) }.
Definition: CCTreeWrapper.h:128
TMVA::Node::GetRight
virtual Node * GetRight() const
Definition: Node.h:90
TMVA::CCTreeWrapper::CCTreeNode::CreateNode
virtual Node * CreateNode() const
Definition: CCTreeWrapper.h:56
TMVA::CCTreeWrapper::~CCTreeWrapper
~CCTreeWrapper()
destructor
Definition: CCTreeWrapper.cxx:116
TGeant4Unit::s
static constexpr double s
Definition: TGeant4SystemOfUnits.h:162
TMVA::Node::GetParent
virtual Node * GetParent() const
Definition: Node.h:91
TMVA::CCTreeWrapper::CCTreeNode::GetLeftDaughter
CCTreeNode * GetLeftDaughter()
Definition: CCTreeWrapper.h:97
N
#define N
TMVA::CCTreeWrapper::InitTree
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
Definition: CCTreeWrapper.cxx:123
TMVA::DecisionTreeNode
Definition: DecisionTreeNode.h:117
TMVA::CCTreeWrapper::EventList
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:42
TMVA::CCTreeWrapper::CCTreeNode::GetResubstitutionEstimate
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:75
TMVA::CCTreeWrapper::CCTreeNode::fDTNode
DecisionTreeNode * fDTNode
G(t), minimum critical point of t and its descendants.
Definition: CCTreeWrapper.h:130
TMVA::CCTreeWrapper::CCTreeNode
Definition: CCTreeWrapper.h:49
TMVA::Node
Node for the BinarySearch or Decision Trees.
Definition: Node.h:58
TMVA::DecisionTree
Implementation of a Decision Tree.
Definition: DecisionTree.h:65
bool
Version.h
TMVA::CCTreeWrapper::CCTreeNode::GoesLeft
virtual Bool_t GoesLeft(const Event &e) const
Definition: CCTreeWrapper.h:116
TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode
virtual void AddContentToNode(std::stringstream &s) const
Definition: CCTreeWrapper.cxx:257
R
#define R(a, b, c, d, e, f, g, h, i)
Definition: RSha256.hxx:110
TMVA::CCTreeWrapper::CCTreeNode::SetResubstitutionEstimate
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:72
SeparationBase.h
TMVA::CCTreeWrapper::CCTreeNode::SetNodeResubstitutionEstimate
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:65
DecisionTree.h
TMVA::CCTreeWrapper::fRoot
CCTreeNode * fRoot
pointer to underlying DecisionTree
Definition: CCTreeWrapper.h:152
TMVA::CCTreeWrapper::CCTreeNode::SetMinAlphaC
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:88
TMVA::CCTreeWrapper::CCTreeNode::fResubstitutionEstimate
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
Definition: CCTreeWrapper.h:127
TMVA::CCTreeWrapper::CCTreeNode::SetNLeafDaughters
void SetNLeafDaughters(Int_t N)
Definition: CCTreeWrapper.h:59
TMVA::CCTreeWrapper::CCTreeNode::fNodeResubstitutionEstimate
Double_t fNodeResubstitutionEstimate
number of terminal descendants
Definition: CCTreeWrapper.h:126
TMVA::CCTreeWrapper::CCTreeNode::ReadContent
virtual void ReadContent(std::stringstream &s)
Definition: CCTreeWrapper.cxx:267
TMVA::DataSet
Class that contains all the data information.
Definition: DataSet.h:58
Event.h
TMVA::CCTreeWrapper::CCTreeWrapper
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Definition: CCTreeWrapper.cxx:104
TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
Definition: CCTreeWrapper.cxx:67
TMVA::CCTreeWrapper::CCTreeNode::GoesRight
virtual Bool_t GoesRight(const Event &e) const
Definition: CCTreeWrapper.h:112
TMVA::CCTreeWrapper::CCTreeNode::GetNLeafDaughters
Int_t GetNLeafDaughters() const
Definition: CCTreeWrapper.h:62
TMVA::Node::GetLeft
virtual Node * GetLeft() const
Definition: Node.h:89
TMVA::CCTreeWrapper::CCTreeNode::SetAlphaC
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:82
TMVA::CCTreeWrapper::CCTreeNode::GetRightDaughter
CCTreeNode * GetRightDaughter()
Definition: CCTreeWrapper.h:98
TMVA::CCTreeWrapper::CCTreeNode::GetDTNode
DecisionTreeNode * GetDTNode() const
Definition: CCTreeWrapper.h:94
unsigned int
TMVA::SeparationBase
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
Definition: SeparationBase.h:82
TMVA::CCTreeWrapper::fDTParent
DecisionTree * fDTParent
pointer to the used quality index calculator
Definition: CCTreeWrapper.h:151
TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
Definition: CCTreeWrapper.cxx:262
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::CCTreeWrapper::CCTreeNode::Print
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
Definition: CCTreeWrapper.cxx:81
TMVA_VERSION_CODE
#define TMVA_VERSION_CODE
Definition: Version.h:47
TMVA::CCTreeWrapper::GetRoot
CCTreeNode * GetRoot()
Definition: CCTreeWrapper.h:148
TMVA::CCTreeWrapper::CCTreeNode::GetAlphaC
Double_t GetAlphaC() const
Definition: CCTreeWrapper.h:85
TMVA::CCTreeWrapper::CCTreeNode::CCTreeNode
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
Definition: CCTreeWrapper.cxx:39
TMVA::Event
Definition: Event.h:51
ROOT::Math::Chebyshev::T
double T(double x)
Definition: ChebyshevPol.h:34
TMVA::CCTreeWrapper::fQualityIndex
SeparationBase * fQualityIndex
Definition: CCTreeWrapper.h:150
TMVA::CCTreeWrapper::TestTreeQuality
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList
Definition: CCTreeWrapper.cxx:183
TMVA::CCTreeWrapper::CCTreeNode::~CCTreeNode
virtual ~CCTreeNode()
destructor of a CCTreeNode
Definition: CCTreeWrapper.cxx:59
TMVA::CCTreeWrapper::CheckEvent
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
Definition: CCTreeWrapper.cxx:226
TMVA::CCTreeWrapper::PruneNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
Definition: CCTreeWrapper.cxx:160
DataSet.h
TMVA::CCTreeWrapper::CCTreeNode::PrintRec
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
Definition: CCTreeWrapper.cxx:93
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int
TMVA::CCTreeWrapper::CCTreeNode::GetMother
CCTreeNode * GetMother()
Definition: CCTreeWrapper.h:99