59 if (foldNumber >= fNumFolds) {
60 Log() << kFATAL <<
"DataSet prepared for \"" << fNumFolds <<
"\" folds, requested fold \"" << foldNumber
61 <<
"\" is outside of range." <<
Endl;
65 auto prepareDataSetInternal = [
this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
66 UInt_t numFolds = fTrainEvents.size();
69 UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
70 [&](
UInt_t sum, std::vector<TMVA::Event *>
v) { return sum + v.size(); });
72 UInt_t nTrain = nTotal - vec.at(foldNumber).size();
73 UInt_t nTest = vec.at(foldNumber).size();
75 std::vector<Event *> tempTrain;
76 std::vector<Event *> tempTest;
78 tempTrain.reserve(nTrain);
79 tempTest.reserve(nTest);
82 for (
UInt_t i = 0; i < numFolds; ++i) {
83 if (i == foldNumber) {
87 tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
91 tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
93 Log() << kDEBUG <<
"Fold prepared, num events in training set: " << tempTrain.size() <<
Endl;
94 Log() << kDEBUG <<
"Fold prepared, num events in test set: " << tempTest.size() <<
Endl;
102 prepareDataSetInternal(fTrainEvents);
104 prepareDataSetInternal(fTestEvents);
106 Log() << kFATAL <<
"PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
117 Log() << kFATAL <<
"Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
120 std::vector<Event *> *tempVec =
new std::vector<Event *>;
122 for (
UInt_t i = 0; i < fNumFolds; ++i) {
123 tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
140 : fDsi(dsi), fIdxFormulaParNumFolds(
std::numeric_limits<
UInt_t>::max()), fSplitFormula(
"", expr),
141 fParValues(fSplitFormula.GetNpar())
144 throw std::runtime_error(
"Split expression \"" + std::string(
fSplitExpr.
Data()) +
"\" is not a valid TFormula.");
152 if (
name ==
"NumFolds" or
name ==
"numFolds") {
166 for (
auto &p : fFormulaParIdxToDsiSpecIdx) {
167 auto iFormulaPar = p.first;
168 auto iSpectator = p.second;
170 fParValues.at(iFormulaPar) = ev->
GetSpectator(iSpectator);
173 if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
174 fParValues[fIdxFormulaParNumFolds] = numFolds;
180 Double_t iFold_d = fSplitFormula.EvalPar(
nullptr, &fParValues[0]);
183 throw std::runtime_error(
"Output of splitExpr must be non-negative.");
186 UInt_t iFold = std::lround(iFold_d);
187 if (iFold >= numFolds) {
188 throw std::runtime_error(
"Output of splitExpr should be a non-negative"
189 "integer between 0 and numFolds-1 inclusive.");
210 for (
UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
221 throw std::runtime_error(
"Spectator \"" + std::string(
name.Data()) +
"\" not found.");
244 :
CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
260 if (fSplitExprString !=
TString(
"")) {
261 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(dsi, fSplitExprString));
265 if (fMakeFoldDataSet) {
266 Log() << kINFO <<
"Splitting in k-folds has been already done" <<
Endl;
270 fMakeFoldDataSet =
kTRUE;
279 fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
280 fTestEvents = SplitSets(testData, fNumFolds, numClasses);
297 std::vector<UInt_t> fOrigToFoldMapping;
298 fOrigToFoldMapping.reserve(nEntries);
300 for (
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
301 fOrigToFoldMapping.push_back(iEvent % numFolds);
306 std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
308 return fOrigToFoldMapping;
318std::vector<std::vector<TMVA::Event *>>
321 const ULong64_t nEntries = oldSet.size();
322 const ULong64_t foldSize = nEntries / numFolds;
324 std::vector<std::vector<Event *>> tempSets;
325 tempSets.reserve(fNumFolds);
326 for (
UInt_t iFold = 0; iFold < numFolds; ++iFold) {
327 tempSets.emplace_back();
328 tempSets.at(iFold).reserve(foldSize);
331 Bool_t useSplitExpr = not(fSplitExpr ==
nullptr or fSplitExprString ==
"");
335 for (
ULong64_t i = 0; i < nEntries; i++) {
337 UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
338 tempSets.at((
UInt_t)iFold).push_back(ev);
343 std::vector<UInt_t> fOrigToFoldMapping;
344 fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
346 for (
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
347 UInt_t iFold = fOrigToFoldMapping[iEvent];
349 tempSets.at(iFold).push_back(ev);
351 fEventToFoldMapping[ev] = iFold;
355 std::vector<std::vector<TMVA::Event *>> oldSets;
356 oldSets.reserve(numClasses);
358 for(
UInt_t iClass = 0; iClass < numClasses; iClass++){
359 oldSets.emplace_back();
361 oldSets.reserve(nEntries);
364 for(
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
368 oldSets.at(iClass).push_back(ev);
371 for(
UInt_t i = 0; i<numClasses; ++i){
374 std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
377 for(
UInt_t i = 0; i<numClasses; ++i) {
378 std::vector<UInt_t> fOrigToFoldMapping;
379 fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
381 for (
UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
382 UInt_t iFold = fOrigToFoldMapping[iEvent];
384 tempSets.at(iFold).push_back(ev);
385 fEventToFoldMapping[ev] = iFold;
unsigned long long ULong64_t
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
UInt_t Eval(UInt_t numFolds, const Event *ev)
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
static Bool_t Validate(TString expr)
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Class that contains all the data information.
std::vector< VariableInfo > & GetSpectatorInfos()
UInt_t GetNClasses() const
DataSet * GetDataSet() const
returns data set
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Class for type info of MVA input variable.
const TString & GetLabel() const
const TString & GetExpression() const
virtual const char * GetName() const
Returns name of object.
const char * Data() const
MsgLogger & Endl(MsgLogger &ml)
static long int sum(long int i)