#ifndef ROOT_TMVA_RuleFitAPI
#define ROOT_TMVA_RuleFitAPI
#ifndef ROOT_TMVA_MethodRuleFit
#include "TMVA/MethodRuleFit.h"
#endif
#ifndef ROOT_TMVA_RuleFit
#include "TMVA/RuleFit.h"
#endif
namespace TMVA {
   class MsgLogger;
   class RuleFitAPI {
   public:
      RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );
      virtual ~RuleFitAPI();
      
      void WelcomeMessage();
      
      void HowtoSetupRF();
      
      void SetRFWorkDir(const char * wdir);
      
      void CheckRFWorkDir();
      
      inline void TrainRuleFit();
      inline void TestRuleFit();
      inline void VarImp();
      
      Bool_t ReadModelSum();
      
      const TString GetRFWorkDir() const { return fRFWorkDir; }
   protected:
      enum ERFMode    { kRfRegress=1, kRfClass=2 };          
      enum EModel     { kRfLinear=0, kRfRules=1, kRfBoth=2 }; 
      enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };    
  
      
      typedef struct {
         Int_t mode;
         Int_t lmode;
         Int_t n;
         Int_t p;
         Int_t max_rules;
         Int_t tree_size;
         Int_t path_speed;
         Int_t path_xval;
         Int_t path_steps;
         Int_t path_testfreq;
         Int_t tree_store;
         Int_t cat_store;
      } IntParms;
      
      typedef struct {
         Float_t  xmiss;
         Float_t  trim_qntl;
         Float_t  huber;
         Float_t  inter_supp;
         Float_t  memory_par;
         Float_t  samp_fract;
         Float_t  path_inc;
         Float_t  conv_fac;
      } RealParms;
      
      void InitRuleFit();
      void FillRealParmsDef();
      void FillIntParmsDef();
      void ImportSetup();
      void SetTrainParms();
      void SetTestParms();
      
      Int_t  RunRuleFit();
      
      void SetRFTrain()   { fRFProgram = kRfTrain; }
      void SetRFPredict() { fRFProgram = kRfPredict; }
      void SetRFVarimp()  { fRFProgram = kRfVarimp; }
      
      inline TString GetRFName(TString name);
      inline Bool_t  OpenRFile(TString name, std::ofstream & f);
      inline Bool_t  OpenRFile(TString name, std::ifstream & f);
      
      inline Bool_t WriteInt(ofstream &   f, const Int_t   *v, Int_t n=1);
      inline Bool_t WriteFloat(ofstream & f, const Float_t *v, Int_t n=1);
      inline Int_t  ReadInt(ifstream & f,   Int_t *v, Int_t n=1) const;
      inline Int_t  ReadFloat(ifstream & f, Float_t *v, Int_t n=1) const;
  
      
      Bool_t WriteAll();
      Bool_t WriteIntParms();
      Bool_t WriteRealParms();
      Bool_t WriteLx();
      Bool_t WriteProgram();
      Bool_t WriteRealVarImp();
      Bool_t WriteRfOut();
      Bool_t WriteRfStatus();
      Bool_t WriteRuleFitMod();
      Bool_t WriteRuleFitSum();
      Bool_t WriteTrain();
      Bool_t WriteVarNames();
      Bool_t WriteVarImp();
      Bool_t WriteYhat();
      Bool_t WriteTest();
      
      Bool_t ReadYhat();
      Bool_t ReadIntParms();
      Bool_t ReadRealParms();
      Bool_t ReadLx();
      Bool_t ReadProgram();
      Bool_t ReadRealVarImp();
      Bool_t ReadRfOut();
      Bool_t ReadRfStatus();
      Bool_t ReadRuleFitMod();
      Bool_t ReadRuleFitSum();
      Bool_t ReadTrainX();
      Bool_t ReadTrainY();
      Bool_t ReadTrainW();
      Bool_t ReadVarNames();
      Bool_t ReadVarImp();
   private:
      
      RuleFitAPI();
      const MethodRuleFit *fMethodRuleFit; 
      RuleFit             *fRuleFit;       
      
      std::vector<Float_t> fRFYhat;      
      std::vector<Float_t> fRFVarImp;    
      std::vector<Int_t>   fRFVarImpInd; 
      TString              fRFWorkDir;   
      IntParms             fRFIntParms;  
      RealParms            fRFRealParms; 
      std::vector<int>     fRFLx;        
      ERFProgram           fRFProgram;   
      TString              fModelType;   
      mutable MsgLogger    fLogger;          
      ClassDef(RuleFitAPI,0)        
   };
} 
void TMVA::RuleFitAPI::TrainRuleFit()
{
   
   SetTrainParms();
   WriteAll();
   RunRuleFit();
}
void TMVA::RuleFitAPI::TestRuleFit()
{
   
   SetTestParms();
   WriteAll();
   RunRuleFit();
   ReadYhat(); 
}
void TMVA::RuleFitAPI::VarImp()
{
   
   SetRFVarimp();
   WriteAll();
   RunRuleFit();
   ReadVarImp(); 
}
TString TMVA::RuleFitAPI::GetRFName(TString name)
{
   
   return fRFWorkDir+"/"+name;
}
Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
{
   
   TString fullName = GetRFName(name);
   f.open(fullName);
   if (!f.is_open()) {
      fLogger << kERROR << "Error opening RuleFit file for output: "
              << fullName << Endl;
      return kFALSE;
   }
   return kTRUE;
}
Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
{
   
   TString fullName = GetRFName(name);
   f.open(fullName);
   if (!f.is_open()) {
      fLogger << kERROR << "Error opening RuleFit file for input: "
              << fullName << Endl;
      return kFALSE;
   }
   return kTRUE;
}
Bool_t TMVA::RuleFitAPI::WriteInt(ofstream &   f, const Int_t   *v, Int_t n)
{
   
   if (!f.is_open()) return kFALSE;
   return f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
}
Bool_t TMVA::RuleFitAPI::WriteFloat(ofstream & f, const Float_t *v, Int_t n)
{
   
   if (!f.is_open()) return kFALSE;
   return f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
}
Int_t TMVA::RuleFitAPI::ReadInt(ifstream & f,   Int_t *v, Int_t n) const
{
   
   if (!f.is_open()) return 0;
   if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
   return 0;
}
Int_t TMVA::RuleFitAPI::ReadFloat(ifstream & f, Float_t *v, Int_t n) const
{
   
   if (!f.is_open()) return 0;
   if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
   return 0;
}
#endif // RuleFitAPI_H
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.