Logo ROOT  
Reference Guide
CostComplexityPruneTool.h
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
27 #ifndef ROOT_TMVA_CostComplexityPruneTool
28 #define ROOT_TMVA_CostComplexityPruneTool
29 
30 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
31 // CostComplexityPruneTool - a class to prune a decision tree using the Cost Complexity method //
32 // (see "Classification and Regression Trees" by Leo Breiman et al) //
33 // //
34 // Some definitions: //
35 // //
36 // T_max - the initial, usually highly overtrained tree, that is to be pruned back //
37 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T //
38 // ~T - set of terminal nodes in T //
39 // T' - the pruned subtree of T_max that has the best quality index R(T') //
40 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha*|~T|) //
41 // //
42 // There are two running modes in CostComplexityPruneTool: (i) one may select a prune strength and prune //
43 // the tree T_max until the criterion //
44 // R(T) - R(t) //
45 // alpha < ---------- //
46 // |~T_t| - 1 //
47 // //
48 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points //
49 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned //
50 // subtree, defined to be the subtree with the best quality index for the validation sample. //
51 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
52 
53 #include "TMVA/SeparationBase.h"
54 #include "TMVA/GiniIndex.h"
55 #include "TMVA/DecisionTree.h"
56 #include "TMVA/Event.h"
57 #include "TMVA/IPruneTool.h"
58 #include <vector>
59 
60 namespace TMVA {
61 
62  class CostComplexityPruneTool : public IPruneTool {
63  public:
64  CostComplexityPruneTool( SeparationBase* qualityIndex = NULL );
65  virtual ~CostComplexityPruneTool( );
66 
67  // calculate the prune sequence for a given tree
68  virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const IPruneTool::EventSample* testEvents = NULL, Bool_t isAutomatic = kFALSE );
69 
70  private:
71  SeparationBase* fQualityIndexTool; //! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
72 
73  std::vector<DecisionTreeNode*> fPruneSequence; //! map of weakest links (i.e., branches to prune) -> pruning index
74  std::vector<Double_t> fPruneStrengthList; //! map of alpha -> pruning index
75  std::vector<Double_t> fQualityIndexList; //! map of R(T) -> pruning index
76 
77  Int_t fOptimalK; //! the optimal index of the prune sequence
78 
79  private:
80  // set the meta data used for cost complexity pruning
81  void InitTreePruningMetaData( DecisionTreeNode* n );
82 
83  // optimize the pruning sequence
84  void Optimize( DecisionTree* dt, Double_t weights );
85 
86  mutable MsgLogger* fLogger; //! output stream to save logging information
87  MsgLogger& Log() const { return *fLogger; }
88 
89  };
90 }
91 
92 
93 #endif
n
const Int_t n
Definition: legend1.C:16
TMVA::CostComplexityPruneTool::fQualityIndexTool
SeparationBase * fQualityIndexTool
Definition: CostComplexityPruneTool.h:119
TMVA::IPruneTool::EventSample
std::vector< const Event * > EventSample
Definition: IPruneTool.h:93
TMVA::CostComplexityPruneTool::Log
MsgLogger & Log() const
output stream to save logging information
Definition: CostComplexityPruneTool.h:135
GiniIndex.h
Int_t
int Int_t
Definition: RtypesCore.h:45
Bool_t
bool Bool_t
Definition: RtypesCore.h:63
TMVA::CostComplexityPruneTool::InitTreePruningMetaData
void InitTreePruningMetaData(DecisionTreeNode *n)
the optimal index of the prune sequence
Definition: CostComplexityPruneTool.cxx:181
SeparationBase.h
DecisionTree.h
TMVA::CostComplexityPruneTool::fQualityIndexList
std::vector< Double_t > fQualityIndexList
map of alpha -> pruning index
Definition: CostComplexityPruneTool.h:123
TMVA::CostComplexityPruneTool::fLogger
MsgLogger * fLogger
Definition: CostComplexityPruneTool.h:134
TMVA::CostComplexityPruneTool::CostComplexityPruneTool
CostComplexityPruneTool(SeparationBase *qualityIndex=NULL)
the constructor for the cost complexity pruning
Definition: CostComplexityPruneTool.cxx:68
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
Event.h
TMVA::CostComplexityPruneTool::CalculatePruningInfo
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
the routine that basically "steers" the pruning process.
Definition: CostComplexityPruneTool.cxx:98
TMVA::CostComplexityPruneTool::Optimize
void Optimize(DecisionTree *dt, Double_t weights)
after the critical values (at which the corresponding nodes would be pruned away) had been establish...
Definition: CostComplexityPruneTool.cxx:236
TMVA::CostComplexityPruneTool::fPruneStrengthList
std::vector< Double_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
Definition: CostComplexityPruneTool.h:122
TMVA::CostComplexityPruneTool::fPruneSequence
std::vector< DecisionTreeNode * > fPruneSequence
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition: CostComplexityPruneTool.h:121
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MsgLogger
Definition: MsgLogger.h:83
TMVA::CostComplexityPruneTool::fOptimalK
Int_t fOptimalK
map of R(T) -> pruning index
Definition: CostComplexityPruneTool.h:125
IPruneTool.h
TMVA::CostComplexityPruneTool::~CostComplexityPruneTool
virtual ~CostComplexityPruneTool()
the destructor for the cost complexity pruning
Definition: CostComplexityPruneTool.cxx:89
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22