Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CCPruner.h
Go to the documentation of this file.
1#ifndef ROOT_TMVA_CCPruner
2#define ROOT_TMVA_CCPruner
3/**********************************************************************************
4 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
5 * Package: TMVA *
6 * Class : CCPruner *
7 * *
8 * *
9 * Description: Cost Complexity Pruning *
10 *
11 * Author: Doug Schouten (dschoute@sfu.ca)
12 *
13 * *
14 * Copyright (c) 2007: *
15 * CERN, Switzerland *
16 * MPI-K Heidelberg, Germany *
17 * U. of Texas at Austin, USA *
18 * *
19 * Redistribution and use in source and binary forms, with or without *
20 * modification, are permitted according to the terms listed in LICENSE *
21 * (see tmva/doc/LICENSE) *
22 **********************************************************************************/
23
24////////////////////////////////////////////////////////////////////////////////////////////////////////////
25// CCPruner - a helper class to prune a decision tree using the Cost Complexity method //
26// (see Classification and Regression Trees by Leo Breiman et al) //
27// //
28// Some definitions: //
29// //
30// T_max - the initial, usually highly overtrained tree, that is to be pruned back //
31// R(T) - quality index (Gini, misclassification rate, or other) of a tree T //
32// ~T - set of terminal nodes in T //
33// T' - the pruned subtree of T_max that has the best quality index R(T') //
34// alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|) //
35// //
36// There are two running modes in CCPruner: (i) one may select a prune strength and prune back //
37// the tree T_max until the criterion //
38// R(T) - R(t) //
39// alpha < ---------- //
40// |~T_t| - 1 //
41// //
42// is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points //
43// alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned //
44// subtree, defined to be the subtree with the best quality index for the validation sample. //
45////////////////////////////////////////////////////////////////////////////////////////////////////////////
46
47
48#include "TMVA/DecisionTree.h"
49
50/* #ifndef ROOT_TMVA_DecisionTreeNode */
51/* #include "TMVA/DecisionTreeNode.h" */
52/* #endif */
53
54#include "TMVA/Event.h"
55#include <vector>
56
57namespace TMVA {
58 class DataSet;
59 class DecisionTreeNode;
60 class SeparationBase;
61
62 class CCPruner {
63 public:
64 typedef std::vector<Event*> EventList;
65
66 CCPruner( DecisionTree* t_max,
67 const EventList* validationSample,
68 SeparationBase* qualityIndex = nullptr );
69
70 CCPruner( DecisionTree* t_max,
71 const DataSet* validationSample,
72 SeparationBase* qualityIndex = nullptr );
73
74 ~CCPruner( );
75
76 // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated)
77 void SetPruneStrength( Float_t alpha = -1.0 );
78
79 void Optimize( );
80
81 // return the list of pruning locations to define the optimal subtree T' of T_max
82 std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const;
83
84 // return the quality index from the validation sample for the optimal subtree T'
85 inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ?
87
88 // return the prune strength (=alpha) corresponding to the prune sequence
89 inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ?
91
92 private:
93 Float_t fAlpha; ///<! regularization parameter in CC pruning
94 const EventList* fValidationSample; ///<! the event sample to select the optimally-pruned tree
95 const DataSet* fValidationDataSet; ///<! the event sample to select the optimally-pruned tree
96 SeparationBase* fQualityIndex; ///<! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
97 Bool_t fOwnQIndex; ///<! flag indicates if fQualityIndex is owned by this
98
99 DecisionTree* fTree; ///<! (pruned) decision tree
100
101 std::vector<TMVA::DecisionTreeNode*> fPruneSequence; ///<! map of weakest links (i.e., branches to prune) -> pruning index
102 std::vector<Float_t> fPruneStrengthList; ///<! map of alpha -> pruning index
103 std::vector<Float_t> fQualityIndexList; ///<! map of R(T) -> pruning index
104
105 Int_t fOptimalK; ///<! index of the optimal tree in the pruned tree sequence
106 Bool_t fDebug; ///<! debug flag
107 };
108}
109
111 fAlpha = (alpha > 0 ? alpha : 0.0);
112}
113
114
115#endif
116
117
float Float_t
Definition RtypesCore.h:57
A helper class to prune a decision tree using the Cost Complexity method (see Classification and Regr...
Definition CCPruner.h:62
Float_t GetOptimalQualityIndex() const
Definition CCPruner.h:85
void SetPruneStrength(Float_t alpha=-1.0)
Definition CCPruner.h:110
Float_t fAlpha
! regularization parameter in CC pruning
Definition CCPruner.h:93
std::vector< Float_t > fQualityIndexList
! map of R(T) -> pruning index
Definition CCPruner.h:103
void Optimize()
determine the pruning sequence
Definition CCPruner.cxx:124
Bool_t fDebug
! debug flag
Definition CCPruner.h:106
Bool_t fOwnQIndex
! flag indicates if fQualityIndex is owned by this
Definition CCPruner.h:97
std::vector< Event * > EventList
Definition CCPruner.h:64
std::vector< TMVA::DecisionTreeNode * > fPruneSequence
! map of weakest links (i.e., branches to prune) -> pruning index
Definition CCPruner.h:101
const EventList * fValidationSample
! the event sample to select the optimally-pruned tree
Definition CCPruner.h:94
std::vector< TMVA::DecisionTreeNode * > GetOptimalPruneSequence() const
return the prune strength (=alpha) corresponding to the prune sequence
Definition CCPruner.cxx:240
Int_t fOptimalK
! index of the optimal tree in the pruned tree sequence
Definition CCPruner.h:105
const DataSet * fValidationDataSet
! the event sample to select the optimally-pruned tree
Definition CCPruner.h:95
std::vector< Float_t > fPruneStrengthList
! map of alpha -> pruning index
Definition CCPruner.h:102
SeparationBase * fQualityIndex
! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition CCPruner.h:96
DecisionTree * fTree
! (pruned) decision tree
Definition CCPruner.h:99
Float_t GetOptimalPruneStrength() const
Definition CCPruner.h:89
Class that contains all the data information.
Definition DataSet.h:58
Implementation of a Decision Tree.
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
create variable transformations