ROOT  6.06/09
Reference Guide
ExpectedErrorPruneTool.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
28 #include "TMVA/MsgLogger.h"
29 
30 #include <map>
31 
32 ////////////////////////////////////////////////////////////////////////////////
33 
35  IPruneTool(),
36  fDeltaPruneStrength(0),
37  fNodePurityLimit(1),
38  fLogger( new MsgLogger("ExpectedErrorPruneTool") )
39 {}
40 
41 ////////////////////////////////////////////////////////////////////////////////
42 
44 {
45  delete fLogger;
46 }
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 
52  const IPruneTool::EventSample* validationSample,
53  Bool_t isAutomatic )
54 {
55  if( isAutomatic ) {
56  //SetAutomatic( );
57  isAutomatic = kFALSE;
58  Log() << kWARNING << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl;
59  }
60  if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
61  // must have a valid decision tree to prune, and if the prune strength
62  // is to be chosen automatically, must have a test sample from
63  // which to calculate the quality of the pruned tree(s)
64  return NULL;
65  }
66  fNodePurityLimit = dt->GetNodePurityLimit();
67 
68  if(IsAutomatic()) {
69  Log() << kFATAL << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl;
70  /*
71  dt->ApplyValidationSample(validationSample);
72  Double_t weights = dt->GetSumWeights(validationSample);
73  // set the initial prune strength
74  fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
75  // better to set it too small, it will be increased automatically
76  fDeltaPruneStrength = 1.0e-5;
77  Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
78 
79  Bool_t forceStop = kFALSE;
80  Int_t errCount = 0,
81  lastNodeCount = nnodes;
82 
83  // find the maxiumum prune strength that still leaves the root's daughter nodes
84 
85  while ( nnodes > 1 && !forceStop ) {
86  fPruneStrength += fDeltaPruneStrength;
87  Log() << "----------------------------------------------------" << Endl;
88  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
89  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
90  fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
91  // test the quality of the pruned tree
92  Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
93  fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
94 
95  nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
96 
97  Log() << "Prune strength : " << fPruneStrength << Endl;
98  Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
99  Log() << "Quality index is: " << quality << Endl;
100 
101  if (lastNodeCount == nnodes) errCount++;
102  else {
103  errCount=0; // reset counter
104  if ( nnodes < lastNodeCount / 2 ) {
105  Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
106  << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
107  fDeltaPruneStrength /= 2.;
108  }
109  }
110  lastNodeCount = nnodes;
111  if (errCount > 20) {
112  Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
113  << " because the number of nodes in the tree didn't change." << Endl;
114  fDeltaPruneStrength *= 2.0;
115  }
116  if (errCount > 40) {
117  Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
118  forceStop = kTRUE;
119  }
120  // reset the tree for the next iteration
121  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
122  fPruneSequence[i]->SetTerminal(false);
123  fPruneSequence.clear();
124  }
125  // from the set of pruned trees, find the one with the optimal quality index
126  std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
127  fPruneStrength = it->second;
128  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
129 
130  // adjust the step size for the next tree automatically
131  fPruneStrength = 1.0e-3;
132  fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
133 
134  return new PruningInfo(it->first, it->second, fPruneSequence);
135  */
136  return NULL;
137  }
138  else { // no automatic pruning - just use the provided prune strength parameter
139  FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() );
140  return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
141  }
142 }
143 
144 ////////////////////////////////////////////////////////////////////////////////
145 /// recursive pruning of nodes using the Expected Error Pruning (EEP)
146 
148 {
151  if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
152  this->FindListOfNodes(l);
153  this->FindListOfNodes(r);
154  if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
155  //node->Print(Log());
156  fPruneSequence.push_back(node);
157  }
158  }
159 }
160 
161 ////////////////////////////////////////////////////////////////////////////////
162 /// calculate the expected statistical error on the subtree below "node"
163 /// which is used in the expected error pruning
164 
166 {
169  if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
170  Double_t subTreeError =
171  (l->GetNEvents() * this->GetSubTreeError(l) +
172  r->GetNEvents() * this->GetSubTreeError(r)) /
173  node->GetNEvents();
174  return subTreeError;
175  }
176  else {
177  return this->GetNodeError(node);
178  }
179 }
180 
181 ////////////////////////////////////////////////////////////////////////////////
182 /// Calculate an UPPER limit on the error made by the classification done
183 /// by this node. If the S/S+B of the node is f, then according to the
184 /// training sample, the error rate (fraction of misclassified events by
185 /// this node) is (1-f)
186 /// Now f has a statistical error according to the binomial distribution
187 /// hence the error on f can be estimated (same error as the binomial error
188 /// for efficency calculations ( sigma = sqrt(eff(1-eff)/nEvts ) )
189 
191 {
192  Double_t errorRate = 0;
193 
194  Double_t nEvts = node->GetNEvents();
195 
196  // fraction of correctly classified events by this node:
197  Double_t f = 0;
198  if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
199  else f = (1-node->GetPurity());
200 
201  Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
202 
203  errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
204 
205  // -------------------------------------------------------------------
206  // standard algorithm:
207  // step 1: Estimate error on node using Laplace estimate
208  // NodeError = (N - n + k -1 ) / (N + k)
209  // N: number of events
210  // k: number of event classes (2 for Signal, Background)
211  // n: n event out of N belong to the class which has the majority in the node
212  // step 2: Approximate "backed-up" error assuming we did not prune
213  // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
214  // nodes)...
215  // Subtree error = Sum_children ( P_i * NodeError_i)
216  // P_i = probability of the node to make the decision, i.e. fraction of events in
217  // leaf node ( N_leaf / N_parent)
218  // step 3:
219 
220  // Minimum Error Pruning (MEP) accordig to Niblett/Bratko
221  //# of correctly classified events by this node:
222  //Double_t n=f*nEvts ;
223  //Double_t p_apriori = 0.5, m=100;
224  //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m);
225 
226  // Pessimistic error Pruing (proposed by Quinlan (error estimat with continuity approximation)
227  //# of correctly classified events by this node:
228  //Double_t n=f*nEvts ;
229  //errorRate = (nEvts - n + 0.5) / nEvts ;
230 
231  //const Double Z=.65;
232  //# of correctly classified events by this node:
233  //Double_t n=f*nEvts ;
234  //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
235  //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z);
236  //errorRate = 1 - errorRate;
237  // -------------------------------------------------------------------
238 
239  return errorRate;
240 }
241 
242 
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:170
virtual DecisionTreeNode * GetRight() const
Int_t GetNodeType(void) const
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual DecisionTreeNode * GetLeft() const
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
Float_t GetPurity(void) const
std::vector< const Event * > EventSample
Definition: IPruneTool.h:76
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
ROOT::R::TRInterface & r
Definition: Object.C:4
void FindListOfNodes(DecisionTreeNode *node)
recursive pruning of nodes using the Expected Error Pruning (EEP)
Float_t GetNEvents(void) const
TLine * l
Definition: textangle.C:4
Bool_t IsTerminal() const
double f(double x)
double Double_t
Definition: RtypesCore.h:55
Double_t GetNodeError(DecisionTreeNode *node) const
Calculate an UPPER limit on the error made by the classification done by this node.
#define NULL
Definition: Rtypes.h:82
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
Double_t GetSubTreeError(DecisionTreeNode *node) const
calculate the expected statistical error on the subtree below "node" which is used in the expected er...
Definition: math.cpp:60