#include "TMVA/CCPruner.h"
#include "TMVA/SeparationBase.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/CCTreeWrapper.h"
#include "TMVA/DataSet.h"
#include <iostream>
#include <fstream>
#include <limits>
#include <math.h>
using namespace TMVA;
CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
SeparationBase* qualityIndex ) :
fAlpha(-1.0),
fValidationSample(validationSample),
fValidationDataSet(NULL),
fOptimalK(-1)
{
fTree = t_max;
if(qualityIndex == NULL) {
fOwnQIndex = true;
fQualityIndex = new MisClassificationError();
}
else {
fOwnQIndex = false;
fQualityIndex = qualityIndex;
}
fDebug = kTRUE;
}
CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
SeparationBase* qualityIndex ) :
fAlpha(-1.0),
fValidationSample(NULL),
fValidationDataSet(validationSample),
fOptimalK(-1)
{
fTree = t_max;
if(qualityIndex == NULL) {
fOwnQIndex = true;
fQualityIndex = new MisClassificationError();
}
else {
fOwnQIndex = false;
fQualityIndex = qualityIndex;
}
fDebug = kTRUE;
}
CCPruner::~CCPruner( )
{
if(fOwnQIndex) delete fQualityIndex;
}
void CCPruner::Optimize( )
{
Bool_t HaveStopCondition = fAlpha > 0;
CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);
Int_t k = 0;
Double_t epsilon = std::numeric_limits<double>::epsilon();
Double_t alpha = -1.0e10;
std::ofstream outfile;
if (fDebug) outfile.open("costcomplexity.log");
if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
delete dTWrapper;
if (fDebug) outfile.close();
return;
}
CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
while(R->GetNLeafDaughters() > 1) {
if(R->GetMinAlphaC() > alpha)
alpha = R->GetMinAlphaC();
if(HaveStopCondition && alpha > fAlpha) break;
CCTreeWrapper::CCTreeNode* t = R;
while(t->GetMinAlphaC() < t->GetAlphaC()) {
if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
t = t->GetLeftDaughter();
else
t = t->GetRightDaughter();
}
if( t == R ) {
if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
break;
}
CCTreeWrapper::CCTreeNode* n = t;
if (fDebug){
outfile << "===========================" << std::endl
<< "Pruning branch listed below" << std::endl
<< "===========================" << std::endl;
t->PrintRec( outfile );
}
if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
break;
}
dTWrapper->PruneNode(t);
while(t != R) {
t = t->GetMother();
t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + t->GetRightDaughter()->GetNLeafDaughters());
t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
t->GetRightDaughter()->GetResubstitutionEstimate());
t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate())/(t->GetNLeafDaughters() - 1));
t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
t->GetRightDaughter()->GetMinAlphaC())));
}
k += 1;
if(!HaveStopCondition) {
Double_t q;
if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
else q = dTWrapper->TestTreeQuality(fValidationSample);
fQualityIndexList.push_back(q);
}
else {
fQualityIndexList.push_back(1.0);
}
fPruneSequence.push_back(n->GetDTNode());
fPruneStrengthList.push_back(alpha);
}
Double_t qmax = -1.0e6;
if(!HaveStopCondition) {
for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
if(fQualityIndexList[i] > qmax) {
qmax = fQualityIndexList[i];
k = i;
}
}
fOptimalK = k;
}
else {
fOptimalK = fPruneSequence.size() - 1;
}
if (fDebug){
outfile << std::endl << "************ Summary **************" << std::endl
<< "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
outfile << "Pruning strength parameters: [";
for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
outfile << fPruneStrengthList[i] << ", ";
outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
outfile << "Misclassification rates: [";
for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
outfile << fQualityIndexList[i] << ", ";
outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
outfile << "Optimal index: " << fOptimalK+1 << std::endl;
outfile.close();
}
delete dTWrapper;
}
std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
{
std::vector<DecisionTreeNode*> optimalSequence;
if( fOptimalK >= 0 ) {
for( Int_t i = 0; i < fOptimalK; i++ ) {
optimalSequence.push_back(fPruneSequence[i]);
}
}
return optimalSequence;
}