Logo ROOT  
Reference Guide
CvSplit.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #include "TMVA/CvSplit.h"
13 
14 #include "TMVA/DataSet.h"
15 #include "TMVA/DataSetFactory.h"
16 #include "TMVA/DataSetInfo.h"
17 #include "TMVA/Event.h"
18 #include "TMVA/MsgLogger.h"
19 #include "TMVA/Tools.h"
20 
21 #include <TString.h>
22 #include <TFormula.h>
23 
24 #include <algorithm>
25 #include <numeric>
26 #include <stdexcept>
27 
30 
31 /* =============================================================================
32  TMVA::CvSplit
33 ============================================================================= */
34 
35 ////////////////////////////////////////////////////////////////////////////////
36 ///
37 
38 TMVA::CvSplit::CvSplit(UInt_t numFolds) : fNumFolds(numFolds), fMakeFoldDataSet(kFALSE) {}
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 /// \brief Set training and test set vectors of dataset described by `dsi`.
42 /// \param[in] dsi DataSetInfo for data set to be split
43 /// \param[in] foldNumber Ordinal of fold to prepare
44 /// \param[in] tt The set used to prepare fold. If equal to `Types::kTraining`
45 /// splitting will be based off the original train set. If instead
46 /// equal to `Types::kTesting` the test set will be used.
47 /// The original training/test set is the set as defined by
48 /// `DataLoader::PrepareTrainingAndTestSet`.
49 ///
50 /// Sets the training and test set vectors of the DataSet described by `dsi` as
51 /// defined by the split. If `tt` is eqal to `Types::kTraining` the split will
52 /// be based off of the original training set.
53 ///
54 /// Note: Requires `MakeKFoldDataSet` to have been called first.
55 ///
56 
58 {
59  if (foldNumber >= fNumFolds) {
60  Log() << kFATAL << "DataSet prepared for \"" << fNumFolds << "\" folds, requested fold \"" << foldNumber
61  << "\" is outside of range." << Endl;
62  return;
63  }
64 
65  auto prepareDataSetInternal = [this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
66  UInt_t numFolds = fTrainEvents.size();
67 
68  // Events in training set (excludes current fold)
69  UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
70  [&](UInt_t sum, std::vector<TMVA::Event *> v) { return sum + v.size(); });
71 
72  UInt_t nTrain = nTotal - vec.at(foldNumber).size();
73  UInt_t nTest = vec.at(foldNumber).size();
74 
75  std::vector<Event *> tempTrain;
76  std::vector<Event *> tempTest;
77 
78  tempTrain.reserve(nTrain);
79  tempTest.reserve(nTest);
80 
81  // Insert data into training set
82  for (UInt_t i = 0; i < numFolds; ++i) {
83  if (i == foldNumber) {
84  continue;
85  }
86 
87  tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
88  }
89 
90  // Insert data into test set
91  tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
92 
93  Log() << kDEBUG << "Fold prepared, num events in training set: " << tempTrain.size() << Endl;
94  Log() << kDEBUG << "Fold prepared, num events in test set: " << tempTest.size() << Endl;
95 
96  // Assign the vectors of the events to rebuild the dataset
97  dsi.GetDataSet()->SetEventCollection(&tempTrain, Types::kTraining, false);
98  dsi.GetDataSet()->SetEventCollection(&tempTest, Types::kTesting, false);
99  };
100 
101  if (tt == Types::kTraining) {
102  prepareDataSetInternal(fTrainEvents);
103  } else if (tt == Types::kTesting) {
104  prepareDataSetInternal(fTestEvents);
105  } else {
106  Log() << kFATAL << "PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
107  return;
108  }
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 ///
113 
115 {
116  if (tt != Types::kTraining) {
117  Log() << kFATAL << "Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
118  }
119 
120  std::vector<Event *> *tempVec = new std::vector<Event *>;
121 
122  for (UInt_t i = 0; i < fNumFolds; ++i) {
123  tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
124  }
125 
126  dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTraining, false);
127  dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTesting, false);
128 
129  delete tempVec;
130 }
131 
132 /* =============================================================================
133  TMVA::CvSplitKFoldsExpr
134 ============================================================================= */
135 
136 ////////////////////////////////////////////////////////////////////////////////
137 ///
138 
140  : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<UInt_t>::max()), fSplitFormula("", expr),
141  fParValues(fSplitFormula.GetNpar())
142 {
143  if (!fSplitFormula.IsValid()) {
144  throw std::runtime_error("Split expression \"" + std::string(fSplitExpr.Data()) + "\" is not a valid TFormula.");
145  }
146 
147  for (Int_t iFormulaPar = 0; iFormulaPar < fSplitFormula.GetNpar(); ++iFormulaPar) {
148  TString name = fSplitFormula.GetParName(iFormulaPar);
149 
150  // std::cout << "Found variable with name \"" << name << "\"." << std::endl;
151 
152  if (name == "NumFolds" || name == "numFolds") {
153  // std::cout << "NumFolds|numFolds is a reserved variable! Adding to context." << std::endl;
154  fIdxFormulaParNumFolds = iFormulaPar;
155  } else {
156  fFormulaParIdxToDsiSpecIdx.push_back(std::make_pair(iFormulaPar, GetSpectatorIndexForName(fDsi, name)));
157  }
158  }
159 }
160 
161 ////////////////////////////////////////////////////////////////////////////////
162 ///
163 
165 {
166  for (auto &p : fFormulaParIdxToDsiSpecIdx) {
167  auto iFormulaPar = p.first;
168  auto iSpectator = p.second;
169 
170  fParValues.at(iFormulaPar) = ev->GetSpectator(iSpectator);
171  }
172 
173  if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
174  fParValues[fIdxFormulaParNumFolds] = numFolds;
175  }
176 
177  // NOTE: We are using a double to represent an integer here. This _will_
178  // lead to problems if the norm of the double grows too large. A quick test
179  // with python suggests that problems arise at a magnitude of ~1e16.
180  Double_t iFold_d = fSplitFormula.EvalPar(nullptr, &fParValues[0]);
181 
182  if (iFold_d < 0) {
183  throw std::runtime_error("Output of splitExpr must be non-negative.");
184  }
185 
186  UInt_t iFold = std::lround(iFold_d);
187  if (iFold >= numFolds) {
188  throw std::runtime_error("Output of splitExpr should be a non-negative"
189  "integer between 0 and numFolds-1 inclusive.");
190  }
191 
192  return iFold;
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 ///
197 
199 {
200  return TFormula("", expr).IsValid();
201 }
202 
203 ////////////////////////////////////////////////////////////////////////////////
204 ///
205 
207 {
208  std::vector<VariableInfo> spectatorInfos = dsi.GetSpectatorInfos();
209 
210  for (UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
211  VariableInfo vi = spectatorInfos[iSpectator];
212  if (vi.GetName() == name) {
213  return iSpectator;
214  } else if (vi.GetLabel() == name) {
215  return iSpectator;
216  } else if (vi.GetExpression() == name) {
217  return iSpectator;
218  }
219  }
220 
221  throw std::runtime_error("Spectator \"" + std::string(name.Data()) + "\" not found.");
222 }
223 
224 /* =============================================================================
225  TMVA::CvSplitKFolds
226 ============================================================================= */
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 /// \brief Splits a dataset into k folds, ready for use in cross validation.
230 /// \param[in] numFolds Number of folds to split data into
231 /// \param[in] stratified If true, use stratified splitting, balancing the
232 /// number of events across classes and folds. If false,
233 /// no such balancing is done. For
234 /// \param[in] splitExpr Expression used to split data into folds. If `""` a
235 /// random assignment will be done. Otherwise the
236 /// expression is fed into a TFormula and evaluated per
237 /// event. The resulting value is the the fold assignment.
238 /// \param[in] seed Used only when using random splitting (i.e. when
239 /// `splitExpr` is `""`). Seed is used to initialise the random
240 /// number generator when assigning events to folds.
241 ///
242 
243 TMVA::CvSplitKFolds::CvSplitKFolds(UInt_t numFolds, TString splitExpr, Bool_t stratified, UInt_t seed)
244  : CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
245 {
246  if (!CvSplitKFoldsExpr::Validate(fSplitExprString) && (splitExpr != TString(""))) {
247  Log() << kFATAL << "Split expression \"" << fSplitExprString << "\" is not a valid TFormula." << Endl;
248  }
249 
250 }
251 
252 ////////////////////////////////////////////////////////////////////////////////
253 /// \brief Prepares a DataSet for cross validation
254 
256 {
257  // Validate spectator
258  // fSpectatorIdx = GetSpectatorIndexForName(dsi, fSpectatorName);
259 
260  if (fSplitExprString != TString("")) {
261  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(dsi, fSplitExprString));
262  }
263 
264  // No need to do it again if the sets have already been split.
265  if (fMakeFoldDataSet) {
266  Log() << kINFO << "Splitting in k-folds has been already done" << Endl;
267  return;
268  }
269 
270  fMakeFoldDataSet = kTRUE;
271 
272  UInt_t numClasses = dsi.GetNClasses();
273 
274  // Get the original event vectors for testing and training from the dataset.
275  std::vector<Event *> trainData = dsi.GetDataSet()->GetEventCollection(Types::kTraining);
276  std::vector<Event *> testData = dsi.GetDataSet()->GetEventCollection(Types::kTesting);
277 
278  // Split the sets into the number of folds.
279  fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
280  fTestEvents = SplitSets(testData, fNumFolds, numClasses);
281 }
282 
283 ////////////////////////////////////////////////////////////////////////////////
284 /// \brief Generates a vector of fold assignments
285 /// \param[in] nEntries Number of events in range
286 /// \param[in] numFolds Number of folds to split data into
287 /// \param[in] seed Random seed
288 ///
289 /// Randomly assigns events to `numFolds` folds. Each fold will hold at most
290 /// `nEntries / numFolds + 1` events.
291 ///
292 
293 std::vector<UInt_t> TMVA::CvSplitKFolds::GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed)
294 {
295  // Generate assignment of the pattern `0, 1, 2, 0, 1, 2, 0, 1 ...` for
296  // `numFolds = 3`.
297  std::vector<UInt_t> fOrigToFoldMapping;
298  fOrigToFoldMapping.reserve(nEntries);
299 
300  for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
301  fOrigToFoldMapping.push_back(iEvent % numFolds);
302  }
303 
304  // Shuffle assignment
306  std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
307 
308  return fOrigToFoldMapping;
309 }
310 
311 
312 ////////////////////////////////////////////////////////////////////////////////
313 /// \brief Split sets for into k-folds
314 /// \param[in] oldSet Original, unsplit, events
315 /// \param[in] numFolds Number of folds to split data into
316 ///
317 
318 std::vector<std::vector<TMVA::Event *>>
319 TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses)
320 {
321  const ULong64_t nEntries = oldSet.size();
322  const ULong64_t foldSize = nEntries / numFolds;
323 
324  std::vector<std::vector<Event *>> tempSets;
325  tempSets.reserve(fNumFolds);
326  for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
327  tempSets.emplace_back();
328  tempSets.at(iFold).reserve(foldSize);
329  }
330 
331  Bool_t useSplitExpr = !(fSplitExpr == nullptr || fSplitExprString == "");
332 
333  if (useSplitExpr) {
334  // Deterministic split
335  for (ULong64_t i = 0; i < nEntries; i++) {
336  TMVA::Event *ev = oldSet[i];
337  UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
338  tempSets.at((UInt_t)iFold).push_back(ev);
339  }
340  } else {
341  if(!fStratified){
342  // Random split
343  std::vector<UInt_t> fOrigToFoldMapping;
344  fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
345 
346  for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
347  UInt_t iFold = fOrigToFoldMapping[iEvent];
348  TMVA::Event *ev = oldSet[iEvent];
349  tempSets.at(iFold).push_back(ev);
350 
351  fEventToFoldMapping[ev] = iFold;
352  }
353  } else {
354  // Stratified Split
355  std::vector<std::vector<TMVA::Event *>> oldSets;
356  oldSets.reserve(numClasses);
357 
358  for(UInt_t iClass = 0; iClass < numClasses; iClass++){
359  oldSets.emplace_back();
360  //find a way to get number of events in each class
361  oldSets.reserve(nEntries);
362  }
363 
364  for(UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
365  // check the class of event and add to its vector of events
366  TMVA::Event *ev = oldSet[iEvent];
367  UInt_t iClass = ev->GetClass();
368  oldSets.at(iClass).push_back(ev);
369  }
370 
371  for(UInt_t i = 0; i<numClasses; ++i){
372  // Shuffle each vector individually
374  std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
375  }
376 
377  for(UInt_t i = 0; i<numClasses; ++i) {
378  std::vector<UInt_t> fOrigToFoldMapping;
379  fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
380 
381  for (UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
382  UInt_t iFold = fOrigToFoldMapping[iEvent];
383  TMVA::Event *ev = oldSets.at(i)[iEvent];
384  tempSets.at(iFold).push_back(ev);
385  fEventToFoldMapping[ev] = iFold;
386  }
387  }
388  }
389  }
390  return tempSets;
391 }
TMVA::CvSplitKFoldsExpr::GetSpectatorIndexForName
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition: CvSplit.cxx:206
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
TMVA::Configurable::Log
MsgLogger & Log() const
Definition: Configurable.h:122
tt
auto * tt
Definition: textangle.C:16
TFormula::IsValid
Bool_t IsValid() const
Definition: TFormula.h:236
TMVA::CvSplitKFoldsExpr::fDsi
DataSetInfo & fDsi
Definition: CvSplit.h:77
TFormula
The Formula class.
Definition: TFormula.h:87
TString::Data
const char * Data() const
Definition: TString.h:369
DataSetInfo.h
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
TMVA::DataSetInfo::GetDataSet
DataSet * GetDataSet() const
returns data set
Definition: DataSetInfo.cxx:480
TMVA::Event::GetClass
UInt_t GetClass() const
Definition: Event.h:86
TMVA::CvSplitKFoldsExpr::Validate
static Bool_t Validate(TString expr)
Definition: CvSplit.cxx:198
sum
static uint64_t sum(uint64_t i)
Definition: Factory.cxx:2345
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMVA::Types::kTesting
@ kTesting
Definition: Types.h:146
TMVA::RandomGenerator
Definition: Tools.h:305
TMVA::CvSplitKFoldsExpr::CvSplitKFoldsExpr
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition: CvSplit.cxx:139
TMVA::DataSetInfo::GetSpectatorInfos
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
CvSplit.h
TMVA::CvSplit::CvSplit
CvSplit(UInt_t numFolds)
Definition: CvSplit.cxx:38
TString
Basic string class.
Definition: TString.h:136
TMVA::CvSplitKFoldsExpr::fIdxFormulaParNumFolds
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
Definition: CvSplit.h:81
TMVA::CvSplitKFolds::fSplitExprString
TString fSplitExprString
Definition: CvSplit.h:108
v
@ v
Definition: rootcling_impl.cxx:3635
TString.h
TMVA::CvSplit::RecombineKFoldDataSet
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition: CvSplit.cxx:114
bool
TMVA::CvSplitKFoldsExpr::fSplitExpr
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition: CvSplit.h:82
TMVA::DataSet::SetEventCollection
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition: DataSet.cxx:250
TMVA::DataSetInfo::GetNClasses
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
TMVA::CvSplitKFoldsExpr
Definition: CvSplit.h:64
TMVA::CvSplitKFolds
Definition: CvSplit.h:92
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::CvSplitKFolds::CvSplitKFolds
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition: CvSplit.cxx:243
TMVA::CvSplit::PrepareFoldDataSet
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition: CvSplit.cxx:57
TFormula::GetParName
const char * GetParName(Int_t ipar) const
MsgLogger.h
TMVA::DataSet::GetEventCollection
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:216
TMVA::Types::ETreeType
ETreeType
Definition: Types.h:144
TMVA::CvSplitKFoldsExpr::Eval
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition: CvSplit.cxx:164
TMVA::CvSplitKFolds::GetEventIndexToFoldMapping
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition: CvSplit.cxx:293
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TMVA::VariableInfo
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
TString::TString
TString()
TString default ctor.
Definition: TString.cxx:87
Event.h
TMVA::CvSplit
Definition: CvSplit.h:37
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
unsigned int
TMVA::Types::kTraining
@ kTraining
Definition: Types.h:145
DataSetFactory.h
ULong64_t
unsigned long long ULong64_t
Definition: RtypesCore.h:74
TFormula::GetNpar
Int_t GetNpar() const
Definition: TFormula.h:225
TMVA::Event::GetSpectator
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:261
TMVA::CvSplitKFolds::SplitSets
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
Definition: CvSplit.cxx:319
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::Event
Definition: Event.h:51
TMVA::VariableInfo::GetLabel
const TString & GetLabel() const
Definition: VariableInfo.h:59
name
char name[80]
Definition: TGX11.cxx:110
TMVA::VariableInfo::GetExpression
const TString & GetExpression() const
Definition: VariableInfo.h:57
TMVA::CvSplitKFolds::MakeKFoldDataSet
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition: CvSplit.cxx:255
Tools.h
TNamed::GetName
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
TMVA::CvSplitKFoldsExpr::fSplitFormula
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:83
DataSet.h
TMVA::CvSplitKFoldsExpr::fFormulaParIdxToDsiSpecIdx
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition: CvSplit.h:80
int