Logo ROOT   6.08/07
Reference Guide
CCPruner.h
Go to the documentation of this file.
1 #ifndef ROOT_TMVA_CCPruner
2 #define ROOT_TMVA_CCPruner
3 /**********************************************************************************
4  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
5  * Package: TMVA *
6  * Class : CCPruner *
7  * Web : http://tmva.sourceforge.net *
8  * *
9  * Description: Cost Complexity Pruning *
10  *
11  * Author: Doug Schouten (dschoute@sfu.ca)
12  *
13  * *
14  * Copyright (c) 2007: *
15  * CERN, Switzerland *
16  * MPI-K Heidelberg, Germany *
17  * U. of Texas at Austin, USA *
18  * *
19  * Redistribution and use in source and binary forms, with or without *
20  * modification, are permitted according to the terms listed in LICENSE *
21  * (http://tmva.sourceforge.net/LICENSE) *
22  **********************************************************************************/
23 
24 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
25 // CCPruner - a helper class to prune a decision tree using the Cost Complexity method //
26 // (see Classification and Regression Trees by Leo Breiman et al) //
27 // //
28 // Some definitions: //
29 // //
30 // T_max - the initial, usually highly overtrained tree, that is to be pruned back //
31 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T //
32 // ~T - set of terminal nodes in T //
33 // T' - the pruned subtree of T_max that has the best quality index R(T') //
34 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|) //
35 // //
36 // There are two running modes in CCPruner: (i) one may select a prune strength and prune back //
37 // the tree T_max until the criterion //
38 // R(T) - R(t) //
39 // alpha < ---------- //
40 // |~T_t| - 1 //
41 // //
42 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points //
43 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned //
44 // subtree, defined to be the subtree with the best quality index for the validation sample. //
45 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
46 
47 
48 #ifndef ROOT_TMVA_DecisionTree
49 #include "TMVA/DecisionTree.h"
50 #endif
51 
52 /* #ifndef ROOT_TMVA_DecisionTreeNode */
53 /* #include "TMVA/DecisionTreeNode.h" */
54 /* #endif */
55 
56 #ifndef ROOT_TMVA_Event
57 #include "TMVA/Event.h"
58 #endif
59 
60 namespace TMVA {
61  class DataSet;
62  class DecisionTreeNode;
63  class SeparationBase;
64 
65  class CCPruner {
66  public:
67  typedef std::vector<Event*> EventList;
68 
69  CCPruner( DecisionTree* t_max,
70  const EventList* validationSample,
71  SeparationBase* qualityIndex = NULL );
72 
73  CCPruner( DecisionTree* t_max,
74  const DataSet* validationSample,
75  SeparationBase* qualityIndex = NULL );
76 
77  ~CCPruner( );
78 
79  // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated)
80  void SetPruneStrength( Float_t alpha = -1.0 );
81 
82  void Optimize( );
83 
84  // return the list of pruning locations to define the optimal subtree T' of T_max
85  std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const;
86 
87  // return the quality index from the validation sample for the optimal subtree T'
88  inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ?
89  fQualityIndexList[fOptimalK] : -1.0); }
90 
91  // return the prune strength (=alpha) corresponding to the prune sequence
92  inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ?
93  fPruneStrengthList[fOptimalK] : -1.0); }
94 
95  private:
96  Float_t fAlpha; //! regularization parameter in CC pruning
97  const EventList* fValidationSample; //! the event sample to select the optimally-pruned tree
98  const DataSet* fValidationDataSet; //! the event sample to select the optimally-pruned tree
99  SeparationBase* fQualityIndex; //! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
100  Bool_t fOwnQIndex; //! flag indicates if fQualityIndex is owned by this
101 
102  DecisionTree* fTree; //! (pruned) decision tree
103 
104  std::vector<TMVA::DecisionTreeNode*> fPruneSequence; //! map of weakest links (i.e., branches to prune) -> pruning index
105  std::vector<Float_t> fPruneStrengthList; //! map of alpha -> pruning index
106  std::vector<Float_t> fQualityIndexList; //! map of R(T) -> pruning index
107 
108  Int_t fOptimalK; //! index of the optimal tree in the pruned tree sequence
109  Bool_t fDebug; //! debug flag
110  };
111 }
112 
114  fAlpha = (alpha > 0 ? alpha : 0.0);
115 }
116 
117 
118 #endif
119 
120 
void Optimize()
determine the pruning sequence
Definition: CCPruner.cxx:100
std::vector< Float_t > fQualityIndexList
map of alpha -> pruning index
Definition: CCPruner.h:106
Float_t fAlpha
Definition: CCPruner.h:96
float Float_t
Definition: RtypesCore.h:53
const DataSet * fValidationDataSet
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:98
Float_t GetOptimalPruneStrength() const
Definition: CCPruner.h:92
Float_t GetOptimalQualityIndex() const
Definition: CCPruner.h:88
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Int_t fOptimalK
map of R(T) -> pruning index
Definition: CCPruner.h:108
DecisionTree * fTree
flag indicates if fQualityIndex is owned by this
Definition: CCPruner.h:102
Bool_t fDebug
index of the optimal tree in the pruned tree sequence
Definition: CCPruner.h:109
SeparationBase * fQualityIndex
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:99
void SetPruneStrength(Float_t alpha=-1.0)
Definition: CCPruner.h:113
Bool_t fOwnQIndex
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition: CCPruner.h:100
std::vector< TMVA::DecisionTreeNode * > fPruneSequence
(pruned) decision tree
Definition: CCPruner.h:104
std::vector< Event * > EventList
Definition: CCPruner.h:67
const EventList * fValidationSample
regularization parameter in CC pruning
Definition: CCPruner.h:97
std::vector< TMVA::DecisionTreeNode * > GetOptimalPruneSequence() const
return the prune strength (=alpha) corresponding to the prune sequence
Definition: CCPruner.cxx:216
Abstract ClassifierFactory template that handles arbitrary types.
#define NULL
Definition: Rtypes.h:82
CCPruner(DecisionTree *t_max, const EventList *validationSample, SeparationBase *qualityIndex=NULL)
constructor
Definition: CCPruner.cxx:45
std::vector< Float_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
Definition: CCPruner.h:105