Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CvSplit.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#include "TMVA/CvSplit.h"
13
14#include "TMVA/DataSet.h"
15#include "TMVA/DataSetFactory.h"
16#include "TMVA/DataSetInfo.h"
17#include "TMVA/Event.h"
18#include "TMVA/MsgLogger.h"
19#include "TMVA/Tools.h"
20
21#include <TString.h>
22#include <TFormula.h>
23
24#include <algorithm>
25#include <numeric>
26#include <stdexcept>
27
30
31/* =============================================================================
32 TMVA::CvSplit
33============================================================================= */
34
35////////////////////////////////////////////////////////////////////////////////
36///
37
38TMVA::CvSplit::CvSplit(UInt_t numFolds) : fNumFolds(numFolds), fMakeFoldDataSet(kFALSE) {}
39
40////////////////////////////////////////////////////////////////////////////////
41/// \brief Set training and test set vectors of dataset described by `dsi`.
42/// \param[in] dsi DataSetInfo for data set to be split
43/// \param[in] foldNumber Ordinal of fold to prepare
44/// \param[in] tt The set used to prepare fold. If equal to `Types::kTraining`
45/// splitting will be based off the original train set. If instead
46/// equal to `Types::kTesting` the test set will be used.
47/// The original training/test set is the set as defined by
48/// `DataLoader::PrepareTrainingAndTestSet`.
49///
50/// Sets the training and test set vectors of the DataSet described by `dsi` as
51/// defined by the split. If `tt` is eqal to `Types::kTraining` the split will
52/// be based off of the original training set.
53///
54/// Note: Requires `MakeKFoldDataSet` to have been called first.
55///
56
58{
59 if (foldNumber >= fNumFolds) {
60 Log() << kFATAL << "DataSet prepared for \"" << fNumFolds << "\" folds, requested fold \"" << foldNumber
61 << "\" is outside of range." << Endl;
62 return;
63 }
64
65 auto prepareDataSetInternal = [this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
66 UInt_t numFolds = fTrainEvents.size();
67
68 // Events in training set (excludes current fold)
69 UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
70 [&](UInt_t sum, std::vector<TMVA::Event *> v) { return sum + v.size(); });
71
72 UInt_t nTrain = nTotal - vec.at(foldNumber).size();
73 UInt_t nTest = vec.at(foldNumber).size();
74
75 std::vector<Event *> tempTrain;
76 std::vector<Event *> tempTest;
77
78 tempTrain.reserve(nTrain);
79 tempTest.reserve(nTest);
80
81 // Insert data into training set
82 for (UInt_t i = 0; i < numFolds; ++i) {
83 if (i == foldNumber) {
84 continue;
85 }
86
87 tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
88 }
89
90 // Insert data into test set
91 tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
92
93 Log() << kDEBUG << "Fold prepared, num events in training set: " << tempTrain.size() << Endl;
94 Log() << kDEBUG << "Fold prepared, num events in test set: " << tempTest.size() << Endl;
95
96 // Assign the vectors of the events to rebuild the dataset
97 dsi.GetDataSet()->SetEventCollection(&tempTrain, Types::kTraining, false);
98 dsi.GetDataSet()->SetEventCollection(&tempTest, Types::kTesting, false);
99 };
100
101 if (tt == Types::kTraining) {
102 prepareDataSetInternal(fTrainEvents);
103 } else if (tt == Types::kTesting) {
104 prepareDataSetInternal(fTestEvents);
105 } else {
106 Log() << kFATAL << "PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
107 return;
108 }
109}
110
111////////////////////////////////////////////////////////////////////////////////
112///
113
115{
116 if (tt != Types::kTraining) {
117 Log() << kFATAL << "Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
118 }
119
120 std::vector<Event *> *tempVec = new std::vector<Event *>;
121
122 for (UInt_t i = 0; i < fNumFolds; ++i) {
123 tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
124 }
125
126 dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTraining, false);
127 dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTesting, false);
128
129 delete tempVec;
130}
131
132/* =============================================================================
133 TMVA::CvSplitKFoldsExpr
134============================================================================= */
135
136////////////////////////////////////////////////////////////////////////////////
137///
138
140 : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<UInt_t>::max()), fSplitFormula("", expr),
141 fParValues(fSplitFormula.GetNpar())
142{
143 if (!fSplitFormula.IsValid()) {
144 throw std::runtime_error("Split expression \"" + std::string(fSplitExpr.Data()) + "\" is not a valid TFormula.");
145 }
146
147 for (Int_t iFormulaPar = 0; iFormulaPar < fSplitFormula.GetNpar(); ++iFormulaPar) {
148 TString name = fSplitFormula.GetParName(iFormulaPar);
149
150 // std::cout << "Found variable with name \"" << name << "\"." << std::endl;
151
152 if (name == "NumFolds" || name == "numFolds") {
153 // std::cout << "NumFolds|numFolds is a reserved variable! Adding to context." << std::endl;
154 fIdxFormulaParNumFolds = iFormulaPar;
155 } else {
156 fFormulaParIdxToDsiSpecIdx.push_back(std::make_pair(iFormulaPar, GetSpectatorIndexForName(fDsi, name)));
157 }
158 }
159}
160
161////////////////////////////////////////////////////////////////////////////////
162///
163
165{
166 for (auto &p : fFormulaParIdxToDsiSpecIdx) {
167 auto iFormulaPar = p.first;
168 auto iSpectator = p.second;
169
170 fParValues.at(iFormulaPar) = ev->GetSpectator(iSpectator);
171 }
172
173 if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
174 fParValues[fIdxFormulaParNumFolds] = numFolds;
175 }
176
177 // NOTE: We are using a double to represent an integer here. This _will_
178 // lead to problems if the norm of the double grows too large. A quick test
179 // with python suggests that problems arise at a magnitude of ~1e16.
180 Double_t iFold_d = fSplitFormula.EvalPar(nullptr, &fParValues[0]);
181
182 if (iFold_d < 0) {
183 throw std::runtime_error("Output of splitExpr must be non-negative.");
184 }
185
186 UInt_t iFold = std::lround(iFold_d);
187 if (iFold >= numFolds) {
188 throw std::runtime_error("Output of splitExpr should be a non-negative"
189 "integer between 0 and numFolds-1 inclusive.");
190 }
191
192 return iFold;
193}
194
195////////////////////////////////////////////////////////////////////////////////
196///
197
199{
200 return TFormula("", expr).IsValid();
201}
202
203////////////////////////////////////////////////////////////////////////////////
204///
205
207{
208 std::vector<VariableInfo> spectatorInfos = dsi.GetSpectatorInfos();
209
210 for (UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
211 VariableInfo vi = spectatorInfos[iSpectator];
212 if (vi.GetName() == name) {
213 return iSpectator;
214 } else if (vi.GetLabel() == name) {
215 return iSpectator;
216 } else if (vi.GetExpression() == name) {
217 return iSpectator;
218 }
219 }
220
221 throw std::runtime_error("Spectator \"" + std::string(name.Data()) + "\" not found.");
222}
223
224/* =============================================================================
225 TMVA::CvSplitKFolds
226============================================================================= */
227
228////////////////////////////////////////////////////////////////////////////////
229/// \brief Splits a dataset into k folds, ready for use in cross validation.
230/// \param[in] numFolds Number of folds to split data into
231/// \param[in] stratified If true, use stratified splitting, balancing the
232/// number of events across classes and folds. If false,
233/// no such balancing is done. For
234/// \param[in] splitExpr Expression used to split data into folds. If `""` a
235/// random assignment will be done. Otherwise the
236/// expression is fed into a TFormula and evaluated per
237/// event. The resulting value is the the fold assignment.
238/// \param[in] seed Used only when using random splitting (i.e. when
239/// `splitExpr` is `""`). Seed is used to initialise the random
240/// number generator when assigning events to folds.
241///
242
243TMVA::CvSplitKFolds::CvSplitKFolds(UInt_t numFolds, TString splitExpr, Bool_t stratified, UInt_t seed)
244 : CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
245{
246 if (!CvSplitKFoldsExpr::Validate(fSplitExprString) && (splitExpr != TString(""))) {
247 Log() << kFATAL << "Split expression \"" << fSplitExprString << "\" is not a valid TFormula." << Endl;
248 }
249
250}
251
252////////////////////////////////////////////////////////////////////////////////
253/// \brief Prepares a DataSet for cross validation
254
256{
257 // Validate spectator
258 // fSpectatorIdx = GetSpectatorIndexForName(dsi, fSpectatorName);
259
260 if (fSplitExprString != TString("")) {
261 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(dsi, fSplitExprString));
262 }
263
264 // No need to do it again if the sets have already been split.
265 if (fMakeFoldDataSet) {
266 Log() << kINFO << "Splitting in k-folds has been already done" << Endl;
267 return;
268 }
269
270 fMakeFoldDataSet = kTRUE;
271
272 UInt_t numClasses = dsi.GetNClasses();
273
274 // Get the original event vectors for testing and training from the dataset.
275 std::vector<Event *> trainData = dsi.GetDataSet()->GetEventCollection(Types::kTraining);
276 std::vector<Event *> testData = dsi.GetDataSet()->GetEventCollection(Types::kTesting);
277
278 // Split the sets into the number of folds.
279 fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
280 fTestEvents = SplitSets(testData, fNumFolds, numClasses);
281}
282
283////////////////////////////////////////////////////////////////////////////////
284/// \brief Generates a vector of fold assignments
285/// \param[in] nEntries Number of events in range
286/// \param[in] numFolds Number of folds to split data into
287/// \param[in] seed Random seed
288///
289/// Randomly assigns events to `numFolds` folds. Each fold will hold at most
290/// `nEntries / numFolds + 1` events.
291///
292
293std::vector<UInt_t> TMVA::CvSplitKFolds::GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed)
294{
295 // Generate assignment of the pattern `0, 1, 2, 0, 1, 2, 0, 1 ...` for
296 // `numFolds = 3`.
297 std::vector<UInt_t> fOrigToFoldMapping;
298 fOrigToFoldMapping.reserve(nEntries);
299
300 for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
301 fOrigToFoldMapping.push_back(iEvent % numFolds);
302 }
303
304 // Shuffle assignment
306 std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
307
308 return fOrigToFoldMapping;
309}
310
311
312////////////////////////////////////////////////////////////////////////////////
313/// \brief Split sets for into k-folds
314/// \param[in] oldSet Original, unsplit, events
315/// \param[in] numFolds Number of folds to split data into
316///
317
318std::vector<std::vector<TMVA::Event *>>
319TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses)
320{
321 const ULong64_t nEntries = oldSet.size();
322 const ULong64_t foldSize = nEntries / numFolds;
323
324 std::vector<std::vector<Event *>> tempSets;
325 tempSets.reserve(fNumFolds);
326 for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
327 tempSets.emplace_back();
328 tempSets.at(iFold).reserve(foldSize);
329 }
330
331 Bool_t useSplitExpr = !(fSplitExpr == nullptr || fSplitExprString == "");
332
333 if (useSplitExpr) {
334 // Deterministic split
335 for (ULong64_t i = 0; i < nEntries; i++) {
336 TMVA::Event *ev = oldSet[i];
337 UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
338 tempSets.at((UInt_t)iFold).push_back(ev);
339 }
340 } else {
341 if(!fStratified){
342 // Random split
343 std::vector<UInt_t> fOrigToFoldMapping;
344 fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
345
346 for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
347 UInt_t iFold = fOrigToFoldMapping[iEvent];
348 TMVA::Event *ev = oldSet[iEvent];
349 tempSets.at(iFold).push_back(ev);
350
351 fEventToFoldMapping[ev] = iFold;
352 }
353 } else {
354 // Stratified Split
355 std::vector<std::vector<TMVA::Event *>> oldSets;
356 oldSets.reserve(numClasses);
357
358 for(UInt_t iClass = 0; iClass < numClasses; iClass++){
359 oldSets.emplace_back();
360 //find a way to get number of events in each class
361 oldSets.reserve(nEntries);
362 }
363
364 for(UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
365 // check the class of event and add to its vector of events
366 TMVA::Event *ev = oldSet[iEvent];
367 UInt_t iClass = ev->GetClass();
368 oldSets.at(iClass).push_back(ev);
369 }
370
371 for(UInt_t i = 0; i<numClasses; ++i){
372 // Shuffle each vector individually
374 std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
375 }
376
377 for(UInt_t i = 0; i<numClasses; ++i) {
378 std::vector<UInt_t> fOrigToFoldMapping;
379 fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
380
381 for (UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
382 UInt_t iFold = fOrigToFoldMapping[iEvent];
383 TMVA::Event *ev = oldSets.at(i)[iEvent];
384 tempSets.at(iFold).push_back(ev);
385 fEventToFoldMapping[ev] = iFold;
386 }
387 }
388 }
389 }
390 return tempSets;
391}
const Bool_t kFALSE
Definition RtypesCore.h:92
double Double_t
Definition RtypesCore.h:59
unsigned long long ULong64_t
Definition RtypesCore.h:74
const Bool_t kTRUE
Definition RtypesCore.h:91
#define ClassImp(name)
Definition Rtypes.h:364
char name[80]
Definition TGX11.cxx:110
The Formula class.
Definition TFormula.h:87
const char * GetParName(Int_t ipar) const
Return parameter name given by integer.
Bool_t IsValid() const
Definition TFormula.h:241
Int_t GetNpar() const
Definition TFormula.h:230
MsgLogger & Log() const
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
Definition CvSplit.h:81
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition CvSplit.cxx:164
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition CvSplit.h:80
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition CvSplit.cxx:206
static Bool_t Validate(TString expr)
Definition CvSplit.cxx:198
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition CvSplit.cxx:139
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition CvSplit.h:83
DataSetInfo & fDsi
Definition CvSplit.h:77
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition CvSplit.h:82
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition CvSplit.cxx:293
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition CvSplit.cxx:255
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
Definition CvSplit.cxx:319
TString fSplitExprString
Definition CvSplit.h:108
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition CvSplit.cxx:243
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition CvSplit.cxx:114
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition CvSplit.cxx:57
CvSplit(UInt_t numFolds)
Definition CvSplit.cxx:38
Class that contains all the data information.
Definition DataSetInfo.h:62
std::vector< VariableInfo > & GetSpectatorInfos()
UInt_t GetNClasses() const
DataSet * GetDataSet() const
returns data set
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition DataSet.cxx:250
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition Event.cxx:261
UInt_t GetClass() const
Definition Event.h:86
@ kTraining
Definition Types.h:145
Class for type info of MVA input variable.
const TString & GetLabel() const
const TString & GetExpression() const
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:158
auto * tt
Definition textangle.C:16
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345