ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
CCTreeWrapper.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCTreeWrapper *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: a light wrapper of a decision tree, used to perform cost *
8  * complexity pruning "in-place" Cost Complexity Pruning *
9  * *
10  * Author: Doug Schouten (dschoute@sfu.ca) *
11  * *
12  * *
13  * Copyright (c) 2007: *
14  * CERN, Switzerland *
15  * MPI-K Heidelberg, Germany *
16  * U. of Texas at Austin, USA *
17  * *
18  * Redistribution and use in source and binary forms, with or without *
19  * modification, are permitted according to the terms listed in LICENSE *
20  * (http://tmva.sourceforge.net/LICENSE) *
21  **********************************************************************************/
22 
23 #include "TMVA/CCTreeWrapper.h"
24 #include "TMVA/DecisionTree.h"
25 
26 #include <iostream>
27 #include <limits>
28 
29 using namespace TMVA;
30 
31 ////////////////////////////////////////////////////////////////////////////////
32 ///constructor of the CCTreeNode
33 
35  Node(),
36  fNLeafDaughters(0),
37  fNodeResubstitutionEstimate(-1.0),
38  fResubstitutionEstimate(-1.0),
39  fAlphaC(-1.0),
40  fMinAlphaC(-1.0),
41  fDTNode(n)
42 {
43  if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
44  SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
45  GetRight()->SetParent(this);
46  SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
47  GetLeft()->SetParent(this);
48  }
49 }
50 
51 ////////////////////////////////////////////////////////////////////////////////
52 /// destructor of a CCTreeNode
53 
55  if(GetLeft() != NULL) delete GetLeftDaughter();
56  if(GetRight() != NULL) delete GetRightDaughter();
57 }
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 /// initialize a node from a data record
61 
62 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
63  std::string header, title;
64  in >> header;
65  in >> title; in >> fNLeafDaughters;
66  in >> title; in >> fNodeResubstitutionEstimate;
67  in >> title; in >> fResubstitutionEstimate;
68  in >> title; in >> fAlphaC;
69  in >> title; in >> fMinAlphaC;
70  return true;
71 }
72 
73 ////////////////////////////////////////////////////////////////////////////////
74 /// printout of the node (can be read in with ReadDataRecord)
75 
76 void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
77  os << "----------------------" << std::endl
78  << "|~T_t| " << fNLeafDaughters << std::endl
79  << "R(t): " << fNodeResubstitutionEstimate << std::endl
80  << "R(T_t): " << fResubstitutionEstimate << std::endl
81  << "g(t): " << fAlphaC << std::endl
82  << "G(t): " << fMinAlphaC << std::endl;
83 }
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 /// recursive printout of the node and its daughters
87 
88 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
89  this->Print(os);
90  if(this->GetLeft() != NULL && this->GetRight() != NULL) {
91  this->GetLeft()->PrintRec(os);
92  this->GetRight()->PrintRec(os);
93  }
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 /// constructor
98 
100  fRoot(NULL)
101 {
102  fDTParent = T;
103  fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
104  fQualityIndex = qualityIndex;
105  InitTree(fRoot);
106 }
107 
108 ////////////////////////////////////////////////////////////////////////////////
109 /// destructor
110 
112  delete fRoot;
113 }
114 
115 ////////////////////////////////////////////////////////////////////////////////
116 /// initialize the node t and all its descendants
117 
119 {
120  Double_t s = t->GetDTNode()->GetNSigEvents();
121  Double_t b = t->GetDTNode()->GetNBkgEvents();
122  // Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
123  // Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
124  // set R(t) = Gini(t) or MisclassificationError(t), etc.
125  t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
126 
127  if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
128  // traverse the tree
129  InitTree(t->GetLeftDaughter());
130  InitTree(t->GetRightDaughter());
131  // set |~T_t|
134  // set R(T) = sum[t' in ~T]{ R(t) }
137  // set g(t)
139  (t->GetNLeafDaughters() - 1));
140  // G(t) = min( g(t), G(l(t)), G(r(t)) )
142  t->GetRightDaughter()->GetMinAlphaC())));
143  }
144  else { // n is a terminal node
145  t->SetNLeafDaughters(1);
146  t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
149  }
150 }
151 
152 ////////////////////////////////////////////////////////////////////////////////
153 /// remove the branch rooted at node t
154 
156 {
157  if( t->GetLeft() != NULL &&
158  t->GetRight() != NULL ) {
159  CCTreeNode* l = t->GetLeftDaughter();
160  CCTreeNode* r = t->GetRightDaughter();
161  t->SetNLeafDaughters( 1 );
165  delete l;
166  delete r;
167  t->SetLeft(NULL);
168  t->SetRight(NULL);
169  }else{
170  std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
171  }
172 }
173 
174 ////////////////////////////////////////////////////////////////////////////////
175 /// return the misclassification rate of a pruned tree for a validation event sample
176 /// using an EventList
177 
179 {
180  Double_t ncorrect=0, nfalse=0;
181  for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
182  Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
183 
184  if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
185  ncorrect += (*validationSample)[ievt]->GetWeight();
186  }
187  else{
188  nfalse += (*validationSample)[ievt]->GetWeight();
189  }
190  }
191  return ncorrect / (ncorrect + nfalse);
192 }
193 
194 ////////////////////////////////////////////////////////////////////////////////
195 /// return the misclassification rate of a pruned tree for a validation event sample
196 /// using the DataSet
197 
199 {
200  validationSample->SetCurrentType(Types::kValidation);
201  // test the tree quality.. in terms of Miscalssification
202  Double_t ncorrect=0, nfalse=0;
203  for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
204  const Event *ev = validationSample->GetEvent(ievt);
205 
206  Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
207 
208  if (isSignalType == (ev->GetClass() == 0)) {
209  ncorrect += ev->GetWeight();
210  }
211  else{
212  nfalse += ev->GetWeight();
213  }
214  }
215  return ncorrect / (ncorrect + nfalse);
216 }
217 
218 ////////////////////////////////////////////////////////////////////////////////
219 /// return the decision tree output for an event
220 
222 {
223  const DecisionTreeNode* current = fRoot->GetDTNode();
224  CCTreeNode* t = fRoot;
225 
226  while(//current->GetNodeType() == 0 &&
227  t->GetLeft() != NULL &&
228  t->GetRight() != NULL){ // at an interior (non-leaf) node
229  if (current->GoesRight(e)) {
230  //current = (DecisionTreeNode*)current->GetRight();
231  t = t->GetRightDaughter();
232  current = t->GetDTNode();
233  }
234  else {
235  //current = (DecisionTreeNode*)current->GetLeft();
236  t = t->GetLeftDaughter();
237  current = t->GetDTNode();
238  }
239  }
240 
241  if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
242  else return current->GetPurity();
243 }
244 
245 ////////////////////////////////////////////////////////////////////////////////
246 
248 {}
249 
250 ////////////////////////////////////////////////////////////////////////////////
251 
252 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
253 {}
254 
255 ////////////////////////////////////////////////////////////////////////////////
256 
257 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */ )
258 {}
259 
260 ////////////////////////////////////////////////////////////////////////////////
261 
262 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
263 {}
virtual ~CCTreeNode()
destructor of a CCTreeNode
CCTreeNode * fRoot
pointer to underlying DecisionTree
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
long long Long64_t
Definition: RtypesCore.h:69
const char * current
Definition: demos.C:12
virtual DecisionTreeNode * GetRight() const
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Definition: Node.h:97
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:83
virtual DecisionTreeNode * GetLeft() const
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:80
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
TTree * T
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
TClass * GetClass(T *)
Definition: TClass.h:554
virtual void SetLeft(Node *l)
Definition: Node.h:96
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:76
Float_t GetPurity(void) const
DecisionTreeNode * GetDTNode() const
virtual void ReadContent(std::stringstream &s)
Float_t GetNBkgEvents(void) const
DecisionTree * fDTParent
pointer to the used quality index calculator
TThread * t[5]
Definition: threadsh1.C:13
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:50
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
ROOT::R::TRInterface & r
Definition: Object.C:4
SeparationBase * fQualityIndex
TPaveLabel title(3, 27.1, 15, 28.7,"ROOT Environment and Tools")
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
virtual void AddAttributesToNode(void *node) const
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:73
unsigned int UInt_t
Definition: RtypesCore.h:42
TLine * l
Definition: textangle.C:4
virtual void AddContentToNode(std::stringstream &s) const
const Double_t infinity
Definition: CsgOps.cxx:85
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:96
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
virtual void SetParent(Node *p)
Definition: Node.h:98
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
const Event * GetEvent() const
Definition: DataSet.cxx:186
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:111
void Print(std::ostream &os, const OptionType &opt)
Double_t GetMinAlphaC() const
Definition: CCTreeWrapper.h:99
double Double_t
Definition: RtypesCore.h:55
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
virtual Node * GetRight() const
Definition: Node.h:92
UInt_t GetClass() const
Definition: Event.h:86
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:90
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
#define NULL
Definition: Rtypes.h:82
Float_t GetNSigEvents(void) const
~CCTreeWrapper()
destructor
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
virtual Bool_t GoesRight(const Event &) const
test event if it decends the tree at this node to the right
const Int_t n
Definition: legend1.C:16
virtual Node * GetLeft() const
Definition: Node.h:91