ROOT logo
/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : CCTreeWrapper                                                         *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description: a light wrapper of a decision tree, used to perform cost          *
 *              complexity pruning "in-place" Cost Complexity Pruning             *
 *                                                                                *  
 * Author: Doug Schouten (dschoute@sfu.ca)                                        *
 *                                                                                *
 *                                                                                *
 * Copyright (c) 2007:                                                            *
 *      CERN, Switzerland                                                         *
 *      MPI-K Heidelberg, Germany                                                 *
 *      U. of Texas at Austin, USA                                                *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#include "TMVA/CCTreeWrapper.h"

#include <iostream>
#include <limits>

using namespace TMVA;

//_______________________________________________________________________
TMVA::CCTreeWrapper::CCTreeNode::CCTreeNode( DecisionTreeNode* n ) :
   Node(),
   fNLeafDaughters(0),
   fNodeResubstitutionEstimate(-1.0),
   fResubstitutionEstimate(-1.0),
   fAlphaC(-1.0),
   fMinAlphaC(-1.0),
   fDTNode(n)
{
   //constructor of the CCTreeNode
   if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
      SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
      GetRight()->SetParent(this);
      SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
      GetLeft()->SetParent(this);
   }
}

//_______________________________________________________________________
TMVA::CCTreeWrapper::CCTreeNode::~CCTreeNode() {
   // destructor of a CCTreeNode

   if(GetLeft() != NULL) delete GetLeftDaughter();
   if(GetRight() != NULL) delete GetRightDaughter();
}

//_______________________________________________________________________
Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
   // initialize a node from a data record
   
   std::string header, title;
   in >> header;
   in >> title; in >> fNLeafDaughters;
   in >> title; in >> fNodeResubstitutionEstimate;
   in >> title; in >> fResubstitutionEstimate;
   in >> title; in >> fAlphaC;
   in >> title; in >> fMinAlphaC;
   return true;
}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
   // printout of the node (can be read in with ReadDataRecord)

   os << "----------------------" << std::endl 
      << "|~T_t| " << fNLeafDaughters << std::endl 
      << "R(t): " << fNodeResubstitutionEstimate << std::endl 
      << "R(T_t): " << fResubstitutionEstimate << std::endl
      << "g(t): " << fAlphaC << std::endl
      << "G(t): " << fMinAlphaC << std::endl;
}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
   // recursive printout of the node and its daughters 

   this->Print(os);
   if(this->GetLeft() != NULL && this->GetRight() != NULL) {
      this->GetLeft()->PrintRec(os);
      this->GetRight()->PrintRec(os);
   }
}

//_______________________________________________________________________
TMVA::CCTreeWrapper::CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex ) :
   fRoot(NULL)
{
   // constructor

   fDTParent = T;
   fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
   fQualityIndex = qualityIndex;
   InitTree(fRoot);
}
  
//_______________________________________________________________________
TMVA::CCTreeWrapper::~CCTreeWrapper( ) {
   // destructor

   delete fRoot; 
}  

//_______________________________________________________________________
void TMVA::CCTreeWrapper::InitTree( CCTreeNode* t )
{
   // initialize the node t and all its descendants
   Double_t s = t->GetDTNode()->GetNSigEvents();
   Double_t b = t->GetDTNode()->GetNBkgEvents();
   //   Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
   //   Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
   // set R(t) = Gini(t) or MisclassificationError(t), etc.
   t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));

   if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
      // traverse the tree 
      InitTree(t->GetLeftDaughter());
      InitTree(t->GetRightDaughter());
      // set |~T_t|
      t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + 
                           t->GetRightDaughter()->GetNLeafDaughters());    
      // set R(T) = sum[t' in ~T]{ R(t) }
      t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
                                   t->GetRightDaughter()->GetResubstitutionEstimate());
      // set g(t)
      t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate()) / 
                   (t->GetNLeafDaughters() - 1));
      // G(t) = min( g(t), G(l(t)), G(r(t)) )
      t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(), 
                                                        t->GetRightDaughter()->GetMinAlphaC())));
   }
   else { // n is a terminal node
      t->SetNLeafDaughters(1);
      t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
      t->SetAlphaC(std::numeric_limits<double>::infinity( ));
      t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
   }
}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::PruneNode( CCTreeNode* t )
{
   // remove the branch rooted at node t

   if( t->GetLeft() != NULL &&
       t->GetRight() != NULL ) {
      CCTreeNode* l = t->GetLeftDaughter();
      CCTreeNode* r = t->GetRightDaughter();
      t->SetNLeafDaughters( 1 );
      t->SetResubstitutionEstimate( t->GetNodeResubstitutionEstimate() );
      t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
      t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
      delete l;
      delete r;
      t->SetLeft(NULL);
      t->SetRight(NULL);
   }else{
      std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
   }
}

//_______________________________________________________________________
Double_t TMVA::CCTreeWrapper::TestTreeQuality( const EventList* validationSample )
{
   // return the misclassification rate of a pruned tree for a validation event sample
   // using an EventList

   Double_t ncorrect=0, nfalse=0;
   for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
      Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
      
      if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
         ncorrect += (*validationSample)[ievt]->GetWeight();
      }
      else{
         nfalse += (*validationSample)[ievt]->GetWeight();
      }
   }
   return  ncorrect / (ncorrect + nfalse);
}

//_______________________________________________________________________
Double_t TMVA::CCTreeWrapper::TestTreeQuality( const DataSet* validationSample )
{
   // return the misclassification rate of a pruned tree for a validation event sample
   // using the DataSet

   validationSample->SetCurrentType(Types::kValidation);
   // test the tree quality.. in terms of Miscalssification
   Double_t ncorrect=0, nfalse=0;
   for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
      const Event *ev = validationSample->GetEvent(ievt);

      Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
      
      if (isSignalType == (ev->GetClass() == 0)) {
         ncorrect += ev->GetWeight();
      }
      else{
         nfalse += ev->GetWeight();
      }
   }
   return  ncorrect / (ncorrect + nfalse);
}

//_______________________________________________________________________
Double_t TMVA::CCTreeWrapper::CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf )
{
   // return the decision tree output for an event 

   const DecisionTreeNode* current = fRoot->GetDTNode();
   CCTreeNode* t = fRoot;

   while(//current->GetNodeType() == 0 &&
         t->GetLeft() != NULL &&
         t->GetRight() != NULL){ // at an interior (non-leaf) node
      if (current->GoesRight(e)) {
         //current = (DecisionTreeNode*)current->GetRight();
         t = t->GetRightDaughter();
         current = t->GetDTNode();
      }
      else {
         //current = (DecisionTreeNode*)current->GetLeft();
         t = t->GetLeftDaughter();
         current = t->GetDTNode();
      }
   }
  
   if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
   else return current->GetPurity();
}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::CCTreeNode::AddAttributesToNode( void* /*node*/ ) const
{}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
{}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */  )
{}

//_______________________________________________________________________
void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
{}
 CCTreeWrapper.cxx:1
 CCTreeWrapper.cxx:2
 CCTreeWrapper.cxx:3
 CCTreeWrapper.cxx:4
 CCTreeWrapper.cxx:5
 CCTreeWrapper.cxx:6
 CCTreeWrapper.cxx:7
 CCTreeWrapper.cxx:8
 CCTreeWrapper.cxx:9
 CCTreeWrapper.cxx:10
 CCTreeWrapper.cxx:11
 CCTreeWrapper.cxx:12
 CCTreeWrapper.cxx:13
 CCTreeWrapper.cxx:14
 CCTreeWrapper.cxx:15
 CCTreeWrapper.cxx:16
 CCTreeWrapper.cxx:17
 CCTreeWrapper.cxx:18
 CCTreeWrapper.cxx:19
 CCTreeWrapper.cxx:20
 CCTreeWrapper.cxx:21
 CCTreeWrapper.cxx:22
 CCTreeWrapper.cxx:23
 CCTreeWrapper.cxx:24
 CCTreeWrapper.cxx:25
 CCTreeWrapper.cxx:26
 CCTreeWrapper.cxx:27
 CCTreeWrapper.cxx:28
 CCTreeWrapper.cxx:29
 CCTreeWrapper.cxx:30
 CCTreeWrapper.cxx:31
 CCTreeWrapper.cxx:32
 CCTreeWrapper.cxx:33
 CCTreeWrapper.cxx:34
 CCTreeWrapper.cxx:35
 CCTreeWrapper.cxx:36
 CCTreeWrapper.cxx:37
 CCTreeWrapper.cxx:38
 CCTreeWrapper.cxx:39
 CCTreeWrapper.cxx:40
 CCTreeWrapper.cxx:41
 CCTreeWrapper.cxx:42
 CCTreeWrapper.cxx:43
 CCTreeWrapper.cxx:44
 CCTreeWrapper.cxx:45
 CCTreeWrapper.cxx:46
 CCTreeWrapper.cxx:47
 CCTreeWrapper.cxx:48
 CCTreeWrapper.cxx:49
 CCTreeWrapper.cxx:50
 CCTreeWrapper.cxx:51
 CCTreeWrapper.cxx:52
 CCTreeWrapper.cxx:53
 CCTreeWrapper.cxx:54
 CCTreeWrapper.cxx:55
 CCTreeWrapper.cxx:56
 CCTreeWrapper.cxx:57
 CCTreeWrapper.cxx:58
 CCTreeWrapper.cxx:59
 CCTreeWrapper.cxx:60
 CCTreeWrapper.cxx:61
 CCTreeWrapper.cxx:62
 CCTreeWrapper.cxx:63
 CCTreeWrapper.cxx:64
 CCTreeWrapper.cxx:65
 CCTreeWrapper.cxx:66
 CCTreeWrapper.cxx:67
 CCTreeWrapper.cxx:68
 CCTreeWrapper.cxx:69
 CCTreeWrapper.cxx:70
 CCTreeWrapper.cxx:71
 CCTreeWrapper.cxx:72
 CCTreeWrapper.cxx:73
 CCTreeWrapper.cxx:74
 CCTreeWrapper.cxx:75
 CCTreeWrapper.cxx:76
 CCTreeWrapper.cxx:77
 CCTreeWrapper.cxx:78
 CCTreeWrapper.cxx:79
 CCTreeWrapper.cxx:80
 CCTreeWrapper.cxx:81
 CCTreeWrapper.cxx:82
 CCTreeWrapper.cxx:83
 CCTreeWrapper.cxx:84
 CCTreeWrapper.cxx:85
 CCTreeWrapper.cxx:86
 CCTreeWrapper.cxx:87
 CCTreeWrapper.cxx:88
 CCTreeWrapper.cxx:89
 CCTreeWrapper.cxx:90
 CCTreeWrapper.cxx:91
 CCTreeWrapper.cxx:92
 CCTreeWrapper.cxx:93
 CCTreeWrapper.cxx:94
 CCTreeWrapper.cxx:95
 CCTreeWrapper.cxx:96
 CCTreeWrapper.cxx:97
 CCTreeWrapper.cxx:98
 CCTreeWrapper.cxx:99
 CCTreeWrapper.cxx:100
 CCTreeWrapper.cxx:101
 CCTreeWrapper.cxx:102
 CCTreeWrapper.cxx:103
 CCTreeWrapper.cxx:104
 CCTreeWrapper.cxx:105
 CCTreeWrapper.cxx:106
 CCTreeWrapper.cxx:107
 CCTreeWrapper.cxx:108
 CCTreeWrapper.cxx:109
 CCTreeWrapper.cxx:110
 CCTreeWrapper.cxx:111
 CCTreeWrapper.cxx:112
 CCTreeWrapper.cxx:113
 CCTreeWrapper.cxx:114
 CCTreeWrapper.cxx:115
 CCTreeWrapper.cxx:116
 CCTreeWrapper.cxx:117
 CCTreeWrapper.cxx:118
 CCTreeWrapper.cxx:119
 CCTreeWrapper.cxx:120
 CCTreeWrapper.cxx:121
 CCTreeWrapper.cxx:122
 CCTreeWrapper.cxx:123
 CCTreeWrapper.cxx:124
 CCTreeWrapper.cxx:125
 CCTreeWrapper.cxx:126
 CCTreeWrapper.cxx:127
 CCTreeWrapper.cxx:128
 CCTreeWrapper.cxx:129
 CCTreeWrapper.cxx:130
 CCTreeWrapper.cxx:131
 CCTreeWrapper.cxx:132
 CCTreeWrapper.cxx:133
 CCTreeWrapper.cxx:134
 CCTreeWrapper.cxx:135
 CCTreeWrapper.cxx:136
 CCTreeWrapper.cxx:137
 CCTreeWrapper.cxx:138
 CCTreeWrapper.cxx:139
 CCTreeWrapper.cxx:140
 CCTreeWrapper.cxx:141
 CCTreeWrapper.cxx:142
 CCTreeWrapper.cxx:143
 CCTreeWrapper.cxx:144
 CCTreeWrapper.cxx:145
 CCTreeWrapper.cxx:146
 CCTreeWrapper.cxx:147
 CCTreeWrapper.cxx:148
 CCTreeWrapper.cxx:149
 CCTreeWrapper.cxx:150
 CCTreeWrapper.cxx:151
 CCTreeWrapper.cxx:152
 CCTreeWrapper.cxx:153
 CCTreeWrapper.cxx:154
 CCTreeWrapper.cxx:155
 CCTreeWrapper.cxx:156
 CCTreeWrapper.cxx:157
 CCTreeWrapper.cxx:158
 CCTreeWrapper.cxx:159
 CCTreeWrapper.cxx:160
 CCTreeWrapper.cxx:161
 CCTreeWrapper.cxx:162
 CCTreeWrapper.cxx:163
 CCTreeWrapper.cxx:164
 CCTreeWrapper.cxx:165
 CCTreeWrapper.cxx:166
 CCTreeWrapper.cxx:167
 CCTreeWrapper.cxx:168
 CCTreeWrapper.cxx:169
 CCTreeWrapper.cxx:170
 CCTreeWrapper.cxx:171
 CCTreeWrapper.cxx:172
 CCTreeWrapper.cxx:173
 CCTreeWrapper.cxx:174
 CCTreeWrapper.cxx:175
 CCTreeWrapper.cxx:176
 CCTreeWrapper.cxx:177
 CCTreeWrapper.cxx:178
 CCTreeWrapper.cxx:179
 CCTreeWrapper.cxx:180
 CCTreeWrapper.cxx:181
 CCTreeWrapper.cxx:182
 CCTreeWrapper.cxx:183
 CCTreeWrapper.cxx:184
 CCTreeWrapper.cxx:185
 CCTreeWrapper.cxx:186
 CCTreeWrapper.cxx:187
 CCTreeWrapper.cxx:188
 CCTreeWrapper.cxx:189
 CCTreeWrapper.cxx:190
 CCTreeWrapper.cxx:191
 CCTreeWrapper.cxx:192
 CCTreeWrapper.cxx:193
 CCTreeWrapper.cxx:194
 CCTreeWrapper.cxx:195
 CCTreeWrapper.cxx:196
 CCTreeWrapper.cxx:197
 CCTreeWrapper.cxx:198
 CCTreeWrapper.cxx:199
 CCTreeWrapper.cxx:200
 CCTreeWrapper.cxx:201
 CCTreeWrapper.cxx:202
 CCTreeWrapper.cxx:203
 CCTreeWrapper.cxx:204
 CCTreeWrapper.cxx:205
 CCTreeWrapper.cxx:206
 CCTreeWrapper.cxx:207
 CCTreeWrapper.cxx:208
 CCTreeWrapper.cxx:209
 CCTreeWrapper.cxx:210
 CCTreeWrapper.cxx:211
 CCTreeWrapper.cxx:212
 CCTreeWrapper.cxx:213
 CCTreeWrapper.cxx:214
 CCTreeWrapper.cxx:215
 CCTreeWrapper.cxx:216
 CCTreeWrapper.cxx:217
 CCTreeWrapper.cxx:218
 CCTreeWrapper.cxx:219
 CCTreeWrapper.cxx:220
 CCTreeWrapper.cxx:221
 CCTreeWrapper.cxx:222
 CCTreeWrapper.cxx:223
 CCTreeWrapper.cxx:224
 CCTreeWrapper.cxx:225
 CCTreeWrapper.cxx:226
 CCTreeWrapper.cxx:227
 CCTreeWrapper.cxx:228
 CCTreeWrapper.cxx:229
 CCTreeWrapper.cxx:230
 CCTreeWrapper.cxx:231
 CCTreeWrapper.cxx:232
 CCTreeWrapper.cxx:233
 CCTreeWrapper.cxx:234
 CCTreeWrapper.cxx:235
 CCTreeWrapper.cxx:236
 CCTreeWrapper.cxx:237
 CCTreeWrapper.cxx:238
 CCTreeWrapper.cxx:239
 CCTreeWrapper.cxx:240
 CCTreeWrapper.cxx:241
 CCTreeWrapper.cxx:242
 CCTreeWrapper.cxx:243
 CCTreeWrapper.cxx:244
 CCTreeWrapper.cxx:245
 CCTreeWrapper.cxx:246
 CCTreeWrapper.cxx:247
 CCTreeWrapper.cxx:248
 CCTreeWrapper.cxx:249
 CCTreeWrapper.cxx:250
 CCTreeWrapper.cxx:251
 CCTreeWrapper.cxx:252
 CCTreeWrapper.cxx:253
 CCTreeWrapper.cxx:254
 CCTreeWrapper.cxx:255
 CCTreeWrapper.cxx:256