Logo ROOT   6.10/09
Reference Guide
CCTreeWrapper.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCTreeWrapper *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: a light wrapper of a decision tree, used to perform cost *
8  * complexity pruning "in-place" Cost Complexity Pruning *
9  * *
10  * Author: Doug Schouten (dschoute@sfu.ca) *
11  * *
12  * *
13  * Copyright (c) 2007: *
14  * CERN, Switzerland *
15  * MPI-K Heidelberg, Germany *
16  * U. of Texas at Austin, USA *
17  * *
18  * Redistribution and use in source and binary forms, with or without *
19  * modification, are permitted according to the terms listed in LICENSE *
20  * (http://tmva.sourceforge.net/LICENSE) *
21  **********************************************************************************/
22 
23 /*! \class TMVA::CCTreeWrapper
24 \ingroup TMVA
25 
26 */
27 
28 #include "TMVA/CCTreeWrapper.h"
29 #include "TMVA/DecisionTree.h"
30 
31 #include <iostream>
32 #include <limits>
33 
34 using namespace TMVA;
35 
36 ////////////////////////////////////////////////////////////////////////////////
37 ///constructor of the CCTreeNode
38 
40  Node(),
41  fNLeafDaughters(0),
42  fNodeResubstitutionEstimate(-1.0),
43  fResubstitutionEstimate(-1.0),
44  fAlphaC(-1.0),
45  fMinAlphaC(-1.0),
46  fDTNode(n)
47 {
48  if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
49  SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
50  GetRight()->SetParent(this);
51  SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
52  GetLeft()->SetParent(this);
53  }
54 }
55 
56 ////////////////////////////////////////////////////////////////////////////////
57 /// destructor of a CCTreeNode
58 
60  if(GetLeft() != NULL) delete GetLeftDaughter();
61  if(GetRight() != NULL) delete GetRightDaughter();
62 }
63 
64 ////////////////////////////////////////////////////////////////////////////////
65 /// initialize a node from a data record
66 
67 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
68  std::string header, title;
69  in >> header;
70  in >> title; in >> fNLeafDaughters;
71  in >> title; in >> fNodeResubstitutionEstimate;
72  in >> title; in >> fResubstitutionEstimate;
73  in >> title; in >> fAlphaC;
74  in >> title; in >> fMinAlphaC;
75  return true;
76 }
77 
78 ////////////////////////////////////////////////////////////////////////////////
79 /// printout of the node (can be read in with ReadDataRecord)
80 
81 void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
82  os << "----------------------" << std::endl
83  << "|~T_t| " << fNLeafDaughters << std::endl
84  << "R(t): " << fNodeResubstitutionEstimate << std::endl
85  << "R(T_t): " << fResubstitutionEstimate << std::endl
86  << "g(t): " << fAlphaC << std::endl
87  << "G(t): " << fMinAlphaC << std::endl;
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 /// recursive printout of the node and its daughters
92 
93 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
94  this->Print(os);
95  if(this->GetLeft() != NULL && this->GetRight() != NULL) {
96  this->GetLeft()->PrintRec(os);
97  this->GetRight()->PrintRec(os);
98  }
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// constructor
103 
105  fRoot(NULL)
106 {
107  fDTParent = T;
108  fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
109  fQualityIndex = qualityIndex;
110  InitTree(fRoot);
111 }
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 /// destructor
115 
117  delete fRoot;
118 }
119 
120 ////////////////////////////////////////////////////////////////////////////////
121 /// initialize the node t and all its descendants
122 
124 {
125  Double_t s = t->GetDTNode()->GetNSigEvents();
126  Double_t b = t->GetDTNode()->GetNBkgEvents();
127  // Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
128  // Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
129  // set R(t) = Gini(t) or MisclassificationError(t), etc.
131 
132  if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
133  // traverse the tree
136  // set |~T_t|
139  // set R(T) = sum[t' in ~T]{ R(t) }
142  // set g(t)
144  (t->GetNLeafDaughters() - 1));
145  // G(t) = min( g(t), G(l(t)), G(r(t)) )
146  t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
147  t->GetRightDaughter()->GetMinAlphaC())));
148  }
149  else { // n is a terminal node
150  t->SetNLeafDaughters(1);
152  t->SetAlphaC(std::numeric_limits<double>::infinity( ));
153  t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
154  }
155 }
156 
157 ////////////////////////////////////////////////////////////////////////////////
158 /// remove the branch rooted at node t
159 
161 {
162  if( t->GetLeft() != NULL &&
163  t->GetRight() != NULL ) {
164  CCTreeNode* l = t->GetLeftDaughter();
165  CCTreeNode* r = t->GetRightDaughter();
166  t->SetNLeafDaughters( 1 );
168  t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
169  t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
170  delete l;
171  delete r;
172  t->SetLeft(NULL);
173  t->SetRight(NULL);
174  }else{
175  std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
176  }
177 }
178 
179 ////////////////////////////////////////////////////////////////////////////////
180 /// return the misclassification rate of a pruned tree for a validation event sample
181 /// using an EventList
182 
184 {
185  Double_t ncorrect=0, nfalse=0;
186  for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
187  Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
188 
189  if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
190  ncorrect += (*validationSample)[ievt]->GetWeight();
191  }
192  else{
193  nfalse += (*validationSample)[ievt]->GetWeight();
194  }
195  }
196  return ncorrect / (ncorrect + nfalse);
197 }
198 
199 ////////////////////////////////////////////////////////////////////////////////
200 /// return the misclassification rate of a pruned tree for a validation event sample
201 /// using the DataSet
202 
204 {
205  validationSample->SetCurrentType(Types::kValidation);
206  // test the tree quality.. in terms of Misclassification
207  Double_t ncorrect=0, nfalse=0;
208  for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
209  const Event *ev = validationSample->GetEvent(ievt);
210 
211  Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
212 
213  if (isSignalType == (ev->GetClass() == 0)) {
214  ncorrect += ev->GetWeight();
215  }
216  else{
217  nfalse += ev->GetWeight();
218  }
219  }
220  return ncorrect / (ncorrect + nfalse);
221 }
222 
223 ////////////////////////////////////////////////////////////////////////////////
224 /// return the decision tree output for an event
225 
227 {
228  const DecisionTreeNode* current = fRoot->GetDTNode();
229  CCTreeNode* t = fRoot;
230 
231  while(//current->GetNodeType() == 0 &&
232  t->GetLeft() != NULL &&
233  t->GetRight() != NULL){ // at an interior (non-leaf) node
234  if (current->GoesRight(e)) {
235  //current = (DecisionTreeNode*)current->GetRight();
236  t = t->GetRightDaughter();
237  current = t->GetDTNode();
238  }
239  else {
240  //current = (DecisionTreeNode*)current->GetLeft();
241  t = t->GetLeftDaughter();
242  current = t->GetDTNode();
243  }
244  }
245 
246  if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
247  else return current->GetPurity();
248 }
249 
250 ////////////////////////////////////////////////////////////////////////////////
251 
253 {}
254 
255 ////////////////////////////////////////////////////////////////////////////////
256 
257 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
258 {}
259 
260 ////////////////////////////////////////////////////////////////////////////////
261 
262 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */ )
263 {}
264 
265 ////////////////////////////////////////////////////////////////////////////////
266 
267 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
268 {}
virtual void PrintRec(std::ostream &os) const =0
virtual ~CCTreeNode()
destructor of a CCTreeNode
CCTreeNode * fRoot
pointer to underlying DecisionTree
long long Long64_t
Definition: RtypesCore.h:69
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
double T(double x)
Definition: ChebyshevPol.h:34
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:156
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Definition: Node.h:93
#define NULL
Definition: RtypesCore.h:88
Float_t GetNSigEvents(void) const
Double_t fAlphaC
R(T_t) = sum[t&#39; in ~T_t]{ R(t) }.
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:88
virtual void AddContentToNode(std::stringstream &s) const
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:70
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
TClass * GetClass(T *)
Definition: TClass.h:545
virtual void SetLeft(Node *l)
Definition: Node.h:92
Float_t GetNBkgEvents(void) const
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
virtual Node * GetRight() const
Definition: Node.h:88
virtual Node * GetLeft() const
Definition: Node.h:87
UInt_t GetClass() const
Definition: Event.h:81
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:73
Double_t fNodeResubstitutionEstimate
number of terminal descendants
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:382
virtual void ReadContent(std::stringstream &s)
Class that contains all the data information.
Definition: DataSet.h:69
DecisionTree * fDTParent
pointer to the used quality index calculator
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:40
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
TRandom2 r(17)
SeparationBase * fQualityIndex
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:63
Implementation of a Decision Tree.
Definition: DecisionTree.h:59
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
TLine * l
Definition: textangle.C:4
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:86
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
virtual void SetParent(Node *p)
Definition: Node.h:94
virtual Double_t GetSeparationIndex(const Double_t s, const Double_t b)=0
Float_t GetPurity(void) const
virtual Bool_t GoesRight(const Event &) const
test event if it descends the tree at this node to the right
double Double_t
Definition: RtypesCore.h:55
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:100
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
DecisionTreeNode * GetDTNode() const
Definition: CCTreeWrapper.h:92
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:80
Abstract ClassifierFactory template that handles arbitrary types.
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
you should not use this method at all Int_t Int_t Double_t Double_t Double_t Int_t Double_t Double_t Double_t Double_t b
Definition: TRolke.cxx:630
virtual DecisionTreeNode * GetLeft() const
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:215
virtual DecisionTreeNode * GetRight() const
~CCTreeWrapper()
destructor
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
const Int_t n
Definition: legend1.C:16
const Event * GetEvent() const
Definition: DataSet.cxx:202
virtual void AddAttributesToNode(void *node) const
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:66