33 #ifndef ROOT_TMVA_MethodBase 34 #define ROOT_TMVA_MethodBase 53 #ifndef ROOT_TMVA_IMethod 56 #ifndef ROOT_TMVA_Configurable 59 #ifndef ROOT_TMVA_Types 62 #ifndef ROOT_TMVA_DataSet 65 #ifndef ROOT_TMVA_Event 68 #ifndef ROOT_TMVA_TransformationHandler 71 #ifndef ROOT_TMVA_Results 106 void Init(std::vector<TString>& graphTitles);
109 void AddPoint(std::vector<Double_t>& dat);
133 const TString& theOption =
"" );
147 virtual void CheckSetup();
159 virtual std::map<TString,Double_t> OptimizeTuningParameters(
TString fomType=
"ROCIntegral",
TString fitType=
"FitGA");
160 virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
162 virtual void Train() = 0;
173 virtual void TestClassification();
177 virtual void TestMulticlass();
188 virtual void Init() = 0;
189 virtual void DeclareOptions() = 0;
190 virtual void ProcessOptions() = 0;
191 virtual void DeclareCompatibilityOptions();
212 virtual std::vector<Double_t> GetMvaValues(
Long64_t firstEvt = 0,
Long64_t lastEvt = -1,
Bool_t logProgress =
false);
219 const std::vector<Float_t>* ptr = &GetRegressionValues();
225 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
231 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
243 virtual const Ranking* CreateRanking() = 0;
246 virtual void MakeClass(
const TString& classFileName =
TString(
"") )
const;
249 void PrintHelpMessage()
const;
255 void WriteStateToFile ()
const;
256 void ReadStateFromFile ();
260 virtual void AddWeightsXMLTo (
void* parent )
const = 0;
261 virtual void ReadWeightsFromXML (
void* wghtnode ) = 0;
262 virtual void ReadWeightsFromStream( std::istream& ) = 0;
268 void WriteStateToXML (
void* parent )
const;
269 void ReadStateFromXML (
void* parent );
270 void WriteStateToStream ( std::ostream& tf )
const;
271 void WriteVarsToStream ( std::ostream& tf,
const TString& prefix =
"" )
const;
275 void ReadStateFromStream ( std::istream& tf );
276 void ReadStateFromStream (
TFile& rf );
277 void ReadStateFromXMLString(
const char* xmlstr );
281 void AddVarsXMLTo (
void* parent )
const;
282 void AddSpectatorsXMLTo (
void* parent )
const;
283 void AddTargetsXMLTo (
void* parent )
const;
284 void AddClassesXMLTo (
void* parent )
const;
285 void ReadVariablesFromXML (
void* varnode );
286 void ReadSpectatorsFromXML(
void* specnode);
287 void ReadTargetsFromXML (
void* tarnode );
288 void ReadClassesFromXML (
void* clsnode );
289 void ReadVarsFromStream ( std::istream& istr );
298 virtual void WriteMonitoringHistosToFile()
const;
312 virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
313 virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
314 virtual Double_t GetSignificance()
const;
318 Double_t& optimal_significance_value )
const;
320 virtual Double_t GetSeparation(
PDF* pdfS = 0,
PDF* pdfB = 0 )
const;
330 const char*
GetName ()
const {
return fMethodName.Data(); }
333 TString GetWeightFileName()
const;
387 TString GetTrainingTMVAVersionString()
const;
388 TString GetTrainingROOTVersionString()
const;
392 if(fTransformationPointer && takeReroutedIfAvailable)
return *fTransformationPointer;
else return fTransformation;
396 if(fTransformationPointer && takeReroutedIfAvailable)
return *fTransformationPointer;
else return fTransformation;
414 const Event* GetEvent ()
const;
420 const std::vector<TMVA::Event*>& GetEventCollection(
Types::ETreeType type );
427 virtual Bool_t IsSignalLike();
443 bool fExitFromTraining =
false;
444 UInt_t fIPyMaxIter = 0, fIPyCurrentIter = 0;
450 if (fInteractive)
delete fInteractive;
459 fExitFromTraining =
true;
464 if (fExitFromTraining && fInteractive){
466 fInteractive =
nullptr;
468 return fExitFromTraining;
484 void SetWeightFileName(
TString );
487 void SetWeightFileDir(
TString fileDir );
541 void DeclareBaseOptions();
542 void ProcessBaseOptions();
551 void ResetThisBase();
556 void CreateMVAPdfs();
563 Bool_t GetLine( std::istream& fin,
char * buf );
742 return GetTransformationHandler().Transform(ev);
748 return GetTransformationHandler().Transform(fTmpEvent);
750 return GetTransformationHandler().Transform(
Data()->GetEvent());
755 assert(fTmpEvent==0);
756 return GetTransformationHandler().Transform(
Data()->GetEvent(ievt));
761 assert(fTmpEvent==0);
762 return GetTransformationHandler().Transform(
Data()->GetEvent(ievt, type));
767 assert(fTmpEvent==0);
773 assert(fTmpEvent==0);
Bool_t HasMVAPdfs() const
Types::EAnalysisType fAnalysisType
void SetModelPersistence(Bool_t status)
virtual void ReadWeightsFromStream(TFile &)
virtual const std::vector< Float_t > & GetMulticlassValues()
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
TString GetMethodName(Types::EMVA method) const
Bool_t fIgnoreNegWeightsInTraining
Bool_t IsConstructedFromWeightFile() const
virtual void MakeClassSpecificHeader(std::ostream &, const TString &="") const
const TString GetProbaName() const
std::vector< TGraph * > fGraphs
const TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
const TString & GetOriginalVarName(Int_t ivar) const
TString fVariableTransformTypeString
void SetMethodBaseDir(TDirectory *methodDir)
Base class for spline implementation containing the Draw/Paint methods //.
TransformationHandler * fTransformationPointer
Types::ESBType fVariableTransformType
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
A TMultiGraph is a collection of TGraph (or derived) objects.
void InitIPythonInteractive()
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
virtual const std::vector< Float_t > & GetRegressionValues()
tomato 1-D histogram with a float per channel (see TH1 documentation)}
void SetTrainTime(Double_t trainTime)
TMultiGraph * fMultiGraph
const TString & GetInternalVarName(Int_t ivar) const
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
UInt_t GetNTargets() const
std::vector< TString > * fInputVars
const char * GetInputTitle(Int_t i) const
void SetSilentFile(Bool_t status)
Double_t GetMean(Int_t ivar) const
Double_t GetTrainTime() const
const TString & GetInputLabel(Int_t i) const
void SetMethodDir(TDirectory *methodDir)
const TString & GetWeightFileDir() const
const TString & GetInputVar(Int_t i) const
DataSetInfo & fDataSetInfo
#define ClassDef(name, id)
ECutOrientation fCutOrientation
Bool_t TxtWeightsOnly() const
UInt_t GetTrainingTMVAVersionCode() const
const Event * GetEvent() const
void ClearGraphs()
This function sets the point number to 0 for all graphs.
~IPythonInteractive()
standard destructor
std::vector< std::vector< double > > Data
Double_t GetXmin(Int_t ivar) const
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
DataSetInfo & DataInfo() const
Bool_t DoRegression() const
const Event * GetTrainingEvent(Long64_t ievt) const
UInt_t fTMVATrainingVersion
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Double_t GetXmax(Int_t ivar) const
TransformationHandler fTransformation
Bool_t DoMulticlass() const
virtual void MakeClassSpecific(std::ostream &, const TString &="") const
const Event * GetTestingEvent(Long64_t ievt) const
Bool_t HasTrainingTree() const
std::string GetMethodName(TCppMethod_t)
TDirectory * fMethodBaseDir
UInt_t fROOTTrainingVersion
const char * GetName() const
UInt_t GetTrainingROOTVersionCode() const
const TString & GetJobName() const
const TString & GetMethodName() const
TSpline * fSplTrainEffBvsS
tomato 1-D histogram with a double per channel (see TH1 documentation)}
Double_t GetSignalReferenceCutOrientation() const
void SetNormalised(Bool_t norm)
Double_t GetTestTime() const
UInt_t GetNVariables() const
std::vector< const std::vector< TMVA::Event * > * > fEventCollections
TString fVerbosityLevelString
Bool_t IgnoreEventsWithNegWeightsInTraining() const
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
void SetTestTime(Double_t testTime)
Describe directory structure in memory.
std::vector< Float_t > * fMulticlassReturnVal
Bool_t IsNormalised() const
void SetFile(TFile *file)
IPythonInteractive()
standard constructur
Bool_t fConstructedFromWeightFile
TString fVarTransformString
Double_t GetRMS(Int_t ivar) const
This class is needed by JsMVA, and it's a helper class for tracking errors during the training in Jup...
Abstract ClassifierFactory template that handles arbitrary types.
TString GetMethodTypeName() const
Double_t fSignalReferenceCut
the data set information (sometimes needed)
Double_t GetSignalReferenceCut() const
A Graph is a graphics object made of two arrays X and Y with npoints each.
void DisableWriting(Bool_t setter)
ECutOrientation GetCutOrientation() const
std::vector< Float_t > * fRegressionReturnVal
Types::EAnalysisType GetAnalysisType() const
A TTree object has a header with a name and a title.
const TString & GetTestvarName() const
void SetTestvarName(const TString &v="")
TMultiGraph * GetInteractiveTrainingError()
double norm(double *x, double *p)
Types::EMVA GetMethodType() const
void SetBaseDir(TDirectory *methodDir)
virtual void SetAnalysisType(Types::EAnalysisType type)
void SetSignalReferenceCut(Double_t cut)
Double_t fSignalReferenceCutOrientation
Bool_t IsModelPersistence()