ROOT   6.10/09 Reference Guide
ExpectedErrorPruneTool.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
24  * *
25  **********************************************************************************/
26
27 /*! \class TMVA::ExpectedErrorPruneTool
28 \ingroup TMVA
29
30 A helper class to prune a decision tree using the expected error (C4.5) method
31
32 Uses an upper limit on the error made by the classification done by each node.
33 If the \f$\frac{S}{S+B} \f$ of the node is \f$f \f$, then according to the
34 training sample, the error rate (fraction of misclassified events by this
35 node) is \f$(1-f) \f$. Now \f$f \f$ has a statistical error according to the
36 binomial distribution hence the error on \f$f \f$ can be estimated (same error
37 as the binomial error for efficiency calculations
38 \f$(\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
39
40 This tool prunes branches from a tree if the expected error of a node is less
41 than that of the sum of the error in its descendants.
42
43 */
44
46 #include "TMVA/DecisionTree.h"
47 #include "TMVA/IPruneTool.h"
48 #include "TMVA/MsgLogger.h"
49 #include "TMVA/Types.h"
50
51 #include "RtypesCore.h"
52 #include "Rtypes.h"
53 #include "TMath.h"
54
55 #include <map>
56
57 ////////////////////////////////////////////////////////////////////////////////
58
60  IPruneTool(),
61  fDeltaPruneStrength(0),
62  fNodePurityLimit(1),
63  fLogger( new MsgLogger("ExpectedErrorPruneTool") )
64 {}
65
66 ////////////////////////////////////////////////////////////////////////////////
67
69 {
70  delete fLogger;
71 }
72
73 ////////////////////////////////////////////////////////////////////////////////
74
77  const IPruneTool::EventSample* validationSample,
78  Bool_t isAutomatic )
79 {
80  if( isAutomatic ) {
81  //SetAutomatic( );
82  isAutomatic = kFALSE;
83  Log() << kWARNING << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
84  }
85  if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
86  // must have a valid decision tree to prune, and if the prune strength
87  // is to be chosen automatically, must have a test sample from
88  // which to calculate the quality of the pruned tree(s)
89  return NULL;
90  }
92
93  if(IsAutomatic()) {
94  Log() << kFATAL << "Sorry automatic pruning strength determination is not implemented yet" << Endl;
95  /*
96  dt->ApplyValidationSample(validationSample);
97  Double_t weights = dt->GetSumWeights(validationSample);
98  // set the initial prune strength
99  fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
100  // better to set it too small, it will be increased automatically
101  fDeltaPruneStrength = 1.0e-5;
102  Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
103
104  Bool_t forceStop = kFALSE;
105  Int_t errCount = 0,
106  lastNodeCount = nnodes;
107
108  // find the maximum prune strength that still leaves the root's daughter nodes
109
110  while ( nnodes > 1 && !forceStop ) {
111  fPruneStrength += fDeltaPruneStrength;
112  Log() << "----------------------------------------------------" << Endl;
113  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
114  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
115  fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
116  // test the quality of the pruned tree
117  Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
118  fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
119
120  nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
121
122  Log() << "Prune strength : " << fPruneStrength << Endl;
123  Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
124  Log() << "Quality index is: " << quality << Endl;
125
126  if (lastNodeCount == nnodes) errCount++;
127  else {
128  errCount=0; // reset counter
129  if ( nnodes < lastNodeCount / 2 ) {
130  Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
131  << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
132  fDeltaPruneStrength /= 2.;
133  }
134  }
135  lastNodeCount = nnodes;
136  if (errCount > 20) {
137  Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
138  << " because the number of nodes in the tree didn't change." << Endl;
139  fDeltaPruneStrength *= 2.0;
140  }
141  if (errCount > 40) {
142  Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
143  forceStop = kTRUE;
144  }
145  // reset the tree for the next iteration
146  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
147  fPruneSequence[i]->SetTerminal(false);
148  fPruneSequence.clear();
149  }
150  // from the set of pruned trees, find the one with the optimal quality index
151  std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
152  fPruneStrength = it->second;
153  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
154
155  // adjust the step size for the next tree automatically
156  fPruneStrength = 1.0e-3;
157  fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
158
159  return new PruningInfo(it->first, it->second, fPruneSequence);
160  */
161  return NULL;
162  }
163  else { // no automatic pruning - just use the provided prune strength parameter
165  return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
166  }
167 }
168
169 ////////////////////////////////////////////////////////////////////////////////
170 /// recursive pruning of nodes using the Expected Error Pruning (EEP)
171
173 {
176  if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
177  this->FindListOfNodes(l);
178  this->FindListOfNodes(r);
179  if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
180  //node->Print(Log());
181  fPruneSequence.push_back(node);
182  }
183  }
184 }
185
186 ////////////////////////////////////////////////////////////////////////////////
187 /// calculate the expected statistical error on the subtree below "node"
188 /// which is used in the expected error pruning
189
191 {
194  if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
195  Double_t subTreeError =
196  (l->GetNEvents() * this->GetSubTreeError(l) +
197  r->GetNEvents() * this->GetSubTreeError(r)) /
198  node->GetNEvents();
199  return subTreeError;
200  }
201  else {
202  return this->GetNodeError(node);
203  }
204 }
205
206 ////////////////////////////////////////////////////////////////////////////////
207 /// Calculate an UPPER limit on the error made by the classification done
208 /// by this node. If the S/S+B of the node is f, then according to the
209 /// training sample, the error rate (fraction of misclassified events by
210 /// this node) is (1-f)
211 /// Now f has a statistical error according to the binomial distribution
212 /// hence the error on f can be estimated (same error as the binomial error
213 /// for efficiency calculations
214 /// \f$(\sigma = \sqrt{\frac{(eff(1-eff)}{nEvts}}) \f$
215
217 {
218  Double_t errorRate = 0;
219
220  Double_t nEvts = node->GetNEvents();
221
222  // fraction of correctly classified events by this node:
223  Double_t f = 0;
224  if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
225  else f = (1-node->GetPurity());
226
227  Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
228
229  errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
230
231  // -------------------------------------------------------------------
232  // standard algorithm:
233  // step 1: Estimate error on node using Laplace estimate
234  // NodeError = (N - n + k -1 ) / (N + k)
235  // N: number of events
236  // k: number of event classes (2 for Signal, Background)
237  // n: n event out of N belong to the class which has the majority in the node
238  // step 2: Approximate "backed-up" error assuming we did not prune
239  // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
240  // nodes)...
241  // Subtree error = Sum_children ( P_i * NodeError_i)
242  // P_i = probability of the node to make the decision, i.e. fraction of events in
243  // leaf node ( N_leaf / N_parent)
244  // step 3:
245
246  // Minimum Error Pruning (MEP) according to Niblett/Bratko
247  //# of correctly classified events by this node:
248  //Double_t n=f*nEvts ;
249  //Double_t p_apriori = 0.5, m=100;
250  //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m);
251
252  // Pessimistic error Pruning (proposed by Quinlan (error estimat with continuity approximation)
253  //# of correctly classified events by this node:
254  //Double_t n=f*nEvts ;
255  //errorRate = (nEvts - n + 0.5) / nEvts ;
256
257  //const Double Z=.65;
258  //# of correctly classified events by this node:
259  //Double_t n=f*nEvts ;
260  //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
261  //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z);
262  //errorRate = 1 - errorRate;
263  // -------------------------------------------------------------------
264
265  return errorRate;
266 }
267
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Float_t GetNEvents(void) const
Double_t fPruneStrength
Definition: IPruneTool.h:101
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:156
bool Bool_t
Definition: RtypesCore.h:59
std::vector< DecisionTreeNode * > fPruneSequence
the purity limit for labelling a terminal node as signal
#define NULL
Definition: RtypesCore.h:88
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:88
Bool_t IsAutomatic() const
Definition: IPruneTool.h:95
std::vector< const Event * > EventSample
Definition: IPruneTool.h:74
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
Int_t GetNodeType(void) const
TRandom2 r(17)
Double_t GetSubTreeError(DecisionTreeNode *node) const
calculate the expected statistical error on the subtree below "node" which is used in the expected er...
void FindListOfNodes(DecisionTreeNode *node)
recursive pruning of nodes using the Expected Error Pruning (EEP)
Double_t GetNodeError(DecisionTreeNode *node) const
Calculate an UPPER limit on the error made by the classification done by this node.
Implementation of a Decision Tree.
Definition: DecisionTree.h:59
TLine * l
Definition: textangle.C:4
const Bool_t kFALSE
Definition: RtypesCore.h:92
Float_t GetPurity(void) const
Double_t fNodePurityLimit
the stepsize for optimizing the pruning strength parameter
double f(double x)
double Double_t
Definition: RtypesCore.h:55
IPruneTool - a helper interface class to prune a decision tree.
Definition: IPruneTool.h:70
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
virtual DecisionTreeNode * GetLeft() const
virtual DecisionTreeNode * GetRight() const
Double_t Sqrt(Double_t x)
Definition: TMath.h:591
MsgLogger * fLogger
the (optimal) prune sequence