Logo ROOT  
Reference Guide
CvSplit.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_CvSplit
13 #define ROOT_TMVA_CvSplit
14 
15 #include "TMVA/Configurable.h"
16 #include "TMVA/Types.h"
17 
18 #include <Rtypes.h>
19 #include <TFormula.h>
20 
21 #include <memory>
22 #include <vector>
23 #include <map>
24 
25 class TString;
26 
27 namespace TMVA {
28 
29 class CrossValidation;
30 class DataSetInfo;
31 class Event;
32 
33 /* =============================================================================
34  TMVA::CvSplit
35 ============================================================================= */
36 
37 class CvSplit : public Configurable {
38 public:
39  CvSplit(UInt_t numFolds);
40  virtual ~CvSplit() {}
41 
42  virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
43  virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
45 
48 
49 protected:
52 
53  std::vector<std::vector<TMVA::Event *>> fTrainEvents;
54  std::vector<std::vector<TMVA::Event *>> fTestEvents;
55 
56 protected:
57  ClassDef(CvSplit, 0);
58 };
59 
60 /* =============================================================================
61  TMVA::CvSplitKFoldsExpr
62 ============================================================================= */
63 
65 public:
68 
69  UInt_t Eval(UInt_t numFolds, const Event *ev);
70 
71  static Bool_t Validate(TString expr);
72 
73 private:
75 
76 private:
78 
79  std::vector<std::pair<Int_t, Int_t>>
80  fFormulaParIdxToDsiSpecIdx; //! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
81  Int_t fIdxFormulaParNumFolds; //! Keeps track of the index of reserved par "NumFolds" in splitExpr.
82  TString fSplitExpr; //! Expression used to split data into folds. Should output values between 0 and numFolds.
83  TFormula fSplitFormula; //! TFormula for splitExpr.
84 
85  std::vector<Double_t> fParValues;
86 };
87 
88 /* =============================================================================
89  TMVA::CvSplitKFolds
90 ============================================================================= */
91 
92 class CvSplitKFolds : public CvSplit {
93 
95 
96 public:
97  CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
98  ~CvSplitKFolds() override {}
99 
100  void MakeKFoldDataSet(DataSetInfo &dsi) override;
101 
102 private:
103  std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
104  std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
105 
106 private:
108  TString fSplitExprString; //! Expression used to split data into folds. Should output values between 0 and numFolds.
109  std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
110  Bool_t fStratified; // If true, use stratified split. (Balance class presence in each fold).
111 
112  // Used for CrossValidation with random splits (not using the
113  // CVSplitKFoldsExpr functionality) to communicate Event to fold mapping.
114  std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
115 
116 private:
118 };
119 
120 } // end namespace TMVA
121 
122 #endif
TMVA::CvSplitKFoldsExpr::GetSpectatorIndexForName
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition: CvSplit.cxx:206
TMVA::CvSplitKFolds::fEventToFoldMapping
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
Definition: CvSplit.h:114
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
tt
auto * tt
Definition: textangle.C:16
TMVA::Configurable
Definition: Configurable.h:45
TMVA::CvSplitKFolds::fSeed
UInt_t fSeed
Definition: CvSplit.h:107
TMVA::CvSplitKFoldsExpr::fDsi
DataSetInfo & fDsi
Definition: CvSplit.h:77
TFormula
The Formula class.
Definition: TFormula.h:87
TMVA::CvSplitKFoldsExpr::Validate
static Bool_t Validate(TString expr)
Definition: CvSplit.cxx:198
TMVA::CvSplitKFoldsExpr::CvSplitKFoldsExpr
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition: CvSplit.cxx:139
TMVA::CvSplit::fMakeFoldDataSet
Bool_t fMakeFoldDataSet
Definition: CvSplit.h:51
TMVA::CvSplit::CvSplit
CvSplit(UInt_t numFolds)
Definition: CvSplit.cxx:38
TString
Basic string class.
Definition: TString.h:136
Bool_t
bool Bool_t
Definition: RtypesCore.h:63
TMVA::CvSplitKFoldsExpr::fIdxFormulaParNumFolds
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
Definition: CvSplit.h:81
TMVA::CvSplitKFolds::fSplitExprString
TString fSplitExprString
Definition: CvSplit.h:108
TMVA::CvSplit::RecombineKFoldDataSet
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition: CvSplit.cxx:114
bool
TMVA::CvSplitKFoldsExpr::fSplitExpr
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition: CvSplit.h:82
TMVA::CvSplit::~CvSplit
virtual ~CvSplit()
Definition: CvSplit.h:40
TMVA::CvSplitKFoldsExpr::~CvSplitKFoldsExpr
~CvSplitKFoldsExpr()
Definition: CvSplit.h:67
TMVA::CvSplitKFoldsExpr
Definition: CvSplit.h:64
TMVA::CvSplitKFolds
Definition: CvSplit.h:92
TMVA::CvSplitKFolds::~CvSplitKFolds
~CvSplitKFolds() override
Definition: CvSplit.h:98
TMVA::CvSplit::MakeKFoldDataSet
virtual void MakeKFoldDataSet(DataSetInfo &dsi)=0
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::CvSplitKFolds::CvSplitKFolds
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition: CvSplit.cxx:243
TMVA::CvSplit::PrepareFoldDataSet
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition: CvSplit.cxx:57
TMVA::CvSplitKFolds::fStratified
Bool_t fStratified
Definition: CvSplit.h:110
TMVA::CvSplit::GetNumFolds
UInt_t GetNumFolds()
Definition: CvSplit.h:46
TMVA::Types::ETreeType
ETreeType
Definition: Types.h:144
TMVA::CvSplitKFoldsExpr::Eval
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition: CvSplit.cxx:164
TMVA::CvSplitKFolds::GetEventIndexToFoldMapping
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition: CvSplit.cxx:293
UInt_t
unsigned int UInt_t
Definition: RtypesCore.h:46
TMVA::CvSplit
Definition: CvSplit.h:37
Types.h
Configurable.h
unsigned int
TMVA::Types::kTraining
@ kTraining
Definition: Types.h:145
TMVA::CvSplitKFolds::fSplitExpr
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:109
TMVA::CvSplit::NeedsRebuild
Bool_t NeedsRebuild()
Definition: CvSplit.h:47
TMVA::CvSplitKFolds::SplitSets
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
Definition: CvSplit.cxx:319
TMVA::CvSplitKFolds::CrossValidation
friend CrossValidation
Definition: CvSplit.h:94
TMVA::CvSplit::fTestEvents
std::vector< std::vector< TMVA::Event * > > fTestEvents
Definition: CvSplit.h:54
TMVA::CvSplit::fTrainEvents
std::vector< std::vector< TMVA::Event * > > fTrainEvents
Definition: CvSplit.h:53
TMVA::Event
Definition: Event.h:51
ClassDef
#define ClassDef(name, id)
Definition: Rtypes.h:325
name
char name[80]
Definition: TGX11.cxx:110
TMVA::CvSplitKFolds::MakeKFoldDataSet
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition: CvSplit.cxx:255
TMVA::CvSplitKFoldsExpr::fSplitFormula
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:83
TMVA::CvSplitKFoldsExpr::fParValues
std::vector< Double_t > fParValues
TFormula for splitExpr.
Definition: CvSplit.h:85
Rtypes.h
TMVA::CvSplitKFoldsExpr::fFormulaParIdxToDsiSpecIdx
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition: CvSplit.h:80
TMVA::CvSplit::fNumFolds
UInt_t fNumFolds
Definition: CvSplit.h:50
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int
TMVA::CvSplitKFolds::ClassDefOverride
ClassDefOverride(CvSplitKFolds, 0)