ROOT   Reference Guide
ExpectedErrorPruneTool.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : TMVA::DecisionTree *
5 * Web : http://tmva.sourceforge.net *
6 * *
7 * Description: *
8 * Implementation of a Decision Tree *
9 * *
10 * Authors (alphabetical): *
11 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
24 * *
25 **********************************************************************************/
26
27/*! \class TMVA::ExpectedErrorPruneTool
28\ingroup TMVA
29
30A helper class to prune a decision tree using the expected error (C4.5) method
31
32Uses an upper limit on the error made by the classification done by each node.
33If the \f$\frac{S}{S+B} \f$ of the node is \f$f \f$, then according to the
34training sample, the error rate (fraction of misclassified events by this
35node) is \f$(1-f) \f$. Now \f$f \f$ has a statistical error according to the
36binomial distribution hence the error on \f$f \f$ can be estimated (same error
37as the binomial error for efficiency calculations
38\f$(\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
39
40This tool prunes branches from a tree if the expected error of a node is less
41than that of the sum of the error in its descendants.
42
43*/
44
46#include "TMVA/DecisionTree.h"
47#include "TMVA/IPruneTool.h"
48#include "TMVA/MsgLogger.h"
49#include "TMVA/Types.h"
50
51#include "RtypesCore.h"
52#include "Rtypes.h"
53#include "TMath.h"
54
55#include <map>
56
57////////////////////////////////////////////////////////////////////////////////
58
60 IPruneTool(),
61 fDeltaPruneStrength(0),
62 fNodePurityLimit(1),
63 fLogger( new MsgLogger("ExpectedErrorPruneTool") )
64{}
65
66////////////////////////////////////////////////////////////////////////////////
67
69{
70 delete fLogger;
71}
72
73////////////////////////////////////////////////////////////////////////////////
74
77 const IPruneTool::EventSample* validationSample,
78 Bool_t isAutomatic )
79{
80 if( isAutomatic ) {
81 //SetAutomatic( );
82 isAutomatic = kFALSE;
83 Log() << kWARNING << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
84 }
85 if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
86 // must have a valid decision tree to prune, and if the prune strength
87 // is to be chosen automatically, must have a test sample from
88 // which to calculate the quality of the pruned tree(s)
89 return NULL;
90 }
91 fNodePurityLimit = dt->GetNodePurityLimit();
92
93 if(IsAutomatic()) {
94 Log() << kFATAL << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
95 /*
96 dt->ApplyValidationSample(validationSample);
97 Double_t weights = dt->GetSumWeights(validationSample);
98 // set the initial prune strength
99 fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
100 // better to set it too small, it will be increased automatically
101 fDeltaPruneStrength = 1.0e-5;
102 Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
103
104 Bool_t forceStop = kFALSE;
105 Int_t errCount = 0,
106 lastNodeCount = nnodes;
107
108 // find the maximum prune strength that still leaves the root's daughter nodes
109
110 while ( nnodes > 1 && !forceStop ) {
111 fPruneStrength += fDeltaPruneStrength;
112 Log() << "----------------------------------------------------" << Endl;
113 FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
114 for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
115 fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
116 // test the quality of the pruned tree
117 Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
118 fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
119
120 nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
121
122 Log() << "Prune strength : " << fPruneStrength << Endl;
123 Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
124 Log() << "Quality index is: " << quality << Endl;
125
126 if (lastNodeCount == nnodes) errCount++;
127 else {
128 errCount=0; // reset counter
129 if ( nnodes < lastNodeCount / 2 ) {
130 Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
131 << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
132 fDeltaPruneStrength /= 2.;
133 }
134 }
135 lastNodeCount = nnodes;
136 if (errCount > 20) {
137 Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
138 << " because the number of nodes in the tree didn't change." << Endl;
139 fDeltaPruneStrength *= 2.0;
140 }
141 if (errCount > 40) {
142 Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
143 forceStop = kTRUE;
144 }
145 // reset the tree for the next iteration
146 for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
147 fPruneSequence[i]->SetTerminal(false);
148 fPruneSequence.clear();
149 }
150 // from the set of pruned trees, find the one with the optimal quality index
151 std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
152 fPruneStrength = it->second;
153 FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
154
155 // adjust the step size for the next tree automatically
156 fPruneStrength = 1.0e-3;
157 fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
158
159 return new PruningInfo(it->first, it->second, fPruneSequence);
160 */
161 return NULL;
162 }
163 else { // no automatic pruning - just use the provided prune strength parameter
164 FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() );
165 return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
166 }
167}
168
169////////////////////////////////////////////////////////////////////////////////
170/// recursive pruning of nodes using the Expected Error Pruning (EEP)
171
173{
176 if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
177 this->FindListOfNodes(l);
178 this->FindListOfNodes(r);
179 if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
180 //node->Print(Log());
181 fPruneSequence.push_back(node);
182 }
183 }
184}
185
186////////////////////////////////////////////////////////////////////////////////
187/// calculate the expected statistical error on the subtree below "node"
188/// which is used in the expected error pruning
189
191{
194 if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
195 Double_t subTreeError =
196 (l->GetNEvents() * this->GetSubTreeError(l) +
197 r->GetNEvents() * this->GetSubTreeError(r)) /
198 node->GetNEvents();
199 return subTreeError;
200 }
201 else {
202 return this->GetNodeError(node);
203 }
204}
205
206////////////////////////////////////////////////////////////////////////////////
207/// Calculate an UPPER limit on the error made by the classification done
208/// by this node. If the S/S+B of the node is f, then according to the
209/// training sample, the error rate (fraction of misclassified events by
210/// this node) is (1-f)
211/// Now f has a statistical error according to the binomial distribution
212/// hence the error on f can be estimated (same error as the binomial error
213/// for efficiency calculations
214/// \f$(\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
215
217{
218 Double_t errorRate = 0;
219
220 Double_t nEvts = node->GetNEvents();
221
222 // fraction of correctly classified events by this node:
223 Double_t f = 0;
224 if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
225 else f = (1-node->GetPurity());
226
227 Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
228
229 errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
230
231 // -------------------------------------------------------------------
232 // standard algorithm:
233 // step 1: Estimate error on node using Laplace estimate
234 // NodeError = (N - n + k -1 ) / (N + k)
235 // N: number of events
236 // k: number of event classes (2 for Signal, Background)
237 // n: n event out of N belong to the class which has the majority in the node
238 // step 2: Approximate "backed-up" error assuming we did not prune
239 // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
240 // nodes)...
241 // Subtree error = Sum_children ( P_i * NodeError_i)
242 // P_i = probability of the node to make the decision, i.e. fraction of events in
243 // leaf node ( N_leaf / N_parent)
244 // step 3:
245
246 // Minimum Error Pruning (MEP) according to Niblett/Bratko
247 //# of correctly classified events by this node:
248 //Double_t n=f*nEvts ;
249 //Double_t p_apriori = 0.5, m=100;
250 //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m);
251
252 // Pessimistic error Pruning (proposed by Quinlan (error estimat with continuity approximation)
253 //# of correctly classified events by this node:
254 //Double_t n=f*nEvts ;
255 //errorRate = (nEvts - n + 0.5) / nEvts ;
256
257 //const Double Z=.65;
258 //# of correctly classified events by this node:
259 //Double_t n=f*nEvts ;
260 //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
261 //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z);
262 //errorRate = 1 - errorRate;
263 // -------------------------------------------------------------------
264
265 return errorRate;
266}
267
ROOT::R::TRInterface & r
Definition: Object.C:4
#define f(i)
Definition: RSha256.hxx:104
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
virtual DecisionTreeNode * GetLeft() const
Int_t GetNodeType(void) const
Float_t GetNEvents(void) const
Float_t GetPurity(void) const
virtual DecisionTreeNode * GetRight() const
Implementation of a Decision Tree.
Definition: DecisionTree.h:64
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:161
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:93
void FindListOfNodes(DecisionTreeNode *node)
recursive pruning of nodes using the Expected Error Pruning (EEP)
Double_t GetNodeError(DecisionTreeNode *node) const
Calculate an UPPER limit on the error made by the classification done by this node.
Double_t GetSubTreeError(DecisionTreeNode *node) const
calculate the expected statistical error on the subtree below "node" which is used in the expected er...
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
IPruneTool - a helper interface class to prune a decision tree.
Definition: IPruneTool.h:70
std::vector< const Event * > EventSample
Definition: IPruneTool.h:74
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681
auto * l
Definition: textangle.C:4