#include "TMVA/MethodCommittee.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "Riostream.h"
#include "TRandom.h"
#include <algorithm>
#include "TObjString.h"
#include "TDirectory.h"
#include "TMVA/Ranking.h"
#include "TMVA/Methods.h"
using std::vector;
ClassImp(TMVA::MethodCommittee)
 
TMVA::MethodCommittee::MethodCommittee( TString jobName, TString committeeTitle, DataSet& theData, 
                                        TString committeeOptions,
                                        Types::EMVA method, TString methodOptions,
                                        TDirectory* theTargetDir )
   : TMVA::MethodBase( jobName, committeeTitle, theData, committeeOptions, theTargetDir ),
     fMemberType( method ),
     fMemberOption( methodOptions )
{
   
   InitCommittee(); 
   DeclareOptions();
   ParseOptions();
   
   ProcessOptions();
   
   fBoostFactorHist = new TH1F("fBoostFactor","Ada Boost weights",100,1,100);
   fErrFractHist    = new TH2F("fErrFractHist","error fraction vs tree number",
                               fNMembers,0,fNMembers,50,0,0.5);
   fMonitorNtuple   = new TTree("fMonitorNtuple","Committee variables");
   fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
   fMonitorNtuple->Branch("boostFactor",&fBoostFactor,"boostFactor/D");
   fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
TMVA::MethodCommittee::MethodCommittee( DataSet& theData, 
                                        TString theWeightFile,  
                                        TDirectory* theTargetDir )
   : TMVA::MethodBase( theData, theWeightFile, theTargetDir ) 
{
   
   
   
   
   InitCommittee();
  
   DeclareOptions();
}
void TMVA::MethodCommittee::DeclareOptions() 
{
   
   
   
   
   
   
   
   
   
   DeclareOptionRef(fNMembers, "NMembers", "number of members in the committee");
   DeclareOptionRef(fUseMemberDecision=kFALSE, "UseMemberDecision", "use binary information from IsSignal");
   DeclareOptionRef(fUseWeightedMembers=kTRUE, "UseWeightedMembers", "use weighted trees or simple average in classification from the forest");
   DeclareOptionRef(fBoostType, "BoostType", "boosting type");
   AddPreDefVal(TString("AdaBoost"));
   AddPreDefVal(TString("Bagging"));
}
void TMVA::MethodCommittee::ProcessOptions() 
{
   
   MethodBase::ProcessOptions();
}
void TMVA::MethodCommittee::InitCommittee( void )
{
   
   SetMethodName( "Committee" );
   SetMethodType( TMVA::Types::kCommittee );
   SetTestvarName();
   fNMembers  = 100;
   fBoostType = "AdaBoost";   
   fCommittee.clear();
   fBoostWeights.clear();
}
TMVA::MethodCommittee::~MethodCommittee( void )
{
   
   for (UInt_t i=0; i<GetCommittee().size(); i++)   delete fCommittee[i];
   fCommittee.clear();
}
void TMVA::MethodCommittee::WriteStateToFile() const
{ 
   
   
   TString fname(GetWeightFileName());
   fLogger << kINFO << "creating weight file: " << fname << Endl;
   
   std::ofstream* fout = new std::ofstream( fname );
   if (!fout->good()) { 
      fLogger << kFATAL << "<WriteStateToFile> "
              << "unable to open output  weight file: " << fname << endl;
   }
   
   WriteStateToStream( *fout );
}
void TMVA::MethodCommittee::Train( void )
{  
   
   if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
   fLogger << kINFO << "will train "<< fNMembers << " committee members ... patience please" << Endl;
   Timer timer( fNMembers, GetName() ); 
   for (UInt_t imember=0; imember<fNMembers; imember++){
      timer.DrawProgressBar( imember );
      TMVA::IMethod *method = 0;
      
      
      switch(fMemberType) {
      case TMVA::Types::kCuts:       
         method = new TMVA::MethodCuts           ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kFisher:     
         method = new TMVA::MethodFisher         ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kKNN:     
         method = new TMVA::MethodKNN            ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kMLP:        
         method = new TMVA::MethodMLP            ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kTMlpANN:    
         method = new TMVA::MethodTMlpANN        ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kCFMlpANN:   
         method = new TMVA::MethodCFMlpANN       ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kLikelihood: 
         method = new TMVA::MethodLikelihood     ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kHMatrix:    
         method = new TMVA::MethodHMatrix        ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kPDERS:      
         method = new TMVA::MethodPDERS          ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kBDT:        
         method = new TMVA::MethodBDT            ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kSVM:        
         method = new TMVA::MethodSVM            ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kRuleFit:    
         method = new TMVA::MethodRuleFit        ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      case TMVA::Types::kBayesClassifier:    
         method = new TMVA::MethodBayesClassifier( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
      default:
         fLogger << kFATAL << "method: " << fMemberType << " does not exist" << Endl;
      }
      
      
      method->Train();
      GetBoostWeights().push_back( this->Boost( method, imember ) );
      GetCommittee().push_back( method );
      fMonitorNtuple->Fill();
   }
   
   fLogger << kINFO << "elapsed time: " << timer.GetElapsedTime()    
           << "                              " << Endl;    
}
Double_t TMVA::MethodCommittee::Boost( TMVA::IMethod* method, UInt_t imember )
{
   
   
  
   if      (fBoostType=="AdaBoost") return this->AdaBoost( method );
   else if (fBoostType=="Bagging")  return this->Bagging( imember );
   else {
      fLogger << kINFO << GetOptions() << Endl;
      fLogger << kFATAL << "<Boost> unknown boost option called" << Endl;
   }
   return 1.0;
}
Double_t TMVA::MethodCommittee::AdaBoost( TMVA::IMethod* method )
{
   
   
   
   
   
   
   
   
   
   Double_t adaBoostBeta = 1.;   
   
   if (!HasTrainingTree()) fLogger << kFATAL << "<AdaBoost> Data().TrainingTree() is zero pointer" << Endl;
   
   Event& event = GetEvent();
   Double_t err=0, sumw=0, sumwfalse=0, count=0;
   vector<Bool_t> correctSelected;
   
   MethodBase* mbase = (MethodBase*)method;
   for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
      
      ReadTrainingEvent(ievt);
      
      sumw += event.GetBoostWeight();
      
      Bool_t isSignalType = mbase->IsSignalLike();
      
      
      if (isSignalType == event.IsSignal()) correctSelected.push_back( kTRUE );
      else {
         sumwfalse += event.GetBoostWeight();
         count += 1;
         correctSelected.push_back( kFALSE );
      }    
   }
   if (0 == sumw) {
      fLogger << kFATAL << "<AdaBoost> fatal error sum of event boostweights is zero" << Endl;
   }
   
   err = sumwfalse/sumw;
   Double_t newSumw=0;
   Int_t i=0;
   Double_t boostFactor = 1;
   if (err>0){
      if (adaBoostBeta == 1){
         boostFactor = (1-err)/err ;
      }
      else {
         boostFactor =  pow((1-err)/err,adaBoostBeta) ;
      }
   }
   else {
      boostFactor = 1000; 
   }
   
   for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
      
      ReadTrainingEvent(ievt);
      if (!correctSelected[ievt]) event.SetBoostWeight( event.GetBoostWeight() * boostFactor);
      newSumw += event.GetBoostWeight();    
      i++;
   }
   
   for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
      event.SetBoostWeight( event.GetBoostWeight() * sumw / newSumw );      
   }
   fBoostFactorHist->Fill(boostFactor);
   fErrFractHist->Fill(GetCommittee().size(),err);
   
   fBoostFactor   = boostFactor;
   fErrorFraction = err;
  
   
   return log(boostFactor);
}
Double_t TMVA::MethodCommittee::Bagging( UInt_t imember )
{
   
   
   Double_t newSumw = 0;
   TRandom *trandom   = new TRandom( imember );
   
   Event& event = GetEvent();
   
   for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
      
      ReadTrainingEvent(ievt);
      Double_t newWeight = trandom->Rndm();
      event.SetBoostWeight( newWeight );
      newSumw += newWeight;
   }
   
   for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
      event.SetBoostWeight( event.GetBoostWeight() * Data().GetNEvtTrain() / newSumw );      
   }
   
   return 1.0;  
}
void TMVA::MethodCommittee::WriteWeightsToStream( ostream& o ) const
{
   
   for (UInt_t imember=0; imember<GetCommittee().size(); imember++) {
      o << endl;
      o << "------------------------------ new member: " << imember << " ---------------" << endl;
      o << "boost weight: " << GetBoostWeights()[imember] << endl;
      ((MethodBase*)GetCommittee()[imember])->WriteStateToStream( o );
   }   
}
  
void  TMVA::MethodCommittee::ReadWeightsFromStream( istream& istr )
{
   
   
   std::vector<IMethod*>::iterator member = GetCommittee().begin();
   for (; member != GetCommittee().end(); member++) delete *member;
   GetCommittee().clear();
   GetBoostWeights().clear();
   TString  dummy;
   UInt_t   imember;
   Double_t boostWeight;
   
   
   for (UInt_t i=0; i<fNMembers; i++) {
       
      istr >> dummy >> dummy >> dummy >> imember;
      istr >> dummy >> dummy >> boostWeight;
      if (imember != i) {
         fLogger << kFATAL << "<ReadWeightsFromStream> fatal error while reading Weight file \n "
                 << ": mismatch imember: " << imember << " != i: " << i << Endl;
      }
      TMVA::IMethod *method = 0;
      
      
      switch(fMemberType) {
      case TMVA::Types::kCuts:       
         method = new TMVA::MethodCuts           ( Data(), "" ); break;
      case TMVA::Types::kFisher:     
         method = new TMVA::MethodFisher         ( Data(), "" ); break;
      case TMVA::Types::kKNN:     
         method = new TMVA::MethodKNN            ( Data(), "" ); break;
      case TMVA::Types::kMLP:        
         method = new TMVA::MethodMLP            ( Data(), "" ); break;
      case TMVA::Types::kTMlpANN:    
         method = new TMVA::MethodTMlpANN        ( Data(), "" ); break;
      case TMVA::Types::kCFMlpANN:   
         method = new TMVA::MethodCFMlpANN       ( Data(), "" ); break;
      case TMVA::Types::kLikelihood: 
         method = new TMVA::MethodLikelihood     ( Data(), "" ); break;
      case TMVA::Types::kHMatrix:    
         method = new TMVA::MethodHMatrix        ( Data(), "" ); break;
      case TMVA::Types::kPDERS:      
         method = new TMVA::MethodPDERS          ( Data(), "" ); break;
      case TMVA::Types::kBDT:        
         method = new TMVA::MethodBDT            ( Data(), "" ); break;
      case TMVA::Types::kSVM:        
         method = new TMVA::MethodSVM            ( Data(), "" ); break;
      case TMVA::Types::kRuleFit:    
         method = new TMVA::MethodRuleFit        ( Data(), "" ); break;
      case TMVA::Types::kBayesClassifier:    
         method = new TMVA::MethodBayesClassifier( Data(), "" ); break;
      default:
         fLogger << kFATAL << "<ReadWeightsFromStream> fatal error: method: " 
                 << fMemberType << " does not exist" << endl;
      }
      
      ((MethodBase*)method)->ReadStateFromStream(istr);
      GetCommittee().push_back(method);
      GetBoostWeights().push_back(boostWeight);
   }
}
Double_t TMVA::MethodCommittee::GetMvaValue()
{
   
   
   
   
   
   
   
   Double_t myMVA = 0;
   Double_t norm  = 0;
   for (UInt_t itree=0; itree<GetCommittee().size(); itree++) {
      Double_t tmpMVA = ( fUseMemberDecision ? ( ((MethodBase*)GetCommittee()[itree])->IsSignalLike() ? 1.0 : -1.0 ) 
                          : GetCommittee()[itree]->GetMvaValue() );
      if (fUseWeightedMembers){ 
         myMVA += GetBoostWeights()[itree] * tmpMVA;
         norm  += GetBoostWeights()[itree];
      }
      else { 
         myMVA += tmpMVA;
         norm  += 1;
      }
   }
   return (norm != 0) ? myMVA /= Double_t(norm) : -999;
}
void  TMVA::MethodCommittee::WriteMonitoringHistosToFile( void ) const
{
   
   
   fLogger << kINFO << "write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
   fBoostFactorHist->Write();
   fErrFractHist->Write();
   fMonitorNtuple->Write();
   BaseDir()->cd();
}
vector< Double_t > TMVA::MethodCommittee::GetVariableImportance()
{
   
   
   
   
  
   fVariableImportance.resize(GetNvar());
   
   
   
   
   
   
   
   
   
   return fVariableImportance;
}
Double_t TMVA::MethodCommittee::GetVariableImportance(UInt_t ivar)
{
   
   vector<Double_t> relativeImportance = this->GetVariableImportance();
   if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
   else  fLogger << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
   return -1;
}
const TMVA::Ranking* TMVA::MethodCommittee::CreateRanking()
{
   
   
   fRanking = new Ranking( GetName(), "Variable Importance" );
   vector< Double_t> importance(this->GetVariableImportance());
   for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
      fRanking->AddRank( *new Rank( GetInputExp(ivar), importance[ivar] ) );
   }
   return fRanking;
}
void TMVA::MethodCommittee::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
   
   fout << "   // not implemented for class: \"" << className << "\"" << endl;
   fout << "};" << endl;
}
void TMVA::MethodCommittee::GetHelpMessage() const
{
   
   
   
   
   fLogger << Endl;
   fLogger << Tools::Color("bold") << "--- Short description:" << Tools::Color("reset") << Endl;
   fLogger << Endl;
   fLogger << "<None>" << Endl;
   fLogger << Endl;
   fLogger << Tools::Color("bold") << "--- Performance optimisation:" << Tools::Color("reset") << Endl;
   fLogger << Endl;
   fLogger << "<None>" << Endl;
   fLogger << Endl;
   fLogger << Tools::Color("bold") << "--- Performance tuning via configuration options:" << Tools::Color("reset") << Endl;
   fLogger << Endl;
   fLogger << "<None>" << Endl;
}
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.