// @(#)root/tmva $Id: TMVA_MethodBDT.cxx,v 1.1 2006/05/08 12:46:31 brun Exp $ 
// Author: Andreas Hoecker, Helge Voss, Kai Voss 

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : TMVA_MethodBDT  (Boosted Decision Trees)                              *
 *                                                                                *
 * Description:                                                                   *
 *      Analysis of Boosted Decision Trees                                        *
 *                                                                                *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Xavier Prudent  <prudent@lapp.in2p3.fr>  - LAPP, France                   *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-KP Heidelberg, Germany     *
 *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland,                                                        * 
 *      U. of Victoria, Canada,                                                   * 
 *      MPI-KP Heidelberg, Germany,                                               * 
 *      LAPP, Annecy, France                                                      *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://mva.sourceforge.net/license.txt)                                       *
 *                                                                                *
 **********************************************************************************/

//_______________________________________________________________________
//                                                                      
// Analysis of Boosted Decision Trees                                   
//                                                                      
//_______________________________________________________________________

#include "TMVA_MethodBDT.h"
#include "TMVA_Tools.h"
#include "TMVA_Timer.h"
#include "Riostream.h"
#include "TRandom.h"
#include <algorithm>
#include "TObjString.h"

#define DEBUG_TMVA_MethodBDT kTRUE

using std::vector;

ClassImp(TMVA_MethodBDT)
 
//_______________________________________________________________________
 TMVA_MethodBDT::TMVA_MethodBDT( TString jobName, vector<TString>* theVariables,  
				TTree* theTree, TString theOption, TDirectory* theTargetDir )
  : TMVA_MethodBase( jobName, theVariables, theTree, theOption, theTargetDir )
{
  InitBDT();

  if (fOptions.Sizeof()<0) {
    cout << "--- " << GetName() << ": using default options= "<< fOptions <<endl;
  }
  cout << "--- "<<GetName() << " options:" << fOptions <<endl;
  fOptions.ToLower();
  TList*  list  = TMVA_Tools::ParseFormatLine( fOptions );
  if (list->GetSize() > 0){
    fNTrees = atoi( ((TObjString*)list->At(0))->GetString() ) ;
  }
  if (list->GetSize() > 1)fBoostType=((TObjString*)list->At(1))->GetString();
  if (list->GetSize() > 2){
    TString sepType=((TObjString*)list->At(2))->GetString();
    if (sepType.Contains("misclassificationerror")) {
      fSepType = new TMVA_MisClassificationError();
    }
    else if (sepType.Contains("giniindex")) {
      fSepType = new TMVA_GiniIndex();
    }
    else if (sepType.Contains("crossentropy")) {
      fSepType = new TMVA_CrossEntropy();
    }
    else if (sepType.Contains("sdivsqrtsplusb")) {
      fSepType = new TMVA_SdivSqrtSplusB();
    }
    else{
      cout <<"--- TMVA_DecisionTree::TrainNode Error!! separation Routine not found\n" << endl;
      cout << sepType <<endl;
      exit(1);
    }

  }
  else{
    cout <<"---" <<GetName() <<": using default GiniIndex as separation criterion"<<endl;
    fSepType = new TMVA_GiniIndex();
  }
  fMethodName = "BDT"+fSepType->GetName();
  fTestvar    = fTestvarPrefix+GetMethodName();

  if (list->GetSize() > 4){
    fNodeMinEvents = atoi( ((TObjString*)list->At(3))->GetString() ) ;
    fNodeMinSepGain = Double_t(atof( ((TObjString*)list->At(4))->GetString() )) ;
  }
  if (list->GetSize() > 5){
    fNCuts = atoi( ((TObjString*)list->At(5))->GetString() ) ;
  }
  if (list->GetSize() > 6){
    fSignalFraction = atof( ((TObjString*)list->At(6))->GetString() ) ;
  }

  cout << "--- " << GetName() << ": Called with "<<fNTrees <<" trees in the forest"<<endl; 

  cout << "--- " << GetName() << ": Booked with options: "<<endl;
  cout << "--- " << GetName() << ": BoostType: "
       << fBoostType << "   nTress "<< fNTrees<<endl;
  cout << "--- " << GetName() << ": separation criteria in Node training: "
       <<fSepType->GetName()<<endl;
  cout << "--- " << GetName() << ": NodeMinEvents: " << fNodeMinEvents << endl
       << "--- " << GetName() << ": NodeMinSepGain: " << fNodeMinSepGain << endl
       << "--- " << GetName() << ": NCuts        : " << fNCuts         << endl;

  if (0 != fTrainingTree) {
    if (Verbose())
      cout << "--- " << GetName() << " called " << endl;
    // fill the STL Vector with the event sample 
    this->InitEventSample();
  }
  else{
    cout << "--- " << GetName() << ": Warning: no training Tree given " <<endl;
    cout << "--- " << GetName() << "  you'll not allowed to cal Train e.t.c..."<<endl;
  }
}

//_______________________________________________________________________
 TMVA_MethodBDT::TMVA_MethodBDT( vector<TString> *theVariables, 
				TString theWeightFile,  
				TDirectory* theTargetDir )
  : TMVA_MethodBase( theVariables, theWeightFile, theTargetDir ) 
{
  InitBDT();
}

//_______________________________________________________________________
 void TMVA_MethodBDT::InitBDT( void )
{
  fMethodName = "BDT";
  fMethod     = TMVA_Types::BDT;
  fNTrees     = 100;
  fBoostType  = "AdaBoost";
  fNodeMinEvents  = 10;
  fNodeMinSepGain = 0.0002;
  fNCuts          = 20;
  fSignalFraction =-1.;     // -1 means scaling the signal fraction in the is switched off, any
                             // value > 0 would scale the number of background events in the 
                             // training tree by the corresponding number
}

//_______________________________________________________________________
 TMVA_MethodBDT::~TMVA_MethodBDT( void )
{
  for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
  for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
}


//_______________________________________________________________________
 void TMVA_MethodBDT::InitEventSample( void )
{
  // write all Events from the Tree into a vector of TMVA_Events, that are 
  // more easily manipulated 
  // should never be called without existing trainingTree
  if (0 == fTrainingTree) {
    cout << "--- " << GetName() << ": Error in ::Init(): fTrainingTree is zero pointer"
	 << " --> exit(1)" << endl;
    exit(1);
  }
  Int_t nevents = fTrainingTree->GetEntries();
  for (int ievt=0; ievt<nevents; ievt++){
    fEventSample.push_back(new TMVA_Event(fTrainingTree, ievt, fInputVars));
    if (fSignalFraction > 0){
      if (fEventSample.back()->GetType2() < 0) fEventSample.back()->SetWeight(fSignalFraction*fEventSample.back()->GetWeight());
    }
  }
}

//_______________________________________________________________________
 void TMVA_MethodBDT::Train( void )
{  
  // default sanity checks
  if (!CheckSanity()) { 
    cout << "--- " << GetName() << ": Error: sanity check failed" << endl;
    exit(1);
  }

  cout << "--- " << GetName() << ": I will train "<< fNTrees << " Decision Trees"  
       << " ... patience please" << endl;
  TMVA_Timer timer( fNTrees, GetName() ); 
  for (int itree=0; itree<fNTrees; itree++){
    timer.DrawProgressBar( itree );

    fForest.push_back(new TMVA_DecisionTree(fSepType,
      fNodeMinEvents,fNodeMinSepGain,fNCuts));
    fForest.back()->BuildTree(fEventSample);
    this->Boost(fEventSample, fForest.back(), itree);
  }

  // get elapsed time
  cout << "--- " << GetName() << ": elapsed time: " << timer.GetElapsedTime() 
       << endl;    

  // write Weights to file
  WriteWeightsToFile();
}

//_______________________________________________________________________
 void TMVA_MethodBDT::Boost( vector<TMVA_Event*> eventSample, TMVA_DecisionTree *dt, Int_t iTree )
{
  if (fOptions.Contains("adaboost")) this->AdaBoost(eventSample, dt);
  else if (fOptions.Contains("epsilonboost")) this->EpsilonBoost(eventSample, dt);
  else if (fOptions.Contains("bagging")) this->Bagging(eventSample, iTree);
  else {
    cout << "--- " << this->GetName() << "::Boost: ERROR Unknow boost option called\n";
    cout << fOptions << endl;
   exit(1);
  }
}

//_______________________________________________________________________
 void TMVA_MethodBDT::AdaBoost( vector<TMVA_Event*> eventSample, TMVA_DecisionTree *dt )
{
  fAdaBoostBeta=1.;   // that's apparently the standard value :)
  // in order to perform the boosting, you first have to see how the events of 
  // the original sample were selected with this algorithm... in practice, this is
  // already an information we'd have in the node when they were build, but it's easier
  // right now (.... to be changed later) to simple loop over all the event's again.
  Double_t err=0, sumw=0, sumwfalse=0, count=0;
  vector<Bool_t> correctSelected;
  correctSelected.reserve(eventSample.size());
  for (vector<TMVA_Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
    Int_t evType=dt->CheckEvent(*e);
    sumw+=(*e)->GetWeight();

    // I knew I'd get it worng.. 
    // event Type = 0 bkg, 1 sig
    // nodeType:  =-1 bkg  1 sig  
    // if (evType != (*e)->GetType()) { 
    if (evType != (*e)->GetType2()) { 
      sumwfalse+= (*e)->GetWeight();
      count+=1;
      correctSelected.push_back(kFALSE);
    }
    else{
      correctSelected.push_back(kTRUE);
    }    
  }
  err=sumwfalse/sumw;

  Double_t newSumw=0;
  Int_t i=0;
  Double_t newWeight;
  for (vector<TMVA_Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
    if (!correctSelected[i]){
      if (fAdaBoostBeta == 1){
	newWeight = (*e)->GetWeight() * ((1-err)/err) ;
	//	newWeight =  ((1-err)/err) ;
      }else{
	newWeight = (*e)->GetWeight() * pow((1-err)/err,fAdaBoostBeta) ;
	//newWeight =  pow((1-err)/err,fAdaBoostBeta) ;
      }
      (*e)->SetWeight(newWeight);
    }//else (*e)->SetWeight(1.);
    newSumw+=(*e)->GetWeight();    
    i++;
  }
  //re-normalise the Weights
  for (vector<TMVA_Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
    (*e)->SetWeight( (*e)->GetWeight() * sumw / newSumw );
  }
}

//_______________________________________________________________________
 void TMVA_MethodBDT::EpsilonBoost( vector<TMVA_Event*>  /*eventSample*/, TMVA_DecisionTree * /*dt*/ ){
  cout << "!!! Sorry...EpsilonBoost is not yet implement \n"; exit(1);
}

//_______________________________________________________________________
 void TMVA_MethodBDT::Bagging( vector<TMVA_Event*> eventSample, Int_t iTree )
{
  // call it Bootstrapping, re-sampling or whatever you like, in the end it is nothing
  // else but applying "random Weights" to each event.
  Double_t newSumw=0;
  Double_t newWeight;
  TRandom *trandom   = new TRandom(iTree);
  for (vector<TMVA_Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
    newWeight = trandom->Rndm();
    (*e)->SetWeight(newWeight);
    newSumw+=(*e)->GetWeight();    
  }
  for (vector<TMVA_Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
    (*e)->SetWeight( (*e)->GetWeight() * eventSample.size() / newSumw );
  }
}

//_______________________________________________________________________
 void  TMVA_MethodBDT::WriteWeightsToFile( void )
{  
   // write coefficients to file
  TString fname = GetWeightFileName();
  cout << "--- " << GetName() << ": creating Weight file: " << fname << endl;
  ofstream fout( fname );
  if (!fout.good( )) { // file not found --> Error
    cout << "--- " << GetName() << ": Error in ::WriteWeightsToFile: "
         << "unable to open output  Weight file: " << fname << endl;
    exit(1);
  }

  // write variable names and min/max 
  // NOTE: the latter values are mandatory for the normalisation 
  // in the reader application !!!
  fout << this->GetMethodName() <<endl;
  fout << "NVars= " << fNvar <<endl; 
  for (Int_t ivar=0; ivar<fNvar; ivar++) {
    TString var = (*fInputVars)[ivar];
    fout << var << "  " << GetXminNorm( var ) << "  " << GetXmaxNorm( var ) << endl;
  }

  // and save the Weights
  fout << "NTrees= " << fForest.size() <<endl; 
  for (UInt_t i=0; i< fForest.size(); i++){
    fout << "-999 *******Tree " << i << endl;
    (fForest[i])->Print(fout);
  }

  fout.close();    
}
  
//_______________________________________________________________________
 void  TMVA_MethodBDT::ReadWeightsFromFile( void )
{
   // read coefficients from file
  TString fname = GetWeightFileName();
  cout << "--- " << GetName() << ": reading Weight file: " << fname << endl;
  ifstream fin( fname );

  if (!fin.good( )) { // file not found --> Error
    cout << "--- " << GetName() << ": Error in ::ReadWeightsFromFile: "
         << "unable to open input file: " << fname << endl;
    exit(1);
  }

  // read variable names and min/max
  // NOTE: the latter values are mandatory for the normalisation 
  // in the reader application !!!
  TString var, dummy;
  Double_t xmin, xmax;
  fin >> dummy;
  this->SetMethodName(dummy);
  fin >> dummy >> fNvar;
  for (Int_t ivar=0; ivar<fNvar; ivar++) {
    fin >> var >> xmin >> xmax;
    (*fInputVars)[ivar] = var;
    // set min/max
    this->SetXminNorm( ivar, xmin );
    this->SetXmaxNorm( ivar, xmax );
  } 

  // and read the Weights (BDT coefficients)  
  fin >> dummy >> fNTrees;
  cout << "--- " << GetName() << ": Read "<<fNTrees<<" Decision trees\n";
  
  for (UInt_t i=0;i<fForest.size();i++) delete fForest[i];
  fForest.clear();
  Int_t iTree;
  fin >> var >> var;
  for (int i=0;i<fNTrees;i++){
    fin >> iTree;
    if (iTree != i) {
      cout << "--- "  << ": Error while reading Weight file \n ";
      cout << "--- "  << ": mismatch Itree="<<iTree<<" i="<<i<<endl;
      exit(1);
    }
    TMVA_DecisionTreeNode *n = new TMVA_DecisionTreeNode();
    TMVA_NodeID id;
    n->ReadRec(fin,id);
    fForest.push_back(new TMVA_DecisionTree());
    fForest.back()->SetRoot(n);
  }
  
  fin.close();      
}

//_______________________________________________________________________
 Double_t TMVA_MethodBDT::GetMvaValue(TMVA_Event *e)
{
  Double_t myMVA = 0;
  for (UInt_t itree=0; itree<fForest.size(); itree++){
    myMVA += fForest[itree]->CheckEvent(e);
  }
  return myMVA/= Double_t(fForest.size());;
}

//_______________________________________________________________________
 void  TMVA_MethodBDT::WriteHistosToFile( void )
{
  cout << "--- " << GetName() << ": write " << GetName() 
       <<" special histos to file: " << fBaseDir->GetPath() << endl;
}


ROOT page - Class index - Class Hierarchy - Top of the page

This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.