// @(#)root/tmva $Id$
// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : MethodBDT  (Boosted Decision Trees)                                   *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Analysis of Boosted Decision Trees                                        *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
 *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
 *      Doug Schouten   <dschoute@sfu.ca>        - Simon Fraser U., Canada        *
 *      Jan Therhaag    <jan.therhaag@cern.ch>   - U. of Bonn, Germany            *
 *                                                                                *
 * Copyright (c) 2005-2011:                                                       *
 *      CERN, Switzerland                                                         *
 *      U. of Victoria, Canada                                                    *
 *      MPI-K Heidelberg, Germany                                                 *
 *      U. of Bonn, Germany                                                       *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_MethodBDT
#define ROOT_TMVA_MethodBDT

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// MethodBDT                                                            //
//                                                                      //
// Analysis of Boosted Decision Trees                                   //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <vector>
#ifndef ROOT_TH2
#include "TH2.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif

namespace TMVA {

   class SeparationBase;

   class MethodBDT : public MethodBase {

   public:
      // constructor for training and reading
      MethodBDT( const TString& jobName,
                 const TString& methodTitle,
                 DataSetInfo& theData,
                 const TString& theOption = "",
                 TDirectory* theTargetDir = 0 );

      // constructor for calculating BDT-MVA using previously generatad decision trees
      MethodBDT( DataSetInfo& theData,
                 const TString& theWeightFile,
                 TDirectory* theTargetDir = NULL );

      virtual ~MethodBDT( void );

      virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );


      // write all Events from the Tree into a vector of Events, that are
      // more easily manipulated
      void InitEventSample();

      // optimize tuning parameters
      virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA");
      virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);

      // training method
      void Train( void );

      // revoke training
      void Reset( void );

      using MethodBase::ReadWeightsFromStream;

      // write weights to file
      void AddWeightsXMLTo( void* parent ) const;

      // read weights from file
      void ReadWeightsFromStream( std::istream& istr );
      void ReadWeightsFromXML(void* parent);

      // write method specific histos to target file
      void WriteMonitoringHistosToFile( void ) const;

      // calculate the MVA value
      Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0);

      // get the actual forest size (might be less than fNTrees, the requested one, if boosting is stopped early
      UInt_t   GetNTrees() const {return fForest.size();}
   private:
      Double_t GetMvaValue( Double_t* err, Double_t* errUpper, UInt_t useNTrees );
      Double_t PrivateGetMvaValue( const TMVA::Event *ev, Double_t* err=0, Double_t* errUpper=0, UInt_t useNTrees=0 );
      void     BoostMonitor(Int_t iTree);

   public:
      const std::vector<Float_t>& GetMulticlassValues();

      // regression response
      const std::vector<Float_t>& GetRegressionValues();

      // apply the boost algorithm to a tree in the collection
      Double_t Boost( std::vector<const TMVA::Event*>&, DecisionTree *dt, UInt_t cls = 0);

      // ranking of input variables
      const Ranking* CreateRanking();

      // the option handling methods
      void DeclareOptions();
      void ProcessOptions();
      void SetMaxDepth(Int_t d){fMaxDepth = d;}
      void SetMinNodeSize(Double_t sizeInPercent);
      void SetMinNodeSize(TString sizeInPercent);

      void SetNTrees(Int_t d){fNTrees = d;}
      void SetAdaBoostBeta(Double_t b){fAdaBoostBeta = b;}
      void SetNodePurityLimit(Double_t l){fNodePurityLimit = l;} 
      void SetShrinkage(Double_t s){fShrinkage = s;}
      void SetUseNvars(Int_t n){fUseNvars = n;}
      void SetBaggedSampleFraction(Double_t f){fBaggedSampleFraction = f;}


      // get the forest
      inline const std::vector<TMVA::DecisionTree*> & GetForest() const;

      // get the forest
      inline const std::vector<const TMVA::Event*> & GetTrainingEvents() const;

      inline const std::vector<double> & GetBoostWeights() const;

      //return the individual relative variable importance
      std::vector<Double_t> GetVariableImportance();
      Double_t GetVariableImportance(UInt_t ivar);

      Double_t TestTreeQuality( DecisionTree *dt );

      // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
      void MakeClassSpecific( std::ostream&, const TString& ) const;

      // header and auxiliary classes
      void MakeClassSpecificHeader( std::ostream&, const TString& ) const;

      void MakeClassInstantiateNode( DecisionTreeNode *n, std::ostream& fout,
                                     const TString& className ) const;

      void GetHelpMessage() const;

   protected:
      void DeclareCompatibilityOptions();

   private:
      // Init used in the various constructors
      void Init( void );

      void PreProcessNegativeEventWeights();

      // boosting algorithm (adaptive boosting)
      Double_t AdaBoost( std::vector<const TMVA::Event*>&, DecisionTree *dt );

      // boosting algorithm (adaptive boosting with cost matrix)
      Double_t AdaCost( std::vector<const TMVA::Event*>&, DecisionTree *dt );

      // boosting as a random re-weighting
      Double_t Bagging( );

      // boosting special for regression
      Double_t RegBoost( std::vector<const TMVA::Event*>&, DecisionTree *dt );

      // adaboost adapted to regression
      Double_t AdaBoostR2( std::vector<const TMVA::Event*>&, DecisionTree *dt );

      // binomial likelihood gradient boost for classification
      // (see Friedman: "Greedy Function Approximation: a Gradient Boosting Machine"
      // Technical report, Dept. of Statistics, Stanford University)
      Double_t GradBoost( std::vector<const TMVA::Event*>&, DecisionTree *dt, UInt_t cls = 0);
      Double_t GradBoostRegression(std::vector<const TMVA::Event*>&, DecisionTree *dt );
      void InitGradBoost( std::vector<const TMVA::Event*>&);
      void UpdateTargets( std::vector<const TMVA::Event*>&, UInt_t cls = 0);
      void UpdateTargetsRegression( std::vector<const TMVA::Event*>&,Bool_t first=kFALSE);
      Double_t GetGradBoostMVA(const TMVA::Event *e, UInt_t nTrees);
      void     GetBaggedSubSample(std::vector<const TMVA::Event*>&);
      Double_t GetWeightedQuantile(std::vector<std::pair<Double_t, Double_t> > vec, const Double_t quantile, const Double_t SumOfWeights = 0.0);

      std::vector<const TMVA::Event*>       fEventSample;     // the training events
      std::vector<const TMVA::Event*>       fValidationSample;// the Validation events
      std::vector<const TMVA::Event*>       fSubSample;       // subsample for bagged grad boost
      std::vector<const TMVA::Event*>      *fTrainSample;     // pointer to sample actually used in training (fEventSample or fSubSample) for example

      Int_t                           fNTrees;          // number of decision trees requested
      std::vector<DecisionTree*>      fForest;          // the collection of decision trees
      std::vector<double>             fBoostWeights;    // the weights applied in the individual boosts
      Double_t                        fSigToBkgFraction;// Signal to Background fraction assumed during training
      TString                         fBoostType;       // string specifying the boost type
      Double_t                        fAdaBoostBeta;    // beta parameter for AdaBoost algorithm
      TString                         fAdaBoostR2Loss;  // loss type used in AdaBoostR2 (Linear,Quadratic or Exponential)
      Double_t                        fTransitionPoint; // break-down point for gradient regression
      Double_t                        fShrinkage;       // learning rate for gradient boost;
      Bool_t                          fBaggedBoost;     // turn bagging in combination with boost on/off
      Bool_t                          fBaggedGradBoost; // turn bagging in combination with grad boost on/off
      Double_t                        fSumOfWeights;    // sum of all event weights
      std::map< const TMVA::Event*, std::pair<Double_t, Double_t> >       fWeightedResiduals;  // weighted regression residuals
      std::map< const TMVA::Event*,std::vector<double> > fResiduals; // individual event residuals for gradient boost

      //options for the decision Tree
      SeparationBase                 *fSepType;         // the separation used in node splitting
      TString                         fSepTypeS;        // the separation (option string) used in node splitting
      Int_t                           fMinNodeEvents;   // min number of events in node
      Float_t                         fMinNodeSize;     // min percentage of training events in node
      TString                         fMinNodeSizeS;    // string containing min percentage of training events in node

      Int_t                           fNCuts;           // grid used in cut applied in node splitting
      Bool_t                          fUseFisherCuts;   // use multivariate splits using the Fisher criterium
      Double_t                        fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
      Bool_t                          fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
      Bool_t                          fUseYesNoLeaf;    // use sig or bkg classification in leave nodes or sig/bkg
      Double_t                        fNodePurityLimit; // purity limit for sig/bkg nodes
      UInt_t                          fNNodesMax;       // max # of nodes
      UInt_t                          fMaxDepth;        // max depth

      DecisionTree::EPruneMethod       fPruneMethod;     // method used for prunig
      TString                          fPruneMethodS;    // prune method option String
      Double_t                         fPruneStrength;   // a parameter to set the "amount" of pruning..needs to be adjusted
      Double_t                         fFValidationEvents;    // fraction of events to use for pruning
      Bool_t                           fAutomatic;       // use user given prune strength or automatically determined one using a validation sample
      Bool_t                           fRandomisedTrees; // choose a random subset of possible cut variables at each node during training
      UInt_t                           fUseNvars;        // the number of variables used in the randomised tree splitting
      Bool_t                           fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
      UInt_t                           fUseNTrainEvents; // number of randomly picked training events used in randomised (and bagged) trees

      Double_t                         fBaggedSampleFraction;     // relative size of bagged event sample to original sample size
      TString                          fNegWeightTreatment;     // variable that holds the option of how to treat negative event weights in training
      Bool_t                           fNoNegWeightsInTraining; // ignore negative event weights in the training
      Bool_t                           fInverseBoostNegWeights; // boost ev. with neg. weights with 1/boostweight rathre than boostweight
      Bool_t                           fPairNegWeightsGlobal;   // pair ev. with neg. and pos. weights in traning sample and "annihilate" them 
      Bool_t                           fTrainWithNegWeights; // yes there are negative event weights and we don't ignore them
      Bool_t                           fDoBoostMonitor; //create control plot with ROC integral vs tree number


      //some histograms for monitoring
      TTree*                           fMonitorNtuple;   // monitoring ntuple
      Int_t                            fITree;           // ntuple var: ith tree
      Double_t                         fBoostWeight;     // ntuple var: boost weight
      Double_t                         fErrorFraction;   // ntuple var: misclassification error fraction

      Double_t                         fCss;             // Cost factor
      Double_t                         fCts_sb;          // Cost factor
      Double_t                         fCtb_ss;          // Cost factor
      Double_t                         fCbb;             // Cost factor
      
      Bool_t                           fDoPreselection;  // do or do not perform automatic pre-selection of 100% eff. cuts

      std::vector<Double_t>            fVariableImportance; // the relative importance of the different variables


      void                             DeterminePreselectionCuts(const std::vector<const TMVA::Event*>& eventSample);
      Double_t                         ApplyPreselectionCuts(const Event* ev);
      
      std::vector<Double_t> fLowSigCut;
      std::vector<Double_t> fLowBkgCut;
      std::vector<Double_t> fHighSigCut;
      std::vector<Double_t> fHighBkgCut;
      
      std::vector<Bool_t>  fIsLowSigCut;  
      std::vector<Bool_t>  fIsLowBkgCut;  
      std::vector<Bool_t>  fIsHighSigCut; 
      std::vector<Bool_t>  fIsHighBkgCut; 
      
      Bool_t fHistoricBool; //historic variable, only needed for "CompatibilityOptions" 


      // debugging flags
      static const Int_t               fgDebugLevel;     // debug level determining some printout/control plots etc.

      // for backward compatibility

      ClassDef(MethodBDT,0)  // Analysis of Boosted Decision Trees
   };

} // namespace TMVA

const std::vector<TMVA::DecisionTree*>& TMVA::MethodBDT::GetForest()         const { return fForest; }
const std::vector<const TMVA::Event*> & TMVA::MethodBDT::GetTrainingEvents() const { return fEventSample; }
const std::vector<double>&              TMVA::MethodBDT::GetBoostWeights()   const { return fBoostWeights; }

#endif
 MethodBDT.h:1
 MethodBDT.h:2
 MethodBDT.h:3
 MethodBDT.h:4
 MethodBDT.h:5
 MethodBDT.h:6
 MethodBDT.h:7
 MethodBDT.h:8
 MethodBDT.h:9
 MethodBDT.h:10
 MethodBDT.h:11
 MethodBDT.h:12
 MethodBDT.h:13
 MethodBDT.h:14
 MethodBDT.h:15
 MethodBDT.h:16
 MethodBDT.h:17
 MethodBDT.h:18
 MethodBDT.h:19
 MethodBDT.h:20
 MethodBDT.h:21
 MethodBDT.h:22
 MethodBDT.h:23
 MethodBDT.h:24
 MethodBDT.h:25
 MethodBDT.h:26
 MethodBDT.h:27
 MethodBDT.h:28
 MethodBDT.h:29
 MethodBDT.h:30
 MethodBDT.h:31
 MethodBDT.h:32
 MethodBDT.h:33
 MethodBDT.h:34
 MethodBDT.h:35
 MethodBDT.h:36
 MethodBDT.h:37
 MethodBDT.h:38
 MethodBDT.h:39
 MethodBDT.h:40
 MethodBDT.h:41
 MethodBDT.h:42
 MethodBDT.h:43
 MethodBDT.h:44
 MethodBDT.h:45
 MethodBDT.h:46
 MethodBDT.h:47
 MethodBDT.h:48
 MethodBDT.h:49
 MethodBDT.h:50
 MethodBDT.h:51
 MethodBDT.h:52
 MethodBDT.h:53
 MethodBDT.h:54
 MethodBDT.h:55
 MethodBDT.h:56
 MethodBDT.h:57
 MethodBDT.h:58
 MethodBDT.h:59
 MethodBDT.h:60
 MethodBDT.h:61
 MethodBDT.h:62
 MethodBDT.h:63
 MethodBDT.h:64
 MethodBDT.h:65
 MethodBDT.h:66
 MethodBDT.h:67
 MethodBDT.h:68
 MethodBDT.h:69
 MethodBDT.h:70
 MethodBDT.h:71
 MethodBDT.h:72
 MethodBDT.h:73
 MethodBDT.h:74
 MethodBDT.h:75
 MethodBDT.h:76
 MethodBDT.h:77
 MethodBDT.h:78
 MethodBDT.h:79
 MethodBDT.h:80
 MethodBDT.h:81
 MethodBDT.h:82
 MethodBDT.h:83
 MethodBDT.h:84
 MethodBDT.h:85
 MethodBDT.h:86
 MethodBDT.h:87
 MethodBDT.h:88
 MethodBDT.h:89
 MethodBDT.h:90
 MethodBDT.h:91
 MethodBDT.h:92
 MethodBDT.h:93
 MethodBDT.h:94
 MethodBDT.h:95
 MethodBDT.h:96
 MethodBDT.h:97
 MethodBDT.h:98
 MethodBDT.h:99
 MethodBDT.h:100
 MethodBDT.h:101
 MethodBDT.h:102
 MethodBDT.h:103
 MethodBDT.h:104
 MethodBDT.h:105
 MethodBDT.h:106
 MethodBDT.h:107
 MethodBDT.h:108
 MethodBDT.h:109
 MethodBDT.h:110
 MethodBDT.h:111
 MethodBDT.h:112
 MethodBDT.h:113
 MethodBDT.h:114
 MethodBDT.h:115
 MethodBDT.h:116
 MethodBDT.h:117
 MethodBDT.h:118
 MethodBDT.h:119
 MethodBDT.h:120
 MethodBDT.h:121
 MethodBDT.h:122
 MethodBDT.h:123
 MethodBDT.h:124
 MethodBDT.h:125
 MethodBDT.h:126
 MethodBDT.h:127
 MethodBDT.h:128
 MethodBDT.h:129
 MethodBDT.h:130
 MethodBDT.h:131
 MethodBDT.h:132
 MethodBDT.h:133
 MethodBDT.h:134
 MethodBDT.h:135
 MethodBDT.h:136
 MethodBDT.h:137
 MethodBDT.h:138
 MethodBDT.h:139
 MethodBDT.h:140
 MethodBDT.h:141
 MethodBDT.h:142
 MethodBDT.h:143
 MethodBDT.h:144
 MethodBDT.h:145
 MethodBDT.h:146
 MethodBDT.h:147
 MethodBDT.h:148
 MethodBDT.h:149
 MethodBDT.h:150
 MethodBDT.h:151
 MethodBDT.h:152
 MethodBDT.h:153
 MethodBDT.h:154
 MethodBDT.h:155
 MethodBDT.h:156
 MethodBDT.h:157
 MethodBDT.h:158
 MethodBDT.h:159
 MethodBDT.h:160
 MethodBDT.h:161
 MethodBDT.h:162
 MethodBDT.h:163
 MethodBDT.h:164
 MethodBDT.h:165
 MethodBDT.h:166
 MethodBDT.h:167
 MethodBDT.h:168
 MethodBDT.h:169
 MethodBDT.h:170
 MethodBDT.h:171
 MethodBDT.h:172
 MethodBDT.h:173
 MethodBDT.h:174
 MethodBDT.h:175
 MethodBDT.h:176
 MethodBDT.h:177
 MethodBDT.h:178
 MethodBDT.h:179
 MethodBDT.h:180
 MethodBDT.h:181
 MethodBDT.h:182
 MethodBDT.h:183
 MethodBDT.h:184
 MethodBDT.h:185
 MethodBDT.h:186
 MethodBDT.h:187
 MethodBDT.h:188
 MethodBDT.h:189
 MethodBDT.h:190
 MethodBDT.h:191
 MethodBDT.h:192
 MethodBDT.h:193
 MethodBDT.h:194
 MethodBDT.h:195
 MethodBDT.h:196
 MethodBDT.h:197
 MethodBDT.h:198
 MethodBDT.h:199
 MethodBDT.h:200
 MethodBDT.h:201
 MethodBDT.h:202
 MethodBDT.h:203
 MethodBDT.h:204
 MethodBDT.h:205
 MethodBDT.h:206
 MethodBDT.h:207
 MethodBDT.h:208
 MethodBDT.h:209
 MethodBDT.h:210
 MethodBDT.h:211
 MethodBDT.h:212
 MethodBDT.h:213
 MethodBDT.h:214
 MethodBDT.h:215
 MethodBDT.h:216
 MethodBDT.h:217
 MethodBDT.h:218
 MethodBDT.h:219
 MethodBDT.h:220
 MethodBDT.h:221
 MethodBDT.h:222
 MethodBDT.h:223
 MethodBDT.h:224
 MethodBDT.h:225
 MethodBDT.h:226
 MethodBDT.h:227
 MethodBDT.h:228
 MethodBDT.h:229
 MethodBDT.h:230
 MethodBDT.h:231
 MethodBDT.h:232
 MethodBDT.h:233
 MethodBDT.h:234
 MethodBDT.h:235
 MethodBDT.h:236
 MethodBDT.h:237
 MethodBDT.h:238
 MethodBDT.h:239
 MethodBDT.h:240
 MethodBDT.h:241
 MethodBDT.h:242
 MethodBDT.h:243
 MethodBDT.h:244
 MethodBDT.h:245
 MethodBDT.h:246
 MethodBDT.h:247
 MethodBDT.h:248
 MethodBDT.h:249
 MethodBDT.h:250
 MethodBDT.h:251
 MethodBDT.h:252
 MethodBDT.h:253
 MethodBDT.h:254
 MethodBDT.h:255
 MethodBDT.h:256
 MethodBDT.h:257
 MethodBDT.h:258
 MethodBDT.h:259
 MethodBDT.h:260
 MethodBDT.h:261
 MethodBDT.h:262
 MethodBDT.h:263
 MethodBDT.h:264
 MethodBDT.h:265
 MethodBDT.h:266
 MethodBDT.h:267
 MethodBDT.h:268
 MethodBDT.h:269
 MethodBDT.h:270
 MethodBDT.h:271
 MethodBDT.h:272
 MethodBDT.h:273
 MethodBDT.h:274
 MethodBDT.h:275
 MethodBDT.h:276
 MethodBDT.h:277
 MethodBDT.h:278
 MethodBDT.h:279
 MethodBDT.h:280
 MethodBDT.h:281
 MethodBDT.h:282
 MethodBDT.h:283
 MethodBDT.h:284
 MethodBDT.h:285
 MethodBDT.h:286
 MethodBDT.h:287
 MethodBDT.h:288
 MethodBDT.h:289
 MethodBDT.h:290
 MethodBDT.h:291
 MethodBDT.h:292
 MethodBDT.h:293
 MethodBDT.h:294
 MethodBDT.h:295
 MethodBDT.h:296
 MethodBDT.h:297
 MethodBDT.h:298
 MethodBDT.h:299
 MethodBDT.h:300
 MethodBDT.h:301
 MethodBDT.h:302
 MethodBDT.h:303
 MethodBDT.h:304
 MethodBDT.h:305
 MethodBDT.h:306
 MethodBDT.h:307
 MethodBDT.h:308