ROOT logo
// @(#)root/tmva $Id: MethodBDT.h 29122 2009-06-22 06:51:30Z brun $
// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : MethodBDT  (Boosted Decision Trees)                                   *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Analysis of Boosted Decision Trees                                        *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
 *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
 *      Doug Schouten   <dschoute@sfu.ca>        - Simon Fraser U., Canada        *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         *
 *      U. of Victoria, Canada                                                    *
 *      MPI-K Heidelberg, Germany                                                 *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_MethodBDT
#define ROOT_TMVA_MethodBDT

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// MethodBDT                                                            //
//                                                                      //
// Analysis of Boosted Decision Trees                                   //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <vector>
#ifndef ROOT_TH2
#include "TH2.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif

namespace TMVA {
   
   class SeparationBase;
   
   class MethodBDT : public MethodBase {
      
   public:
      // constructor for training and reading
      MethodBDT( const TString& jobName,
                 const TString& methodTitle,
                 DataSetInfo& theData,
                 const TString& theOption = "",
                 TDirectory* theTargetDir = 0 );

      // constructor for calculating BDT-MVA using previously generatad decision trees
      MethodBDT( DataSetInfo& theData,
                 const TString& theWeightFile,
                 TDirectory* theTargetDir = NULL );

      virtual ~MethodBDT( void );

      virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );


      // write all Events from the Tree into a vector of Events, that are
      // more easily manipulated
      void InitEventSample();

      // training method
      void Train( void );

      using MethodBase::WriteWeightsToStream;
      using MethodBase::ReadWeightsFromStream;

      // write weights to file
      void WriteWeightsToStream( ostream& o ) const;
      void AddWeightsXMLTo( void* parent ) const;

      // read weights from file
      void ReadWeightsFromStream( istream& istr );
      void ReadWeightsFromXML(void* parent);

      // write method specific histos to target file
      void WriteMonitoringHistosToFile( void ) const;

      // calculate the MVA value
      Double_t GetMvaValue( Double_t* err = 0);
      Double_t GetMvaValue( Double_t* err , UInt_t useNTrees );

      // regression response
      const std::vector<Float_t>& GetRegressionValues();

      // apply the boost algorithm to a tree in the collection
      Double_t Boost( std::vector<Event*>, DecisionTree *dt, Int_t iTree );

      // ranking of input variables
      const Ranking* CreateRanking();

      // the option handling methods
      void DeclareOptions();
      void ProcessOptions();

      // get the forest
      inline const std::vector<DecisionTree*> & GetForest() const;

      // get the forest
      inline const std::vector<Event*> & GetTrainingEvents() const;

      inline const std::vector<double> & GetBoostWeights() const;

      //return the individual relative variable importance
      std::vector<Double_t> GetVariableImportance();
      Double_t GetVariableImportance(UInt_t ivar);

      Double_t TestTreeQuality( DecisionTree *dt );

      // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
      void MakeClassSpecific( std::ostream&, const TString& ) const;

      // header and auxiliary classes
      void MakeClassSpecificHeader( std::ostream&, const TString& ) const;

      void MakeClassInstantiateNode( DecisionTreeNode *n, std::ostream& fout,
                                     const TString& className ) const;

      void GetHelpMessage() const;

   private:
      // Init used in the various constructors
      void Init( void );

      // boosting algorithm (adaptive boosting)
      Double_t AdaBoost( std::vector<Event*>, DecisionTree *dt );

      // boosting as a random re-weighting
      Double_t Bagging( std::vector<Event*>, Int_t iTree );

      // boosting special for regression
      Double_t RegBoost( std::vector<Event*>, DecisionTree *dt );

      // adaboost adapted to regression
      Double_t AdaBoostR2( std::vector<Event*>, DecisionTree *dt );
      
      // binomial likelihood gradient boost for classification
      // (see Friedman: "Greedy Function Approximation: a Gradient Boosting Machine"
      // Technical report, Dept. of Statistics, Stanford University)
      Double_t GradBoost( std::vector<Event*>, DecisionTree *dt );
      void InitGradBoost( std::vector<Event*>);
      void UpdateTargets( std::vector<Event*>);
      Double_t GetGradBoostMVA(TMVA::Event& e, UInt_t nTrees);
      void GetRandomSubSample();

      std::vector<Event*>             fEventSample;     // the training events
      std::vector<Event*>             fValidationSample;// the Validation events
      std::vector<Event*>             fSubSample;       // subsample for bagged grad boost

      Int_t                           fNTrees;          // number of decision trees requested
      std::vector<DecisionTree*>      fForest;          // the collection of decision trees
      std::vector<double>             fBoostWeights;    // the weights applied in the individual boosts
      TString                         fBoostType;       // string specifying the boost type
      Double_t                        fAdaBoostBeta;    // beta parameter for AdaBoost algorithm
      TString                         fAdaBoostR2Loss;  // loss type used in AdaBoostR2 (Linear,Quadratic or Exponential)
      Double_t                        fShrinkage;       // learning rate for gradient boost;
      Bool_t                          fBaggedGradBoost; // turn bagging in combination with grad boost on/off
      Double_t                        fSampleFraction;  // fraction of events used for bagged grad boost

      //options for the decision Tree
      SeparationBase                 *fSepType;         // the separation used in node splitting
      TString                         fSepTypeS;        // the separation (option string) used in node splitting
      Int_t                           fNodeMinEvents;   // min number of events in node

      Int_t                           fNCuts;           // grid used in cut applied in node splitting
      Bool_t                          fUseYesNoLeaf;    // use sig or bkg classification in leave nodes or sig/bkg
      Double_t                        fNodePurityLimit; // purity limit for sig/bkg nodes
      Bool_t                          fUseWeightedTrees;// use average classification from the trees, or have the individual trees trees in the forest weighted (e.g. log(boostweight) from AdaBoost
      UInt_t                          fNNodesMax;       // max # of nodes
      UInt_t                          fMaxDepth;        // max depth 


      //some histograms for monitoring
      TTree*                           fMonitorNtuple;   // monitoring ntuple
      Int_t                            fITree;           // ntuple var: ith tree
      Double_t                         fBoostWeight;     // ntuple var: boost weight
      Double_t                         fErrorFraction;   // ntuple var: misclassification error fraction
      Double_t                         fPruneStrength;   // a parameter to set the "amount" of pruning..needs to be adjusted
      DecisionTree::EPruneMethod       fPruneMethod;     // method used for prunig
      TString                          fPruneMethodS;    // prune method option String
      Bool_t                           fPruneBeforeBoost;// flag to prune before boosting
      Double_t                         fFValidationEvents;    // fraction of events to use for pruning
      Bool_t                           fAutomatic;       // use user given prune strength or automatically determined one using a validation sample
      Bool_t                           fRandomisedTrees; // choose a random subset of possible cut variables at each node during training
      UInt_t                           fUseNvars;        // the number of variables used in the randomised tree splitting
      UInt_t                           fUseNTrainEvents; // number of randomly picked training events used in randomised (and bagged) trees 

      std::vector<Double_t>           fVariableImportance; // the relative importance of the different variables

      // debugging flags
      static const Int_t  fgDebugLevel = 0;     // debug level determining some printout/control plots etc.


      ClassDef(MethodBDT,0)  // Analysis of Boosted Decision Trees
         };

} // namespace TMVA

const std::vector<TMVA::DecisionTree*>& TMVA::MethodBDT::GetForest()         const { return fForest; }
const std::vector<TMVA::Event*>&        TMVA::MethodBDT::GetTrainingEvents() const { return fEventSample; }
const std::vector<double>&              TMVA::MethodBDT::GetBoostWeights()   const { return fBoostWeights; }

#endif
 MethodBDT.h:1
 MethodBDT.h:2
 MethodBDT.h:3
 MethodBDT.h:4
 MethodBDT.h:5
 MethodBDT.h:6
 MethodBDT.h:7
 MethodBDT.h:8
 MethodBDT.h:9
 MethodBDT.h:10
 MethodBDT.h:11
 MethodBDT.h:12
 MethodBDT.h:13
 MethodBDT.h:14
 MethodBDT.h:15
 MethodBDT.h:16
 MethodBDT.h:17
 MethodBDT.h:18
 MethodBDT.h:19
 MethodBDT.h:20
 MethodBDT.h:21
 MethodBDT.h:22
 MethodBDT.h:23
 MethodBDT.h:24
 MethodBDT.h:25
 MethodBDT.h:26
 MethodBDT.h:27
 MethodBDT.h:28
 MethodBDT.h:29
 MethodBDT.h:30
 MethodBDT.h:31
 MethodBDT.h:32
 MethodBDT.h:33
 MethodBDT.h:34
 MethodBDT.h:35
 MethodBDT.h:36
 MethodBDT.h:37
 MethodBDT.h:38
 MethodBDT.h:39
 MethodBDT.h:40
 MethodBDT.h:41
 MethodBDT.h:42
 MethodBDT.h:43
 MethodBDT.h:44
 MethodBDT.h:45
 MethodBDT.h:46
 MethodBDT.h:47
 MethodBDT.h:48
 MethodBDT.h:49
 MethodBDT.h:50
 MethodBDT.h:51
 MethodBDT.h:52
 MethodBDT.h:53
 MethodBDT.h:54
 MethodBDT.h:55
 MethodBDT.h:56
 MethodBDT.h:57
 MethodBDT.h:58
 MethodBDT.h:59
 MethodBDT.h:60
 MethodBDT.h:61
 MethodBDT.h:62
 MethodBDT.h:63
 MethodBDT.h:64
 MethodBDT.h:65
 MethodBDT.h:66
 MethodBDT.h:67
 MethodBDT.h:68
 MethodBDT.h:69
 MethodBDT.h:70
 MethodBDT.h:71
 MethodBDT.h:72
 MethodBDT.h:73
 MethodBDT.h:74
 MethodBDT.h:75
 MethodBDT.h:76
 MethodBDT.h:77
 MethodBDT.h:78
 MethodBDT.h:79
 MethodBDT.h:80
 MethodBDT.h:81
 MethodBDT.h:82
 MethodBDT.h:83
 MethodBDT.h:84
 MethodBDT.h:85
 MethodBDT.h:86
 MethodBDT.h:87
 MethodBDT.h:88
 MethodBDT.h:89
 MethodBDT.h:90
 MethodBDT.h:91
 MethodBDT.h:92
 MethodBDT.h:93
 MethodBDT.h:94
 MethodBDT.h:95
 MethodBDT.h:96
 MethodBDT.h:97
 MethodBDT.h:98
 MethodBDT.h:99
 MethodBDT.h:100
 MethodBDT.h:101
 MethodBDT.h:102
 MethodBDT.h:103
 MethodBDT.h:104
 MethodBDT.h:105
 MethodBDT.h:106
 MethodBDT.h:107
 MethodBDT.h:108
 MethodBDT.h:109
 MethodBDT.h:110
 MethodBDT.h:111
 MethodBDT.h:112
 MethodBDT.h:113
 MethodBDT.h:114
 MethodBDT.h:115
 MethodBDT.h:116
 MethodBDT.h:117
 MethodBDT.h:118
 MethodBDT.h:119
 MethodBDT.h:120
 MethodBDT.h:121
 MethodBDT.h:122
 MethodBDT.h:123
 MethodBDT.h:124
 MethodBDT.h:125
 MethodBDT.h:126
 MethodBDT.h:127
 MethodBDT.h:128
 MethodBDT.h:129
 MethodBDT.h:130
 MethodBDT.h:131
 MethodBDT.h:132
 MethodBDT.h:133
 MethodBDT.h:134
 MethodBDT.h:135
 MethodBDT.h:136
 MethodBDT.h:137
 MethodBDT.h:138
 MethodBDT.h:139
 MethodBDT.h:140
 MethodBDT.h:141
 MethodBDT.h:142
 MethodBDT.h:143
 MethodBDT.h:144
 MethodBDT.h:145
 MethodBDT.h:146
 MethodBDT.h:147
 MethodBDT.h:148
 MethodBDT.h:149
 MethodBDT.h:150
 MethodBDT.h:151
 MethodBDT.h:152
 MethodBDT.h:153
 MethodBDT.h:154
 MethodBDT.h:155
 MethodBDT.h:156
 MethodBDT.h:157
 MethodBDT.h:158
 MethodBDT.h:159
 MethodBDT.h:160
 MethodBDT.h:161
 MethodBDT.h:162
 MethodBDT.h:163
 MethodBDT.h:164
 MethodBDT.h:165
 MethodBDT.h:166
 MethodBDT.h:167
 MethodBDT.h:168
 MethodBDT.h:169
 MethodBDT.h:170
 MethodBDT.h:171
 MethodBDT.h:172
 MethodBDT.h:173
 MethodBDT.h:174
 MethodBDT.h:175
 MethodBDT.h:176
 MethodBDT.h:177
 MethodBDT.h:178
 MethodBDT.h:179
 MethodBDT.h:180
 MethodBDT.h:181
 MethodBDT.h:182
 MethodBDT.h:183
 MethodBDT.h:184
 MethodBDT.h:185
 MethodBDT.h:186
 MethodBDT.h:187
 MethodBDT.h:188
 MethodBDT.h:189
 MethodBDT.h:190
 MethodBDT.h:191
 MethodBDT.h:192
 MethodBDT.h:193
 MethodBDT.h:194
 MethodBDT.h:195
 MethodBDT.h:196
 MethodBDT.h:197
 MethodBDT.h:198
 MethodBDT.h:199
 MethodBDT.h:200
 MethodBDT.h:201
 MethodBDT.h:202
 MethodBDT.h:203
 MethodBDT.h:204
 MethodBDT.h:205
 MethodBDT.h:206
 MethodBDT.h:207
 MethodBDT.h:208
 MethodBDT.h:209
 MethodBDT.h:210
 MethodBDT.h:211
 MethodBDT.h:212
 MethodBDT.h:213
 MethodBDT.h:214
 MethodBDT.h:215
 MethodBDT.h:216
 MethodBDT.h:217
 MethodBDT.h:218
 MethodBDT.h:219
 MethodBDT.h:220
 MethodBDT.h:221
 MethodBDT.h:222
 MethodBDT.h:223
 MethodBDT.h:224
 MethodBDT.h:225
 MethodBDT.h:226