44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
80 fSigs[iFold] = fr.
fSig;
81 fSeps[iFold] = fr.
fSep;
82 fEff01s[iFold] = fr.
fEff01;
83 fEff10s[iFold] = fr.
fEff10;
84 fEff30s[iFold] = fr.
fEff30;
94 return fROCCurves.get();
110 Double_t increment = 1.0 / (numSamples-1);
111 std::vector<Double_t>
x(numSamples),
y(numSamples);
113 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
115 for(
UInt_t iSample = 0; iSample < numSamples; iSample++) {
116 Double_t xPoint = iSample * increment;
119 for(
Int_t iGraph = 0; iGraph < rocCurveList->
GetSize(); iGraph++) {
120 TGraph *foldROC =
static_cast<TGraph *
>(rocCurveList->
At(iGraph));
121 rocSum += foldROC->
Eval(xPoint);
125 y[iSample] = rocSum/rocCurveList->
GetSize();
128 return new TGraph(numSamples, &
x[0], &
y[0]);
135 for(
auto &roc : fROCs) {
138 return avg/fROCs.size();
147 for(
auto &roc : fROCs) {
160 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
161 for(
auto &item:fROCs) {
162 fLogger << kINFO <<
Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
165 fLogger << kINFO <<
"------------------------" <<
Endl;
166 fLogger << kINFO <<
Form(
"Average ROC-Int : %.4f",GetROCAverage()) <<
Endl;
167 fLogger << kINFO <<
Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) <<
Endl;
176 fROCCurves->Draw(
"AL");
177 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
178 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
179 Float_t adjust=1+fROCs.size()*0.01;
180 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181 c->SetTitle(
"Cross Validation ROC Curves");
195 for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
196 TGraph * foldRocGraph =
dynamic_cast<TGraph *
>(foldRocObj->Clone());
199 rocs->
Add(foldRocGraph);
204 TGraph *avgRocGraph = GetAvgROCCurve(100);
205 avgRocGraph->
SetTitle(
"Avg ROC Curve");
208 rocs->
Add(avgRocGraph);
214 title =
"Cross Validation Average ROC Curve";
229 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(nCurves-1)),
230 "Avg ROC Curve",
"l");
231 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(0)),
232 "Fold ROC Curves",
"l");
239 c->SetTitle(
"Cross Validation Average ROC Curve");
281 :
TMVA::
Envelope(jobName, dataloader, nullptr, options),
282 fAnalysisType(
Types::kMaxAnalysisType),
283 fAnalysisTypeStr(
"Auto"),
284 fSplitTypeStr(
"Random"),
286 fCvFactoryOptions(
""),
293 fOutputFactoryOptions(
""),
294 fOutputFile(outputFile),
296 fSplitExprString(
""),
298 fTransformations(
""),
326 DeclareOptionRef(fSilent,
"Silent",
327 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
328 "class object (default: False)");
329 DeclareOptionRef(fVerbose,
"V",
"Verbose flag");
330 DeclareOptionRef(fVerboseLevel =
TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)");
331 AddPreDefVal(
TString(
"Debug"));
332 AddPreDefVal(
TString(
"Verbose"));
335 DeclareOptionRef(fTransformations,
"Transformations",
336 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
337 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
340 DeclareOptionRef(fDrawProgressBar,
"DrawProgressBar",
"Boolean to show draw progress bar");
341 DeclareOptionRef(fCorrelations,
"Correlations",
"Boolean to show correlation in output");
342 DeclareOptionRef(fROC,
"ROC",
"Boolean to show ROC in output");
345 DeclareOptionRef(fAnalysisTypeStr,
"AnalysisType",
346 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
347 AddPreDefVal(
TString(
"Classification"));
348 AddPreDefVal(
TString(
"Regression"));
349 AddPreDefVal(
TString(
"Multiclass"));
353 DeclareOptionRef(fSplitTypeStr,
"SplitType",
354 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
355 AddPreDefVal(
TString(
"Deterministic"));
356 AddPreDefVal(
TString(
"Random"));
357 AddPreDefVal(
TString(
"RandomStratified"));
359 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
360 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
361 DeclareOptionRef(fNumWorkerProcs,
"NumWorkerProcs",
362 "Determines how many processes to use for evaluation. 1 means no"
363 " parallelisation. 2 means use 2 processes. 0 means figure out the"
364 " number automatically based on the number of cpus available. Default"
367 DeclareOptionRef(fFoldFileOutput,
"FoldFileOutput",
368 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
369 "specifed for the combined output with a _foldX suffix. (default: false)");
371 DeclareOptionRef(fOutputEnsembling =
TString(
"None"),
"OutputEnsembling",
372 "Combines output from contained methods. If None, no combination is performed. (default None)");
384 if (fSplitTypeStr !=
"Deterministic" && fSplitExprString !=
"") {
385 Log() << kFATAL <<
"SplitExpr can only be used with Deterministic Splitting" <<
Endl;
389 fAnalysisTypeStr.ToLower();
390 if (fAnalysisTypeStr ==
"classification") {
392 }
else if (fAnalysisTypeStr ==
"regression") {
394 }
else if (fAnalysisTypeStr ==
"multiclass") {
396 }
else if (fAnalysisTypeStr ==
"auto") {
401 fCvFactoryOptions +=
"V:";
402 fOutputFactoryOptions +=
"V:";
404 fCvFactoryOptions +=
"!V:";
405 fOutputFactoryOptions +=
"!V:";
408 fCvFactoryOptions +=
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
409 fOutputFactoryOptions +=
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
411 fCvFactoryOptions +=
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
412 fOutputFactoryOptions +=
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
414 if (!fDrawProgressBar) {
415 fCvFactoryOptions +=
"!DrawProgressBar:";
416 fOutputFactoryOptions +=
"!DrawProgressBar:";
419 if (fTransformations !=
"") {
420 fCvFactoryOptions +=
Form(
"Transformations=%s:", fTransformations.Data());
421 fOutputFactoryOptions +=
Form(
"Transformations=%s:", fTransformations.Data());
425 fCvFactoryOptions +=
"Correlations:";
426 fOutputFactoryOptions +=
"Correlations:";
428 fCvFactoryOptions +=
"!Correlations:";
429 fOutputFactoryOptions +=
"!Correlations:";
433 fCvFactoryOptions +=
"ROC:";
434 fOutputFactoryOptions +=
"ROC:";
436 fCvFactoryOptions +=
"!ROC:";
437 fOutputFactoryOptions +=
"!ROC:";
441 fCvFactoryOptions +=
Form(
"Silent:");
442 fOutputFactoryOptions +=
Form(
"Silent:");
446 if (fFoldFileOutput && fOutputFile ==
nullptr) {
447 Log() << kFATAL <<
"No output file given, cannot generate per fold output." <<
Endl;
452 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
457 if (fOutputFile ==
nullptr) {
458 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
460 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
463 if(fSplitTypeStr ==
"Random"){
464 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kFALSE));
465 }
else if(fSplitTypeStr ==
"RandomStratified"){
466 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kTRUE));
468 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString));
476 if (i != fNumFolds) {
478 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
479 fDataLoader->MakeKFoldDataSet(*fSplit);
489 if (splitExpr != fSplitExprString) {
490 fSplitExprString = splitExpr;
491 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
492 fDataLoader->MakeKFoldDataSet(*fSplit);
514 Log() << kDEBUG <<
"Processing " << methodTitle <<
" fold " << iFold <<
Endl;
517 TFile *foldOutputFile =
nullptr;
519 if (fFoldFileOutput && fOutputFile !=
nullptr) {
522 Log() << kINFO <<
"Creating fold output at:" << path <<
Endl;
523 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
527 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
534 fFoldFactory->TestAllMethods();
535 fFoldFactory->EvaluateAllMethods();
541 result.
fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
543 TGraph *
gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle,
true);
567 if (fFoldFileOutput && foldOutputFile !=
nullptr) {
568 foldOutputFile->
Close();
577 fFoldFactory->DeleteAllMethods();
578 fFoldFactory->fMethodsMap.clear();
592 fDataLoader->MakeKFoldDataSet(*fSplit);
596 fResults.reserve(fMethods.size());
597 for (
auto & methodInfo : fMethods) {
600 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
601 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
603 if (methodTypeName ==
"") {
604 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
608 Log() << kINFO <<
Endl;
609 Log() << kINFO <<
Endl;
610 Log() << kINFO <<
"========================================" <<
Endl;
611 Log() << kINFO <<
"Processing folds for method " << methodTitle <<
Endl;
612 Log() << kINFO <<
"========================================" <<
Endl;
613 Log() << kINFO <<
Endl;
616 auto nWorkers = fNumWorkerProcs;
622 for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
623 auto fold_result = ProcessFold(iFold, methodInfo);
624 result.Fill(fold_result);
629 std::vector<CrossValidationFoldResult> result_vector;
631 auto workItem = [
this, methodInfo](
UInt_t iFold) {
632 return ProcessFold(iFold, methodInfo);
637 for (
auto && fold_result : result_vector) {
638 result.Fill(fold_result);
643 fResults.push_back(result);
647 Form(
"SplitExpr=%s:NumFolds=%i"
648 ":EncapsulatedMethodName=%s"
649 ":EncapsulatedMethodTypeName=%s"
650 ":OutputEnsembling=%s",
651 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.
Data(), fOutputEnsembling.Data());
657 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
663 Log() << kINFO <<
Endl;
664 Log() << kINFO <<
Endl;
665 Log() << kINFO <<
"========================================" <<
Endl;
666 Log() << kINFO <<
"Folds processed for all methods, evaluating." <<
Endl;
667 Log() << kINFO <<
"========================================" <<
Endl;
668 Log() << kINFO <<
Endl;
671 fDataLoader->RecombineKFoldDataSet(*fSplit);
674 for (
auto & methodInfo : fMethods) {
675 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
676 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
678 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
681 if (fOutputFile !=
nullptr) {
682 fFactory->WriteDataInformation(method->fDataSetInfo);
686 method->TrainMethod();
691 fFactory->TestAllMethods();
694 fFactory->EvaluateAllMethods();
696 Log() << kINFO <<
"Evaluation done." <<
Endl;
702 if (fResults.empty()) {
703 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute a function without arguments several times.
This class provides a simple interface to execute the same task multiple times in parallel,...
A pseudo container class which is a generator of indices.
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
void Close(Option_t *option="") override
Close a file.
A TGraph is an object made of two arrays X and Y with npoints each.
virtual void SetTitle(const char *title="")
Change (i.e.
virtual Double_t Eval(Double_t x, TSpline *spline=nullptr, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
This class displays a legend box (TPaveText) containing several legend entries.
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
UInt_t GetNumWorkers() const
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
Float_t GetROCAverage() const
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
virtual void ParseOptions()
Method to parse the internal option string.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
static void EnableOutput()
class to storage options for the differents methods
T GetValue(const TString &key)
Singleton class for Global types used by TMVA.
A TMultiGraph is a collection of TGraph (or derived) objects.
TList * GetListOfGraphs() const
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
virtual void SetName(const char *name)
Set the name of the TNamed.
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad for instance with: gROOT->SetSelectedPad(gPad...
const char * Data() const
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
LongDouble_t Power(LongDouble_t x, LongDouble_t y)