Logo ROOT  
Reference Guide
ExpectedErrorPruneTool.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
27 /*! \class TMVA::ExpectedErrorPruneTool
28 \ingroup TMVA
29 
30 A helper class to prune a decision tree using the expected error (C4.5) method
31 
32 Uses an upper limit on the error made by the classification done by each node.
33 If the \f$ \frac{S}{S+B} \f$ of the node is \f$ f \f$, then according to the
34 training sample, the error rate (fraction of misclassified events by this
35 node) is \f$ (1-f) \f$. Now \f$ f \f$ has a statistical error according to the
36 binomial distribution hence the error on \f$ f \f$ can be estimated (same error
37 as the binomial error for efficiency calculations
38 \f$ (\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
39 
40 This tool prunes branches from a tree if the expected error of a node is less
41 than that of the sum of the error in its descendants.
42 
43 */
44 
46 #include "TMVA/DecisionTree.h"
47 #include "TMVA/IPruneTool.h"
48 #include "TMVA/MsgLogger.h"
49 #include "TMVA/Types.h"
50 
51 #include "RtypesCore.h"
52 #include "Rtypes.h"
53 #include "TMath.h"
54 
55 #include <map>
56 
57 // pin the vtable here.
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 
63  IPruneTool(),
64  fDeltaPruneStrength(0),
65  fNodePurityLimit(1),
66  fLogger( new MsgLogger("ExpectedErrorPruneTool") )
67 {}
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 
72 {
73  delete fLogger;
74 }
75 
76 ////////////////////////////////////////////////////////////////////////////////
77 
80  const IPruneTool::EventSample* validationSample,
81  Bool_t isAutomatic )
82 {
83  if( isAutomatic ) {
84  //SetAutomatic( );
85  isAutomatic = kFALSE;
86  Log() << kWARNING << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
87  }
88  if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
89  // must have a valid decision tree to prune, and if the prune strength
90  // is to be chosen automatically, must have a test sample from
91  // which to calculate the quality of the pruned tree(s)
92  return NULL;
93  }
94  fNodePurityLimit = dt->GetNodePurityLimit();
95 
96  if(IsAutomatic()) {
97  Log() << kFATAL << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
98  /*
99  dt->ApplyValidationSample(validationSample);
100  Double_t weights = dt->GetSumWeights(validationSample);
101  // set the initial prune strength
102  fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
103  // better to set it too small, it will be increased automatically
104  fDeltaPruneStrength = 1.0e-5;
105  Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
106 
107  Bool_t forceStop = kFALSE;
108  Int_t errCount = 0,
109  lastNodeCount = nnodes;
110 
111  // find the maximum prune strength that still leaves the root's daughter nodes
112 
113  while ( nnodes > 1 && !forceStop ) {
114  fPruneStrength += fDeltaPruneStrength;
115  Log() << "----------------------------------------------------" << Endl;
116  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
117  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
118  fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
119  // test the quality of the pruned tree
120  Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
121  fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
122 
123  nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
124 
125  Log() << "Prune strength : " << fPruneStrength << Endl;
126  Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
127  Log() << "Quality index is: " << quality << Endl;
128 
129  if (lastNodeCount == nnodes) errCount++;
130  else {
131  errCount=0; // reset counter
132  if ( nnodes < lastNodeCount / 2 ) {
133  Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
134  << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
135  fDeltaPruneStrength /= 2.;
136  }
137  }
138  lastNodeCount = nnodes;
139  if (errCount > 20) {
140  Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
141  << " because the number of nodes in the tree didn't change." << Endl;
142  fDeltaPruneStrength *= 2.0;
143  }
144  if (errCount > 40) {
145  Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
146  forceStop = kTRUE;
147  }
148  // reset the tree for the next iteration
149  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
150  fPruneSequence[i]->SetTerminal(false);
151  fPruneSequence.clear();
152  }
153  // from the set of pruned trees, find the one with the optimal quality index
154  std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
155  fPruneStrength = it->second;
156  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
157 
158  // adjust the step size for the next tree automatically
159  fPruneStrength = 1.0e-3;
160  fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
161 
162  return new PruningInfo(it->first, it->second, fPruneSequence);
163  */
164  return NULL;
165  }
166  else { // no automatic pruning - just use the provided prune strength parameter
167  FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() );
168  return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
169  }
170 }
171 
172 ////////////////////////////////////////////////////////////////////////////////
173 /// recursive pruning of nodes using the Expected Error Pruning (EEP)
174 
176 {
179  if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
180  this->FindListOfNodes(l);
181  this->FindListOfNodes(r);
182  if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
183  //node->Print(Log());
184  fPruneSequence.push_back(node);
185  }
186  }
187 }
188 
189 ////////////////////////////////////////////////////////////////////////////////
190 /// calculate the expected statistical error on the subtree below "node"
191 /// which is used in the expected error pruning
192 
194 {
197  if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
198  Double_t subTreeError =
199  (l->GetNEvents() * this->GetSubTreeError(l) +
200  r->GetNEvents() * this->GetSubTreeError(r)) /
201  node->GetNEvents();
202  return subTreeError;
203  }
204  else {
205  return this->GetNodeError(node);
206  }
207 }
208 
209 ////////////////////////////////////////////////////////////////////////////////
210 /// Calculate an UPPER limit on the error made by the classification done
211 /// by this node. If the S/S+B of the node is f, then according to the
212 /// training sample, the error rate (fraction of misclassified events by
213 /// this node) is (1-f)
214 /// Now f has a statistical error according to the binomial distribution
215 /// hence the error on f can be estimated (same error as the binomial error
216 /// for efficiency calculations
217 /// \f$ (\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
218 
220 {
221  Double_t errorRate = 0;
222 
223  Double_t nEvts = node->GetNEvents();
224 
225  // fraction of correctly classified events by this node:
226  Double_t f = 0;
227  if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
228  else f = (1-node->GetPurity());
229 
230  Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
231 
232  errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
233 
234  // -------------------------------------------------------------------
235  // standard algorithm:
236  // step 1: Estimate error on node using Laplace estimate
237  // NodeError = (N - n + k -1 ) / (N + k)
238  // N: number of events
239  // k: number of event classes (2 for Signal, Background)
240  // n: n event out of N belong to the class which has the majority in the node
241  // step 2: Approximate "backed-up" error assuming we did not prune
242  // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
243  // nodes)...
244  // Subtree error = Sum_children ( P_i * NodeError_i)
245  // P_i = probability of the node to make the decision, i.e. fraction of events in
246  // leaf node ( N_leaf / N_parent)
247  // step 3:
248 
249  // Minimum Error Pruning (MEP) according to Niblett/Bratko
250  //# of correctly classified events by this node:
251  //Double_t n=f*nEvts ;
252  //Double_t p_apriori = 0.5, m=100;
253  //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m);
254 
255  // Pessimistic error Pruning (proposed by Quinlan (error estimat with continuity approximation)
256  //# of correctly classified events by this node:
257  //Double_t n=f*nEvts ;
258  //errorRate = (nEvts - n + 0.5) / nEvts ;
259 
260  //const Double Z=.65;
261  //# of correctly classified events by this node:
262  //Double_t n=f*nEvts ;
263  //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
264  //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z);
265  //errorRate = 1 - errorRate;
266  // -------------------------------------------------------------------
267 
268  return errorRate;
269 }
270 
l
auto * l
Definition: textangle.C:4
TMVA::ExpectedErrorPruneTool::CalculatePruningInfo
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
Definition: ExpectedErrorPruneTool.cxx:79
TMVA::IPruneTool
IPruneTool - a helper interface class to prune a decision tree.
Definition: IPruneTool.h:70
TMVA::ExpectedErrorPruneTool::FindListOfNodes
void FindListOfNodes(DecisionTreeNode *node)
recursive pruning of nodes using the Expected Error Pruning (EEP)
Definition: ExpectedErrorPruneTool.cxx:175
TMVA::PruningInfo
Definition: IPruneTool.h:39
TMVA::IPruneTool::EventSample
std::vector< const Event * > EventSample
Definition: IPruneTool.h:74
f
#define f(i)
Definition: RSha256.hxx:104
TMVA::ExpectedErrorPruneTool::GetSubTreeError
Double_t GetSubTreeError(DecisionTreeNode *node) const
calculate the expected statistical error on the subtree below "node" which is used in the expected er...
Definition: ExpectedErrorPruneTool.cxx:193
TMVA::DecisionTreeNode::GetNodeType
Int_t GetNodeType(void) const
Definition: DecisionTreeNode.h:165
r
ROOT::R::TRInterface & r
Definition: Object.C:4
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMath::Sqrt
Double_t Sqrt(Double_t x)
Definition: TMath.h:691
TMVA::ExpectedErrorPruneTool::ExpectedErrorPruneTool
ExpectedErrorPruneTool()
Definition: ExpectedErrorPruneTool.cxx:62
TMVA::IPruneTool::~IPruneTool
virtual ~IPruneTool()
Definition: ExpectedErrorPruneTool.cxx:58
TMVA::DecisionTreeNode
Definition: DecisionTreeNode.h:117
TMVA::ExpectedErrorPruneTool::GetNodeError
Double_t GetNodeError(DecisionTreeNode *node) const
Calculate an UPPER limit on the error made by the classification done by this node.
Definition: ExpectedErrorPruneTool.cxx:219
TMVA::DecisionTreeNode::IsTerminal
Bool_t IsTerminal() const
Definition: DecisionTreeNode.h:337
TMVA::DecisionTree
Implementation of a Decision Tree.
Definition: DecisionTree.h:65
bool
TMVA::DecisionTree::GetNodePurityLimit
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:162
TMVA::DecisionTree::GetRoot
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:94
MsgLogger.h
DecisionTree.h
TMVA::DecisionTreeNode::GetLeft
virtual DecisionTreeNode * GetLeft() const
Definition: DecisionTreeNode.h:282
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TMVA::DecisionTreeNode::GetPurity
Float_t GetPurity(void) const
Definition: DecisionTreeNode.h:168
Types.h
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
ExpectedErrorPruneTool.h
Double_t
double Double_t
Definition: RtypesCore.h:59
RtypesCore.h
TMVA::MsgLogger
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
TMVA::ExpectedErrorPruneTool::~ExpectedErrorPruneTool
virtual ~ExpectedErrorPruneTool()
Definition: ExpectedErrorPruneTool.cxx:71
IPruneTool.h
Rtypes.h
TMVA::DecisionTreeNode::GetRight
virtual DecisionTreeNode * GetRight() const
Definition: DecisionTreeNode.h:283
TMath.h
TMVA::DecisionTreeNode::GetNEvents
Float_t GetNEvents(void) const
Definition: DecisionTreeNode.h:236