#ifndef ROOT_TMVA_MethodRuleFit
#define ROOT_TMVA_MethodRuleFit
#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef ROOT_TMatrixDfwd
#include "TMatrixDfwd.h"
#endif
#ifndef ROOT_TVectorD
#include "TVectorD.h"
#endif
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_RuleFit
#include "TMVA/RuleFit.h"
#endif
namespace TMVA {
class SeparationBase;
class MethodRuleFit : public MethodBase {
public:
MethodRuleFit( const TString& jobName,
const TString& methodTitle,
DataSetInfo& theData,
const TString& theOption = "",
TDirectory* theTargetDir = 0 );
MethodRuleFit( DataSetInfo& theData,
const TString& theWeightFile,
TDirectory* theTargetDir = NULL );
virtual ~MethodRuleFit( void );
virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t );
void Train( void );
using MethodBase::ReadWeightsFromStream;
void AddWeightsXMLTo ( void* parent ) const;
void ReadWeightsFromStream( istream& istr );
void ReadWeightsFromXML ( void* wghtnode );
Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
void WriteMonitoringHistosToFile( void ) const;
const Ranking* CreateRanking();
Bool_t UseBoost() const { return fUseBoost; }
RuleFit* GetRuleFitPtr() { return &fRuleFit; }
const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
TDirectory* GetMethodBaseDir() const { return BaseDir(); }
const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
Int_t GetNTrees() const { return fNTrees; }
Double_t GetTreeEveFrac() const { return fTreeEveFrac; }
const SeparationBase* GetSeparationBaseConst() const { return fSepType; }
SeparationBase* GetSeparationBase() const { return fSepType; }
TMVA::DecisionTree::EPruneMethod GetPruneMethod() const { return fPruneMethod; }
Double_t GetPruneStrength() const { return fPruneStrength; }
Double_t GetMinFracNEve() const { return fMinFracNEve; }
Double_t GetMaxFracNEve() const { return fMaxFracNEve; }
Int_t GetNCuts() const { return fNCuts; }
Int_t GetGDNPathSteps() const { return fGDNPathSteps; }
Double_t GetGDPathStep() const { return fGDPathStep; }
Double_t GetGDErrScale() const { return fGDErrScale; }
Double_t GetGDPathEveFrac() const { return fGDPathEveFrac; }
Double_t GetGDValidEveFrac() const { return fGDValidEveFrac; }
Double_t GetLinQuantile() const { return fLinQuantile; }
const TString GetRFWorkDir() const { return fRFWorkDir; }
Int_t GetRFNrules() const { return fRFNrules; }
Int_t GetRFNendnodes() const { return fRFNendnodes; }
protected:
void MakeClassSpecific( std::ostream&, const TString& ) const;
void MakeClassRuleCuts( std::ostream& ) const;
void MakeClassLinear( std::ostream& ) const;
void GetHelpMessage() const;
void Init( void );
void InitEventSample( void );
void InitMonitorNtuple();
void TrainTMVARuleFit();
void TrainJFRuleFit();
private:
template<typename T>
inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
template<typename T>
inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
template<typename T>
inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
void DeclareOptions();
void ProcessOptions();
RuleFit fRuleFit;
std::vector<TMVA::Event *> fEventSample;
Double_t fSignalFraction;
TTree *fMonitorNtuple;
Double_t fNTImportance;
Double_t fNTCoefficient;
Double_t fNTSupport;
Int_t fNTNcuts;
Int_t fNTNvars;
Double_t fNTPtag;
Double_t fNTPss;
Double_t fNTPsb;
Double_t fNTPbs;
Double_t fNTPbb;
Double_t fNTSSB;
Int_t fNTType;
TString fRuleFitModuleS;
Bool_t fUseRuleFitJF;
TString fRFWorkDir;
Int_t fRFNrules;
Int_t fRFNendnodes;
std::vector<DecisionTree *> fForest;
Int_t fNTrees;
Double_t fTreeEveFrac;
SeparationBase *fSepType;
Double_t fMinFracNEve;
Double_t fMaxFracNEve;
Int_t fNCuts;
TString fSepTypeS;
TString fPruneMethodS;
TMVA::DecisionTree::EPruneMethod fPruneMethod;
Double_t fPruneStrength;
TString fForestTypeS;
Bool_t fUseBoost;
Double_t fGDPathEveFrac;
Double_t fGDValidEveFrac;
Double_t fGDTau;
Double_t fGDTauPrec;
Double_t fGDTauMin;
Double_t fGDTauMax;
UInt_t fGDTauScan;
Double_t fGDPathStep;
Int_t fGDNPathSteps;
Double_t fGDErrScale;
Double_t fMinimp;
TString fModelTypeS;
Double_t fRuleMinDist;
Double_t fLinQuantile;
ClassDef(MethodRuleFit,0)
};
}
template<typename T>
inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
{
if (var>vmax) return 1;
if (var<vmin) return -1;
return 0;
}
template<typename T>
inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
{
Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
Bool_t modif=kFALSE;
if (dir==1) {
modif = kTRUE;
var=vmax;
}
if (dir==-1) {
modif = kTRUE;
var=vmin;
}
if (modif) {
mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
}
return modif;
}
template<typename T>
inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
{
Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
Bool_t modif=kFALSE;
if (dir!=0) {
modif = kTRUE;
var=vdef;
}
if (modif) {
mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
}
return modif;
}
#endif // MethodRuleFit_H