Logo ROOT  
Reference Guide
CCTreeWrapper.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCTreeWrapper *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: a light wrapper of a decision tree, used to perform cost *
8  * complexity pruning "in-place" Cost Complexity Pruning *
9  * *
10  * Author: Doug Schouten (dschoute@sfu.ca) *
11  * *
12  * *
13  * Copyright (c) 2007: *
14  * CERN, Switzerland *
15  * MPI-K Heidelberg, Germany *
16  * U. of Texas at Austin, USA *
17  * *
18  * Redistribution and use in source and binary forms, with or without *
19  * modification, are permitted according to the terms listed in LICENSE *
20  * (http://tmva.sourceforge.net/LICENSE) *
21  **********************************************************************************/
22 
23 /*! \class TMVA::CCTreeWrapper
24 \ingroup TMVA
25 
26 */
27 
28 #include "TMVA/CCTreeWrapper.h"
29 #include "TMVA/DecisionTree.h"
30 
31 #include <iostream>
32 #include <limits>
33 
34 using namespace TMVA;
35 
36 ////////////////////////////////////////////////////////////////////////////////
37 ///constructor of the CCTreeNode
38 
40  Node(),
41  fNLeafDaughters(0),
42  fNodeResubstitutionEstimate(-1.0),
43  fResubstitutionEstimate(-1.0),
44  fAlphaC(-1.0),
45  fMinAlphaC(-1.0),
46  fDTNode(n)
47 {
48  if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
49  SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
50  GetRight()->SetParent(this);
51  SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
52  GetLeft()->SetParent(this);
53  }
54 }
55 
56 ////////////////////////////////////////////////////////////////////////////////
57 /// destructor of a CCTreeNode
58 
60  if(GetLeft() != NULL) delete GetLeftDaughter();
61  if(GetRight() != NULL) delete GetRightDaughter();
62 }
63 
64 ////////////////////////////////////////////////////////////////////////////////
65 /// initialize a node from a data record
66 
67 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
68  std::string header, title;
69  in >> header;
70  in >> title; in >> fNLeafDaughters;
71  in >> title; in >> fNodeResubstitutionEstimate;
72  in >> title; in >> fResubstitutionEstimate;
73  in >> title; in >> fAlphaC;
74  in >> title; in >> fMinAlphaC;
75  return true;
76 }
77 
78 ////////////////////////////////////////////////////////////////////////////////
79 /// printout of the node (can be read in with ReadDataRecord)
80 
81 void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
82  os << "----------------------" << std::endl
83  << "|~T_t| " << fNLeafDaughters << std::endl
84  << "R(t): " << fNodeResubstitutionEstimate << std::endl
85  << "R(T_t): " << fResubstitutionEstimate << std::endl
86  << "g(t): " << fAlphaC << std::endl
87  << "G(t): " << fMinAlphaC << std::endl;
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 /// recursive printout of the node and its daughters
92 
93 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
94  this->Print(os);
95  if(this->GetLeft() != NULL && this->GetRight() != NULL) {
96  this->GetLeft()->PrintRec(os);
97  this->GetRight()->PrintRec(os);
98  }
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// constructor
103 
105  fRoot(NULL)
106 {
107  fDTParent = T;
108  fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
109  fQualityIndex = qualityIndex;
110  InitTree(fRoot);
111 }
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 /// destructor
115 
117  delete fRoot;
118 }
119 
120 ////////////////////////////////////////////////////////////////////////////////
121 /// initialize the node t and all its descendants
122 
124 {
125  Double_t s = t->GetDTNode()->GetNSigEvents();
126  Double_t b = t->GetDTNode()->GetNBkgEvents();
127  // Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
128  // Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
129  // set R(t) = Gini(t) or MisclassificationError(t), etc.
130  t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
131 
132  if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
133  // traverse the tree
134  InitTree(t->GetLeftDaughter());
135  InitTree(t->GetRightDaughter());
136  // set |~T_t|
139  // set R(T) = sum[t' in ~T]{ R(t) }
142  // set g(t)
144  (t->GetNLeafDaughters() - 1));
145  // G(t) = min( g(t), G(l(t)), G(r(t)) )
146  t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
147  t->GetRightDaughter()->GetMinAlphaC())));
148  }
149  else { // n is a terminal node
150  t->SetNLeafDaughters(1);
151  t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
152  t->SetAlphaC(std::numeric_limits<double>::infinity( ));
153  t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
154  }
155 }
156 
157 ////////////////////////////////////////////////////////////////////////////////
158 /// remove the branch rooted at node t
159 
161 {
162  if( t->GetLeft() != NULL &&
163  t->GetRight() != NULL ) {
164  CCTreeNode* l = t->GetLeftDaughter();
165  CCTreeNode* r = t->GetRightDaughter();
166  t->SetNLeafDaughters( 1 );
168  t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
169  t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
170  delete l;
171  delete r;
172  t->SetLeft(NULL);
173  t->SetRight(NULL);
174  }else{
175  std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
176  }
177 }
178 
179 ////////////////////////////////////////////////////////////////////////////////
180 /// return the misclassification rate of a pruned tree for a validation event sample
181 /// using an EventList
182 
184 {
185  Double_t ncorrect=0, nfalse=0;
186  for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
187  Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
188 
189  if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
190  ncorrect += (*validationSample)[ievt]->GetWeight();
191  }
192  else{
193  nfalse += (*validationSample)[ievt]->GetWeight();
194  }
195  }
196  return ncorrect / (ncorrect + nfalse);
197 }
198 
199 ////////////////////////////////////////////////////////////////////////////////
200 /// return the misclassification rate of a pruned tree for a validation event sample
201 /// using the DataSet
202 
204 {
205  validationSample->SetCurrentType(Types::kValidation);
206  // test the tree quality.. in terms of Misclassification
207  Double_t ncorrect=0, nfalse=0;
208  for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
209  const Event *ev = validationSample->GetEvent(ievt);
210 
211  Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
212 
213  if (isSignalType == (ev->GetClass() == 0)) {
214  ncorrect += ev->GetWeight();
215  }
216  else{
217  nfalse += ev->GetWeight();
218  }
219  }
220  return ncorrect / (ncorrect + nfalse);
221 }
222 
223 ////////////////////////////////////////////////////////////////////////////////
224 /// return the decision tree output for an event
225 
227 {
228  const DecisionTreeNode* current = fRoot->GetDTNode();
229  CCTreeNode* t = fRoot;
230 
231  while(//current->GetNodeType() == 0 &&
232  t->GetLeft() != NULL &&
233  t->GetRight() != NULL){ // at an interior (non-leaf) node
234  if (current->GoesRight(e)) {
235  //current = (DecisionTreeNode*)current->GetRight();
236  t = t->GetRightDaughter();
237  current = t->GetDTNode();
238  }
239  else {
240  //current = (DecisionTreeNode*)current->GetLeft();
241  t = t->GetLeftDaughter();
242  current = t->GetDTNode();
243  }
244  }
245 
246  if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
247  else return current->GetPurity();
248 }
249 
250 ////////////////////////////////////////////////////////////////////////////////
251 
253 {}
254 
255 ////////////////////////////////////////////////////////////////////////////////
256 
257 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
258 {}
259 
260 ////////////////////////////////////////////////////////////////////////////////
261 
262 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */ )
263 {}
264 
265 ////////////////////////////////////////////////////////////////////////////////
266 
267 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
268 {}
TMVA::CCTreeWrapper::CCTreeNode::GetMinAlphaC
Double_t GetMinAlphaC() const
Definition: CCTreeWrapper.h:91
l
auto * l
Definition: textangle.C:4
n
const Int_t n
Definition: legend1.C:16
ROOT::Math::IntegOptionsUtil::Print
void Print(std::ostream &os, const OptionType &opt)
Definition: IntegratorOptions.cxx:91
TMVA::Node::SetParent
virtual void SetParent(Node *p)
Definition: Node.h:96
e
#define e(i)
Definition: RSha256.hxx:103
TMVA::CCTreeWrapper::CCTreeNode::GetNodeResubstitutionEstimate
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:68
TMVA::CCTreeWrapper::CCTreeNode::AddAttributesToNode
virtual void AddAttributesToNode(void *node) const
Definition: CCTreeWrapper.cxx:252
TMVA::Node::GetRight
virtual Node * GetRight() const
Definition: Node.h:90
TMVA::Event::GetClass
UInt_t GetClass() const
Definition: Event.h:86
TMVA::DataSet::SetCurrentType
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:89
r
ROOT::R::TRInterface & r
Definition: Object.C:4
Long64_t
long long Long64_t
Definition: RtypesCore.h:73
TMVA::CCTreeWrapper::~CCTreeWrapper
~CCTreeWrapper()
destructor
Definition: CCTreeWrapper.cxx:116
TGeant4Unit::s
static constexpr double s
Definition: TGeant4SystemOfUnits.h:162
CCTreeWrapper.h
TMVA::DecisionTreeNode::GoesRight
virtual Bool_t GoesRight(const Event &) const
test event if it descends the tree at this node to the right
Definition: DecisionTreeNode.cxx:155
TMVA::CCTreeWrapper::CCTreeNode::GetLeftDaughter
CCTreeNode * GetLeftDaughter()
Definition: CCTreeWrapper.h:97
TMVA::CCTreeWrapper::InitTree
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
Definition: CCTreeWrapper.cxx:123
ROOT::GetClass
TClass * GetClass(T *)
Definition: TClass.h:594
TMVA::DecisionTreeNode
Definition: DecisionTreeNode.h:117
TMVA::CCTreeWrapper::EventList
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:42
TMVA::CCTreeWrapper::CCTreeNode::GetResubstitutionEstimate
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:75
TMVA::CCTreeWrapper::CCTreeNode
Definition: CCTreeWrapper.h:49
b
#define b(i)
Definition: RSha256.hxx:100
TMVA::Node
Node for the BinarySearch or Decision Trees.
Definition: Node.h:58
TMVA::DecisionTree
Implementation of a Decision Tree.
Definition: DecisionTree.h:65
bool
TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode
virtual void AddContentToNode(std::stringstream &s) const
Definition: CCTreeWrapper.cxx:257
TMVA::CCTreeWrapper::CCTreeNode::SetResubstitutionEstimate
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:72
TMVA::DecisionTreeNode::GetNSigEvents
Float_t GetNSigEvents(void) const
Definition: DecisionTreeNode.h:230
TMVA::CCTreeWrapper::CCTreeNode::SetNodeResubstitutionEstimate
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:65
DecisionTree.h
TMVA::CCTreeWrapper::fRoot
CCTreeNode * fRoot
pointer to underlying DecisionTree
Definition: CCTreeWrapper.h:152
TMVA::CCTreeWrapper::CCTreeNode::SetMinAlphaC
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:88
TMVA::DecisionTreeNode::GetNBkgEvents
Float_t GetNBkgEvents(void) const
Definition: DecisionTreeNode.h:233
TMVA::DataSet::GetEvent
const Event * GetEvent() const
Definition: DataSet.cxx:202
TMVA::DataSet::GetNEvents
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
TMVA::Node::SetLeft
virtual void SetLeft(Node *l)
Definition: Node.h:94
TMVA::CCTreeWrapper::CCTreeNode::SetNLeafDaughters
void SetNLeafDaughters(Int_t N)
Definition: CCTreeWrapper.h:59
TMVA::CCTreeWrapper::CCTreeNode::ReadContent
virtual void ReadContent(std::stringstream &s)
Definition: CCTreeWrapper.cxx:267
TMVA::DataSet
Class that contains all the data information.
Definition: DataSet.h:58
TMVA::DecisionTreeNode::GetPurity
Float_t GetPurity(void) const
Definition: DecisionTreeNode.h:168
TMVA::CCTreeWrapper::CCTreeWrapper
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Definition: CCTreeWrapper.cxx:104
TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
Definition: CCTreeWrapper.cxx:67
TMVA::CCTreeWrapper::CCTreeNode::GetNLeafDaughters
Int_t GetNLeafDaughters() const
Definition: CCTreeWrapper.h:62
TMVA::Node::GetLeft
virtual Node * GetLeft() const
Definition: Node.h:89
TMVA::CCTreeWrapper::CCTreeNode::SetAlphaC
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:82
TMVA::CCTreeWrapper::CCTreeNode::GetRightDaughter
CCTreeNode * GetRightDaughter()
Definition: CCTreeWrapper.h:98
TMVA::CCTreeWrapper::CCTreeNode::GetDTNode
DecisionTreeNode * GetDTNode() const
Definition: CCTreeWrapper.h:94
unsigned int
TMVA::SeparationBase
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
Definition: SeparationBase.h:82
TMVA::CCTreeWrapper::fDTParent
DecisionTree * fDTParent
pointer to the used quality index calculator
Definition: CCTreeWrapper.h:151
TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
Definition: CCTreeWrapper.cxx:262
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::Types::kValidation
@ kValidation
Definition: Types.h:148
TMVA::CCTreeWrapper::CCTreeNode::Print
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
Definition: CCTreeWrapper.cxx:81
TMVA::Event::GetWeight
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
TMVA::CCTreeWrapper::CCTreeNode::GetAlphaC
Double_t GetAlphaC() const
Definition: CCTreeWrapper.h:85
TMVA::CCTreeWrapper::CCTreeNode::CCTreeNode
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
Definition: CCTreeWrapper.cxx:39
TMVA::Event
Definition: Event.h:51
ROOT::Math::Chebyshev::T
double T(double x)
Definition: ChebyshevPol.h:34
TMVA::CCTreeWrapper::fQualityIndex
SeparationBase * fQualityIndex
Definition: CCTreeWrapper.h:150
TMVA::CCTreeWrapper::TestTreeQuality
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList
Definition: CCTreeWrapper.cxx:183
TMVA::CCTreeWrapper::CCTreeNode::~CCTreeNode
virtual ~CCTreeNode()
destructor of a CCTreeNode
Definition: CCTreeWrapper.cxx:59
TMVA::Node::SetRight
virtual void SetRight(Node *r)
Definition: Node.h:95
TMVA::CCTreeWrapper::CheckEvent
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
Definition: CCTreeWrapper.cxx:226
TMVA::CCTreeWrapper::PruneNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
Definition: CCTreeWrapper.cxx:160
TMVA::CCTreeWrapper::CCTreeNode::PrintRec
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
Definition: CCTreeWrapper.cxx:93
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22