Logo ROOT  
Reference Guide
CostComplexityPruneTool.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
27 /*! \class TMVA::CostComplexityPruneTool
28 \ingroup TMVA
29 A class to prune a decision tree using the Cost Complexity method.
30 (see "Classification and Regression Trees" by Leo Breiman et al)
31 
32 ### Some definitions:
33 
34  - \f$ T_{max} \f$ - the initial, usually highly overtrained tree, that is to be pruned back
35  - \f$ R(T) \f$ - quality index (Gini, misclassification rate, or other) of a tree \f$ T \f$
36  - \f$ \sim T \f$ - set of terminal nodes in \f$ T \f$
37  - \f$ T' \f$ - the pruned subtree of \f$ T_max \f$ that has the best quality index \f$ R(T') \f$
38  - \f$ \alpha \f$ - the prune strength parameter in Cost Complexity pruning \f$ (R_{\alpha}(T) = R(T) + \alpha*|\sim T|) \f$
39 
40 There are two running modes in CCPruner: (i) one may select a prune strength and prune back
41 the tree \f$ T_{max}\f$ until the criterion:
42 \f[
43  \alpha < \frac{R(T) - R(t)}{|\sim T_t| - 1}
44 \f]
45 
46 is true for all nodes t in \f$ T \f$, or (ii) the algorithm finds the sequence of critical points
47 \f$ \alpha_k < \alpha_{k+1} ... < \alpha_K \f$ such that \f$ T_K = root(T_{max}) \f$ and then selects the optimally-pruned
48 subtree, defined to be the subtree with the best quality index for the validation sample.
49 */
50 
52 
53 #include "TMVA/MsgLogger.h"
54 #include "TMVA/SeparationBase.h"
55 #include "TMVA/DecisionTree.h"
56 
57 #include "RtypesCore.h"
58 
59 #include <limits>
60 #include <cmath>
61 
62 using namespace TMVA;
63 
64 
65 ////////////////////////////////////////////////////////////////////////////////
66 /// the constructor for the cost complexity pruning
67 
69  IPruneTool(),
70  fLogger(new MsgLogger("CostComplexityPruneTool") )
71 {
72  fOptimalK = -1;
73 
74  // !! changed from Dougs code. Now use the QualityIndex stored already
75  // in the nodes when no "new" QualityIndex calculator is given. Like this
76  // I can easily implement the Regression. For Regression, the pruning uses the
77  // same separation index as in the tree building, hence doesn't need to re-calculate
78  // (which would need more info than simply "s" and "b")
79 
80  fQualityIndexTool = qualityIndex;
81 
82  //fLogger->SetMinType( kDEBUG );
83  fLogger->SetMinType( kWARNING );
84 }
85 
86 ////////////////////////////////////////////////////////////////////////////////
87 /// the destructor for the cost complexity pruning
88 
90  if(fQualityIndexTool != NULL) delete fQualityIndexTool;
91 }
92 
93 ////////////////////////////////////////////////////////////////////////////////
94 /// the routine that basically "steers" the pruning process. Call the calculation of
95 /// the pruning sequence, the tree quality and alike..
96 
99  const IPruneTool::EventSample* validationSample,
100  Bool_t isAutomatic )
101 {
102  if( isAutomatic ) SetAutomatic();
103 
104  if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
105  // must have a valid decision tree to prune, and if the prune strength
106  // is to be chosen automatically, must have a test sample from
107  // which to calculate the quality of the pruned tree(s)
108  return NULL;
109  }
110 
111  Double_t Q = -1.0;
112  Double_t W = 1.0;
113 
114  if(IsAutomatic()) {
115  // run the pruning validation sample through the unpruned tree
116  dt->ApplyValidationSample(validationSample);
117  W = dt->GetSumWeights(validationSample); // get the sum of weights in the pruning validation sample
118  // calculate the quality of the tree in the unpruned case
119  Q = dt->TestPrunedTreeQuality();
120 
121  Log() << kDEBUG << "Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
122  Log() << kDEBUG << "Sum of weights in pruning validation sample: " << W << Endl;
123  Log() << kDEBUG << "Quality of tree prior to any pruning is " << Q/W << Endl;
124  }
125 
126  // store the cost complexity metadata for the decision tree at each node
127  try {
129  }
130  catch(const std::string &error) {
131  Log() << kERROR << "Couldn't initialize the tree meta data because of error ("
132  << error << ")" << Endl;
133  return NULL;
134  }
135 
136  Log() << kDEBUG << "Automatic cost complexity pruning is " << (IsAutomatic()?"on":"off") << "." << Endl;
137 
138  try {
139  Optimize( dt, W ); // run the cost complexity pruning algorithm
140  }
141  catch(const std::string &error) {
142  Log() << kERROR << "Error optimizing pruning sequence ("
143  << error << ")" << Endl;
144  return NULL;
145  }
146 
147  Log() << kDEBUG << "Index of pruning sequence to stop at: " << fOptimalK << Endl;
148 
149  PruningInfo* info = new PruningInfo();
150 
151 
152  if(fOptimalK < 0) {
153  // no pruning necessary, or wasn't able to compute a sequence
154  info->PruneStrength = 0;
155  info->QualityIndex = Q/W;
156  info->PruneSequence.clear();
157  Log() << kINFO << "no proper pruning could be calculated. Tree "
158  << dt->GetTreeID() << " will not be pruned. Do not worry if this "
159  << " happens for a few trees " << Endl;
160  return info;
161  }
163  Log() << kDEBUG << " prune until k=" << fOptimalK << " with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
164  for( Int_t i = 0; i < fOptimalK; i++ ){
165  info->PruneSequence.push_back(fPruneSequence[i]);
166  }
167  if( IsAutomatic() ){
169  }
170  else {
172  }
173 
174  return info;
175 }
176 
177 ////////////////////////////////////////////////////////////////////////////////
178 /// initialise "meta data" for the pruning, like the "costcomplexity", the
179 /// critical alpha, the minimal alpha down the tree, etc... for each node!!
180 
182  if( n == NULL ) return;
183 
184  Double_t s = n->GetNSigEvents();
185  Double_t b = n->GetNBkgEvents();
186  // set R(t) = N_events*Gini(t) or MisclassificationError(t), etc.
188  else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
189 
190  if(n->GetLeft() != NULL && n->GetRight() != NULL) { // n is an interior (non-leaf) node
191  n->SetTerminal(kFALSE);
192  // traverse the tree
193  InitTreePruningMetaData(n->GetLeft());
194  InitTreePruningMetaData(n->GetRight());
195  // set |~T_t|
196  n->SetNTerminal( n->GetLeft()->GetNTerminal() +
197  n->GetRight()->GetNTerminal());
198  // set R(T) = sum[n' in ~T]{ R(n') }
199  n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
200  n->GetRight()->GetSubTreeR()));
201  // set alpha_c, the alpha value at which it becomes advantageous to prune at node n
202  n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
203  (n->GetNTerminal() - 1)));
204 
205  // G(t) = min( alpha_c, G(l(n)), G(r(n)) )
206  // the minimum alpha in subtree rooted at this node
207  n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
208  n->GetRight()->GetAlphaMinSubtree())));
209  n->SetCC(n->GetAlpha());
210 
211  } else { // n is a terminal node
212  n->SetNTerminal( 1 ); n->SetTerminal( );
213  if (fQualityIndexTool) n->SetSubTreeR(((s+b)*fQualityIndexTool->GetSeparationIndex(s,b)));
214  else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
215  n->SetAlpha(std::numeric_limits<double>::infinity( ));
216  n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
217  n->SetCC(n->GetAlpha());
218  }
219 
220  // DecisionTreeNode* R = (DecisionTreeNode*)mdt->GetRoot();
221  // Double_t x = R->GetAlphaMinSubtree();
222  // Log() << "alphaMin(Root) = " << x << Endl;
223 }
224 
225 
226 ////////////////////////////////////////////////////////////////////////////////
227 /// after the critical \f$ \alpha \f$ values (at which the corresponding nodes would
228 /// be pruned away) had been established in the "InitMetaData" we need now:
229 /// automatic pruning:
230 ///
231 /// find the value of \f$ \alpha \f$ for which the test sample gives minimal error,
232 /// on the tree with all nodes pruned that have \f$ \alpha_{critical} < \alpha \f$,
233 /// fixed parameter pruning
234 ///
235 
237  Int_t k = 1;
238  Double_t alpha = -1.0e10;
240 
241  fQualityIndexList.clear();
242  fPruneSequence.clear();
243  fPruneStrengthList.clear();
244 
246 
247  Double_t qmin = 0.0;
248  if(IsAutomatic()){
249  // initialize the tree quality (actually at this stage, it is the quality of the yet unpruned tree
250  qmin = dt->TestPrunedTreeQuality()/weights;
251  }
252 
253  // now prune the tree in steps until it is gone. At each pruning step, the pruning
254  // takes place at the node that is regarded as the "weakest link".
255  // for automatic pruning, at each step, we calculate the current quality of the
256  // tree and in the end we will prune at the minimum of the tree quality
257  // for the fixed parameter pruning, the cut is simply set at a relative position
258  // in the sequence according to the "length" of the sequence of pruned trees.
259  // 100: at the end (pruned until the root node would be the next pruning candidate
260  // 50: in the middle of the sequence
261  // etc...
262  while(R->GetNTerminal() > 1) { // prune upwards to the root node
263 
264  // initialize alpha
265  alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
266 
267  if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
268  Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
269  break;
270  }
271 
272 
273  DecisionTreeNode* t = R;
274 
275  // descend to the weakest link
276  while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
277  // std::cout << t->GetAlphaMinSubtree() << " " << t->GetAlpha()<< " "
278  // << t->GetAlphaMinSubtree()- t->GetAlpha()<< " t==R?" << int(t == R) << std::endl;
279  // while( (t->GetAlphaMinSubtree() - t->GetAlpha()) < epsilon) {
280  // if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree())/TMath::Abs(t->GetAlphaMinSubtree()) < epsilon) {
282  t = t->GetLeft();
283  } else {
284  t = t->GetRight();
285  }
286  }
287 
288  if( t == R ) {
289  Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
290  break;
291  }
292 
293  DecisionTreeNode* n = t;
294 
295  // Log() << kDEBUG << "alpha[" << k << "]: " << alpha << Endl;
296  // Log() << kDEBUG << "===========================" << Endl
297  // << "Pruning branch listed below the node" << Endl;
298  // t->Print( Log() );
299  // Log() << kDEBUG << "===========================" << Endl;
300  // t->PrintRecPrune( Log() );
301 
302  dt->PruneNodeInPlace(t); // prune the branch rooted at node t
303 
304  while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
305  t = t->GetParent();
306  t->SetNTerminal(t->GetLeft()->GetNTerminal() + t->GetRight()->GetNTerminal());
307  t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
308  t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
309  t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
310  t->GetRight()->GetAlphaMinSubtree())));
311  t->SetCC(t->GetAlpha());
312  }
313  k += 1;
314 
315  Log() << kDEBUG << "after this pruning step I would have " << R->GetNTerminal() << " remaining terminal nodes " << Endl;
316 
317  if(IsAutomatic()) {
318  Double_t q = dt->TestPrunedTreeQuality()/weights;
319  fQualityIndexList.push_back(q);
320  }
321  else {
322  fQualityIndexList.push_back(1.0);
323  }
324  fPruneSequence.push_back(n);
325  fPruneStrengthList.push_back(alpha);
326  }
327 
328  if(fPruneSequence.empty()) {
329  fOptimalK = -1;
330  return;
331  }
332 
333  if(IsAutomatic()) {
334  k = -1;
335  for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
336  if(fQualityIndexList[i] < qmin) {
337  qmin = fQualityIndexList[i];
338  k = i;
339  }
340  }
341  fOptimalK = k;
342  }
343  else {
344  // regularize the prune strength relative to this tree
345  fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
346  Log() << kDEBUG << "SequenzeSize="<<fPruneSequence.size()
347  << " fOptimalK " << fOptimalK << Endl;
348 
349  }
350 
351  Log() << kDEBUG << "\n************ Summary for Tree " << dt->GetTreeID() << " *******" << Endl
352  << "Number of trees in the sequence: " << fPruneSequence.size() << Endl;
353 
354  Log() << kDEBUG << "Pruning strength parameters: [";
355  for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
356  Log() << kDEBUG << fPruneStrengthList[i] << ", ";
357  Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << Endl;
358 
359  Log() << kDEBUG << "Misclassification rates: [";
360  for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
361  Log() << kDEBUG << fQualityIndexList[i] << ", ";
362  Log() << kDEBUG << fQualityIndexList[fQualityIndexList.size()-1] << "]" << Endl;
363 
364  Log() << kDEBUG << "Prune index: " << fOptimalK+1 << Endl;
365 
366 }
367 
n
const Int_t n
Definition: legend1.C:16
TMVA::IPruneTool
Definition: IPruneTool.h:89
TMVA::PruningInfo
Definition: IPruneTool.h:58
TMVA::DecisionTreeNode::GetAlpha
Double_t GetAlpha() const
Definition: DecisionTreeNode.h:330
TMVA::CostComplexityPruneTool::fQualityIndexTool
SeparationBase * fQualityIndexTool
Definition: CostComplexityPruneTool.h:119
TMVA::IPruneTool::EventSample
std::vector< const Event * > EventSample
Definition: IPruneTool.h:93
TMVA::CostComplexityPruneTool::Log
MsgLogger & Log() const
output stream to save logging information
Definition: CostComplexityPruneTool.h:135
TMVA::DecisionTreeNode::SetCC
void SetCC(Double_t cc)
Definition: DecisionTreeNode.cxx:403
TMVA::PruningInfo::PruneStrength
Double_t PruneStrength
quality measure for a pruned subtree T of T_max
Definition: IPruneTool.h:84
TMath::Max
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:212
CostComplexityPruneTool.h
TMVA::DecisionTreeNode::SetNTerminal
void SetNTerminal(Int_t n)
Definition: DecisionTreeNode.h:337
TMVA::DecisionTreeNode::GetNodeR
Double_t GetNodeR() const
Definition: DecisionTreeNode.h:320
TMVA::SeparationBase::GetSeparationIndex
virtual Double_t GetSeparationIndex(const Double_t s, const Double_t b)=0
TMVA::DecisionTreeNode::GetSubTreeR
Double_t GetSubTreeR() const
Definition: DecisionTreeNode.h:324
TGeant4Unit::s
static constexpr double s
Definition: TGeant4SystemOfUnits.h:168
TMVA::PruningInfo::PruneSequence
std::vector< DecisionTreeNode * > PruneSequence
the regularization parameter for pruning
Definition: IPruneTool.h:85
TMVA::DecisionTreeNode
Definition: DecisionTreeNode.h:141
TMath::Abs
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
TMVA::DecisionTree::TestPrunedTreeQuality
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
Definition: DecisionTree.cxx:1043
b
#define b(i)
Definition: RSha256.hxx:118
TMVA::DecisionTree
Definition: DecisionTree.h:65
bool
q
float * q
Definition: THbookFile.cxx:89
TMVA::CostComplexityPruneTool::InitTreePruningMetaData
void InitTreePruningMetaData(DecisionTreeNode *n)
the optimal index of the prune sequence
Definition: CostComplexityPruneTool.cxx:181
TMVA::DecisionTree::GetNodePurityLimit
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:162
R
#define R(a, b, c, d, e, f, g, h, i)
Definition: RSha256.hxx:128
TMVA::DecisionTree::GetRoot
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:94
SeparationBase.h
MsgLogger.h
DecisionTree.h
TMVA::CostComplexityPruneTool::fQualityIndexList
std::vector< Double_t > fQualityIndexList
map of alpha -> pruning index
Definition: CostComplexityPruneTool.h:123
TMVA::DecisionTree::GetTreeID
Int_t GetTreeID()
Definition: DecisionTree.h:186
TMVA::DecisionTreeNode::GetLeft
virtual DecisionTreeNode * GetLeft() const
Definition: DecisionTreeNode.h:306
TMVA::CostComplexityPruneTool::fLogger
MsgLogger * fLogger
Definition: CostComplexityPruneTool.h:134
epsilon
REAL epsilon
Definition: triangle.c:617
TMVA::CostComplexityPruneTool::CostComplexityPruneTool
CostComplexityPruneTool(SeparationBase *qualityIndex=NULL)
the constructor for the cost complexity pruning
Definition: CostComplexityPruneTool.cxx:68
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TMVA::MsgLogger::SetMinType
void SetMinType(EMsgType minType)
Definition: MsgLogger.h:120
TMVA::CostComplexityPruneTool::CalculatePruningInfo
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
the routine that basically "steers" the pruning process.
Definition: CostComplexityPruneTool.cxx:98
TMVA::DecisionTree::ApplyValidationSample
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
Definition: DecisionTree.cxx:1029
TMVA::DecisionTreeNode::SetSubTreeR
void SetSubTreeR(Double_t r)
Definition: DecisionTreeNode.h:323
TMVA::IPruneTool::IsAutomatic
Bool_t IsAutomatic() const
Definition: IPruneTool.h:114
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:182
unsigned int
TMVA::CostComplexityPruneTool::Optimize
void Optimize(DecisionTree *dt, Double_t weights)
after the critical values (at which the corresponding nodes would be pruned away) had been establish...
Definition: CostComplexityPruneTool.cxx:236
TMVA::CostComplexityPruneTool::fPruneStrengthList
std::vector< Double_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
Definition: CostComplexityPruneTool.h:122
TMVA::SeparationBase
Definition: SeparationBase.h:121
TMVA::CostComplexityPruneTool::fPruneSequence
std::vector< DecisionTreeNode * > fPruneSequence
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition: CostComplexityPruneTool.h:121
Double_t
double Double_t
Definition: RtypesCore.h:59
RtypesCore.h
TMVA::MsgLogger
Definition: MsgLogger.h:83
TMVA::CostComplexityPruneTool::fOptimalK
Int_t fOptimalK
map of R(T) -> pruning index
Definition: CostComplexityPruneTool.h:125
TMVA::DecisionTreeNode::GetAlphaMinSubtree
Double_t GetAlphaMinSubtree() const
Definition: DecisionTreeNode.h:334
TMVA::DecisionTree::GetSumWeights
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
Definition: DecisionTree.cxx:1118
TMVA::IPruneTool::fPruneStrength
Double_t fPruneStrength
Definition: IPruneTool.h:120
TMVA::DecisionTreeNode::GetNTerminal
Int_t GetNTerminal() const
Definition: DecisionTreeNode.h:338
TMVA::DecisionTreeNode::SetAlpha
void SetAlpha(Double_t alpha)
Definition: DecisionTreeNode.h:329
TMVA::DecisionTree::PruneNodeInPlace
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
Definition: DecisionTree.cxx:1217
TMVA::CostComplexityPruneTool::~CostComplexityPruneTool
virtual ~CostComplexityPruneTool()
the destructor for the cost complexity pruning
Definition: CostComplexityPruneTool.cxx:89
TMVA::IPruneTool::SetAutomatic
void SetAutomatic()
Definition: IPruneTool.h:113
TMVA::PruningInfo::QualityIndex
Double_t QualityIndex
Definition: IPruneTool.h:83
TMVA::DecisionTreeNode::GetRight
virtual DecisionTreeNode * GetRight() const
Definition: DecisionTreeNode.h:307
TMVA::DecisionTreeNode::GetParent
virtual DecisionTreeNode * GetParent() const
Definition: DecisionTreeNode.h:308
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int
TMVA::DecisionTreeNode::SetAlphaMinSubtree
void SetAlphaMinSubtree(Double_t g)
Definition: DecisionTreeNode.h:333
ROOT::Math::Cephes::Q
static double Q[]
Definition: SpecFuncCephes.cxx:294