12#ifndef ROOT_TMVA_CROSS_EVALUATION
13#define ROOT_TMVA_CROSS_EVALUATION
46using EventCollection_t = std::vector<Event *>;
47using EventTypes_t = std::vector<Bool_t>;
48using EventOutputs_t = std::vector<Float_t>;
49using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
80 std::map<UInt_t, Float_t>
fROCs;
140 const std::vector<CrossValidationResult> &
GetResults()
const;
#define ClassDef(name, id)
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
A Graph is a graphics object made of two arrays X and Y with npoints each.
CrossValidationFoldResult(UInt_t iFold)
CrossValidationFoldResult()
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > GetTrainEff10Values() const
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetTrainEff30Values() const
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > GetEff10Values() const
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > GetTrainEff01Values() const
Float_t GetROCAverage() const
std::vector< Double_t > fEffAreas
std::vector< Double_t > GetEff01Values() const
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
std::vector< Double_t > GetSigValues() const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
std::map< UInt_t, Float_t > GetROCValues() const
std::vector< Double_t > GetEffAreaValues() const
TCanvas * Draw(const TString name="CrossValidation") const
std::vector< Double_t > GetSepValues() const
std::vector< Double_t > GetEff30Values() const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
std::vector< CrossValidationResult > fResults
std::unique_ptr< Factory > fFoldFactory
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
Bool_t fFoldStatus
If true: generate output file for each fold.
std::unique_ptr< CvSplitKFolds > fSplit
TFile * fOutputFile
How to combine output of individual folds.
Types::EAnalysisType fAnalysisType
TString fCvFactoryOptions
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
TString fOutputFactoryOptions
Number of processes to use for fold evaluation.
std::unique_ptr< Factory > fFactory
TString fOutputEnsembling
UInt_t fNumWorkerProcs
Number of folds to prepare.
TString fJobName
If true: dataset is prepared.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
This is the main MVA steering class.
class to storage options for the differents methods
A TMultiGraph is a collection of TGraph (or derived) objects.
Abstract ClassifierFactory template that handles arbitrary types.