Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ExpectedErrorPruneTool.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : TMVA::DecisionTree *
5 * *
6 * *
7 * Description: *
8 * Implementation of a Decision Tree *
9 * *
10 * Authors (alphabetical): *
11 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://mva.sourceforge.net/license.txt) *
24 * *
25 **********************************************************************************/
26
27/*! \class TMVA::ExpectedErrorPruneTool
28\ingroup TMVA
29
30A helper class to prune a decision tree using the expected error (C4.5) method
31
32Uses an upper limit on the error made by the classification done by each node.
33If the \f$ \frac{S}{S+B} \f$ of the node is \f$ f \f$, then according to the
34training sample, the error rate (fraction of misclassified events by this
35node) is \f$ (1-f) \f$. Now \f$ f \f$ has a statistical error according to the
36binomial distribution hence the error on \f$ f \f$ can be estimated (same error
37as the binomial error for efficiency calculations
38\f$ (\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
39
40This tool prunes branches from a tree if the expected error of a node is less
41than that of the sum of the error in its descendants.
42
43*/
44
46#include "TMVA/DecisionTree.h"
47#include "TMVA/IPruneTool.h"
48#include "TMVA/MsgLogger.h"
49#include "TMVA/Types.h"
50
51#include "RtypesCore.h"
52#include "Rtypes.h"
53#include "TMath.h"
54
55#include <map>
56
57// pin the vtable here.
59
60////////////////////////////////////////////////////////////////////////////////
61
63 IPruneTool(),
64 fDeltaPruneStrength(0),
65 fNodePurityLimit(1),
66 fLogger( new MsgLogger("ExpectedErrorPruneTool") )
67{}
68
69////////////////////////////////////////////////////////////////////////////////
70
72{
73 delete fLogger;
74}
75
76////////////////////////////////////////////////////////////////////////////////
77
80 const IPruneTool::EventSample* validationSample,
81 Bool_t isAutomatic )
82{
83 if( isAutomatic ) {
84 //SetAutomatic( );
85 isAutomatic = kFALSE;
86 Log() << kWARNING << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
87 }
88 if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
89 // must have a valid decision tree to prune, and if the prune strength
90 // is to be chosen automatically, must have a test sample from
91 // which to calculate the quality of the pruned tree(s)
92 return NULL;
93 }
94 fNodePurityLimit = dt->GetNodePurityLimit();
95
96 if(IsAutomatic()) {
97 Log() << kFATAL << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
98 /*
99 dt->ApplyValidationSample(validationSample);
100 Double_t weights = dt->GetSumWeights(validationSample);
101 // set the initial prune strength
102 fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
103 // better to set it too small, it will be increased automatically
104 fDeltaPruneStrength = 1.0e-5;
105 Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
106
107 Bool_t forceStop = kFALSE;
108 Int_t errCount = 0,
109 lastNodeCount = nnodes;
110
111 // find the maximum prune strength that still leaves the root's daughter nodes
112
113 while ( nnodes > 1 && !forceStop ) {
114 fPruneStrength += fDeltaPruneStrength;
115 Log() << "----------------------------------------------------" << Endl;
116 FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
117 for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
118 fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
119 // test the quality of the pruned tree
120 Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
121 fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
122
123 nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
124
125 Log() << "Prune strength : " << fPruneStrength << Endl;
126 Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
127 Log() << "Quality index is: " << quality << Endl;
128
129 if (lastNodeCount == nnodes) errCount++;
130 else {
131 errCount=0; // reset counter
132 if ( nnodes < lastNodeCount / 2 ) {
133 Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
134 << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
135 fDeltaPruneStrength /= 2.;
136 }
137 }
138 lastNodeCount = nnodes;
139 if (errCount > 20) {
140 Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
141 << " because the number of nodes in the tree didn't change." << Endl;
142 fDeltaPruneStrength *= 2.0;
143 }
144 if (errCount > 40) {
145 Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
146 forceStop = kTRUE;
147 }
148 // reset the tree for the next iteration
149 for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
150 fPruneSequence[i]->SetTerminal(false);
151 fPruneSequence.clear();
152 }
153 // from the set of pruned trees, find the one with the optimal quality index
154 std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
155 fPruneStrength = it->second;
156 FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
157
158 // adjust the step size for the next tree automatically
159 fPruneStrength = 1.0e-3;
160 fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
161
162 return new PruningInfo(it->first, it->second, fPruneSequence);
163 */
164 return NULL;
165 }
166 else { // no automatic pruning - just use the provided prune strength parameter
167 FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() );
168 return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
169 }
170}
171
172////////////////////////////////////////////////////////////////////////////////
173/// recursive pruning of nodes using the Expected Error Pruning (EEP)
174
176{
179 if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
180 this->FindListOfNodes(l);
181 this->FindListOfNodes(r);
182 if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
183 //node->Print(Log());
184 fPruneSequence.push_back(node);
185 }
186 }
187}
188
189////////////////////////////////////////////////////////////////////////////////
190/// calculate the expected statistical error on the subtree below "node"
191/// which is used in the expected error pruning
192
194{
197 if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
198 Double_t subTreeError =
199 (l->GetNEvents() * this->GetSubTreeError(l) +
200 r->GetNEvents() * this->GetSubTreeError(r)) /
201 node->GetNEvents();
202 return subTreeError;
203 }
204 else {
205 return this->GetNodeError(node);
206 }
207}
208
209////////////////////////////////////////////////////////////////////////////////
210/// Calculate an UPPER limit on the error made by the classification done
211/// by this node. If the S/S+B of the node is f, then according to the
212/// training sample, the error rate (fraction of misclassified events by
213/// this node) is (1-f)
214/// Now f has a statistical error according to the binomial distribution
215/// hence the error on f can be estimated (same error as the binomial error
216/// for efficiency calculations
217/// \f$ (\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
218
220{
221 Double_t errorRate = 0;
222
223 Double_t nEvts = node->GetNEvents();
224
225 // fraction of correctly classified events by this node:
226 Double_t f = 0;
227 if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
228 else f = (1-node->GetPurity());
229
230 Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
231
232 errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
233
234 // -------------------------------------------------------------------
235 // standard algorithm:
236 // step 1: Estimate error on node using Laplace estimate
237 // NodeError = (N - n + k -1 ) / (N + k)
238 // N: number of events
239 // k: number of event classes (2 for Signal, Background)
240 // n: n event out of N belong to the class which has the majority in the node
241 // step 2: Approximate "backed-up" error assuming we did not prune
242 // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
243 // nodes)...
244 // Subtree error = Sum_children ( P_i * NodeError_i)
245 // P_i = probability of the node to make the decision, i.e. fraction of events in
246 // leaf node ( N_leaf / N_parent)
247 // step 3:
248
249 // Minimum Error Pruning (MEP) according to Niblett/Bratko
250 //# of correctly classified events by this node:
251 //Double_t n=f*nEvts ;
252 //Double_t p_apriori = 0.5, m=100;
253 //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m);
254
255 // Pessimistic error Pruning (proposed by Quinlan (error estimat with continuity approximation)
256 //# of correctly classified events by this node:
257 //Double_t n=f*nEvts ;
258 //errorRate = (nEvts - n + 0.5) / nEvts ;
259
260 //const Double Z=.65;
261 //# of correctly classified events by this node:
262 //Double_t n=f*nEvts ;
263 //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
264 //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z);
265 //errorRate = 1 - errorRate;
266 // -------------------------------------------------------------------
267
268 return errorRate;
269}
270
#define f(i)
Definition RSha256.hxx:104
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
virtual DecisionTreeNode * GetLeft() const
Int_t GetNodeType(void) const
return node type: 1 signal node, -1 bkg leave, 0 intermediate Node
Float_t GetNEvents(void) const
return the number of events that entered the node (during training), or -1 if traininfo undefined
Float_t GetPurity(void) const
return S/(S+B) (purity) at this node (from training)
Bool_t IsTerminal() const
flag indicates whether this node is terminal
virtual DecisionTreeNode * GetRight() const
Implementation of a Decision Tree.
Double_t GetNodePurityLimit() const
virtual DecisionTreeNode * GetRoot() const
void FindListOfNodes(DecisionTreeNode *node)
recursive pruning of nodes using the Expected Error Pruning (EEP)
Double_t GetNodeError(DecisionTreeNode *node) const
Calculate an UPPER limit on the error made by the classification done by this node.
Double_t GetSubTreeError(DecisionTreeNode *node) const
calculate the expected statistical error on the subtree below "node" which is used in the expected er...
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=nullptr, Bool_t isAutomatic=kFALSE)
IPruneTool - a helper interface class to prune a decision tree.
Definition IPruneTool.h:70
std::vector< const Event * > EventSample
Definition IPruneTool.h:74
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:662
TLine l
Definition textangle.C:4