#include "TMVA/CostComplexityPruneTool.h"
#include "TMVA/MsgLogger.h"
#include <fstream>
#include <limits>
#include <math.h>
using namespace TMVA;
CostComplexityPruneTool::CostComplexityPruneTool( SeparationBase* qualityIndex ) :
IPruneTool(),
fLogger(new MsgLogger("CostComplexityPruneTool") )
{
fOptimalK = -1;
fQualityIndexTool = qualityIndex;
fLogger->SetMinType( kWARNING );
}
CostComplexityPruneTool::~CostComplexityPruneTool( ) {
if(fQualityIndexTool != NULL) delete fQualityIndexTool;
}
PruningInfo*
CostComplexityPruneTool::CalculatePruningInfo( DecisionTree* dt,
const IPruneTool::EventSample* validationSample,
Bool_t isAutomatic )
{
if( isAutomatic ) SetAutomatic();
if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
return NULL;
}
Double_t Q = -1.0;
Double_t W = 1.0;
if(IsAutomatic()) {
dt->ApplyValidationSample(validationSample);
W = dt->GetSumWeights(validationSample);
Q = dt->TestPrunedTreeQuality();
Log() << kDEBUG << "Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
Log() << kDEBUG << "Sum of weights in pruning validation sample: " << W << Endl;
Log() << kDEBUG << "Quality of tree prior to any pruning is " << Q/W << Endl;
}
try {
InitTreePruningMetaData((DecisionTreeNode*)dt->GetRoot());
}
catch(std::string error) {
Log() << kERROR << "Couldn't initialize the tree meta data because of error ("
<< error << ")" << Endl;
return NULL;
}
Log() << kDEBUG << "Automatic cost complexity pruning is " << (IsAutomatic()?"on":"off") << "." << Endl;
try {
Optimize( dt, W );
}
catch(std::string error) {
Log() << kERROR << "Error optimzing pruning sequence ("
<< error << ")" << Endl;
return NULL;
}
Log() << kDEBUG << "Index of pruning sequence to stop at: " << fOptimalK << Endl;
PruningInfo* info = new PruningInfo();
if(fOptimalK < 0) {
info->PruneStrength = 0;
info->QualityIndex = Q/W;
info->PruneSequence.clear();
Log() << kINFO << "no proper pruning could be calulated. Tree "
<< dt->GetTreeID() << " will not be pruned. Do not worry if this "
<< " happens for a few trees " << Endl;
return info;
}
info->QualityIndex = fQualityIndexList[fOptimalK]/W;
Log() << kDEBUG << " prune until k=" << fOptimalK << " with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
for( Int_t i = 0; i < fOptimalK; i++ ){
info->PruneSequence.push_back(fPruneSequence[i]);
}
if( IsAutomatic() ){
info->PruneStrength = fPruneStrengthList[fOptimalK];
}
else {
info->PruneStrength = fPruneStrength;
}
return info;
}
void CostComplexityPruneTool::InitTreePruningMetaData( DecisionTreeNode* n ) {
if( n == NULL ) return;
Double_t s = n->GetNSigEvents();
Double_t b = n->GetNBkgEvents();
if (fQualityIndexTool) n->SetNodeR( (s+b)*fQualityIndexTool->GetSeparationIndex(s,b));
else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
if(n->GetLeft() != NULL && n->GetRight() != NULL) {
n->SetTerminal(kFALSE);
InitTreePruningMetaData(n->GetLeft());
InitTreePruningMetaData(n->GetRight());
n->SetNTerminal( n->GetLeft()->GetNTerminal() +
n->GetRight()->GetNTerminal());
n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
n->GetRight()->GetSubTreeR()));
n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
(n->GetNTerminal() - 1)));
n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
n->GetRight()->GetAlphaMinSubtree())));
n->SetCC(n->GetAlpha());
} else {
n->SetNTerminal( 1 ); n->SetTerminal( );
if (fQualityIndexTool) n->SetSubTreeR(((s+b)*fQualityIndexTool->GetSeparationIndex(s,b)));
else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
n->SetAlpha(std::numeric_limits<double>::infinity( ));
n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
n->SetCC(n->GetAlpha());
}
}
void CostComplexityPruneTool::Optimize( DecisionTree* dt, Double_t weights ) {
Int_t k = 1;
Double_t alpha = -1.0e10;
Double_t epsilon = std::numeric_limits<double>::epsilon();
fQualityIndexList.clear();
fPruneSequence.clear();
fPruneStrengthList.clear();
DecisionTreeNode* R = (DecisionTreeNode*)dt->GetRoot();
Double_t qmin = 0.0;
if(IsAutomatic()){
qmin = dt->TestPrunedTreeQuality()/weights;
}
while(R->GetNTerminal() > 1) {
alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
break;
}
DecisionTreeNode* t = R;
while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree()) < epsilon) {
t = t->GetLeft();
} else {
t = t->GetRight();
}
}
if( t == R ) {
Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
break;
}
DecisionTreeNode* n = t;
dt->PruneNodeInPlace(t);
while(t != R) {
t = t->GetParent();
t->SetNTerminal(t->GetLeft()->GetNTerminal() + t->GetRight()->GetNTerminal());
t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
t->GetRight()->GetAlphaMinSubtree())));
t->SetCC(t->GetAlpha());
}
k += 1;
Log() << kDEBUG << "after this pruning step I would have " << R->GetNTerminal() << " remaining terminal nodes " << Endl;
if(IsAutomatic()) {
Double_t q = dt->TestPrunedTreeQuality()/weights;
fQualityIndexList.push_back(q);
}
else {
fQualityIndexList.push_back(1.0);
}
fPruneSequence.push_back(n);
fPruneStrengthList.push_back(alpha);
}
if(fPruneSequence.empty()) {
fOptimalK = -1;
return;
}
if(IsAutomatic()) {
k = -1;
for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
if(fQualityIndexList[i] < qmin) {
qmin = fQualityIndexList[i];
k = i;
}
}
fOptimalK = k;
}
else {
fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
Log() << kDEBUG << "SequenzeSize="<<fPruneSequence.size()
<< " fOptimalK " << fOptimalK << Endl;
}
Log() << kDEBUG << "\n************ Summary for Tree " << dt->GetTreeID() << " *******" << Endl
<< "Number of trees in the sequence: " << fPruneSequence.size() << Endl;
Log() << kDEBUG << "Pruning strength parameters: [";
for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
Log() << kDEBUG << fPruneStrengthList[i] << ", ";
Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << Endl;
Log() << kDEBUG << "Misclassification rates: [";
for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
Log() << kDEBUG << fQualityIndexList[i] << ", ";
Log() << kDEBUG << fQualityIndexList[fQualityIndexList.size()-1] << "]" << Endl;
Log() << kDEBUG << "Prune index: " << fOptimalK+1 << Endl;
}
CostComplexityPruneTool.cxx:1 CostComplexityPruneTool.cxx:2 CostComplexityPruneTool.cxx:3 CostComplexityPruneTool.cxx:4 CostComplexityPruneTool.cxx:5 CostComplexityPruneTool.cxx:6 CostComplexityPruneTool.cxx:7 CostComplexityPruneTool.cxx:8 CostComplexityPruneTool.cxx:9 CostComplexityPruneTool.cxx:10 CostComplexityPruneTool.cxx:11 CostComplexityPruneTool.cxx:12 CostComplexityPruneTool.cxx:13 CostComplexityPruneTool.cxx:14 CostComplexityPruneTool.cxx:15 CostComplexityPruneTool.cxx:16 CostComplexityPruneTool.cxx:17 CostComplexityPruneTool.cxx:18 CostComplexityPruneTool.cxx:19 CostComplexityPruneTool.cxx:20 CostComplexityPruneTool.cxx:21 CostComplexityPruneTool.cxx:22 CostComplexityPruneTool.cxx:23 CostComplexityPruneTool.cxx:24 CostComplexityPruneTool.cxx:25 CostComplexityPruneTool.cxx:26 CostComplexityPruneTool.cxx:27 CostComplexityPruneTool.cxx:28 CostComplexityPruneTool.cxx:29 CostComplexityPruneTool.cxx:30 CostComplexityPruneTool.cxx:31 CostComplexityPruneTool.cxx:32 CostComplexityPruneTool.cxx:33 CostComplexityPruneTool.cxx:34 CostComplexityPruneTool.cxx:35 CostComplexityPruneTool.cxx:36 CostComplexityPruneTool.cxx:37 CostComplexityPruneTool.cxx:38 CostComplexityPruneTool.cxx:39 CostComplexityPruneTool.cxx:40 CostComplexityPruneTool.cxx:41 CostComplexityPruneTool.cxx:42 CostComplexityPruneTool.cxx:43 CostComplexityPruneTool.cxx:44 CostComplexityPruneTool.cxx:45 CostComplexityPruneTool.cxx:46 CostComplexityPruneTool.cxx:47 CostComplexityPruneTool.cxx:48 CostComplexityPruneTool.cxx:49 CostComplexityPruneTool.cxx:50 CostComplexityPruneTool.cxx:51 CostComplexityPruneTool.cxx:52 CostComplexityPruneTool.cxx:53 CostComplexityPruneTool.cxx:54 CostComplexityPruneTool.cxx:55 CostComplexityPruneTool.cxx:56 CostComplexityPruneTool.cxx:57 CostComplexityPruneTool.cxx:58 CostComplexityPruneTool.cxx:59 CostComplexityPruneTool.cxx:60 CostComplexityPruneTool.cxx:61 CostComplexityPruneTool.cxx:62 CostComplexityPruneTool.cxx:63 CostComplexityPruneTool.cxx:64 CostComplexityPruneTool.cxx:65 CostComplexityPruneTool.cxx:66 CostComplexityPruneTool.cxx:67 CostComplexityPruneTool.cxx:68 CostComplexityPruneTool.cxx:69 CostComplexityPruneTool.cxx:70 CostComplexityPruneTool.cxx:71 CostComplexityPruneTool.cxx:72 CostComplexityPruneTool.cxx:73 CostComplexityPruneTool.cxx:74 CostComplexityPruneTool.cxx:75 CostComplexityPruneTool.cxx:76 CostComplexityPruneTool.cxx:77 CostComplexityPruneTool.cxx:78 CostComplexityPruneTool.cxx:79 CostComplexityPruneTool.cxx:80 CostComplexityPruneTool.cxx:81 CostComplexityPruneTool.cxx:82 CostComplexityPruneTool.cxx:83 CostComplexityPruneTool.cxx:84 CostComplexityPruneTool.cxx:85 CostComplexityPruneTool.cxx:86 CostComplexityPruneTool.cxx:87 CostComplexityPruneTool.cxx:88 CostComplexityPruneTool.cxx:89 CostComplexityPruneTool.cxx:90 CostComplexityPruneTool.cxx:91 CostComplexityPruneTool.cxx:92 CostComplexityPruneTool.cxx:93 CostComplexityPruneTool.cxx:94 CostComplexityPruneTool.cxx:95 CostComplexityPruneTool.cxx:96 CostComplexityPruneTool.cxx:97 CostComplexityPruneTool.cxx:98 CostComplexityPruneTool.cxx:99 CostComplexityPruneTool.cxx:100 CostComplexityPruneTool.cxx:101 CostComplexityPruneTool.cxx:102 CostComplexityPruneTool.cxx:103 CostComplexityPruneTool.cxx:104 CostComplexityPruneTool.cxx:105 CostComplexityPruneTool.cxx:106 CostComplexityPruneTool.cxx:107 CostComplexityPruneTool.cxx:108 CostComplexityPruneTool.cxx:109 CostComplexityPruneTool.cxx:110 CostComplexityPruneTool.cxx:111 CostComplexityPruneTool.cxx:112 CostComplexityPruneTool.cxx:113 CostComplexityPruneTool.cxx:114 CostComplexityPruneTool.cxx:115 CostComplexityPruneTool.cxx:116 CostComplexityPruneTool.cxx:117 CostComplexityPruneTool.cxx:118 CostComplexityPruneTool.cxx:119 CostComplexityPruneTool.cxx:120 CostComplexityPruneTool.cxx:121 CostComplexityPruneTool.cxx:122 CostComplexityPruneTool.cxx:123 CostComplexityPruneTool.cxx:124 CostComplexityPruneTool.cxx:125 CostComplexityPruneTool.cxx:126 CostComplexityPruneTool.cxx:127 CostComplexityPruneTool.cxx:128 CostComplexityPruneTool.cxx:129 CostComplexityPruneTool.cxx:130 CostComplexityPruneTool.cxx:131 CostComplexityPruneTool.cxx:132 CostComplexityPruneTool.cxx:133 CostComplexityPruneTool.cxx:134 CostComplexityPruneTool.cxx:135 CostComplexityPruneTool.cxx:136 CostComplexityPruneTool.cxx:137 CostComplexityPruneTool.cxx:138 CostComplexityPruneTool.cxx:139 CostComplexityPruneTool.cxx:140 CostComplexityPruneTool.cxx:141 CostComplexityPruneTool.cxx:142 CostComplexityPruneTool.cxx:143 CostComplexityPruneTool.cxx:144 CostComplexityPruneTool.cxx:145 CostComplexityPruneTool.cxx:146 CostComplexityPruneTool.cxx:147 CostComplexityPruneTool.cxx:148 CostComplexityPruneTool.cxx:149 CostComplexityPruneTool.cxx:150 CostComplexityPruneTool.cxx:151 CostComplexityPruneTool.cxx:152 CostComplexityPruneTool.cxx:153 CostComplexityPruneTool.cxx:154 CostComplexityPruneTool.cxx:155 CostComplexityPruneTool.cxx:156 CostComplexityPruneTool.cxx:157 CostComplexityPruneTool.cxx:158 CostComplexityPruneTool.cxx:159 CostComplexityPruneTool.cxx:160 CostComplexityPruneTool.cxx:161 CostComplexityPruneTool.cxx:162 CostComplexityPruneTool.cxx:163 CostComplexityPruneTool.cxx:164 CostComplexityPruneTool.cxx:165 CostComplexityPruneTool.cxx:166 CostComplexityPruneTool.cxx:167 CostComplexityPruneTool.cxx:168 CostComplexityPruneTool.cxx:169 CostComplexityPruneTool.cxx:170 CostComplexityPruneTool.cxx:171 CostComplexityPruneTool.cxx:172 CostComplexityPruneTool.cxx:173 CostComplexityPruneTool.cxx:174 CostComplexityPruneTool.cxx:175 CostComplexityPruneTool.cxx:176 CostComplexityPruneTool.cxx:177 CostComplexityPruneTool.cxx:178 CostComplexityPruneTool.cxx:179 CostComplexityPruneTool.cxx:180 CostComplexityPruneTool.cxx:181 CostComplexityPruneTool.cxx:182 CostComplexityPruneTool.cxx:183 CostComplexityPruneTool.cxx:184 CostComplexityPruneTool.cxx:185 CostComplexityPruneTool.cxx:186 CostComplexityPruneTool.cxx:187 CostComplexityPruneTool.cxx:188 CostComplexityPruneTool.cxx:189 CostComplexityPruneTool.cxx:190 CostComplexityPruneTool.cxx:191 CostComplexityPruneTool.cxx:192 CostComplexityPruneTool.cxx:193 CostComplexityPruneTool.cxx:194 CostComplexityPruneTool.cxx:195 CostComplexityPruneTool.cxx:196 CostComplexityPruneTool.cxx:197 CostComplexityPruneTool.cxx:198 CostComplexityPruneTool.cxx:199 CostComplexityPruneTool.cxx:200 CostComplexityPruneTool.cxx:201 CostComplexityPruneTool.cxx:202 CostComplexityPruneTool.cxx:203 CostComplexityPruneTool.cxx:204 CostComplexityPruneTool.cxx:205 CostComplexityPruneTool.cxx:206 CostComplexityPruneTool.cxx:207 CostComplexityPruneTool.cxx:208 CostComplexityPruneTool.cxx:209 CostComplexityPruneTool.cxx:210 CostComplexityPruneTool.cxx:211 CostComplexityPruneTool.cxx:212 CostComplexityPruneTool.cxx:213 CostComplexityPruneTool.cxx:214 CostComplexityPruneTool.cxx:215 CostComplexityPruneTool.cxx:216 CostComplexityPruneTool.cxx:217 CostComplexityPruneTool.cxx:218 CostComplexityPruneTool.cxx:219 CostComplexityPruneTool.cxx:220 CostComplexityPruneTool.cxx:221 CostComplexityPruneTool.cxx:222 CostComplexityPruneTool.cxx:223 CostComplexityPruneTool.cxx:224 CostComplexityPruneTool.cxx:225 CostComplexityPruneTool.cxx:226 CostComplexityPruneTool.cxx:227 CostComplexityPruneTool.cxx:228 CostComplexityPruneTool.cxx:229 CostComplexityPruneTool.cxx:230 CostComplexityPruneTool.cxx:231 CostComplexityPruneTool.cxx:232 CostComplexityPruneTool.cxx:233 CostComplexityPruneTool.cxx:234 CostComplexityPruneTool.cxx:235 CostComplexityPruneTool.cxx:236 CostComplexityPruneTool.cxx:237 CostComplexityPruneTool.cxx:238 CostComplexityPruneTool.cxx:239 CostComplexityPruneTool.cxx:240 CostComplexityPruneTool.cxx:241 CostComplexityPruneTool.cxx:242 CostComplexityPruneTool.cxx:243 CostComplexityPruneTool.cxx:244 CostComplexityPruneTool.cxx:245 CostComplexityPruneTool.cxx:246 CostComplexityPruneTool.cxx:247 CostComplexityPruneTool.cxx:248 CostComplexityPruneTool.cxx:249 CostComplexityPruneTool.cxx:250 CostComplexityPruneTool.cxx:251 CostComplexityPruneTool.cxx:252 CostComplexityPruneTool.cxx:253 CostComplexityPruneTool.cxx:254 CostComplexityPruneTool.cxx:255 CostComplexityPruneTool.cxx:256 CostComplexityPruneTool.cxx:257 CostComplexityPruneTool.cxx:258 CostComplexityPruneTool.cxx:259 CostComplexityPruneTool.cxx:260 CostComplexityPruneTool.cxx:261 CostComplexityPruneTool.cxx:262 CostComplexityPruneTool.cxx:263 CostComplexityPruneTool.cxx:264 CostComplexityPruneTool.cxx:265 CostComplexityPruneTool.cxx:266 CostComplexityPruneTool.cxx:267 CostComplexityPruneTool.cxx:268 CostComplexityPruneTool.cxx:269 CostComplexityPruneTool.cxx:270 CostComplexityPruneTool.cxx:271 CostComplexityPruneTool.cxx:272 CostComplexityPruneTool.cxx:273 CostComplexityPruneTool.cxx:274 CostComplexityPruneTool.cxx:275 CostComplexityPruneTool.cxx:276 CostComplexityPruneTool.cxx:277 CostComplexityPruneTool.cxx:278 CostComplexityPruneTool.cxx:279 CostComplexityPruneTool.cxx:280 CostComplexityPruneTool.cxx:281 CostComplexityPruneTool.cxx:282 CostComplexityPruneTool.cxx:283 CostComplexityPruneTool.cxx:284 CostComplexityPruneTool.cxx:285 CostComplexityPruneTool.cxx:286 CostComplexityPruneTool.cxx:287 CostComplexityPruneTool.cxx:288 CostComplexityPruneTool.cxx:289 CostComplexityPruneTool.cxx:290 CostComplexityPruneTool.cxx:291 CostComplexityPruneTool.cxx:292 CostComplexityPruneTool.cxx:293 CostComplexityPruneTool.cxx:294 CostComplexityPruneTool.cxx:295 CostComplexityPruneTool.cxx:296 CostComplexityPruneTool.cxx:297 CostComplexityPruneTool.cxx:298 CostComplexityPruneTool.cxx:299 CostComplexityPruneTool.cxx:300 CostComplexityPruneTool.cxx:301 CostComplexityPruneTool.cxx:302 CostComplexityPruneTool.cxx:303 CostComplexityPruneTool.cxx:304 CostComplexityPruneTool.cxx:305 CostComplexityPruneTool.cxx:306 CostComplexityPruneTool.cxx:307 CostComplexityPruneTool.cxx:308 CostComplexityPruneTool.cxx:309 CostComplexityPruneTool.cxx:310 CostComplexityPruneTool.cxx:311 CostComplexityPruneTool.cxx:312 CostComplexityPruneTool.cxx:313 CostComplexityPruneTool.cxx:314 CostComplexityPruneTool.cxx:315 CostComplexityPruneTool.cxx:316 CostComplexityPruneTool.cxx:317 CostComplexityPruneTool.cxx:318 CostComplexityPruneTool.cxx:319 CostComplexityPruneTool.cxx:320 CostComplexityPruneTool.cxx:321 CostComplexityPruneTool.cxx:322 CostComplexityPruneTool.cxx:323 CostComplexityPruneTool.cxx:324 CostComplexityPruneTool.cxx:325 CostComplexityPruneTool.cxx:326 CostComplexityPruneTool.cxx:327 CostComplexityPruneTool.cxx:328 CostComplexityPruneTool.cxx:329 CostComplexityPruneTool.cxx:330 CostComplexityPruneTool.cxx:331 CostComplexityPruneTool.cxx:332 CostComplexityPruneTool.cxx:333 CostComplexityPruneTool.cxx:334 CostComplexityPruneTool.cxx:335 CostComplexityPruneTool.cxx:336 CostComplexityPruneTool.cxx:337 CostComplexityPruneTool.cxx:338 CostComplexityPruneTool.cxx:339