// @(#)root/tmva $Id$
// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : RuleFit                                                               *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      A class implementing various fits of rule ensembles                       *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
 *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         *
 *      Iowa State U.                                                             *
 *      MPI-K Heidelberg, Germany                                                 *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_RuleFit
#define ROOT_TMVA_RuleFit

#include <algorithm>

#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_RuleEnsemble
#include "TMVA/RuleEnsemble.h"
#endif
#ifndef ROOT_TMVA_RuleFitParams
#include "TMVA/RuleFitParams.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif

namespace TMVA {


   class MethodBase;
   class MethodRuleFit;
   class MsgLogger;

   class RuleFit {

   public:

      // main constructor
      RuleFit( const TMVA::MethodBase *rfbase );

      // empty constructor
      RuleFit( void );

      virtual ~RuleFit( void );

      void InitNEveEff();
      void InitPtrs( const TMVA::MethodBase *rfbase );
      void Initialize(  const TMVA::MethodBase *rfbase );

      void SetMsgType( EMsgType t );

      void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );

      void ReshuffleEvents() { std::random_shuffle(fTrainingEventsRndm.begin(),fTrainingEventsRndm.end()); }

      void SetMethodBase( const MethodBase *rfbase );

      // make the forest of trees for rule generation
      void MakeForest();

      // build a tree
      void BuildTree( TMVA::DecisionTree *dt );

      // save event weights
      void SaveEventWeights();

      // restore saved event weights
      void RestoreEventWeights();

      // boost events based on the given tree
      void Boost( TMVA::DecisionTree *dt );

      // calculate and print some statistics on the given forest
      void ForestStatistics();

      // calculate the discriminating variable for the given event
      Double_t EvalEvent( const Event& e );

      // calculate sum of 
      Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );

      // do the fitting of the coefficients
      void     FitCoefficients();

      // calculate variable and rule importance from a set of events
      void     CalcImportance();

      // set usage of linear term
      void     SetModelLinear()                      { fRuleEnsemble.SetModelLinear(); }
      // set usage of rules
      void     SetModelRules()                       { fRuleEnsemble.SetModelRules(); }
      // set usage of linear term
      void     SetModelFull()                        { fRuleEnsemble.SetModelFull(); }
      // set minimum importance allowed
      void     SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
      // set minimum rule distance - see RuleEnsemble
      void     SetRuleMinDist( Double_t d )          { fRuleEnsemble.SetRuleMinDist(d); }
      // set path related parameters
      void     SetGDTau( Double_t t=0.0 )       { fRuleFitParams.SetGDTau(t); }
      void     SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
      void     SetGDNPathSteps( Int_t n=100 )   { fRuleFitParams.SetGDNPathSteps(n); }
      // make visualization histograms
      void     SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
      void     UseImportanceVisHists()       { fVisHistsUseImp = kTRUE; }
      void     UseCoefficientsVisHists()     { fVisHistsUseImp = kFALSE; }
      void     MakeVisHists();
      void     FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
      void     FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
      void     FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
      void     FillLin(TH2F* h2,Int_t vind);
      void     FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
      void     NormVisHists(std::vector<TH2F *> & hlist);
      void     MakeDebugHists();
      Bool_t   GetCorrVars(TString & title, TString & var1, TString & var2);
      // accessors
      UInt_t        GetNTreeSample()            const { return fNTreeSample; }
      Double_t      GetNEveEff()                const { return fNEveEffTrain; } // reweighted number of events = sum(wi)
      const Event*  GetTrainingEvent(UInt_t i)  const { return static_cast< const Event *>(fTrainingEvents[i]); }
      Double_t      GetTrainingEventWeight(UInt_t i)  const { return fTrainingEvents[i]->GetWeight(); }

      //      const Event*  GetTrainingEvent(UInt_t i, UInt_t isub)  const { return &(fTrainingEvents[fSubsampleEvents[isub]])[i]; }

      const std::vector< const TMVA::Event * > & GetTrainingEvents()  const { return fTrainingEvents; }
      //      const std::vector< Int_t >               & GetSubsampleEvents() const { return fSubsampleEvents; }

      //      void  GetSubsampleEvents(Int_t sub, UInt_t & ibeg, UInt_t & iend) const;
      void  GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
      //
      const std::vector< const TMVA::DecisionTree *> & GetForest()     const { return fForest; }
      const RuleEnsemble                       & GetRuleEnsemble()     const { return fRuleEnsemble; }
            RuleEnsemble                       * GetRuleEnsemblePtr()        { return &fRuleEnsemble; }
      const RuleFitParams                      & GetRuleFitParams()    const { return fRuleFitParams; }
            RuleFitParams                      * GetRuleFitParamsPtr()       { return &fRuleFitParams; }
      const MethodRuleFit                      * GetMethodRuleFit()    const { return fMethodRuleFit; }
      const MethodBase                         * GetMethodBase()       const { return fMethodBase; }

   private:

      // copy constructor
      RuleFit( const RuleFit & other );

      // copy method
      void Copy( const RuleFit & other );

      std::vector<const TMVA::Event *>    fTrainingEvents;      // all training events
      std::vector<const TMVA::Event *>    fTrainingEventsRndm;  // idem, but randomly shuffled
      std::vector<Double_t>               fEventWeights;        // original weights of the events - follows fTrainingEvents
      UInt_t                              fNTreeSample;         // number of events in sub sample = frac*neve

      Double_t                            fNEveEffTrain;    // reweighted number of events = sum(wi)
      std::vector< const TMVA::DecisionTree *>  fForest;    // the input forest of decision trees
      RuleEnsemble                        fRuleEnsemble;    // the ensemble of rules
      RuleFitParams                       fRuleFitParams;   // fit rule parameters
      const MethodRuleFit                *fMethodRuleFit;   // pointer the method which initialized this RuleFit instance
      const MethodBase                   *fMethodBase;      // pointer the method base which initialized this RuleFit instance
      Bool_t                              fVisHistsUseImp;  // if true, use importance as weight; else coef in vis hists

      mutable MsgLogger*                  fLogger;   // message logger
      MsgLogger& Log() const { return *fLogger; }    

      static const Int_t randSEED = 0; // set to 1 for debugging purposes or to zero for random seeds

      ClassDef(RuleFit,0)  // Calculations for Friedman's RuleFit method
   };
}

#endif
 RuleFit.h:1
 RuleFit.h:2
 RuleFit.h:3
 RuleFit.h:4
 RuleFit.h:5
 RuleFit.h:6
 RuleFit.h:7
 RuleFit.h:8
 RuleFit.h:9
 RuleFit.h:10
 RuleFit.h:11
 RuleFit.h:12
 RuleFit.h:13
 RuleFit.h:14
 RuleFit.h:15
 RuleFit.h:16
 RuleFit.h:17
 RuleFit.h:18
 RuleFit.h:19
 RuleFit.h:20
 RuleFit.h:21
 RuleFit.h:22
 RuleFit.h:23
 RuleFit.h:24
 RuleFit.h:25
 RuleFit.h:26
 RuleFit.h:27
 RuleFit.h:28
 RuleFit.h:29
 RuleFit.h:30
 RuleFit.h:31
 RuleFit.h:32
 RuleFit.h:33
 RuleFit.h:34
 RuleFit.h:35
 RuleFit.h:36
 RuleFit.h:37
 RuleFit.h:38
 RuleFit.h:39
 RuleFit.h:40
 RuleFit.h:41
 RuleFit.h:42
 RuleFit.h:43
 RuleFit.h:44
 RuleFit.h:45
 RuleFit.h:46
 RuleFit.h:47
 RuleFit.h:48
 RuleFit.h:49
 RuleFit.h:50
 RuleFit.h:51
 RuleFit.h:52
 RuleFit.h:53
 RuleFit.h:54
 RuleFit.h:55
 RuleFit.h:56
 RuleFit.h:57
 RuleFit.h:58
 RuleFit.h:59
 RuleFit.h:60
 RuleFit.h:61
 RuleFit.h:62
 RuleFit.h:63
 RuleFit.h:64
 RuleFit.h:65
 RuleFit.h:66
 RuleFit.h:67
 RuleFit.h:68
 RuleFit.h:69
 RuleFit.h:70
 RuleFit.h:71
 RuleFit.h:72
 RuleFit.h:73
 RuleFit.h:74
 RuleFit.h:75
 RuleFit.h:76
 RuleFit.h:77
 RuleFit.h:78
 RuleFit.h:79
 RuleFit.h:80
 RuleFit.h:81
 RuleFit.h:82
 RuleFit.h:83
 RuleFit.h:84
 RuleFit.h:85
 RuleFit.h:86
 RuleFit.h:87
 RuleFit.h:88
 RuleFit.h:89
 RuleFit.h:90
 RuleFit.h:91
 RuleFit.h:92
 RuleFit.h:93
 RuleFit.h:94
 RuleFit.h:95
 RuleFit.h:96
 RuleFit.h:97
 RuleFit.h:98
 RuleFit.h:99
 RuleFit.h:100
 RuleFit.h:101
 RuleFit.h:102
 RuleFit.h:103
 RuleFit.h:104
 RuleFit.h:105
 RuleFit.h:106
 RuleFit.h:107
 RuleFit.h:108
 RuleFit.h:109
 RuleFit.h:110
 RuleFit.h:111
 RuleFit.h:112
 RuleFit.h:113
 RuleFit.h:114
 RuleFit.h:115
 RuleFit.h:116
 RuleFit.h:117
 RuleFit.h:118
 RuleFit.h:119
 RuleFit.h:120
 RuleFit.h:121
 RuleFit.h:122
 RuleFit.h:123
 RuleFit.h:124
 RuleFit.h:125
 RuleFit.h:126
 RuleFit.h:127
 RuleFit.h:128
 RuleFit.h:129
 RuleFit.h:130
 RuleFit.h:131
 RuleFit.h:132
 RuleFit.h:133
 RuleFit.h:134
 RuleFit.h:135
 RuleFit.h:136
 RuleFit.h:137
 RuleFit.h:138
 RuleFit.h:139
 RuleFit.h:140
 RuleFit.h:141
 RuleFit.h:142
 RuleFit.h:143
 RuleFit.h:144
 RuleFit.h:145
 RuleFit.h:146
 RuleFit.h:147
 RuleFit.h:148
 RuleFit.h:149
 RuleFit.h:150
 RuleFit.h:151
 RuleFit.h:152
 RuleFit.h:153
 RuleFit.h:154
 RuleFit.h:155
 RuleFit.h:156
 RuleFit.h:157
 RuleFit.h:158
 RuleFit.h:159
 RuleFit.h:160
 RuleFit.h:161
 RuleFit.h:162
 RuleFit.h:163
 RuleFit.h:164
 RuleFit.h:165
 RuleFit.h:166
 RuleFit.h:167
 RuleFit.h:168
 RuleFit.h:169
 RuleFit.h:170
 RuleFit.h:171
 RuleFit.h:172
 RuleFit.h:173
 RuleFit.h:174
 RuleFit.h:175
 RuleFit.h:176
 RuleFit.h:177
 RuleFit.h:178
 RuleFit.h:179
 RuleFit.h:180
 RuleFit.h:181
 RuleFit.h:182
 RuleFit.h:183
 RuleFit.h:184
 RuleFit.h:185