#ifndef ROOT_TMVA_RuleFit
#define ROOT_TMVA_RuleFit
#include <algorithm>
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_RuleEnsemble
#include "TMVA/RuleEnsemble.h"
#endif
#ifndef ROOT_TMVA_RuleFitParams
#include "TMVA/RuleFitParams.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif
#ifndef ROOT_TMVA_MsgLogger
#include "TMVA/MsgLogger.h"
#endif
namespace TMVA {
   class MethodBase;
   class MethodRuleFit;
   class RuleFit {
   public:
      
      RuleFit( const TMVA::MethodBase *rfbase );
      
      RuleFit( void );
      virtual ~RuleFit( void );
      void InitNEveEff();
      void InitPtrs( const TMVA::MethodBase *rfbase );
      void Initialize(  const TMVA::MethodBase *rfbase );
      void SetMsgType( EMsgType t );
      void SetTrainingEvents( const std::vector<TMVA::Event *> & el );
      void ReshuffleEvents() { std::random_shuffle(fTrainingEventsRndm.begin(),fTrainingEventsRndm.end()); }
      void SetMethodBase( const MethodBase *rfbase );
      
      void MakeForest();
      
      void BuildTree( TMVA::DecisionTree *dt );
      
      void SaveEventWeights();
      
      void RestoreEventWeights();
      
      void Boost( TMVA::DecisionTree *dt );
      
      void ForestStatistics();
      
      Double_t EvalEvent( const Event& e );
      
      Double_t CalcWeightSum( const std::vector<TMVA::Event *> *events, UInt_t neve=0 );
      
      void     FitCoefficients();
      
      void     CalcImportance();
      
      void     SetModelLinear()                      { fRuleEnsemble.SetModelLinear(); }
      
      void     SetModelRules()                       { fRuleEnsemble.SetModelRules(); }
      
      void     SetModelFull()                        { fRuleEnsemble.SetModelFull(); }
      
      void     SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
      
      void     SetRuleMinDist( Double_t d )          { fRuleEnsemble.SetRuleMinDist(d); }
      
      void     SetGDTau( Double_t t=0.0 )       { fRuleFitParams.SetGDTau(t); }
      void     SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
      void     SetGDNPathSteps( Int_t n=100 )   { fRuleFitParams.SetGDNPathSteps(n); }
      
      void     SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
      void     UseImportanceVisHists()       { fVisHistsUseImp = kTRUE; }
      void     UseCoefficientsVisHists()     { fVisHistsUseImp = kFALSE; }
      void     MakeVisHists();
      void     FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
      void     FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
      void     FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
      void     FillLin(TH2F* h2,Int_t vind);
      void     FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
      void     NormVisHists(std::vector<TH2F *> & hlist);
      void     MakeDebugHists();
      Bool_t   GetCorrVars(TString & title, TString & var1, TString & var2);
      
      UInt_t        GetNTreeSample()            const { return fNTreeSample; }
      Double_t      GetNEveEff()                const { return fNEveEffTrain; } 
      const Event*  GetTrainingEvent(UInt_t i)  const { return static_cast< const Event *>(fTrainingEvents[i]); }
      Double_t      GetTrainingEventWeight(UInt_t i)  const { return fTrainingEvents[i]->GetWeight(); }
      
      const std::vector< TMVA::Event * > & GetTrainingEvents()  const { return fTrainingEvents; }
      
      
      void  GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
      
      const std::vector< const TMVA::DecisionTree *> & GetForest()     const { return fForest; }
      const RuleEnsemble                       & GetRuleEnsemble()     const { return fRuleEnsemble; }
            RuleEnsemble                       * GetRuleEnsemblePtr()        { return &fRuleEnsemble; }
      const RuleFitParams                      & GetRuleFitParams()    const { return fRuleFitParams; }
            RuleFitParams                      * GetRuleFitParamsPtr()       { return &fRuleFitParams; }
      const MethodRuleFit                      * GetMethodRuleFit()    const { return fMethodRuleFit; }
      const MethodBase                         * GetMethodBase()       const { return fMethodBase; }
   private:
      
      RuleFit( const RuleFit & other );
      
      
      std::vector<TMVA::Event *>          fTrainingEvents;      
      std::vector<TMVA::Event *>          fTrainingEventsRndm;  
      std::vector<Double_t>               fEventWeights;        
      UInt_t                              fNTreeSample;         
      Double_t                            fNEveEffTrain;    
      std::vector< const TMVA::DecisionTree *>  fForest;    
      RuleEnsemble                        fRuleEnsemble;    
      RuleFitParams                       fRuleFitParams;   
      const MethodRuleFit                *fMethodRuleFit;   
      const MethodBase                   *fMethodBase;      
      Bool_t                              fVisHistsUseImp;  
      mutable MsgLogger                   fLogger;          
      ClassDef(RuleFit,0)  
   };
}
#endif
Last change: Tue May 13 17:21:00 2008
Last generated: 2008-05-13 17:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.