33 #ifndef ROOT_TMVA_MethodBase 34 #define ROOT_TMVA_MethodBase 88 namespace Experimental {
96 void Init(std::vector<TString>& graphTitles);
99 void AddPoint(std::vector<Double_t>& dat);
127 const TString& theOption =
"" );
141 virtual void CheckSetup();
153 virtual std::map<TString,Double_t> OptimizeTuningParameters(
TString fomType=
"ROCIntegral",
TString fitType=
"FitGA");
154 virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
156 virtual void Train() = 0;
167 virtual void TestClassification();
171 virtual void TestMulticlass();
182 virtual void Init() = 0;
183 virtual void DeclareOptions() = 0;
184 virtual void ProcessOptions() = 0;
185 virtual void DeclareCompatibilityOptions();
206 virtual std::vector<Double_t> GetMvaValues(
Long64_t firstEvt = 0,
Long64_t lastEvt = -1,
Bool_t logProgress =
false);
213 const std::vector<Float_t>* ptr = &GetRegressionValues();
219 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
225 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
237 virtual const Ranking* CreateRanking() = 0;
240 virtual void MakeClass(
const TString& classFileName =
TString(
"") )
const;
243 void PrintHelpMessage()
const;
249 void WriteStateToFile ()
const;
250 void ReadStateFromFile ();
254 virtual void AddWeightsXMLTo (
void* parent )
const = 0;
255 virtual void ReadWeightsFromXML (
void* wghtnode ) = 0;
256 virtual void ReadWeightsFromStream( std::istream& ) = 0;
262 void WriteStateToXML (
void* parent )
const;
263 void ReadStateFromXML (
void* parent );
264 void WriteStateToStream ( std::ostream& tf )
const;
265 void WriteVarsToStream ( std::ostream& tf,
const TString& prefix =
"" )
const;
269 void ReadStateFromStream ( std::istream& tf );
270 void ReadStateFromStream (
TFile& rf );
271 void ReadStateFromXMLString(
const char* xmlstr );
275 void AddVarsXMLTo (
void* parent )
const;
276 void AddSpectatorsXMLTo (
void* parent )
const;
277 void AddTargetsXMLTo (
void* parent )
const;
278 void AddClassesXMLTo (
void* parent )
const;
279 void ReadVariablesFromXML (
void* varnode );
280 void ReadSpectatorsFromXML(
void* specnode);
281 void ReadTargetsFromXML (
void* tarnode );
282 void ReadClassesFromXML (
void* clsnode );
283 void ReadVarsFromStream ( std::istream& istr );
292 virtual void WriteMonitoringHistosToFile()
const;
306 virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
307 virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
309 virtual Double_t GetSignificance()
const;
313 Double_t& optimal_significance_value )
const;
315 virtual Double_t GetSeparation(
PDF* pdfS = 0,
PDF* pdfB = 0 )
const;
325 const char*
GetName ()
const {
return fMethodName.Data(); }
328 TString GetWeightFileName()
const;
382 TString GetTrainingTMVAVersionString()
const;
383 TString GetTrainingROOTVersionString()
const;
387 if(fTransformationPointer && takeReroutedIfAvailable)
return *fTransformationPointer;
else return fTransformation;
391 if(fTransformationPointer && takeReroutedIfAvailable)
return *fTransformationPointer;
else return fTransformation;
409 const Event* GetEvent ()
const;
415 const std::vector<TMVA::Event*>& GetEventCollection(
Types::ETreeType type );
422 virtual Bool_t IsSignalLike();
438 bool fExitFromTraining =
false;
439 UInt_t fIPyMaxIter = 0, fIPyCurrentIter = 0;
445 if (fInteractive)
delete fInteractive;
454 fExitFromTraining =
true;
459 if (fExitFromTraining && fInteractive){
461 fInteractive =
nullptr;
463 return fExitFromTraining;
479 void SetWeightFileName(
TString );
482 void SetWeightFileDir(
TString fileDir );
536 void DeclareBaseOptions();
537 void ProcessBaseOptions();
546 void ResetThisBase();
551 void CreateMVAPdfs();
558 Bool_t GetLine( std::istream& fin,
char * buf );
737 return GetTransformationHandler().Transform(ev);
743 return GetTransformationHandler().Transform(fTmpEvent);
745 return GetTransformationHandler().Transform(Data()->GetEvent());
750 assert(fTmpEvent==0);
751 return GetTransformationHandler().Transform(Data()->GetEvent(ievt));
756 assert(fTmpEvent==0);
757 return GetTransformationHandler().Transform(Data()->GetEvent(ievt, type));
762 assert(fTmpEvent==0);
768 assert(fTmpEvent==0);
Bool_t HasMVAPdfs() const
Types::EAnalysisType fAnalysisType
void SetModelPersistence(Bool_t status)
virtual void ReadWeightsFromStream(TFile &)
virtual const std::vector< Float_t > & GetMulticlassValues()
TString GetMethodName(Types::EMVA method) const
Bool_t fIgnoreNegWeightsInTraining
Bool_t IsConstructedFromWeightFile() const
virtual void MakeClassSpecificHeader(std::ostream &, const TString &="") const
const TString GetProbaName() const
std::vector< TGraph * > fGraphs
const TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
const TString & GetOriginalVarName(Int_t ivar) const
TString fVariableTransformTypeString
void SetMethodBaseDir(TDirectory *methodDir)
Base class for spline implementation containing the Draw/Paint methods.
TransformationHandler * fTransformationPointer
Types::ESBType fVariableTransformType
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
A TMultiGraph is a collection of TGraph (or derived) objects.
void InitIPythonInteractive()
Virtual base Class for all MVA method.
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
virtual const std::vector< Float_t > & GetRegressionValues()
1-D histogram with a float per channel (see TH1 documentation)}
void SetTrainTime(Double_t trainTime)
const TString & GetInternalVarName(Int_t ivar) const
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Ranking for variables in method (implementation)
UInt_t GetNTargets() const
std::vector< TString > * fInputVars
const char * GetInputTitle(Int_t i) const
void SetSilentFile(Bool_t status)
Double_t GetMean(Int_t ivar) const
Double_t GetTrainTime() const
const TString & GetInputLabel(Int_t i) const
void SetMethodDir(TDirectory *methodDir)
const TString & GetWeightFileDir() const
const TString & GetInputVar(Int_t i) const
DataSetInfo & fDataSetInfo
#define ClassDef(name, id)
ECutOrientation fCutOrientation
Bool_t TxtWeightsOnly() const
UInt_t GetTrainingTMVAVersionCode() const
const Event * GetEvent() const
void Init(TClassEdit::TInterpreterLookupHelper *helper)
Virtual base class for combining several TMVA method.
Double_t GetXmin(Int_t ivar) const
DataSetInfo & DataInfo() const
Bool_t DoRegression() const
Class that contains all the data information.
PDF wrapper for histograms; uses user-defined spline interpolation.
const Event * GetTrainingEvent(Long64_t ievt) const
UInt_t fTMVATrainingVersion
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Class for boosting a TMVA method.
Double_t GetXmax(Int_t ivar) const
TransformationHandler fTransformation
Bool_t DoMulticlass() const
Class that contains all the data information.
virtual void MakeClassSpecific(std::ostream &, const TString &="") const
const Event * GetTestingEvent(Long64_t ievt) const
Bool_t HasTrainingTree() const
std::string GetMethodName(TCppMethod_t)
TDirectory * fMethodBaseDir
UInt_t fROOTTrainingVersion
const char * GetName() const
UInt_t GetTrainingROOTVersionCode() const
const TString & GetJobName() const
const TString & GetMethodName() const
TSpline * fSplTrainEffBvsS
This is the main MVA steering class.
1-D histogram with a double per channel (see TH1 documentation)}
Linear interpolation of TGraph.
Double_t GetSignalReferenceCutOrientation() const
void SetNormalised(Bool_t norm)
Double_t GetTestTime() const
UInt_t GetNVariables() const
std::vector< const std::vector< TMVA::Event * > * > fEventCollections
TString fVerbosityLevelString
Class for categorizing the phase space.
Bool_t IgnoreEventsWithNegWeightsInTraining() const
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
void SetTestTime(Double_t testTime)
Multivariate optimisation of signal efficiency for given background efficiency, applying rectangular ...
Describe directory structure in memory.
Class to perform cross validation, splitting the dataloader into folds.
std::vector< Float_t > * fMulticlassReturnVal
Bool_t IsNormalised() const
void SetFile(TFile *file)
Bool_t fConstructedFromWeightFile
TString fVarTransformString
Interface for all concrete MVA method implementations.
Double_t GetRMS(Int_t ivar) const
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
This class is needed by JsMVA, and it's a helper class for tracking errors during the training in Jup...
Abstract ClassifierFactory template that handles arbitrary types.
TString GetMethodTypeName() const
Class that is the base-class for a vector of result.
Double_t fSignalReferenceCut
the data set information (sometimes needed)
Double_t GetSignalReferenceCut() const
A Graph is a graphics object made of two arrays X and Y with npoints each.
void DisableWriting(Bool_t setter)
ECutOrientation GetCutOrientation() const
std::vector< Float_t > * fRegressionReturnVal
Types::EAnalysisType GetAnalysisType() const
A TTree object has a header with a name and a title.
const TString & GetTestvarName() const
void SetTestvarName(const TString &v="")
TMultiGraph * GetInteractiveTrainingError()
Types::EMVA GetMethodType() const
void SetBaseDir(TDirectory *methodDir)
virtual void SetAnalysisType(Types::EAnalysisType type)
void SetSignalReferenceCut(Double_t cut)
Double_t fSignalReferenceCutOrientation
Bool_t IsModelPersistence()