ROOT logo
// @(#)root/tmva $Id: MethodBoost.h 31458 2009-11-30 13:58:20Z stelzer $   
// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen 

 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : MethodCompositeBase                                                   *
 * Web    :                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Virtual base class for all MVA method                                     *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker    <> - CERN, Switzerland           *
 *      Joerg Stelzer      <>  - CERN, Switzerland           *
 *      Helge Voss         <>     - MPI-K Heidelberg, Germany   *
 *      Kai Voss           <>       - U. of Victoria, Canada      *
 *      Or Cohen           <>    - Weizmann Inst., Israel      *
 *      Eckhard v. Toerne  <>        - U of Bonn, Germany          *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         * 
 *      U. of Victoria, Canada                                                    * 
 *      MPI-K Heidelberg, Germany                                                 * 
 *      LAPP, Annecy, France                                                      *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (                                          *

#ifndef ROOT_TMVA_MethodBoost
#define ROOT_TMVA_MethodBoost

//                                                                      //
// MethodBoost                                                          //
//                                                                      //
// Class for boosting a TMVA method                                     //
//                                                                      //

#include <iosfwd>
#include <vector>

#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"

#ifndef ROOT_TMVA_MethodCompositeBase
#include "TMVA/MethodCompositeBase.h"

namespace TMVA {

   class MethodBoost : public MethodCompositeBase {

   public :

      // constructors
      MethodBoost( const TString& jobName,
                   const TString& methodTitle,
                   DataSetInfo& theData,
                   const TString& theOption = "",
                   TDirectory* theTargetDir = NULL );

      MethodBoost( DataSetInfo& dsi, 
                   const TString& theWeightFile,  
                   TDirectory* theTargetDir = NULL );      
      virtual ~MethodBoost( void );

      virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ );

      // training and boosting all the classifiers
      void Train( void );

      // ranking of input variables
      const Ranking* CreateRanking();
      // saves the name and options string of the boosted classifier
      Bool_t BookMethod( Types::EMVA theMethod, TString methodTitle, TString theOption );
      void SetBoostedMethodName ( TString methodName )     { fBoostedMethodName  = methodName; }

      Int_t          GetBoostNum() { return fBoostNum; }

      // gives the monitoring historgram from the vector according to index of the 
      // histrogram added in the MonitorBoost function
      TH1*           GetMonitoringHist( Int_t histInd ) { return (*fMonitorHist)[fDefaultHistNum+histInd]; }

      void           AddMonitoringHist( TH1* hist )     { return fMonitorHist->push_back(hist); }

      Types::EBoostStage    GetBoostStage() { return fBoostStage; }

      void CleanBoostOptions();

      Double_t GetMvaValue( Double_t* err );

   private :
      // clean up 
      void ClearAll();

      // print fit results
      void PrintResults( const TString&, std::vector<Double_t>&, const Double_t ) const;

      // initializing mostly monitoring tools of the boost process
      void Init();
      void InitHistos();
      void CheckSetup();

      // the option handling methods
      void DeclareOptions();
      void ProcessOptions();

      MethodBoost* SetStage( Types::EBoostStage stage ) { fBoostStage = stage; return this; }

      //training a single classifier
      void SingleTrain();

      //calculating a boosting weight from the classifier, storing it in the next one
      void SingleBoost();

      //writing the monitoring histograms and tree to a file
      void WriteMonitoringHistosToFile( void ) const;

      // write evaluation histograms into target file
      virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype);

      // performs the MethodBase testing + testing of each boosted classifier
      virtual void TestClassification();

      //finding the MVA to cut between sig and bgd according to fMVACutPerc,fMVACutType
      void FindMVACut();

      //setting all the boost weights to 1
      void ResetBoostWeights();

      //creating the vectors of histogram for monitoring MVA response of each classifier
      void CreateMVAHistorgrams();

      //Number of times the classifier is boosted (set by the user)
      Int_t             fBoostNum;
      // string specifying the boost type (AdaBoost / Bagging )
      TString           fBoostType; 

      // string specifying the boost type ( ByError,Average,LastMethod )
      TString           fMethodWeightType;

      //estimation of the level error of the classifier analysing the train dataset
      Double_t          fMethodError;
      //estimation of the level error of the classifier analysing the train dataset (with unboosted weights)
      Double_t          fOrigMethodError;

      //the weight used to boost the next classifier
      Double_t          fBoostWeight;

      // min and max values for the classifier response
      TString fTransformString;

      //ADA boost parameter, default is 1
      Double_t          fADABoostBeta;

      // details of the boosted classifier
      TString           fBoostedMethodName;
      TString           fBoostedMethodTitle;
      TString           fBoostedMethodOptions;

      // histograms to monitor values during the boosting
      std::vector<TH1*>* fMonitorHist;

      //whether to monitor the MVA response of every classifier using the
      Bool_t                fMonitorBoostedMethod;

      //MVA output from each classifier over the training hist, using orignal events weights
      std::vector< TH1* >   fTrainSigMVAHist;
      std::vector< TH1* >   fTrainBgdMVAHist;
      //MVA output from each classifier over the training hist, using boosted events weights
      std::vector< TH1* >   fBTrainSigMVAHist;
      std::vector< TH1* >   fBTrainBgdMVAHist;
      //MVA output from each classifier over the testing hist
      std::vector< TH1* >   fTestSigMVAHist;
      std::vector< TH1* >   fTestBgdMVAHist;

      // tree  to monitor values during the boosting
      TTree*            fMonitorTree;

      // the stage of the boosting 
      Types::EBoostStage fBoostStage;

      //the number of histogram filled for every type of boosted classifier
      Int_t             fDefaultHistNum;

      //whether to recalculate the MVA cut at every boosting step
      Bool_t            fRecalculateMVACut;


      // get help message text
      void GetHelpMessage() const;