Logo ROOT   6.08/07
Reference Guide
DataLoader.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata
3 // Mentors: Lorenzo Moneta, Sergei Gleyzer
4 //NOTE: Based on TMVA::Factory
5 
6 /**********************************************************************************
7  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
8  * Package: TMVA *
9  * Class : DataLoader *
10  * Web : http://tmva.sourceforge.net *
11  * *
12  * Description: *
13  * This is a class to load datasets into every booked method *
14  * *
15  * Authors (alphabetical): *
16  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
17  * Omar Zapata <Omar.Zapata@cern.ch> - ITM/UdeA, Colombia *
18  * Sergei Gleyzer<sergei.gleyzer@cern.ch> - CERN, Switzerland *
19  * *
20  * Copyright (c) 2005-2015: *
21  * CERN, Switzerland *
22  * ITM/UdeA, Colombia *
23  * *
24  * Redistribution and use in source and binary forms, with or without *
25  * modification, are permitted according to the terms listed in LICENSE *
26  * (http://tmva.sourceforge.net/LICENSE) *
27  **********************************************************************************/
28 
29 
30 #include "TROOT.h"
31 #include "TFile.h"
32 #include "TTree.h"
33 #include "TLeaf.h"
34 #include "TEventList.h"
35 #include "TH2.h"
36 #include "TText.h"
37 #include "TStyle.h"
38 #include "TMatrixF.h"
39 #include "TMatrixDSym.h"
40 #include "TPaletteAxis.h"
41 #include "TPrincipal.h"
42 #include "TMath.h"
43 #include "TObjString.h"
44 #include "TRandom3.h"
45 
46 #include <string.h>
47 
48 #include "TMVA/Configurable.h"
49 #include "TMVA/DataLoader.h"
50 #include "TMVA/Config.h"
51 #include "TMVA/Tools.h"
52 #include "TMVA/Ranking.h"
53 #include "TMVA/DataSet.h"
54 #include "TMVA/IMethod.h"
55 #include "TMVA/MethodBase.h"
56 #include "TMVA/DataInputHandler.h"
57 #include "TMVA/DataSetManager.h"
58 #include "TMVA/DataSetInfo.h"
59 #include "TMVA/MethodBoost.h"
60 #include "TMVA/MethodCategory.h"
61 
62 #include "TMVA/VariableInfo.h"
69 
70 
72 #include "TMVA/ResultsRegression.h"
73 #include "TMVA/ResultsMulticlass.h"
74 #include "TMVA/Types.h"
75 
76 
78 
79 
80 //_______________________________________________________________________
82 : Configurable( ),
83  fDataSetManager ( NULL ), //DSMTEST
84  fDataInputHandler ( new DataInputHandler ),
85  fTransformations ( "I" ),
86  fVerbose ( kFALSE ),
87  fDataAssignType ( kAssignEvents ),
88  fATreeEvent (0),
89  fMakeFoldDataSet ( kFALSE )
90 {
91  fDataSetManager = new DataSetManager( *fDataInputHandler ); // DSMTEST
92  SetName(thedlName.Data());
93  fLogger->SetSource("DataLoader");
94 }
95 
96 
97 //_______________________________________________________________________
99 {
100  // destructor
101 
102  std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
103  for (;trfIt != fDefaultTrfs.end(); trfIt++) delete (*trfIt);
104 
105  delete fDataInputHandler;
106 
107  // destroy singletons
108  // DataSetManager::DestroyInstance(); // DSMTEST replaced by following line
109  delete fDataSetManager; // DSMTEST
110 
111  // problem with call of REGISTER_METHOD macro ...
112  // ClassifierDataLoader::DestroyInstance();
113  // Types::DestroyInstance();
116 }
117 
118 
119 //_______________________________________________________________________
121 {
122  return fDataSetManager->AddDataSetInfo(dsi); // DSMTEST
123 }
124 
125 //_______________________________________________________________________
127 {
128  DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST
129 
130  if (dsi!=0) return *dsi;
131 
132  return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
133 }
134 
135 //_______________________________________________________________________
137 {
138  return DefaultDataSetInfo(); // DSMTEST
139 }
140 
141 ////////////////////////////////////////////////////////////////////////////////
142 /// Transforms the variables and return a new DataLoader with the transformed
143 /// variables
144 
146 {
147  TString trOptions = "0";
148  TString trName = "None";
149  if (trafoDefinition.Contains("(")) {
150 
151  // contains transformation parameters
152  Ssiz_t parStart = trafoDefinition.Index( "(" );
153  Ssiz_t parLen = trafoDefinition.Index( ")", parStart )-parStart+1;
154 
155  trName = trafoDefinition(0,parStart);
156  trOptions = trafoDefinition(parStart,parLen);
157  trOptions.Remove(parLen-1,1);
158  trOptions.Remove(0,1);
159  }
160  else
161  trName = trafoDefinition;
162 
163  VarTransformHandler* handler = new VarTransformHandler(this);
164  // variance threshold variable transformation
165  if (trName == "VT") {
166 
167  // find threshold value from given input
168  Double_t threshold = 0.0;
169  if (!trOptions.IsFloat()){
170  Log() << kFATAL << " VT transformation must be passed a floating threshold value" << Endl;
171  return this;
172  }
173  else
174  threshold = trOptions.Atof();
175  TMVA::DataLoader *transformedLoader = handler->VarianceThreshold(threshold);
176  return transformedLoader;
177  }
178  else {
179  Log() << kFATAL << "Incorrect transformation string provided, please check" << Endl;
180  }
181  Log() << kINFO << "No transformation applied, returning original loader" << Endl;
182  return this;
183 }
184 
185 // ________________________________________________
186 // the next functions are to assign events directly
187 
188 //_______________________________________________________________________
190 {
191  // create the data assignment tree (for event-wise data assignment by user)
192  TTree * assignTree = new TTree( name, name );
193  assignTree->SetDirectory(0);
194  assignTree->Branch( "type", &fATreeType, "ATreeType/I" );
195  assignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/F" );
196 
197  std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
198  std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
199  std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
200 
201  if (fATreeEvent.size()==0) fATreeEvent.resize(vars.size()+tgts.size()+spec.size());
202  // add variables
203  for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
204  TString vname = vars[ivar].GetExpression();
205  assignTree->Branch( vname, &fATreeEvent[ivar], vname + "/F" );
206  }
207  // add targets
208  for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
209  TString vname = tgts[itgt].GetExpression();
210  assignTree->Branch( vname, &fATreeEvent[vars.size()+itgt], vname + "/F" );
211  }
212  // add spectators
213  for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
214  TString vname = spec[ispc].GetExpression();
215  assignTree->Branch( vname, &fATreeEvent[vars.size()+tgts.size()+ispc], vname + "/F" );
216  }
217  return assignTree;
218 }
219 
220 //_______________________________________________________________________
221 void TMVA::DataLoader::AddSignalTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
222 {
223  // add signal training event
224  AddEvent( "Signal", Types::kTraining, event, weight );
225 }
226 
227 //_______________________________________________________________________
228 void TMVA::DataLoader::AddSignalTestEvent( const std::vector<Double_t>& event, Double_t weight )
229 {
230  // add signal testing event
231  AddEvent( "Signal", Types::kTesting, event, weight );
232 }
233 
234 //_______________________________________________________________________
235 void TMVA::DataLoader::AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
236 {
237  // add signal training event
238  AddEvent( "Background", Types::kTraining, event, weight );
239 }
240 
241 //_______________________________________________________________________
242 void TMVA::DataLoader::AddBackgroundTestEvent( const std::vector<Double_t>& event, Double_t weight )
243 {
244  // add signal training event
245  AddEvent( "Background", Types::kTesting, event, weight );
246 }
247 
248 //_______________________________________________________________________
249 void TMVA::DataLoader::AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
250 {
251  // add signal training event
252  AddEvent( className, Types::kTraining, event, weight );
253 }
254 
255 //_______________________________________________________________________
256 void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
257 {
258  // add signal test event
259  AddEvent( className, Types::kTesting, event, weight );
260 }
261 
262 //_______________________________________________________________________
264  const std::vector<Double_t>& event, Double_t weight )
265 {
266  // add event
267  // vector event : the order of values is: variables + targets + spectators
268  ClassInfo* theClass = DefaultDataSetInfo().AddClass(className); // returns class (creates it if necessary)
269  UInt_t clIndex = theClass->GetNumber();
270 
271 
272  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
273  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
275 
276 
277  if (clIndex>=fTrainAssignTree.size()) {
278  fTrainAssignTree.resize(clIndex+1, 0);
279  fTestAssignTree.resize(clIndex+1, 0);
280  }
281 
282  if (fTrainAssignTree[clIndex]==0) { // does not exist yet
283  fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) );
284  fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) );
285  }
286 
287  fATreeType = clIndex;
288  fATreeWeight = weight;
289  for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
290 
291  if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
292  else fTestAssignTree[clIndex]->Fill();
293 
294 }
295 
296 //_______________________________________________________________________
298 {
299  //
300  return fTrainAssignTree[clIndex]!=0;
301 }
302 
303 //_______________________________________________________________________
305 {
306  // assign event-wise local trees to data set
307  UInt_t size = fTrainAssignTree.size();
308  for(UInt_t i=0; i<size; i++) {
309  if(!UserAssignEvents(i)) continue;
310  const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
311  SetWeightExpression( "weight", className );
312  AddTree(fTrainAssignTree[i], className, 1.0, TCut(""), Types::kTraining );
313  AddTree(fTestAssignTree[i], className, 1.0, TCut(""), Types::kTesting );
314  }
315 }
316 
317 //_______________________________________________________________________
318 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
319  const TCut& cut, const TString& treetype )
320 {
321  // number of signal events (used to compute significance)
323  TString tmpTreeType = treetype; tmpTreeType.ToLower();
324  if (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
325  else if (tmpTreeType.Contains( "train" )) tt = Types::kTraining;
326  else if (tmpTreeType.Contains( "test" )) tt = Types::kTesting;
327  else {
328  Log() << kFATAL << "<AddTree> cannot interpret tree type: \"" << treetype
329  << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
330  }
331  AddTree( tree, className, weight, cut, tt );
332 }
333 
334 //_______________________________________________________________________
335 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
336  const TCut& cut, Types::ETreeType tt )
337 {
338  if(!tree)
339  Log() << kFATAL << "Tree does not exist (empty pointer)." << Endl;
340 
341  DefaultDataSetInfo().AddClass( className );
342 
343  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
344  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
346 
347  Log() << kINFO<< "Add Tree " << tree->GetName() << " of type " << className
348  << " with " << tree->GetEntries() << " events" << Endl;
349  DataInput().AddTree( tree, className, weight, cut, tt );
350 }
351 
352 //_______________________________________________________________________
354 {
355  // number of signal events (used to compute significance)
356  AddTree( signal, "Signal", weight, TCut(""), treetype );
357 }
358 
359 //_______________________________________________________________________
361 {
362  // add signal tree from text file
363 
364  // create trees from these ascii files
365  TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
366  signalTree->ReadFile( datFileS );
367 
368  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \""
369  << datFileS << Endl;
370 
371  // number of signal events (used to compute significance)
372  AddTree( signalTree, "Signal", weight, TCut(""), treetype );
373 }
374 
375 //_______________________________________________________________________
377 {
378  AddTree( signal, "Signal", weight, TCut(""), treetype );
379 }
380 
381 //_______________________________________________________________________
383 {
384  // number of signal events (used to compute significance)
385  AddTree( signal, "Background", weight, TCut(""), treetype );
386 }
387 //_______________________________________________________________________
389 {
390  // add background tree from text file
391 
392  // create trees from these ascii files
393  TTree* bkgTree = new TTree( "TreeB", "Tree (B)" );
394  bkgTree->ReadFile( datFileB );
395 
396  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \""
397  << datFileB << Endl;
398 
399  // number of signal events (used to compute significance)
400  AddTree( bkgTree, "Background", weight, TCut(""), treetype );
401 }
402 
403 //_______________________________________________________________________
405 {
406  AddTree( signal, "Background", weight, TCut(""), treetype );
407 }
408 
409 //_______________________________________________________________________
411 {
412  AddTree( tree, "Signal", weight );
413 }
414 
415 //_______________________________________________________________________
417 {
418  AddTree( tree, "Background", weight );
419 }
420 
421 //_______________________________________________________________________
422 void TMVA::DataLoader::SetTree( TTree* tree, const TString& className, Double_t weight )
423 {
424  // set background tree
425  AddTree( tree, className, weight, TCut(""), Types::kMaxTreeType );
426 }
427 
428 //_______________________________________________________________________
430  Double_t signalWeight, Double_t backgroundWeight )
431 {
432  // define the input trees for signal and background; no cuts are applied
433  AddTree( signal, "Signal", signalWeight, TCut(""), Types::kMaxTreeType );
434  AddTree( background, "Background", backgroundWeight, TCut(""), Types::kMaxTreeType );
435 }
436 
437 //_______________________________________________________________________
438 void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& datFileB,
439  Double_t signalWeight, Double_t backgroundWeight )
440 {
441  DataInput().AddTree( datFileS, "Signal", signalWeight );
442  DataInput().AddTree( datFileB, "Background", backgroundWeight );
443 }
444 
445 //_______________________________________________________________________
446 void TMVA::DataLoader::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut )
447 {
448  // define the input trees for signal and background from single input tree,
449  // containing both signal and background events distinguished by the type
450  // identifiers: SigCut and BgCut
451  AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType );
452  AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType );
453 }
454 
455 //_______________________________________________________________________
456 void TMVA::DataLoader::AddVariable( const TString& expression, const TString& title, const TString& unit,
457  char type, Double_t min, Double_t max )
458 {
459  // user inserts discriminating variable in data set info
460  DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
461 }
462 
463 //_______________________________________________________________________
464 void TMVA::DataLoader::AddVariable( const TString& expression, char type,
465  Double_t min, Double_t max )
466 {
467  // user inserts discriminating variable in data set info
468  DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type );
469 }
470 
471 //_______________________________________________________________________
472 void TMVA::DataLoader::AddTarget( const TString& expression, const TString& title, const TString& unit,
473  Double_t min, Double_t max )
474 {
475  // user inserts target in data set info
476 
479 
480  DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
481 }
482 
483 //_______________________________________________________________________
484 void TMVA::DataLoader::AddSpectator( const TString& expression, const TString& title, const TString& unit,
485  Double_t min, Double_t max )
486 {
487  // user inserts target in data set info
488  DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
489 }
490 
491 //_______________________________________________________________________
493 {
494  // default creation
495  return AddDataSet( fName );
496 }
497 
498 //_______________________________________________________________________
499 void TMVA::DataLoader::SetInputVariables( std::vector<TString>* theVariables )
500 {
501  // fill input variables in data set
502  for (std::vector<TString>::iterator it=theVariables->begin();
503  it!=theVariables->end(); it++) AddVariable(*it);
504 }
505 
506 //_______________________________________________________________________
508 {
509  DefaultDataSetInfo().SetWeightExpression(variable, "Signal");
510 }
511 
512 //_______________________________________________________________________
514 {
515  DefaultDataSetInfo().SetWeightExpression(variable, "Background");
516 }
517 
518 //_______________________________________________________________________
519 void TMVA::DataLoader::SetWeightExpression( const TString& variable, const TString& className )
520 {
521  //Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;
522  if (className=="") {
523  SetSignalWeightExpression(variable);
525  }
526  else DefaultDataSetInfo().SetWeightExpression( variable, className );
527 }
528 
529 //_______________________________________________________________________
530 void TMVA::DataLoader::SetCut( const TString& cut, const TString& className ) {
531  SetCut( TCut(cut), className );
532 }
533 
534 //_______________________________________________________________________
535 void TMVA::DataLoader::SetCut( const TCut& cut, const TString& className )
536 {
537  DefaultDataSetInfo().SetCut( cut, className );
538 }
539 
540 //_______________________________________________________________________
541 void TMVA::DataLoader::AddCut( const TString& cut, const TString& className )
542 {
543  AddCut( TCut(cut), className );
544 }
545 
546 //_______________________________________________________________________
547 void TMVA::DataLoader::AddCut( const TCut& cut, const TString& className )
548 {
549  DefaultDataSetInfo().AddCut( cut, className );
550 }
551 
552 //_______________________________________________________________________
554  Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
555  const TString& otherOpt )
556 {
557  // prepare the training and test trees
559 
560  AddCut( cut );
561 
562  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
563  NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
564 }
565 
566 //_______________________________________________________________________
568 {
569  // prepare the training and test trees
570  // kept for backward compatibility
572 
573  AddCut( cut );
574 
575  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
576  Ntrain, Ntrain, Ntest, Ntest) );
577 }
578 
579 //_______________________________________________________________________
581 {
582  // prepare the training and test trees
583  // -> same cuts for signal and background
585 
587  AddCut( cut );
589 }
590 
591 //_______________________________________________________________________
592 void TMVA::DataLoader::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt )
593 {
594  // prepare the training and test trees
595 
596  // if event-wise data assignment, add local trees to dataset first
598 
599  //Log() << kINFO <<"Preparing trees for training and testing..."<< Endl;
600  AddCut( sigcut, "Signal" );
601  AddCut( bkgcut, "Background" );
602 
603  DefaultDataSetInfo().SetSplitOptions( splitOpt );
604 }
605 
606 //______________________________________________________________________
607 // Function required to split the training and testing datasets into a
608 // number of folds. Required by the CrossValidation and HyperParameterOptimisation
609 // classes. The option to split the training dataset into a training set and
610 // a validation set is implemented but not currently used.
611 void TMVA::DataLoader::MakeKFoldDataSet(UInt_t numberFolds, bool validationSet){
612 
613 
614  // No need to do it again if the sets have already been split.
615  if(fMakeFoldDataSet){
616  Log() << kInfo << "Splitting in k-folds has been already done" << Endl;
617  return;
618  }
619 
621 
622  // Get the original event vectors for testing and training from the dataset.
623  const std::vector<Event*> TrainingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(Types::kTraining);
624  const std::vector<Event*> TestingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(Types::kTesting);
625 
626  std::vector<Event*> TrainSigData;
627  std::vector<Event*> TrainBkgData;
628  std::vector<Event*> TestSigData;
629  std::vector<Event*> TestBkgData;
630 
631  // Split the testing and training sets into signal and background classes.
632  for(UInt_t i=0; i<TrainingData.size(); ++i){
633  if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName(), "Signal", 6) == 0){ TrainSigData.push_back(TrainingData.at(i)); }
634  else if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName(), "Background", 10) == 0){ TrainBkgData.push_back(TrainingData.at(i)); }
635  else{
636  Log() << kFATAL << "DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName() << " is not a recognised class" << Endl;
637  }
638  }
639 
640  for(UInt_t i=0; i<TestingData.size(); ++i){
641  if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName(), "Signal", 6) == 0){ TestSigData.push_back(TestingData.at(i)); }
642  else if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName(), "Background", 10) == 0){ TestBkgData.push_back(TestingData.at(i)); }
643  else{
644  Log() << kFATAL << "DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName() << " is not a recognised class" << Endl;
645  }
646  }
647 
648 
649  // Split the sets into the number of folds.
650  if(validationSet){
651  std::vector<std::vector<Event*>> tempSigEvents = SplitSets(TrainSigData,0,2);
652  std::vector<std::vector<Event*>> tempBkgEvents = SplitSets(TrainBkgData,0,2);
653  fTrainSigEvents = SplitSets(tempSigEvents.at(0),0,numberFolds);
654  fTrainBkgEvents = SplitSets(tempBkgEvents.at(0),0,numberFolds);
655  fValidSigEvents = SplitSets(tempSigEvents.at(1),0,numberFolds);
656  fValidBkgEvents = SplitSets(tempBkgEvents.at(1),0,numberFolds);
657  }
658  else{
659  fTrainSigEvents = SplitSets(TrainSigData,0,numberFolds);
660  fTrainBkgEvents = SplitSets(TrainBkgData,0,numberFolds);
661  }
662 
663  fTestSigEvents = SplitSets(TestSigData,0,numberFolds);
664  fTestBkgEvents = SplitSets(TestBkgData,0,numberFolds);
665 }
666 
667 //______________________________________________________________________
668 // Function for assigning the correct folds to the testing or training set.
670 
671  UInt_t numFolds = fTrainSigEvents.size();
672 
673  std::vector<Event*>* tempTrain = new std::vector<Event*>;
674  std::vector<Event*>* tempTest = new std::vector<Event*>;
675 
676  UInt_t nTrain = 0;
677  UInt_t nTest = 0;
678 
679  // Get the number of events so the memory can be reserved.
680  for(UInt_t i=0; i<numFolds; ++i){
681  if(tt == Types::kTraining){
682  if(i!=foldNumber){
683  nTrain += fTrainSigEvents.at(i).size();
684  nTrain += fTrainBkgEvents.at(i).size();
685  }
686  else{
687  nTest += fTrainSigEvents.at(i).size();
688  nTest += fTrainSigEvents.at(i).size();
689  }
690  }
691  else if(tt == Types::kValidation){
692  if(i!=foldNumber){
693  nTrain += fValidSigEvents.at(i).size();
694  nTrain += fValidBkgEvents.at(i).size();
695  }
696  else{
697  nTest += fValidSigEvents.at(i).size();
698  nTest += fValidSigEvents.at(i).size();
699  }
700  }
701  else if(tt == Types::kTesting){
702  if(i!=foldNumber){
703  nTrain += fTestSigEvents.at(i).size();
704  nTrain += fTestBkgEvents.at(i).size();
705  }
706  else{
707  nTest += fTestSigEvents.at(i).size();
708  nTest += fTestSigEvents.at(i).size();
709  }
710  }
711  }
712 
713  // Reserve memory before filling vectors
714  tempTrain->reserve(nTrain);
715  tempTest->reserve(nTest);
716 
717  // Fill vectors with correct folds for testing and training.
718  for(UInt_t j=0; j<numFolds; ++j){
719  if(tt == Types::kTraining){
720  if(j!=foldNumber){
721  tempTrain->insert(tempTrain->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
722  tempTrain->insert(tempTrain->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
723  }
724  else{
725  tempTest->insert(tempTest->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
726  tempTest->insert(tempTest->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
727  }
728  }
729  else if(tt == Types::kValidation){
730  if(j!=foldNumber){
731  tempTrain->insert(tempTrain->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
732  tempTrain->insert(tempTrain->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
733  }
734  else{
735  tempTest->insert(tempTest->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
736  tempTest->insert(tempTest->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
737  }
738  }
739  else if(tt == Types::kTesting){
740  if(j!=foldNumber){
741  tempTrain->insert(tempTrain->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
742  tempTrain->insert(tempTrain->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
743  }
744  else{
745  tempTest->insert(tempTest->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
746  tempTest->insert(tempTest->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
747  }
748  }
749  }
750 
751  // Assign the vectors of the events to rebuild the dataset
754 
755 }
756 
757 //______________________________________________________________________
758 // Splits the input vector in to equally sized randomly sampled folds.
759 std::vector<std::vector<TMVA::Event*>> TMVA::DataLoader::SplitSets(std::vector<TMVA::Event*>& oldSet, int seedNum, int numFolds){
760 
761  ULong64_t nEntries = oldSet.size();
762  ULong64_t foldSize = nEntries/numFolds;
763 
764  std::vector<std::vector<Event*>> tempSets;
765  tempSets.resize(numFolds);
766 
767  TRandom3 r(seedNum);
768 
769  ULong64_t inSet = 0;
770 
771  for(ULong64_t i=0; i<nEntries; i++){
772  bool inTree = false;
773  if(inSet == foldSize*numFolds){
774  break;
775  }
776  else{
777  while(!inTree){
778  int s = r.Integer(numFolds);
779  if(tempSets.at(s).size()<foldSize){
780  tempSets.at(s).push_back(oldSet.at(i));
781  inSet++;
782  inTree=true;
783  }
784  }
785  }
786  }
787 
788  return tempSets;
789 
790 }
791 
792 //_______________________________________________________________________
793 //Copy method use in VI and CV
795 {
796  TMVA::DataLoader* des=new TMVA::DataLoader(name);
797  DataLoaderCopy(des,this);
798  return des;
799 }
800 
801 //_______________________________________________________________________
803 {
804  //Loading Dataset from DataInputHandler for subseed
805  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();treeinfo++)
806  {
807  des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
808  }
809 
810  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();treeinfo++)
811  {
812  des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
813  }
814 }
815 
816 //_______________________________________________________________________
818 {
819  //returns the correlation matrix of datasets
820  const TMatrixD * m = DefaultDataSetInfo().CorrelationMatrix(className);
822  "CorrelationMatrix"+className, "Correlation Matrix ("+className+")");
823 }
824 
825 
DataSetInfo * GetDataSetInfo(const TString &dsiName)
returns datasetinfo object for given name
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.cxx:382
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
DataSetManager * fDataSetManager
Definition: DataLoader.h:197
virtual ~DataLoader()
Definition: DataLoader.cxx:98
Random number generator class based on M.
Definition: TRandom3.h:29
std::vector< TreeInfo >::const_iterator Bend() const
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void AddTrainingEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
Definition: DataLoader.cxx:249
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
std::vector< TreeInfo >::const_iterator Bbegin() const
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
Definition: DataLoader.h:202
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
TMVA::DataLoader * VarianceThreshold(Double_t threshold)
Computes variance of all the variables and returns a new DataLoader with the selected variables whose...
std::vector< std::vector< TMVA::Event * > > fTrainBkgEvents
Definition: DataLoader.h:218
MsgLogger & Log() const
Definition: Configurable.h:128
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:136
Bool_t IsFloat() const
Returns kTRUE if string contains a floating point or integer number.
Definition: TString.cxx:1835
TTree * CreateEventAssignTrees(const TString &name)
Definition: DataLoader.cxx:189
DataSetInfo & DefaultDataSetInfo()
Definition: DataLoader.cxx:492
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:582
DataLoader * VarTransform(TString trafoDefinition)
Transforms the variables and return a new DataLoader with the transformed variables.
Definition: DataLoader.cxx:145
Basic string class.
Definition: TString.h:137
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1089
int Int_t
Definition: RtypesCore.h:41
void MakeKFoldDataSet(UInt_t numberFolds, bool validationSet=false)
Definition: DataLoader.cxx:611
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Definition: DataLoader.cxx:802
std::vector< TreeInfo >::const_iterator Send() const
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
void SetBackgroundTree(TTree *background, Double_t weight=1.0)
Definition: DataLoader.cxx:416
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:239
DataInputHandler * fDataInputHandler
Definition: DataLoader.h:200
Types::EAnalysisType fAnalysisType
Definition: DataLoader.h:228
void AddBackgroundTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:242
TH2 * GetCorrelationMatrix(const TString &className)
Definition: DataLoader.cxx:817
std::vector< std::vector< TMVA::Event * > > fTestBkgEvents
Definition: DataLoader.h:222
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:456
std::vector< TreeInfo >::const_iterator Sbegin() const
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
TText * tt
Definition: textangle.C:16
void AddTestEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
Definition: DataLoader.cxx:256
void SetInputTrees(const TString &signalFileName, const TString &backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0)
Definition: DataLoader.cxx:438
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
Definition: TRandom.cxx:320
void SetTree(TTree *tree, const TString &className, Double_t weight)
Definition: DataLoader.cxx:422
void PrepareFoldDataSet(UInt_t foldNumber, Types::ETreeType tt)
Definition: DataLoader.cxx:669
void SetInputVariables(std::vector< TString > *theVariables)
Definition: DataLoader.cxx:499
DataSetInfo & AddDataSet(DataSetInfo &)
Definition: DataLoader.cxx:120
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:117
void AddCut(const TString &cut, const TString &className="")
Definition: DataLoader.cxx:541
A specialized string object used for TTree selections.
Definition: TCut.h:27
static void DestroyInstance()
Definition: Tools.cxx:95
void SetInputTreesFromEventAssignTrees()
Definition: DataLoader.cxx:304
Float_t fATreeWeight
Definition: DataLoader.h:225
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:184
Bool_t fMakeFoldDataSet
Definition: DataLoader.h:230
DataInputHandler & DataInput()
Definition: DataLoader.h:183
TRandom2 r(17)
Service class for 2-Dim histogram classes.
Definition: TH2.h:36
const Int_t kInfo
Definition: TError.h:39
ClassInfo * GetClassInfo(Int_t clNum) const
const TMatrixD * CorrelationMatrix(const TString &className) const
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
add tree of className events for tt (Training;Testing..) type as input ..
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
DataSetInfo & AddDataSetInfo(DataSetInfo &dsi)
stores a copy of the dataset info object
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
char * Form(const char *fmt,...)
std::vector< TTree * > fTestAssignTree
Definition: DataLoader.h:215
Bool_t UserAssignEvents(UInt_t clIndex)
Definition: DataLoader.cxx:297
std::vector< Float_t > fATreeEvent
Definition: DataLoader.h:226
std::vector< std::vector< TMVA::Event * > > fTestSigEvents
Definition: DataLoader.h:221
void PrintClasses() const
TString fName
Definition: TNamed.h:36
DataLoader * MakeCopy(TString name)
Definition: DataLoader.cxx:794
TString & Remove(Ssiz_t pos)
Definition: TString.h:616
int Ssiz_t
Definition: RtypesCore.h:63
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: DataLoader.cxx:335
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
Definition: DataLoader.cxx:580
virtual void SetDirectory(TDirectory *dir)
Change the tree&#39;s directory.
Definition: TTree.cxx:8326
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
void AddEvent(const TString &className, Types::ETreeType tt, const std::vector< Double_t > &event, Double_t weight)
Definition: DataLoader.cxx:263
void SetBackgroundWeightExpression(const TString &variable)
Definition: DataLoader.cxx:513
int type
Definition: TGX11.cxx:120
unsigned long long ULong64_t
Definition: RtypesCore.h:70
virtual Long64_t ReadFile(const char *filename, const char *branchDescriptor="", char delimiter=' ')
Create or simply read branches from filename.
Definition: TTree.cxx:7036
static void DestroyInstance()
static function: destroy TMVA instance
Definition: Config.cxx:81
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:472
void SetWeightExpression(const TString &variable, const TString &className="")
Definition: DataLoader.cxx:519
void AddBackgroundTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:235
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:567
void SetEventCollection(std::vector< Event *> *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition: DataSet.cxx:259
ClassInfo * AddClass(const TString &className)
UInt_t GetNumber() const
Definition: ClassInfo.h:73
void SetSignalWeightExpression(const TString &variable)
Definition: DataLoader.cxx:507
virtual Long64_t GetEntries() const
Definition: TTree.h:393
std::vector< std::vector< TMVA::Event * > > fTrainSigEvents
Definition: DataLoader.h:217
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1652
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< TTree * > fTrainAssignTree
Definition: DataLoader.h:214
void AddSignalTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:228
std::vector< std::vector< TMVA::Event * > > fValidBkgEvents
Definition: DataLoader.h:220
void AddSignalTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:221
friend void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
void SetSignalTree(TTree *signal, Double_t weight=1.0)
Definition: DataLoader.cxx:410
#define NULL
Definition: Rtypes.h:82
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
Definition: tree.py:1
Double_t Atof() const
Return floating-point value contained in string.
Definition: TString.cxx:2031
A TTree object has a header with a name and a title.
Definition: TTree.h:98
#define I(x, y, z)
const Bool_t kTRUE
Definition: Rtypes.h:91
std::vector< std::vector< TMVA::Event * > > SplitSets(std::vector< TMVA::Event *> &oldSet, int seedNum, int numFolds)
Definition: DataLoader.cxx:759
std::vector< std::vector< TMVA::Event * > > fValidSigEvents
Definition: DataLoader.h:219
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.cxx:353
DataSet * GetDataSet() const
returns data set
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:112
gr SetName("gr")
void SetCut(const TString &cut, const TString &className="")
Definition: DataLoader.cxx:530
char name[80]
Definition: TGX11.cxx:109
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:484
const char * Data() const
Definition: TString.h:349