/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : CCPruner                                                              *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description: Cost Complexity Pruning                                           *
 * 
 * Author: Doug Schouten (dschoute@sfu.ca)
 *
 *                                                                                *
 * Copyright (c) 2007:                                                            *
 *      CERN, Switzerland                                                         *
 *      MPI-K Heidelberg, Germany                                                 *
 *      U. of Texas at Austin, USA                                                *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#include "TMVA/CCPruner.h"
#include "TMVA/SeparationBase.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/CCTreeWrapper.h"
#include "TMVA/DataSet.h"

#include <iostream>
#include <fstream>
#include <limits>
#include <math.h>

 using namespace TMVA;

//_______________________________________________________________________
CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
                    SeparationBase* qualityIndex ) : 
   fAlpha(-1.0), 
   fValidationSample(validationSample),
   fValidationDataSet(NULL),
   fOptimalK(-1)
{
   // constructor
   fTree = t_max;
   
   if(qualityIndex == NULL) {
      fOwnQIndex = true;
      fQualityIndex = new MisClassificationError();
   }
   else {
      fOwnQIndex = false;
      fQualityIndex = qualityIndex;
   }
   fDebug = kTRUE;
}

//_______________________________________________________________________
CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
                    SeparationBase* qualityIndex ) : 
   fAlpha(-1.0), 
   fValidationSample(NULL),
   fValidationDataSet(validationSample),
   fOptimalK(-1)
{
   // constructor
   fTree = t_max;
   
   if(qualityIndex == NULL) {
      fOwnQIndex = true;
      fQualityIndex = new MisClassificationError();
   }
   else {
      fOwnQIndex = false;
      fQualityIndex = qualityIndex;
   }
   fDebug = kTRUE;
}


//_______________________________________________________________________
CCPruner::~CCPruner( )
{
   if(fOwnQIndex) delete fQualityIndex;
   // destructor
}

//_______________________________________________________________________
void CCPruner::Optimize( )
{
   // determine the pruning sequence

   Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha

   // build a wrapper tree to perform work on
   CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);

   Int_t    k = 0;
   Double_t epsilon = std::numeric_limits<double>::epsilon();
   Double_t alpha = -1.0e10;

   std::ofstream outfile;
   if (fDebug) outfile.open("costcomplexity.log");
   if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
      if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
      delete dTWrapper;
      if (fDebug) outfile.close();
      return;
   }

   CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
   while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
      if(R->GetMinAlphaC() > alpha) 
         alpha = R->GetMinAlphaC(); // initialize alpha

      if(HaveStopCondition && alpha > fAlpha) break;

      CCTreeWrapper::CCTreeNode* t = R;

      while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link

         if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon) 
            t = t->GetLeftDaughter();
         else
            t = t->GetRightDaughter();
      }
    
      if( t == R ) {
         if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
         break;
      }

      CCTreeWrapper::CCTreeNode* n = t;

      if (fDebug){
         outfile << "===========================" << std::endl
                 << "Pruning branch listed below" << std::endl
                 << "===========================" << std::endl;
         t->PrintRec( outfile );
       
      }
      if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
         break;
      }
      dTWrapper->PruneNode(t); // prune the branch rooted at node t

      while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
         t = t->GetMother();
         t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + t->GetRightDaughter()->GetNLeafDaughters());
         t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() + 
                                      t->GetRightDaughter()->GetResubstitutionEstimate());
         t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate())/(t->GetNLeafDaughters() - 1));
         t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(), 
                                                           t->GetRightDaughter()->GetMinAlphaC())));
      }
      k += 1;
      if(!HaveStopCondition) {
         Double_t q;
         if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
         else q = dTWrapper->TestTreeQuality(fValidationSample);
         fQualityIndexList.push_back(q);
      }
      else { 
         fQualityIndexList.push_back(1.0);
      }
      fPruneSequence.push_back(n->GetDTNode());
      fPruneStrengthList.push_back(alpha);
   }
  
   Double_t qmax = -1.0e6;
   if(!HaveStopCondition) {
      for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
         if(fQualityIndexList[i] > qmax) {
            qmax = fQualityIndexList[i];
            k = i;
         }
      }
      fOptimalK = k;
   }
   else {
      fOptimalK = fPruneSequence.size() - 1;
   }

   if (fDebug){
      outfile << std::endl << "************ Summary **************"  << std::endl
              << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
     
      outfile << "Pruning strength parameters: [";
      for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++) 
         outfile << fPruneStrengthList[i] << ", ";
      outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
     
      outfile << "Misclassification rates: [";
      for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++) 
         outfile << fQualityIndexList[i] << ", ";
      outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]"  << std::endl;
     
      outfile << "Optimal index: " << fOptimalK+1 << std::endl;
      outfile.close();
   }
   delete dTWrapper;
}

//_______________________________________________________________________
std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
{
   // return the prune strength (=alpha) corresponding to the prune sequence
   std::vector<DecisionTreeNode*> optimalSequence;
   if( fOptimalK >= 0 ) {
      for( Int_t i = 0; i < fOptimalK; i++ ) {
         optimalSequence.push_back(fPruneSequence[i]);
      }
   }
   return optimalSequence;
}


 CCPruner.cxx:1
 CCPruner.cxx:2
 CCPruner.cxx:3
 CCPruner.cxx:4
 CCPruner.cxx:5
 CCPruner.cxx:6
 CCPruner.cxx:7
 CCPruner.cxx:8
 CCPruner.cxx:9
 CCPruner.cxx:10
 CCPruner.cxx:11
 CCPruner.cxx:12
 CCPruner.cxx:13
 CCPruner.cxx:14
 CCPruner.cxx:15
 CCPruner.cxx:16
 CCPruner.cxx:17
 CCPruner.cxx:18
 CCPruner.cxx:19
 CCPruner.cxx:20
 CCPruner.cxx:21
 CCPruner.cxx:22
 CCPruner.cxx:23
 CCPruner.cxx:24
 CCPruner.cxx:25
 CCPruner.cxx:26
 CCPruner.cxx:27
 CCPruner.cxx:28
 CCPruner.cxx:29
 CCPruner.cxx:30
 CCPruner.cxx:31
 CCPruner.cxx:32
 CCPruner.cxx:33
 CCPruner.cxx:34
 CCPruner.cxx:35
 CCPruner.cxx:36
 CCPruner.cxx:37
 CCPruner.cxx:38
 CCPruner.cxx:39
 CCPruner.cxx:40
 CCPruner.cxx:41
 CCPruner.cxx:42
 CCPruner.cxx:43
 CCPruner.cxx:44
 CCPruner.cxx:45
 CCPruner.cxx:46
 CCPruner.cxx:47
 CCPruner.cxx:48
 CCPruner.cxx:49
 CCPruner.cxx:50
 CCPruner.cxx:51
 CCPruner.cxx:52
 CCPruner.cxx:53
 CCPruner.cxx:54
 CCPruner.cxx:55
 CCPruner.cxx:56
 CCPruner.cxx:57
 CCPruner.cxx:58
 CCPruner.cxx:59
 CCPruner.cxx:60
 CCPruner.cxx:61
 CCPruner.cxx:62
 CCPruner.cxx:63
 CCPruner.cxx:64
 CCPruner.cxx:65
 CCPruner.cxx:66
 CCPruner.cxx:67
 CCPruner.cxx:68
 CCPruner.cxx:69
 CCPruner.cxx:70
 CCPruner.cxx:71
 CCPruner.cxx:72
 CCPruner.cxx:73
 CCPruner.cxx:74
 CCPruner.cxx:75
 CCPruner.cxx:76
 CCPruner.cxx:77
 CCPruner.cxx:78
 CCPruner.cxx:79
 CCPruner.cxx:80
 CCPruner.cxx:81
 CCPruner.cxx:82
 CCPruner.cxx:83
 CCPruner.cxx:84
 CCPruner.cxx:85
 CCPruner.cxx:86
 CCPruner.cxx:87
 CCPruner.cxx:88
 CCPruner.cxx:89
 CCPruner.cxx:90
 CCPruner.cxx:91
 CCPruner.cxx:92
 CCPruner.cxx:93
 CCPruner.cxx:94
 CCPruner.cxx:95
 CCPruner.cxx:96
 CCPruner.cxx:97
 CCPruner.cxx:98
 CCPruner.cxx:99
 CCPruner.cxx:100
 CCPruner.cxx:101
 CCPruner.cxx:102
 CCPruner.cxx:103
 CCPruner.cxx:104
 CCPruner.cxx:105
 CCPruner.cxx:106
 CCPruner.cxx:107
 CCPruner.cxx:108
 CCPruner.cxx:109
 CCPruner.cxx:110
 CCPruner.cxx:111
 CCPruner.cxx:112
 CCPruner.cxx:113
 CCPruner.cxx:114
 CCPruner.cxx:115
 CCPruner.cxx:116
 CCPruner.cxx:117
 CCPruner.cxx:118
 CCPruner.cxx:119
 CCPruner.cxx:120
 CCPruner.cxx:121
 CCPruner.cxx:122
 CCPruner.cxx:123
 CCPruner.cxx:124
 CCPruner.cxx:125
 CCPruner.cxx:126
 CCPruner.cxx:127
 CCPruner.cxx:128
 CCPruner.cxx:129
 CCPruner.cxx:130
 CCPruner.cxx:131
 CCPruner.cxx:132
 CCPruner.cxx:133
 CCPruner.cxx:134
 CCPruner.cxx:135
 CCPruner.cxx:136
 CCPruner.cxx:137
 CCPruner.cxx:138
 CCPruner.cxx:139
 CCPruner.cxx:140
 CCPruner.cxx:141
 CCPruner.cxx:142
 CCPruner.cxx:143
 CCPruner.cxx:144
 CCPruner.cxx:145
 CCPruner.cxx:146
 CCPruner.cxx:147
 CCPruner.cxx:148
 CCPruner.cxx:149
 CCPruner.cxx:150
 CCPruner.cxx:151
 CCPruner.cxx:152
 CCPruner.cxx:153
 CCPruner.cxx:154
 CCPruner.cxx:155
 CCPruner.cxx:156
 CCPruner.cxx:157
 CCPruner.cxx:158
 CCPruner.cxx:159
 CCPruner.cxx:160
 CCPruner.cxx:161
 CCPruner.cxx:162
 CCPruner.cxx:163
 CCPruner.cxx:164
 CCPruner.cxx:165
 CCPruner.cxx:166
 CCPruner.cxx:167
 CCPruner.cxx:168
 CCPruner.cxx:169
 CCPruner.cxx:170
 CCPruner.cxx:171
 CCPruner.cxx:172
 CCPruner.cxx:173
 CCPruner.cxx:174
 CCPruner.cxx:175
 CCPruner.cxx:176
 CCPruner.cxx:177
 CCPruner.cxx:178
 CCPruner.cxx:179
 CCPruner.cxx:180
 CCPruner.cxx:181
 CCPruner.cxx:182
 CCPruner.cxx:183
 CCPruner.cxx:184
 CCPruner.cxx:185
 CCPruner.cxx:186
 CCPruner.cxx:187
 CCPruner.cxx:188
 CCPruner.cxx:189
 CCPruner.cxx:190
 CCPruner.cxx:191
 CCPruner.cxx:192
 CCPruner.cxx:193
 CCPruner.cxx:194
 CCPruner.cxx:195
 CCPruner.cxx:196
 CCPruner.cxx:197
 CCPruner.cxx:198
 CCPruner.cxx:199
 CCPruner.cxx:200
 CCPruner.cxx:201
 CCPruner.cxx:202
 CCPruner.cxx:203
 CCPruner.cxx:204
 CCPruner.cxx:205
 CCPruner.cxx:206
 CCPruner.cxx:207
 CCPruner.cxx:208
 CCPruner.cxx:209
 CCPruner.cxx:210
 CCPruner.cxx:211
 CCPruner.cxx:212
 CCPruner.cxx:213
 CCPruner.cxx:214
 CCPruner.cxx:215
 CCPruner.cxx:216
 CCPruner.cxx:217