#include #include #include #include #include "TChain.h" #include "TFile.h" #include "TTree.h" #include "TString.h" #include "TObjString.h" #include "TSystem.h" #include "TROOT.h" #include "TMVA/Factory.h" #include "TMVA/DataLoader.h" #include "TMVA/Tools.h" #include "TMVA/TMVAGui.h" #include "TMVA/CrossValidation.h" TFile * getDataFile(TString fname) { TFile *input(nullptr); if (!gSystem->AccessPathName(fname)) { input = TFile::Open(fname); // check if file in local directory exists } if (!input) { std::cout << "ERROR: could not open data file " << fname << std::endl; exit(1); } return input; } int TMVACrossValidationRegression() { // This loads the library TMVA::Tools::Instance(); // -------------------------------------------------------------------------- // Create a ROOT output file where TMVA will store ntuples, histograms, etc. TString outfileName("TMVARegCv.root"); TFile * outputFile = TFile::Open(outfileName, "RECREATE"); TString infileName = gROOT->GetTutorialDir() + "/machine_learning/data/tmva_reg_example.root"; TFile * inputFile = getDataFile(infileName); TMVA::DataLoader *dataloader=new TMVA::DataLoader("datasetcvreg"); dataloader->AddVariable("var1", "Variable 1", "units", 'F'); dataloader->AddVariable("var2", "Variable 2", "units", 'F'); // Add the variable carrying the regression target dataloader->AddTarget("fvalue"); TTree * regTree = (TTree*)inputFile->Get("TreeR"); dataloader->AddRegressionTree(regTree, 1.0); // Individual events can be weighted // dataloader->SetWeightExpression("weight", "Regression"); std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl; // Bypasses the normal splitting mechanism, CV uses a new system for this. // Unfortunately the old system is unhappy if we leave the test set empty so // we ensure that there is at least one event by placing the first event in // it. // You can with the selection cut place a global cut on the defined // variables. Only events passing the cut will be using in training/testing. // Example: `TCut selectionCut = "var1 < 1";` TCut selectionCut = ""; dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1" ":SplitMode=Block" ":NormMode=NumEvents" ":!V"); // -------------------------------------------------------------------------- // // This sets up a CrossValidation class (which wraps a TMVA::Factory // internally) for 2-fold cross validation. The data will be split into the // two folds randomly if `splitExpr` is `""`. // // One can also give a deterministic split using spectator variables. An // example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`. // UInt_t numFolds = 2; TString analysisType = "Regression"; TString splitExpr = ""; TString cvOptions = Form("!V" ":!Silent" ":ModelPersistence" ":!FoldFileOutput" ":AnalysisType=%s" ":NumFolds=%i" ":SplitExpr=%s", analysisType.Data(), numFolds, splitExpr.Data()); TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions}; // -------------------------------------------------------------------------- // // Books a method to use for evaluation // cv.BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:" "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3"); // -------------------------------------------------------------------------- // // Train, test and evaluate the booked methods. // Evaluates the booked methods once for each fold and aggregates the result // in the specified output file. // cv.Evaluate(); // -------------------------------------------------------------------------- // // Save the output // outputFile->Close(); std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl; std::cout << "==> TMVACrossValidationRegression is done!" << std::endl; // -------------------------------------------------------------------------- // // Launch the GUI for the root macros // if (!gROOT->IsBatch()) { TMVA::TMVAGui(outfileName); } return 0; } // // This is used if the macro is compiled. If run through ROOT with // `root -b -q MACRO.C` or similar it is unused. // int main(int argc, char **argv) { TMVACrossValidationRegression(); }