ROOT  6.06/09
Reference Guide
Factory.cxx
Go to the documentation of this file.
1 // @(#)Root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Factory *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16  * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * U. of Victoria, Canada *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 //_______________________________________________________________________
34 //
35 // This is the main MVA steering class: it creates all MVA methods,
36 // and guides them through the training, testing and evaluation
37 // phases
38 //_______________________________________________________________________
39 
40 
41 #include "TROOT.h"
42 #include "TFile.h"
43 #include "TTree.h"
44 #include "TLeaf.h"
45 #include "TEventList.h"
46 #include "TH2.h"
47 #include "TText.h"
48 #include "TStyle.h"
49 #include "TMatrixF.h"
50 #include "TMatrixDSym.h"
51 #include "TPaletteAxis.h"
52 #include "TPrincipal.h"
53 #include "TMath.h"
54 #include "TObjString.h"
55 
56 #include "TMVA/Factory.h"
57 #include "TMVA/ClassifierFactory.h"
58 #include "TMVA/Config.h"
59 #include "TMVA/Tools.h"
60 #include "TMVA/Ranking.h"
61 #include "TMVA/DataSet.h"
62 #include "TMVA/IMethod.h"
63 #include "TMVA/MethodBase.h"
64 #include "TMVA/DataInputHandler.h"
65 #include "TMVA/DataSetManager.h"
66 #include "TMVA/DataSetInfo.h"
67 #include "TMVA/MethodBoost.h"
68 #include "TMVA/MethodCategory.h"
69 
75 
77 #include "TMVA/ResultsRegression.h"
78 #include "TMVA/ResultsMulticlass.h"
79 
81 //const Int_t MinNoTestEvents = 1;
82 TFile* TMVA::Factory::fgTargetFile = 0;
83 
85 
86 #define RECREATE_METHODS kTRUE
87 #define READXML kTRUE
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 /// standard constructor
91 /// jobname : this name will appear in all weight file names produced by the MVAs
92 /// theTargetFile : output ROOT file; the test tree and all evaluation plots
93 /// will be stored here
94 /// theOption : option string; currently: "V" for verbose
95 
96 TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
97 : Configurable ( theOption ),
98  fDataSetManager ( NULL ), //DSMTEST
99  fDataInputHandler ( new DataInputHandler ),
100  fTransformations ( "I" ),
101  fVerbose ( kFALSE ),
102  fJobName ( jobName ),
103  fDataAssignType ( kAssignEvents ),
104  fATreeEvent ( NULL ),
105  fAnalysisType ( Types::kClassification )
106 {
107  fgTargetFile = theTargetFile;
108 
109  // DataSetManager::CreateInstance(*fDataInputHandler); // DSMTEST removed
110  fDataSetManager = new DataSetManager( *fDataInputHandler ); // DSMTEST
111 
112 
113  // render silent
114  if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
115 
116 
117  // init configurable
118  SetConfigDescription( "Configuration options for Factory running" );
119  SetConfigName( GetName() );
120 
121  // histograms are not automatically associated with the current
122  // directory and hence don't go out of scope when closing the file
123  // TH1::AddDirectory(kFALSE);
124  Bool_t silent = kFALSE;
125 #ifdef WIN32
126  // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
127  Bool_t color = kFALSE;
128  Bool_t drawProgressBar = kFALSE;
129 #else
130  Bool_t color = !gROOT->IsBatch();
131  Bool_t drawProgressBar = kTRUE;
132 #endif
133  DeclareOptionRef( fVerbose, "V", "Verbose flag" );
134  DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
135  DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
136  DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
137  DeclareOptionRef( drawProgressBar,
138  "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
139 
140  TString analysisType("Auto");
141  DeclareOptionRef( analysisType,
142  "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
143  AddPreDefVal(TString("Classification"));
144  AddPreDefVal(TString("Regression"));
145  AddPreDefVal(TString("Multiclass"));
146  AddPreDefVal(TString("Auto"));
147 
148  ParseOptions();
149  CheckForUnusedOptions();
150 
151  if (Verbose()) Log().SetMinType( kVERBOSE );
152 
153  // global settings
154  gConfig().SetUseColor( color );
155  gConfig().SetSilent( silent );
156  gConfig().SetDrawProgressBar( drawProgressBar );
157 
158  analysisType.ToLower();
159  if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
160  else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
161  else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
162  else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
163 
164  Greetings();
165 }
166 
167 ////////////////////////////////////////////////////////////////////////////////
168 /// print welcome message
169 /// options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
170 
172 {
174  gTools().TMVAWelcomeMessage( Log(), gTools().kLogoWelcomeMsg );
175  gTools().TMVAVersionMessage( Log() ); Log() << Endl;
176 }
177 
178 ////////////////////////////////////////////////////////////////////////////////
179 /// destructor
180 /// delete fATreeEvent;
181 
183 {
184  std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
185  for (;trfIt != fDefaultTrfs.end(); trfIt++) delete (*trfIt);
186 
187  this->DeleteAllMethods();
188  delete fDataInputHandler;
189 
190  // destroy singletons
191  // DataSetManager::DestroyInstance(); // DSMTEST replaced by following line
192  delete fDataSetManager; // DSMTEST
193 
194  // problem with call of REGISTER_METHOD macro ...
195  // ClassifierFactory::DestroyInstance();
196  // Types::DestroyInstance();
199 }
200 
201 ////////////////////////////////////////////////////////////////////////////////
202 /// delete methods
203 
205 {
206  MVector::iterator itrMethod = fMethods.begin();
207  for (; itrMethod != fMethods.end(); itrMethod++) {
208  Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
209  delete (*itrMethod);
210  }
211  fMethods.clear();
212 }
213 
214 ////////////////////////////////////////////////////////////////////////////////
215 
217 {
218  fVerbose = v;
219 }
220 
221 ////////////////////////////////////////////////////////////////////////////////
222 
224 {
225  return fDataSetManager->AddDataSetInfo(dsi); // DSMTEST
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 
231 {
232  DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST
233 
234  if (dsi!=0) return *dsi;
235 
236  return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
237 }
238 
239 // ________________________________________________
240 // the next functions are to assign events directly
241 
242 ////////////////////////////////////////////////////////////////////////////////
243 /// create the data assignment tree (for event-wise data assignment by user)
244 
246 {
247  TTree * assignTree = new TTree( name, name );
248  assignTree->SetDirectory(0);
249  assignTree->Branch( "type", &fATreeType, "ATreeType/I" );
250  assignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/F" );
251 
252  std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
253  std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
254  std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
255 
256  if (!fATreeEvent) fATreeEvent = new Float_t[vars.size()+tgts.size()+spec.size()];
257  // add variables
258  for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
259  TString vname = vars[ivar].GetExpression();
260  assignTree->Branch( vname, &(fATreeEvent[ivar]), vname + "/F" );
261  }
262  // add targets
263  for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
264  TString vname = tgts[itgt].GetExpression();
265  assignTree->Branch( vname, &(fATreeEvent[vars.size()+itgt]), vname + "/F" );
266  }
267  // add spectators
268  for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
269  TString vname = spec[ispc].GetExpression();
270  assignTree->Branch( vname, &(fATreeEvent[vars.size()+tgts.size()+ispc]), vname + "/F" );
271  }
272  return assignTree;
273 }
274 
275 ////////////////////////////////////////////////////////////////////////////////
276 /// add signal training event
277 
278 void TMVA::Factory::AddSignalTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
279 {
280  AddEvent( "Signal", Types::kTraining, event, weight );
281 }
282 
283 ////////////////////////////////////////////////////////////////////////////////
284 /// add signal testing event
285 
286 void TMVA::Factory::AddSignalTestEvent( const std::vector<Double_t>& event, Double_t weight )
287 {
288  AddEvent( "Signal", Types::kTesting, event, weight );
289 }
290 
291 ////////////////////////////////////////////////////////////////////////////////
292 /// add signal training event
293 
294 void TMVA::Factory::AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
295 {
296  AddEvent( "Background", Types::kTraining, event, weight );
297 }
298 
299 ////////////////////////////////////////////////////////////////////////////////
300 /// add signal training event
301 
302 void TMVA::Factory::AddBackgroundTestEvent( const std::vector<Double_t>& event, Double_t weight )
303 {
304  AddEvent( "Background", Types::kTesting, event, weight );
305 }
306 
307 ////////////////////////////////////////////////////////////////////////////////
308 /// add signal training event
309 
310 void TMVA::Factory::AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
311 {
312  AddEvent( className, Types::kTraining, event, weight );
313 }
314 
315 ////////////////////////////////////////////////////////////////////////////////
316 /// add signal test event
317 
318 void TMVA::Factory::AddTestEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
319 {
320  AddEvent( className, Types::kTesting, event, weight );
321 }
322 
323 ////////////////////////////////////////////////////////////////////////////////
324 /// add event
325 /// vector event : the order of values is: variables + targets + spectators
326 
328  const std::vector<Double_t>& event, Double_t weight )
329 {
330  ClassInfo* theClass = DefaultDataSetInfo().AddClass(className); // returns class (creates it if necessary)
331  UInt_t clIndex = theClass->GetNumber();
332 
333 
334  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
335  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
336  fAnalysisType = Types::kMulticlass;
337 
338 
339  if (clIndex>=fTrainAssignTree.size()) {
340  fTrainAssignTree.resize(clIndex+1, 0);
341  fTestAssignTree.resize(clIndex+1, 0);
342  }
343 
344  if (fTrainAssignTree[clIndex]==0) { // does not exist yet
345  fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) );
346  fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) );
347  }
348 
349  fATreeType = clIndex;
350  fATreeWeight = weight;
351  for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
352 
353  if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
354  else fTestAssignTree[clIndex]->Fill();
355 
356 }
357 
358 ////////////////////////////////////////////////////////////////////////////////
359 ///
360 
362 {
363  return fTrainAssignTree[clIndex]!=0;
364 }
365 
366 ////////////////////////////////////////////////////////////////////////////////
367 /// assign event-wise local trees to data set
368 
370 {
371  UInt_t size = fTrainAssignTree.size();
372  for(UInt_t i=0; i<size; i++) {
373  if(!UserAssignEvents(i)) continue;
374  const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
375  SetWeightExpression( "weight", className );
376  AddTree(fTrainAssignTree[i], className, 1.0, TCut(""), Types::kTraining );
377  AddTree(fTestAssignTree[i], className, 1.0, TCut(""), Types::kTesting );
378  }
379 }
380 
381 ////////////////////////////////////////////////////////////////////////////////
382 /// number of signal events (used to compute significance)
383 
384 void TMVA::Factory::AddTree( TTree* tree, const TString& className, Double_t weight,
385  const TCut& cut, const TString& treetype )
386 {
388  TString tmpTreeType = treetype; tmpTreeType.ToLower();
389  if (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
390  else if (tmpTreeType.Contains( "train" )) tt = Types::kTraining;
391  else if (tmpTreeType.Contains( "test" )) tt = Types::kTesting;
392  else {
393  Log() << kFATAL << "<AddTree> cannot interpret tree type: \"" << treetype
394  << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
395  }
396  AddTree( tree, className, weight, cut, tt );
397 }
398 
399 ////////////////////////////////////////////////////////////////////////////////
400 
401 void TMVA::Factory::AddTree( TTree* tree, const TString& className, Double_t weight,
402  const TCut& cut, Types::ETreeType tt )
403 {
404  if(!tree)
405  Log() << kFATAL << "Tree does not exist (empty pointer)." << Endl;
406 
407  DefaultDataSetInfo().AddClass( className );
408 
409  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
410  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
411  fAnalysisType = Types::kMulticlass;
412 
413  Log() << kINFO << "Add Tree " << tree->GetName() << " of type " << className
414  << " with " << tree->GetEntries() << " events" << Endl;
415  DataInput().AddTree( tree, className, weight, cut, tt );
416 }
417 
418 ////////////////////////////////////////////////////////////////////////////////
419 /// number of signal events (used to compute significance)
420 
422 {
423  AddTree( signal, "Signal", weight, TCut(""), treetype );
424 }
425 
426 ////////////////////////////////////////////////////////////////////////////////
427 /// add signal tree from text file
428 
430 {
431  // create trees from these ascii files
432  TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
433  signalTree->ReadFile( datFileS );
434 
435  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \""
436  << datFileS << Endl;
437 
438  // number of signal events (used to compute significance)
439  AddTree( signalTree, "Signal", weight, TCut(""), treetype );
440 }
441 
442 ////////////////////////////////////////////////////////////////////////////////
443 
444 void TMVA::Factory::AddSignalTree( TTree* signal, Double_t weight, const TString& treetype )
445 {
446  AddTree( signal, "Signal", weight, TCut(""), treetype );
447 }
448 
449 ////////////////////////////////////////////////////////////////////////////////
450 /// number of signal events (used to compute significance)
451 
453 {
454  AddTree( signal, "Background", weight, TCut(""), treetype );
455 }
456 ////////////////////////////////////////////////////////////////////////////////
457 /// add background tree from text file
458 
460 {
461  // create trees from these ascii files
462  TTree* bkgTree = new TTree( "TreeB", "Tree (B)" );
463  bkgTree->ReadFile( datFileB );
464 
465  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \""
466  << datFileB << Endl;
467 
468  // number of signal events (used to compute significance)
469  AddTree( bkgTree, "Background", weight, TCut(""), treetype );
470 }
471 
472 ////////////////////////////////////////////////////////////////////////////////
473 
474 void TMVA::Factory::AddBackgroundTree( TTree* signal, Double_t weight, const TString& treetype )
475 {
476  AddTree( signal, "Background", weight, TCut(""), treetype );
477 }
478 
479 ////////////////////////////////////////////////////////////////////////////////
480 
482 {
483  AddTree( tree, "Signal", weight );
484 }
485 
486 ////////////////////////////////////////////////////////////////////////////////
487 
489 {
490  AddTree( tree, "Background", weight );
491 }
492 
493 ////////////////////////////////////////////////////////////////////////////////
494 /// set background tree
495 
496 void TMVA::Factory::SetTree( TTree* tree, const TString& className, Double_t weight )
497 {
498  AddTree( tree, className, weight, TCut(""), Types::kMaxTreeType );
499 }
500 
501 ////////////////////////////////////////////////////////////////////////////////
502 /// define the input trees for signal and background; no cuts are applied
503 
505  Double_t signalWeight, Double_t backgroundWeight )
506 {
507  AddTree( signal, "Signal", signalWeight, TCut(""), Types::kMaxTreeType );
508  AddTree( background, "Background", backgroundWeight, TCut(""), Types::kMaxTreeType );
509 }
510 
511 ////////////////////////////////////////////////////////////////////////////////
512 
513 void TMVA::Factory::SetInputTrees( const TString& datFileS, const TString& datFileB,
514  Double_t signalWeight, Double_t backgroundWeight )
515 {
516  DataInput().AddTree( datFileS, "Signal", signalWeight );
517  DataInput().AddTree( datFileB, "Background", backgroundWeight );
518 }
519 
520 ////////////////////////////////////////////////////////////////////////////////
521 /// define the input trees for signal and background from single input tree,
522 /// containing both signal and background events distinguished by the type
523 /// identifiers: SigCut and BgCut
524 
525 void TMVA::Factory::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut )
526 {
527  AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType );
528  AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType );
529 }
530 
531 ////////////////////////////////////////////////////////////////////////////////
532 /// user inserts discriminating variable in data set info
533 
534 void TMVA::Factory::AddVariable( const TString& expression, const TString& title, const TString& unit,
535  char type, Double_t min, Double_t max )
536 {
537  DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
538 }
539 
540 ////////////////////////////////////////////////////////////////////////////////
541 /// user inserts discriminating variable in data set info
542 
543 void TMVA::Factory::AddVariable( const TString& expression, char type,
545 {
546  DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type );
547 }
548 
549 ////////////////////////////////////////////////////////////////////////////////
550 /// user inserts target in data set info
551 
552 void TMVA::Factory::AddTarget( const TString& expression, const TString& title, const TString& unit,
554 {
555  if( fAnalysisType == Types::kNoAnalysisType )
556  fAnalysisType = Types::kRegression;
557 
558  DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
559 }
560 
561 ////////////////////////////////////////////////////////////////////////////////
562 /// user inserts target in data set info
563 
564 void TMVA::Factory::AddSpectator( const TString& expression, const TString& title, const TString& unit,
566 {
567  DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
568 }
569 
570 ////////////////////////////////////////////////////////////////////////////////
571 /// default creation
572 
574 {
575  return AddDataSet( "Default" );
576 }
577 
578 ////////////////////////////////////////////////////////////////////////////////
579 /// fill input variables in data set
580 
581 void TMVA::Factory::SetInputVariables( std::vector<TString>* theVariables )
582 {
583  for (std::vector<TString>::iterator it=theVariables->begin();
584  it!=theVariables->end(); it++) AddVariable(*it);
585 }
586 
587 ////////////////////////////////////////////////////////////////////////////////
588 
590 {
591  DefaultDataSetInfo().SetWeightExpression(variable, "Signal");
592 }
593 
594 ////////////////////////////////////////////////////////////////////////////////
595 
597 {
598  DefaultDataSetInfo().SetWeightExpression(variable, "Background");
599 }
600 
601 ////////////////////////////////////////////////////////////////////////////////
602 ///Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;
603 
604 void TMVA::Factory::SetWeightExpression( const TString& variable, const TString& className )
605 {
606  if (className=="") {
607  SetSignalWeightExpression(variable);
608  SetBackgroundWeightExpression(variable);
609  }
610  else DefaultDataSetInfo().SetWeightExpression( variable, className );
611 }
612 
613 ////////////////////////////////////////////////////////////////////////////////
614 
615 void TMVA::Factory::SetCut( const TString& cut, const TString& className ) {
616  SetCut( TCut(cut), className );
617 }
618 
619 ////////////////////////////////////////////////////////////////////////////////
620 
621 void TMVA::Factory::SetCut( const TCut& cut, const TString& className )
622 {
623  DefaultDataSetInfo().SetCut( cut, className );
624 }
625 
626 ////////////////////////////////////////////////////////////////////////////////
627 
628 void TMVA::Factory::AddCut( const TString& cut, const TString& className )
629 {
630  AddCut( TCut(cut), className );
631 }
632 
633 ////////////////////////////////////////////////////////////////////////////////
634 
635 void TMVA::Factory::AddCut( const TCut& cut, const TString& className )
636 {
637  DefaultDataSetInfo().AddCut( cut, className );
638 }
639 
640 ////////////////////////////////////////////////////////////////////////////////
641 /// prepare the training and test trees
642 
644  Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
645  const TString& otherOpt )
646 {
647  SetInputTreesFromEventAssignTrees();
648 
649  AddCut( cut );
650 
651  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
652  NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
653 }
654 
655 ////////////////////////////////////////////////////////////////////////////////
656 /// prepare the training and test trees
657 /// kept for backward compatibility
658 
660 {
661  SetInputTreesFromEventAssignTrees();
662 
663  AddCut( cut );
664 
665  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
666  Ntrain, Ntrain, Ntest, Ntest) );
667 }
668 
669 ////////////////////////////////////////////////////////////////////////////////
670 /// prepare the training and test trees
671 /// -> same cuts for signal and background
672 
674 {
675  SetInputTreesFromEventAssignTrees();
676 
677  DefaultDataSetInfo().PrintClasses();
678  AddCut( cut );
679  DefaultDataSetInfo().SetSplitOptions( opt );
680 }
681 
682 ////////////////////////////////////////////////////////////////////////////////
683 /// prepare the training and test trees
684 
685 void TMVA::Factory::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt )
686 {
687  // if event-wise data assignment, add local trees to dataset first
688  SetInputTreesFromEventAssignTrees();
689 
690  Log() << kINFO << "Preparing trees for training and testing..." << Endl;
691  AddCut( sigcut, "Signal" );
692  AddCut( bkgcut, "Background" );
693 
694  DefaultDataSetInfo().SetSplitOptions( splitOpt );
695 }
696 
697 ////////////////////////////////////////////////////////////////////////////////
698 /// Book a classifier or regression method
699 
700 TMVA::MethodBase* TMVA::Factory::BookMethod( TString theMethodName, TString methodTitle, TString theOption )
701 {
702  if( fAnalysisType == Types::kNoAnalysisType ){
703  if( DefaultDataSetInfo().GetNClasses()==2
704  && DefaultDataSetInfo().GetClassInfo("Signal") != NULL
705  && DefaultDataSetInfo().GetClassInfo("Background") != NULL
706  ){
707  fAnalysisType = Types::kClassification; // default is classification
708  } else if( DefaultDataSetInfo().GetNClasses() >= 2 ){
709  fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
710  } else
711  Log() << kFATAL << "No analysis type for " << DefaultDataSetInfo().GetNClasses() << " classes and "
712  << DefaultDataSetInfo().GetNTargets() << " regression targets." << Endl;
713  }
714 
715  // booking via name; the names are translated into enums and the
716  // corresponding overloaded BookMethod is called
717  if (GetMethod( methodTitle ) != 0) {
718  Log() << kFATAL << "Booking failed since method with title <"
719  << methodTitle <<"> already exists"
720  << Endl;
721  }
722 
723  Log() << kINFO << "Booking method: " << gTools().Color("bold") << methodTitle
724  << gTools().Color("reset") << Endl;
725 
726  // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
727  Int_t boostNum = 0;
728  TMVA::Configurable* conf = new TMVA::Configurable( theOption );
729  conf->DeclareOptionRef( boostNum = 0, "Boost_num",
730  "Number of times the classifier will be boosted" );
731  conf->ParseOptions();
732  delete conf;
733 
734  // initialize methods
735  IMethod* im;
736  if (!boostNum) {
737  im = ClassifierFactory::Instance().Create( std::string(theMethodName),
738  fJobName,
739  methodTitle,
740  DefaultDataSetInfo(),
741  theOption );
742  }
743  else {
744  // boosted classifier, requires a specific definition, making it transparent for the user
745  Log() << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
746  im = ClassifierFactory::Instance().Create( std::string("Boost"),
747  fJobName,
748  methodTitle,
749  DefaultDataSetInfo(),
750  theOption );
751  MethodBoost* methBoost = dynamic_cast<MethodBoost*>(im); // DSMTEST divided into two lines
752  if (!methBoost) // DSMTEST
753  Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
754  methBoost->SetBoostedMethodName( theMethodName ); // DSMTEST divided into two lines
755  methBoost->fDataSetManager = fDataSetManager; // DSMTEST
756 
757  }
758 
759  MethodBase *method = dynamic_cast<MethodBase*>(im);
760  if (method==0) return 0; // could not create method
761 
762  // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
763  if (method->GetMethodType() == Types::kCategory) { // DSMTEST
764  MethodCategory *methCat = (dynamic_cast<MethodCategory*>(im)); // DSMTEST
765  if (!methCat) // DSMTEST
766  Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
767  methCat->fDataSetManager = fDataSetManager; // DSMTEST
768  } // DSMTEST
769 
770 
771  if (!method->HasAnalysisType( fAnalysisType,
772  DefaultDataSetInfo().GetNClasses(),
773  DefaultDataSetInfo().GetNTargets() )) {
774  Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling " ;
775  if (fAnalysisType == Types::kRegression) {
776  Log() << "regression with " << DefaultDataSetInfo().GetNTargets() << " targets." << Endl;
777  }
778  else if (fAnalysisType == Types::kMulticlass ) {
779  Log() << "multiclass classification with " << DefaultDataSetInfo().GetNClasses() << " classes." << Endl;
780  }
781  else {
782  Log() << "classification with " << DefaultDataSetInfo().GetNClasses() << " classes." << Endl;
783  }
784  return 0;
785  }
786 
787 
788  method->SetAnalysisType( fAnalysisType );
789  method->SetupMethod();
790  method->ParseOptions();
791  method->ProcessSetup();
792 
793  // check-for-unused-options is performed; may be overridden by derived classes
794  method->CheckSetup();
795 
796  fMethods.push_back( method );
797 
798  return method;
799 }
800 
801 ////////////////////////////////////////////////////////////////////////////////
802 /// books MVA method; the option configuration string is custom for each MVA
803 /// the TString field "theNameAppendix" serves to define (and distinguish)
804 /// several instances of a given MVA, eg, when one wants to compare the
805 /// performance of various configurations
806 
808 {
809  return BookMethod( Types::Instance().GetMethodName( theMethod ), methodTitle, theOption );
810 }
811 
812 ////////////////////////////////////////////////////////////////////////////////
813 /// returns pointer to MVA that corresponds to given method title
814 
815 TMVA::IMethod* TMVA::Factory::GetMethod( const TString &methodTitle ) const
816 {
817  MVector::const_iterator itrMethod = fMethods.begin();
818  MVector::const_iterator itrMethodEnd = fMethods.end();
819  //
820  for (; itrMethod != itrMethodEnd; itrMethod++) {
821  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
822  if ( (mva->GetMethodName())==methodTitle ) return mva;
823  }
824  return 0;
825 }
826 
827 ////////////////////////////////////////////////////////////////////////////////
828 /// put correlations of input data and a few (default + user
829 /// selected) transformations into the root file
830 
832 {
833  RootBaseDir()->cd();
834 
835  DefaultDataSetInfo().GetDataSet(); // builds dataset (including calculation of correlation matrix)
836 
837 
838  // correlation matrix of the default DS
839  const TMatrixD* m(0);
840  const TH2* h(0);
841 
842  if(fAnalysisType == Types::kMulticlass){
843  for (UInt_t cls = 0; cls < DefaultDataSetInfo().GetNClasses() ; cls++) {
844  m = DefaultDataSetInfo().CorrelationMatrix(DefaultDataSetInfo().GetClassInfo(cls)->GetName());
845  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, TString("CorrelationMatrix")+DefaultDataSetInfo().GetClassInfo(cls)->GetName(),
846  "Correlation Matrix ("+ DefaultDataSetInfo().GetClassInfo(cls)->GetName() +TString(")"));
847  if (h!=0) {
848  h->Write();
849  delete h;
850  }
851  }
852  }
853  else{
854  m = DefaultDataSetInfo().CorrelationMatrix( "Signal" );
855  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
856  if (h!=0) {
857  h->Write();
858  delete h;
859  }
860 
861  m = DefaultDataSetInfo().CorrelationMatrix( "Background" );
862  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
863  if (h!=0) {
864  h->Write();
865  delete h;
866  }
867 
868  m = DefaultDataSetInfo().CorrelationMatrix( "Regression" );
869  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
870  if (h!=0) {
871  h->Write();
872  delete h;
873  }
874  }
875 
876  // some default transformations to evaluate
877  // NOTE: all transformations are destroyed after this test
878  TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
879 
880  // plus some user defined transformations
881  processTrfs = fTransformations;
882 
883  // remove any trace of identity transform - if given (avoid to apply it twice)
884  std::vector<TMVA::TransformationHandler*> trfs;
885  TransformationHandler* identityTrHandler = 0;
886 
887  std::vector<TString> trfsDef = gTools().SplitString(processTrfs,';');
888  std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
889  for (; trfsDefIt!=trfsDef.end(); trfsDefIt++) {
890  trfs.push_back(new TMVA::TransformationHandler(DefaultDataSetInfo(), "Factory"));
891  TString trfS = (*trfsDefIt);
892 
893  Log() << kINFO << Endl;
894  Log() << kINFO << "current transformation string: '" << trfS.Data() << "'" << Endl;
896  DefaultDataSetInfo(),
897  *(trfs.back()),
898  Log() );
899 
900  if (trfS.BeginsWith('I')) identityTrHandler = trfs.back();
901  }
902 
903  const std::vector<Event*>& inputEvents = DefaultDataSetInfo().GetDataSet()->GetEventCollection();
904 
905  // apply all transformations
906  std::vector<TMVA::TransformationHandler*>::iterator trfIt = trfs.begin();
907 
908  for (;trfIt != trfs.end(); trfIt++) {
909  // setting a Root dir causes the variables distributions to be saved to the root file
910  (*trfIt)->SetRootDir(RootBaseDir());
911  (*trfIt)->CalcTransformations(inputEvents);
912  }
913  if(identityTrHandler) identityTrHandler->PrintVariableRanking();
914 
915  // clean up
916  for (trfIt = trfs.begin(); trfIt != trfs.end(); trfIt++) delete *trfIt;
917 }
918 
919 ////////////////////////////////////////////////////////////////////////////////
920 /// iterates through all booked methods and sees if they use parameter tuning and if so..
921 /// does just that i.e. calls "Method::Train()" for different parameter setttings and
922 /// keeps in mind the "optimal one"... and that's the one that will later on be used
923 /// in the main training loop.
924 
926 {
927 
928  MVector::iterator itrMethod;
929 
930  // iterate over methods and optimize
931  for( itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
933  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
934  if (!mva) {
935  Log() << kFATAL << "Dynamic cast to MethodBase failed" <<Endl;
936  return;
937  }
938 
939  if (mva->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
940  Log() << kWARNING << "Method " << mva->GetMethodName()
941  << " not trained (training tree has less entries ["
942  << mva->Data()->GetNTrainingEvents()
943  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
944  continue;
945  }
946 
947  Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
948  << (fAnalysisType == Types::kRegression ? "Regression" :
949  (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
950 
951  mva->OptimizeTuningParameters(fomType,fitType);
952  Log() << kINFO << "Optimization of tuning paremters finished for Method:"<<mva->GetName() << Endl;
953  }
954 }
955 
956 ////////////////////////////////////////////////////////////////////////////////
957 /// iterates through all booked methods and calls training
958 
960 {
961  if(fDataInputHandler->GetEntries() <=1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
962  Log() << kFATAL << "No input data for the training provided!" << Endl;
963  }
964 
965  if(fAnalysisType == Types::kRegression && DefaultDataSetInfo().GetNTargets() < 1 )
966  Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
967  else if( (fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification)
968  && DefaultDataSetInfo().GetNClasses() < 2 )
969  Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
970 
971  // iterates over all MVAs that have been booked, and calls their training methods
972 
973  // first print some information about the default dataset
974  WriteDataInformation();
975 
976  // don't do anything if no method booked
977  if (fMethods.empty()) {
978  Log() << kINFO << "...nothing found to train" << Endl;
979  return;
980  }
981 
982  // here the training starts
983  Log() << kINFO << " " << Endl;
984  Log() << kINFO << "Train all methods for "
985  << (fAnalysisType == Types::kRegression ? "Regression" :
986  (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification") ) << " ..." << Endl;
987 
988  MVector::iterator itrMethod;
989 
990  // iterate over methods and train
991  for( itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
993  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
994  if(mva==0) continue;
995 
996  if (mva->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
997  Log() << kWARNING << "Method " << mva->GetMethodName()
998  << " not trained (training tree has less entries ["
999  << mva->Data()->GetNTrainingEvents()
1000  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1001  continue;
1002  }
1003 
1004  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
1005  << (fAnalysisType == Types::kRegression ? "Regression" :
1006  (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
1007  mva->TrainMethod();
1008  Log() << kINFO << "Training finished" << Endl;
1009  }
1010 
1011  if (fAnalysisType != Types::kRegression) {
1012 
1013  // variable ranking
1014  Log() << Endl;
1015  Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1016  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); itrMethod++) {
1017  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1018  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1019 
1020  // create and print ranking
1021  const Ranking* ranking = (*itrMethod)->CreateRanking();
1022  if (ranking != 0) ranking->Print();
1023  else Log() << kINFO << "No variable ranking supplied by classifier: "
1024  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
1025  }
1026  }
1027  }
1028 
1029  // delete all methods and recreate them from weight file - this ensures that the application
1030  // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1031  // in the testing
1032  Log() << Endl;
1033  if (RECREATE_METHODS) {
1034 
1035  Log() << kINFO << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1036 
1037  RootBaseDir()->cd();
1038 
1039  // iterate through all booked methods
1040  for (UInt_t i=0; i<fMethods.size(); i++) {
1041 
1042  MethodBase* m = dynamic_cast<MethodBase*>(fMethods[i]);
1043  if(m==0) continue;
1044 
1045  TMVA::Types::EMVA methodType = m->GetMethodType();
1046  TString weightfile = m->GetWeightFileName();
1047 
1048  // decide if .txt or .xml file should be read:
1049  if (READXML) weightfile.ReplaceAll(".txt",".xml");
1050 
1051  DataSetInfo& dataSetInfo = m->DataInfo();
1052  TString testvarName = m->GetTestvarName();
1053  delete m; //itrMethod[i];
1054 
1055  // recreate
1056  m = dynamic_cast<MethodBase*>( ClassifierFactory::Instance()
1057  .Create( std::string(Types::Instance().GetMethodName(methodType)),
1058  dataSetInfo, weightfile ) );
1059  if( m->GetMethodType() == Types::kCategory ){
1060  MethodCategory *methCat = (dynamic_cast<MethodCategory*>(m));
1061  if( !methCat ) Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1062  else methCat->fDataSetManager = fDataSetManager;
1063  }
1064  //ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1065 
1066  m->SetAnalysisType(fAnalysisType);
1067  m->SetupMethod();
1068  m->ReadStateFromFile();
1069  m->SetTestvarName(testvarName);
1070 
1071  // replace trained method by newly created one (from weight file) in methods vector
1072  fMethods[i] = m;
1073  }
1074  }
1075 }
1076 
1077 ////////////////////////////////////////////////////////////////////////////////
1078 
1080 {
1081  Log() << kINFO << "Test all methods..." << Endl;
1082 
1083  // don't do anything if no method booked
1084  if (fMethods.empty()) {
1085  Log() << kINFO << "...nothing found to test" << Endl;
1086  return;
1087  }
1088 
1089  // iterates over all MVAs that have been booked, and calls their testing methods
1090  // iterate over methods and test
1091  MVector::iterator itrMethod = fMethods.begin();
1092  MVector::iterator itrMethodEnd = fMethods.end();
1093  for (; itrMethod != itrMethodEnd; itrMethod++) {
1095  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1096  if(mva==0) continue;
1097  Types::EAnalysisType analysisType = mva->GetAnalysisType();
1098  Log() << kINFO << "Test method: " << mva->GetMethodName() << " for "
1099  << (analysisType == Types::kRegression ? "Regression" :
1100  (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << " performance" << Endl;
1101  mva->AddOutput( Types::kTesting, analysisType );
1102  }
1103 }
1104 
1105 ////////////////////////////////////////////////////////////////////////////////
1106 /// Print predefined help message of classifier
1107 /// iterate over methods and test
1108 
1109 void TMVA::Factory::MakeClass( const TString& methodTitle ) const
1110 {
1111  if (methodTitle != "") {
1112  IMethod* method = GetMethod( methodTitle );
1113  if (method) method->MakeClass();
1114  else {
1115  Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle
1116  << "\" in list" << Endl;
1117  }
1118  }
1119  else {
1120 
1121  // no classifier specified, print all hepl messages
1122  MVector::const_iterator itrMethod = fMethods.begin();
1123  MVector::const_iterator itrMethodEnd = fMethods.end();
1124  for (; itrMethod != itrMethodEnd; itrMethod++) {
1125  MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1126  if(method==0) continue;
1127  Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1128  method->MakeClass();
1129  }
1130  }
1131 }
1132 
1133 ////////////////////////////////////////////////////////////////////////////////
1134 /// Print predefined help message of classifier
1135 /// iterate over methods and test
1136 
1137 void TMVA::Factory::PrintHelpMessage( const TString& methodTitle ) const
1138 {
1139  if (methodTitle != "") {
1140  IMethod* method = GetMethod( methodTitle );
1141  if (method) method->PrintHelpMessage();
1142  else {
1143  Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle
1144  << "\" in list" << Endl;
1145  }
1146  }
1147  else {
1148 
1149  // no classifier specified, print all hepl messages
1150  MVector::const_iterator itrMethod = fMethods.begin();
1151  MVector::const_iterator itrMethodEnd = fMethods.end();
1152  for (; itrMethod != itrMethodEnd; itrMethod++) {
1153  MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1154  if(method==0) continue;
1155  Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1156  method->PrintHelpMessage();
1157  }
1158  }
1159 }
1160 
1161 ////////////////////////////////////////////////////////////////////////////////
1162 /// iterates over all MVA input varables and evaluates them
1163 
1165 {
1166  Log() << kINFO << "Evaluating all variables..." << Endl;
1168 
1169  for (UInt_t i=0; i<DefaultDataSetInfo().GetNVariables(); i++) {
1170  TString s = DefaultDataSetInfo().GetVariableInfo(i).GetLabel();
1171  if (options.Contains("V")) s += ":V";
1172  this->BookMethod( "Variable", s );
1173  }
1174 }
1175 
1176 ////////////////////////////////////////////////////////////////////////////////
1177 /// iterates over all MVAs that have been booked, and calls their evaluation methods
1178 
1180 {
1181  Log() << kINFO << "Evaluate all methods..." << Endl;
1182 
1183  // don't do anything if no method booked
1184  if (fMethods.empty()) {
1185  Log() << kINFO << "...nothing found to evaluate" << Endl;
1186  return;
1187  }
1188 
1189  // -----------------------------------------------------------------------
1190  // First part of evaluation process
1191  // --> compute efficiencies, and other separation estimators
1192  // -----------------------------------------------------------------------
1193 
1194  // although equal, we now want to seperate the outpuf for the variables
1195  // and the real methods
1196  Int_t isel; // will be 0 for a Method; 1 for a Variable
1197  Int_t nmeth_used[2] = {0,0}; // 0 Method; 1 Variable
1198 
1199  std::vector<std::vector<TString> > mname(2);
1200  std::vector<std::vector<Double_t> > sig(2), sep(2), roc(2);
1201  std::vector<std::vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
1202  std::vector<std::vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
1203  std::vector<std::vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
1204 
1205  std::vector<std::vector<Float_t> > multiclass_testEff;
1206  std::vector<std::vector<Float_t> > multiclass_trainEff;
1207  std::vector<std::vector<Float_t> > multiclass_testPur;
1208  std::vector<std::vector<Float_t> > multiclass_trainPur;
1209 
1210  std::vector<std::vector<Double_t> > biastrain(1); // "bias" of the regression on the training data
1211  std::vector<std::vector<Double_t> > biastest(1); // "bias" of the regression on test data
1212  std::vector<std::vector<Double_t> > devtrain(1); // "dev" of the regression on the training data
1213  std::vector<std::vector<Double_t> > devtest(1); // "dev" of the regression on test data
1214  std::vector<std::vector<Double_t> > rmstrain(1); // "rms" of the regression on the training data
1215  std::vector<std::vector<Double_t> > rmstest(1); // "rms" of the regression on test data
1216  std::vector<std::vector<Double_t> > minftrain(1); // "minf" of the regression on the training data
1217  std::vector<std::vector<Double_t> > minftest(1); // "minf" of the regression on test data
1218  std::vector<std::vector<Double_t> > rhotrain(1); // correlation of the regression on the training data
1219  std::vector<std::vector<Double_t> > rhotest(1); // correlation of the regression on test data
1220 
1221  // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1222  std::vector<std::vector<Double_t> > biastrainT(1);
1223  std::vector<std::vector<Double_t> > biastestT(1);
1224  std::vector<std::vector<Double_t> > devtrainT(1);
1225  std::vector<std::vector<Double_t> > devtestT(1);
1226  std::vector<std::vector<Double_t> > rmstrainT(1);
1227  std::vector<std::vector<Double_t> > rmstestT(1);
1228  std::vector<std::vector<Double_t> > minftrainT(1);
1229  std::vector<std::vector<Double_t> > minftestT(1);
1230 
1231  // following vector contains all methods - with the exception of Cuts, which are special
1232  MVector methodsNoCuts;
1233 
1234  Bool_t doRegression = kFALSE;
1235  Bool_t doMulticlass = kFALSE;
1236 
1237  // iterate over methods and evaluate
1238  MVector::iterator itrMethod = fMethods.begin();
1239  MVector::iterator itrMethodEnd = fMethods.end();
1240  for (; itrMethod != itrMethodEnd; itrMethod++) {
1242  MethodBase* theMethod = dynamic_cast<MethodBase*>(*itrMethod);
1243  if(theMethod==0) continue;
1244  if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
1245 
1246  if (theMethod->DoRegression()) {
1247  doRegression = kTRUE;
1248 
1249  Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1250  Double_t bias, dev, rms, mInf;
1251  Double_t biasT, devT, rmsT, mInfT;
1252  Double_t rho;
1253 
1254  theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting );
1255  biastest[0] .push_back( bias );
1256  devtest[0] .push_back( dev );
1257  rmstest[0] .push_back( rms );
1258  minftest[0] .push_back( mInf );
1259  rhotest[0] .push_back( rho );
1260  biastestT[0] .push_back( biasT );
1261  devtestT[0] .push_back( devT );
1262  rmstestT[0] .push_back( rmsT );
1263  minftestT[0] .push_back( mInfT );
1264 
1265  theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining );
1266  biastrain[0] .push_back( bias );
1267  devtrain[0] .push_back( dev );
1268  rmstrain[0] .push_back( rms );
1269  minftrain[0] .push_back( mInf );
1270  rhotrain[0] .push_back( rho );
1271  biastrainT[0].push_back( biasT );
1272  devtrainT[0] .push_back( devT );
1273  rmstrainT[0] .push_back( rmsT );
1274  minftrainT[0].push_back( mInfT );
1275 
1276  mname[0].push_back( theMethod->GetMethodName() );
1277  nmeth_used[0]++;
1278 
1279  Log() << kINFO << "Write evaluation histograms to file" << Endl;
1282  }
1283  else if (theMethod->DoMulticlass()) {
1284  doMulticlass = kTRUE;
1285  Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1286  Log() << kINFO << "Write evaluation histograms to file" << Endl;
1289 
1290  theMethod->TestMulticlass();
1291  multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1292 
1293  nmeth_used[0]++;
1294  mname[0].push_back( theMethod->GetMethodName() );
1295  }
1296  else {
1297 
1298  Log() << kINFO << "Evaluate classifier: " << theMethod->GetMethodName() << Endl;
1299  isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1300 
1301  // perform the evaluation
1302  theMethod->TestClassification();
1303 
1304  // evaluate the classifier
1305  mname[isel].push_back( theMethod->GetMethodName() );
1306  sig[isel].push_back ( theMethod->GetSignificance() );
1307  sep[isel].push_back ( theMethod->GetSeparation() );
1308  roc[isel].push_back ( theMethod->GetROCIntegral() );
1309 
1310  Double_t err;
1311  eff01[isel].push_back( theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err) );
1312  eff01err[isel].push_back( err );
1313  eff10[isel].push_back( theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err) );
1314  eff10err[isel].push_back( err );
1315  eff30[isel].push_back( theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err) );
1316  eff30err[isel].push_back( err );
1317  effArea[isel].push_back( theMethod->GetEfficiency("", Types::kTesting, err) ); // computes the area (average)
1318 
1319  trainEff01[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.01") ); // the first pass takes longer
1320  trainEff10[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.10") );
1321  trainEff30[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.30") );
1322 
1323  nmeth_used[isel]++;
1324 
1325  Log() << kINFO << "Write evaluation histograms to file" << Endl;
1328  }
1329  }
1330  if (doRegression) {
1331 
1332  std::vector<TString> vtemps = mname[0];
1333  std::vector< std::vector<Double_t> > vtmp;
1334  vtmp.push_back( devtest[0] ); // this is the vector that is ranked
1335  vtmp.push_back( devtrain[0] );
1336  vtmp.push_back( biastest[0] );
1337  vtmp.push_back( biastrain[0] );
1338  vtmp.push_back( rmstest[0] );
1339  vtmp.push_back( rmstrain[0] );
1340  vtmp.push_back( minftest[0] );
1341  vtmp.push_back( minftrain[0] );
1342  vtmp.push_back( rhotest[0] );
1343  vtmp.push_back( rhotrain[0] );
1344  vtmp.push_back( devtestT[0] ); // this is the vector that is ranked
1345  vtmp.push_back( devtrainT[0] );
1346  vtmp.push_back( biastestT[0] );
1347  vtmp.push_back( biastrainT[0]);
1348  vtmp.push_back( rmstestT[0] );
1349  vtmp.push_back( rmstrainT[0] );
1350  vtmp.push_back( minftestT[0] );
1351  vtmp.push_back( minftrainT[0]);
1352  gTools().UsefulSortAscending( vtmp, &vtemps );
1353  mname[0] = vtemps;
1354  devtest[0] = vtmp[0];
1355  devtrain[0] = vtmp[1];
1356  biastest[0] = vtmp[2];
1357  biastrain[0] = vtmp[3];
1358  rmstest[0] = vtmp[4];
1359  rmstrain[0] = vtmp[5];
1360  minftest[0] = vtmp[6];
1361  minftrain[0] = vtmp[7];
1362  rhotest[0] = vtmp[8];
1363  rhotrain[0] = vtmp[9];
1364  devtestT[0] = vtmp[10];
1365  devtrainT[0] = vtmp[11];
1366  biastestT[0] = vtmp[12];
1367  biastrainT[0] = vtmp[13];
1368  rmstestT[0] = vtmp[14];
1369  rmstrainT[0] = vtmp[15];
1370  minftestT[0] = vtmp[16];
1371  minftrainT[0] = vtmp[17];
1372  }
1373  else if (doMulticlass) {
1374  // TODO: fill in something meaningfull
1375 
1376  }
1377  else {
1378  // now sort the variables according to the best 'eff at Beff=0.10'
1379  for (Int_t k=0; k<2; k++) {
1380  std::vector< std::vector<Double_t> > vtemp;
1381  vtemp.push_back( effArea[k] ); // this is the vector that is ranked
1382  vtemp.push_back( eff10[k] );
1383  vtemp.push_back( eff01[k] );
1384  vtemp.push_back( eff30[k] );
1385  vtemp.push_back( eff10err[k] );
1386  vtemp.push_back( eff01err[k] );
1387  vtemp.push_back( eff30err[k] );
1388  vtemp.push_back( trainEff10[k] );
1389  vtemp.push_back( trainEff01[k] );
1390  vtemp.push_back( trainEff30[k] );
1391  vtemp.push_back( sig[k] );
1392  vtemp.push_back( sep[k] );
1393  vtemp.push_back( roc[k] );
1394  std::vector<TString> vtemps = mname[k];
1395  gTools().UsefulSortDescending( vtemp, &vtemps );
1396  effArea[k] = vtemp[0];
1397  eff10[k] = vtemp[1];
1398  eff01[k] = vtemp[2];
1399  eff30[k] = vtemp[3];
1400  eff10err[k] = vtemp[4];
1401  eff01err[k] = vtemp[5];
1402  eff30err[k] = vtemp[6];
1403  trainEff10[k] = vtemp[7];
1404  trainEff01[k] = vtemp[8];
1405  trainEff30[k] = vtemp[9];
1406  sig[k] = vtemp[10];
1407  sep[k] = vtemp[11];
1408  roc[k] = vtemp[12];
1409  mname[k] = vtemps;
1410  }
1411  }
1412 
1413  // -----------------------------------------------------------------------
1414  // Second part of evaluation process
1415  // --> compute correlations among MVAs
1416  // --> compute correlations between input variables and MVA (determines importsance)
1417  // --> count overlaps
1418  // -----------------------------------------------------------------------
1419 
1420  const Int_t nmeth = methodsNoCuts.size();
1421  const Int_t nvar = DefaultDataSetInfo().GetNVariables();
1422  if (!doRegression && !doMulticlass ) {
1423 
1424  if (nmeth > 0) {
1425 
1426  // needed for correlations
1427  Double_t *dvec = new Double_t[nmeth+nvar];
1428  std::vector<Double_t> rvec;
1429 
1430  // for correlations
1431  TPrincipal* tpSig = new TPrincipal( nmeth+nvar, "" );
1432  TPrincipal* tpBkg = new TPrincipal( nmeth+nvar, "" );
1433 
1434  // set required tree branch references
1435  Int_t ivar = 0;
1436  std::vector<TString>* theVars = new std::vector<TString>;
1437  std::vector<ResultsClassification*> mvaRes;
1438  for (itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); itrMethod++, ivar++) {
1439  MethodBase* m = dynamic_cast<MethodBase*>(*itrMethod);
1440  if(m==0) continue;
1441  theVars->push_back( m->GetTestvarName() );
1442  rvec.push_back( m->GetSignalReferenceCut() );
1443  theVars->back().ReplaceAll( "MVA_", "" );
1444  mvaRes.push_back( dynamic_cast<ResultsClassification*>( m->Data()->GetResults( m->GetMethodName(),
1445  Types::kTesting,
1447  }
1448 
1449  // for overlap study
1450  TMatrixD* overlapS = new TMatrixD( nmeth, nmeth );
1451  TMatrixD* overlapB = new TMatrixD( nmeth, nmeth );
1452  (*overlapS) *= 0; // init...
1453  (*overlapB) *= 0; // init...
1454 
1455  // loop over test tree
1456  DataSet* defDs = DefaultDataSetInfo().GetDataSet();
1458  for (Int_t ievt=0; ievt<defDs->GetNEvents(); ievt++) {
1459  const Event* ev = defDs->GetEvent(ievt);
1460 
1461  // for correlations
1462  TMatrixD* theMat = 0;
1463  for (Int_t im=0; im<nmeth; im++) {
1464  // check for NaN value
1465  Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1466  if (TMath::IsNaN(retval)) {
1467  Log() << kWARNING << "Found NaN return value in event: " << ievt
1468  << " for method \"" << methodsNoCuts[im]->GetName() << "\"" << Endl;
1469  dvec[im] = 0;
1470  }
1471  else dvec[im] = retval;
1472  }
1473  for (Int_t iv=0; iv<nvar; iv++) dvec[iv+nmeth] = (Double_t)ev->GetValue(iv);
1474  if (DefaultDataSetInfo().IsSignal(ev)) { tpSig->AddRow( dvec ); theMat = overlapS; }
1475  else { tpBkg->AddRow( dvec ); theMat = overlapB; }
1476 
1477  // count overlaps
1478  for (Int_t im=0; im<nmeth; im++) {
1479  for (Int_t jm=im; jm<nmeth; jm++) {
1480  if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) {
1481  (*theMat)(im,jm)++;
1482  if (im != jm) (*theMat)(jm,im)++;
1483  }
1484  }
1485  }
1486  }
1487 
1488  // renormalise overlap matrix
1489  (*overlapS) *= (1.0/defDs->GetNEvtSigTest()); // init...
1490  (*overlapB) *= (1.0/defDs->GetNEvtBkgdTest()); // init...
1491 
1492  tpSig->MakePrincipals();
1493  tpBkg->MakePrincipals();
1494 
1495  const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
1496  const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
1497 
1498  const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
1499  const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
1500 
1501  // print correlation matrices
1502  if (corrMatS != 0 && corrMatB != 0) {
1503 
1504  // extract MVA matrix
1505  TMatrixD mvaMatS(nmeth,nmeth);
1506  TMatrixD mvaMatB(nmeth,nmeth);
1507  for (Int_t im=0; im<nmeth; im++) {
1508  for (Int_t jm=0; jm<nmeth; jm++) {
1509  mvaMatS(im,jm) = (*corrMatS)(im,jm);
1510  mvaMatB(im,jm) = (*corrMatB)(im,jm);
1511  }
1512  }
1513 
1514  // extract variables - to MVA matrix
1515  std::vector<TString> theInputVars;
1516  TMatrixD varmvaMatS(nvar,nmeth);
1517  TMatrixD varmvaMatB(nvar,nmeth);
1518  for (Int_t iv=0; iv<nvar; iv++) {
1519  theInputVars.push_back( DefaultDataSetInfo().GetVariableInfo( iv ).GetLabel() );
1520  for (Int_t jm=0; jm<nmeth; jm++) {
1521  varmvaMatS(iv,jm) = (*corrMatS)(nmeth+iv,jm);
1522  varmvaMatB(iv,jm) = (*corrMatB)(nmeth+iv,jm);
1523  }
1524  }
1525 
1526  if (nmeth > 1) {
1527  Log() << kINFO << Endl;
1528  Log() << kINFO << "Inter-MVA correlation matrix (signal):" << Endl;
1529  gTools().FormattedOutput( mvaMatS, *theVars, Log() );
1530  Log() << kINFO << Endl;
1531 
1532  Log() << kINFO << "Inter-MVA correlation matrix (background):" << Endl;
1533  gTools().FormattedOutput( mvaMatB, *theVars, Log() );
1534  Log() << kINFO << Endl;
1535  }
1536 
1537  Log() << kINFO << "Correlations between input variables and MVA response (signal):" << Endl;
1538  gTools().FormattedOutput( varmvaMatS, theInputVars, *theVars, Log() );
1539  Log() << kINFO << Endl;
1540 
1541  Log() << kINFO << "Correlations between input variables and MVA response (background):" << Endl;
1542  gTools().FormattedOutput( varmvaMatB, theInputVars, *theVars, Log() );
1543  Log() << kINFO << Endl;
1544  }
1545  else Log() << kWARNING << "<TestAllMethods> cannot compute correlation matrices" << Endl;
1546 
1547  // print overlap matrices
1548  Log() << kINFO << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1549  Log() << kINFO << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1550  Log() << kINFO << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1551  gTools().FormattedOutput( rvec, *theVars, "Method" , "Cut value", Log() );
1552  Log() << kINFO << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1553 
1554  // give notice that cut method has been excluded from this test
1555  if (nmeth != (Int_t)fMethods.size())
1556  Log() << kINFO << "Note: no correlations and overlap with cut method are provided at present" << Endl;
1557 
1558  if (nmeth > 1) {
1559  Log() << kINFO << Endl;
1560  Log() << kINFO << "Inter-MVA overlap matrix (signal):" << Endl;
1561  gTools().FormattedOutput( *overlapS, *theVars, Log() );
1562  Log() << kINFO << Endl;
1563 
1564  Log() << kINFO << "Inter-MVA overlap matrix (background):" << Endl;
1565  gTools().FormattedOutput( *overlapB, *theVars, Log() );
1566  }
1567 
1568  // cleanup
1569  delete tpSig;
1570  delete tpBkg;
1571  delete corrMatS;
1572  delete corrMatB;
1573  delete theVars;
1574  delete overlapS;
1575  delete overlapB;
1576  delete [] dvec;
1577  }
1578  }
1579 
1580  // -----------------------------------------------------------------------
1581  // Third part of evaluation process
1582  // --> output
1583  // -----------------------------------------------------------------------
1584 
1585  if (doRegression) {
1586 
1587  Log() << kINFO << Endl;
1588  TString hLine = "-------------------------------------------------------------------------";
1589  Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1590  Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1591  Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1592  Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1593  Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1594  Log() << kINFO << hLine << Endl;
1595  Log() << kINFO << "MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1596  Log() << kINFO << hLine << Endl;
1597 
1598  for (Int_t i=0; i<nmeth_used[0]; i++) {
1599  Log() << kINFO << Form("%-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1600  (const char*)mname[0][i],
1601  biastest[0][i], biastestT[0][i],
1602  rmstest[0][i], rmstestT[0][i],
1603  minftest[0][i], minftestT[0][i] )
1604  << Endl;
1605  }
1606  Log() << kINFO << hLine << Endl;
1607  Log() << kINFO << Endl;
1608  Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1609  Log() << kINFO << "(overtraining check)" << Endl;
1610  Log() << kINFO << hLine << Endl;
1611  Log() << kINFO << "MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1612  Log() << kINFO << hLine << Endl;
1613 
1614  for (Int_t i=0; i<nmeth_used[0]; i++) {
1615  Log() << kINFO << Form("%-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1616  (const char*)mname[0][i],
1617  biastrain[0][i], biastrainT[0][i],
1618  rmstrain[0][i], rmstrainT[0][i],
1619  minftrain[0][i], minftrainT[0][i] )
1620  << Endl;
1621  }
1622  Log() << kINFO << hLine << Endl;
1623  Log() << kINFO << Endl;
1624  }
1625  else if( doMulticlass ){
1626  Log() << Endl;
1627  TString hLine = "--------------------------------------------------------------------------------";
1628  Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1629  Log() << kINFO << hLine << Endl;
1630  TString header= "MVA Method ";
1631  for(UInt_t icls = 0; icls<DefaultDataSetInfo().GetNClasses(); ++icls){
1632  header += Form("%-12s ",DefaultDataSetInfo().GetClassInfo(icls)->GetName().Data());
1633  }
1634  Log() << kINFO << header << Endl;
1635  Log() << kINFO << hLine << Endl;
1636  for (Int_t i=0; i<nmeth_used[0]; i++) {
1637  TString res = Form("%-15s",(const char*)mname[0][i]);
1638  for(UInt_t icls = 0; icls<DefaultDataSetInfo().GetNClasses(); ++icls){
1639  res += Form("%#1.3f ",(multiclass_testEff[i][icls])*(multiclass_testPur[i][icls]));
1640  }
1641  Log() << kINFO << res << Endl;
1642  }
1643  Log() << kINFO << hLine << Endl;
1644  Log() << kINFO << Endl;
1645 
1646  }
1647  else {
1648  Log() << Endl;
1649  TString hLine = "--------------------------------------------------------------------------------";
1650  Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
1651  Log() << kINFO << hLine << Endl;
1652  Log() << kINFO << "MVA Signal efficiency at bkg eff.(error): | Sepa- Signifi- " << Endl;
1653  Log() << kINFO << "Method: @B=0.01 @B=0.10 @B=0.30 ROC-integ. | ration: cance: " << Endl;
1654  Log() << kINFO << hLine << Endl;
1655  for (Int_t k=0; k<2; k++) {
1656  if (k == 1 && nmeth_used[k] > 0) {
1657  Log() << kINFO << hLine << Endl;
1658  Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
1659  }
1660  for (Int_t i=0; i<nmeth_used[k]; i++) {
1661  if (k == 1) mname[k][i].ReplaceAll( "Variable_", "" );
1662  if (sep[k][i] < 0 || sig[k][i] < 0) {
1663  // cannot compute separation/significance -> no MVA (usually for Cuts)
1664  Log() << kINFO << Form("%-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i) %#1.3f | -- --",
1665  (const char*)mname[k][i],
1666  eff01[k][i], Int_t(1000*eff01err[k][i]),
1667  eff10[k][i], Int_t(1000*eff10err[k][i]),
1668  eff30[k][i], Int_t(1000*eff30err[k][i]),
1669  effArea[k][i]) << Endl;
1670  }
1671  else {
1672  Log() << kINFO << Form("%-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i) %#1.3f | %#1.3f %#1.3f",
1673  (const char*)mname[k][i],
1674  eff01[k][i], Int_t(1000*eff01err[k][i]),
1675  eff10[k][i], Int_t(1000*eff10err[k][i]),
1676  eff30[k][i], Int_t(1000*eff30err[k][i]),
1677  effArea[k][i],
1678  sep[k][i], sig[k][i]) << Endl;
1679  }
1680  }
1681  }
1682  Log() << kINFO << hLine << Endl;
1683  Log() << kINFO << Endl;
1684  Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
1685  Log() << kINFO << hLine << Endl;
1686  Log() << kINFO << "MVA Signal efficiency: from test sample (from training sample) " << Endl;
1687  Log() << kINFO << "Method: @B=0.01 @B=0.10 @B=0.30 " << Endl;
1688  Log() << kINFO << hLine << Endl;
1689  for (Int_t k=0; k<2; k++) {
1690  if (k == 1 && nmeth_used[k] > 0) {
1691  Log() << kINFO << hLine << Endl;
1692  Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
1693  }
1694  for (Int_t i=0; i<nmeth_used[k]; i++) {
1695  if (k == 1) mname[k][i].ReplaceAll( "Variable_", "" );
1696  Log() << kINFO << Form("%-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
1697  (const char*)mname[k][i],
1698  eff01[k][i],trainEff01[k][i],
1699  eff10[k][i],trainEff10[k][i],
1700  eff30[k][i],trainEff30[k][i]) << Endl;
1701  }
1702  }
1703  Log() << kINFO << hLine << Endl;
1704  Log() << kINFO << Endl;
1705  }
1706 
1707  // write test tree
1708  RootBaseDir()->cd();
1709  DefaultDataSetInfo().GetDataSet()->GetTree(Types::kTesting) ->Write( "", TObject::kOverwrite );
1710  DefaultDataSetInfo().GetDataSet()->GetTree(Types::kTraining)->Write( "", TObject::kOverwrite );
1711 
1712  // references for citation
1714 }
1715 
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:823
Principal Components Analysis (PCA)
Definition: TPrincipal.h:28
void AddSignalTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: Factory.cxx:278
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:573
void SetInputTrees(const TString &signalFileName, const TString &backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0)
Definition: Factory.cxx:513
virtual void MakeClass(const TString &methodTitle="") const
Print predefined help message of classifier iterate over methods and test.
Definition: Factory.cxx:1109
void OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
iterates through all booked methods and sees if they use parameter tuning and if so.
Definition: Factory.cxx:925
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
static void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
create variable transformations
Definition: MethodBase.cxx:481
float Float_t
Definition: RtypesCore.h:53
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition: Tools.cxx:1333
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:635
const char * GetName() const
Returns name of object.
Definition: MethodBase.h:299
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
void SetSignalWeightExpression(const TString &variable)
Definition: Factory.cxx:589
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimzier with the set of paremeters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:626
Config & gConfig()
void SetInputVariables(std::vector< TString > *theVariables)
fill input variables in data set
Definition: Factory.cxx:581
TH1 * h
Definition: legend2.C:5
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:45
DataSet * Data() const
Definition: MethodBase.h:363
Double_t background(Double_t *x, Double_t *par)
EAnalysisType
Definition: Types.h:124
virtual void MakeClass(const TString &classFileName=TString("")) const =0
void AddBackgroundTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: Factory.cxx:294
#define gROOT
Definition: TROOT.h:340
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: Factory.cxx:564
Basic string class.
Definition: TString.h:137
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1088
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: Factory.cxx:534
void TrainAllMethods()
iterates through all booked methods and calls training
Definition: Factory.cxx:959
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual void TestMulticlass()
test multiclass classification
DataSetInfo & DefaultDataSetInfo()
default creation
Definition: Factory.cxx:573
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
void AddTrainingEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal training event
Definition: Factory.cxx:310
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:558
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: Factory.cxx:552
void SetSignalTree(TTree *signal, Double_t weight=1.0)
Definition: Factory.cxx:481
void WriteDataInformation()
put correlations of input data and a few (default + user selected) transformations into the root file...
Definition: Factory.cxx:831
const TString & GetMethodName() const
Definition: MethodBase.h:296
TString GetWeightFileName() const
retrieve weight file name
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:231
void AddTestEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal test event
Definition: Factory.cxx:318
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: Factory.cxx:452
#define RECREATE_METHODS
const char * Data() const
Definition: TString.h:349
void AddEvent(const TString &className, Types::ETreeType tt, const std::vector< Double_t > &event, Double_t weight)
add event vector event : the order of values is: variables + targets + spectators ...
Definition: Factory.cxx:327
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
Definition: Types.cxx:61
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:391
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:386
Tools & gTools()
Definition: Tools.cxx:79
#define READXML
TText * tt
Definition: textangle.C:16
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
Definition: Factory.cxx:815
ClassImp(TMVA::Factory) TMVA
standard constructor jobname : this name will appear in all weight file names produced by the MVAs th...
Definition: Factory.cxx:84
void ReadStateFromFile()
Function to write options and weights to file.
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
void EvaluateAllVariables(TString options="")
iterates over all MVA input varables and evaluates them
Definition: Factory.cxx:1164
Bool_t DoMulticlass() const
Definition: MethodBase.h:393
std::vector< std::vector< double > > Data
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:295
TTree * CreateEventAssignTrees(const TString &name)
create the data assignment tree (for event-wise data assignment by user)
Definition: Factory.cxx:245
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:70
Types::EMVA GetMethodType() const
Definition: MethodBase.h:298
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
Definition: Cppyy.cxx:701
A specialized string object used for TTree selections.
Definition: TCut.h:27
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:66
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:24
static void DestroyInstance()
Definition: Tools.cxx:95
Bool_t UserAssignEvents(UInt_t clIndex)
Definition: Factory.cxx:361
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:80
virtual ~Factory()
destructor delete fATreeEvent;
Definition: Factory.cxx:182
void AddBackgroundTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: Factory.cxx:302
void SetCut(const TString &cut, const TString &className="")
Definition: Factory.cxx:615
DataSetInfo & AddDataSet(DataSetInfo &)
Definition: Factory.cxx:223
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
TString info(resultsName+"/"); switch(type) { case Types::kTraining: info += "kTraining/"; break; cas...
Definition: DataSet.cxx:257
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:707
Service class for 2-Dim histogram classes.
Definition: TH2.h:36
MethodBase * BookMethod(TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:700
virtual void AddRow(const Double_t *x)
Begin_Html.
Definition: TPrincipal.cxx:372
SVector< double, 2 > v
Definition: Dict.h:5
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition: Tools.cxx:1310
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:390
void EvaluateAllMethods(void)
iterates over all MVAs that have been booked, and calls their evaluation methods
Definition: Factory.cxx:1179
void TestAllMethods()
Definition: Factory.cxx:1079
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
char * Form(const char *fmt,...)
DataSetManager * fDataSetManager
void AddCut(const TString &cut, const TString &className="")
Definition: Factory.cxx:628
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
void SetBackgroundWeightExpression(const TString &variable)
Definition: Factory.cxx:596
virtual Double_t GetSignificance() const
compute significance of mean difference significance = | - |/Sqrt(RMS_S2 + RMS_B2) ...
void Greetings()
print welcome message options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg ...
Definition: Factory.cxx:171
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:950
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:87
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:216
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:398
const Event * GetEvent() const
Definition: DataSet.cxx:180
void SetInputTreesFromEventAssignTrees()
assign event-wise local trees to data set
Definition: Factory.cxx:369
void PrintHelpMessage(const TString &methodTitle="") const
Print predefined help message of classifier iterate over methods and test.
Definition: Factory.cxx:1137
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:111
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
Definition: TTree.cxx:8095
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:320
double Double_t
Definition: RtypesCore.h:55
virtual void PrintHelpMessage() const =0
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification ...
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:337
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: Factory.cxx:401
int type
Definition: TGX11.cxx:120
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
static void DestroyInstance()
static function: destroy TMVA instance
Definition: Config.cxx:78
MsgLogger & Log() const
Definition: Configurable.h:130
DataSetInfo & DataInfo() const
Definition: MethodBase.h:364
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:547
RooCmdArg Verbose(Bool_t flag=kTRUE)
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition: Tools.cxx:1324
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:310
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
#define name(a, b)
Definition: linkTestLib0.cpp:5
void PrintVariableRanking() const
prints ranking of input variables
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:896
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: Factory.cxx:421
virtual Long64_t ReadFile(const char *filename, const char *branchDescriptor="", char delimiter= ' ')
Create or simply read branches from filename.
Definition: TTree.cxx:6820
void SetUseColor(Bool_t uc)
Definition: Config.h:61
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1624
Bool_t DoRegression() const
Definition: MethodBase.h:392
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as = (1/2) Int_-oo..+oo { (S(x) - B(x))^2/(S(x) + B(x)) dx } ...
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate of regression output versus "true" value from test sample ...
Definition: MethodBase.cxx:931
DataSetManager * fDataSetManager
Definition: MethodBoost.h:194
void SetWeightExpression(const TString &variable, const TString &className="")
Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;.
Definition: Factory.cxx:604
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
void PrintHelpMessage() const
prints out method-specific help method
void SetSilent(Bool_t s)
Definition: Config.h:64
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:567
Int_t IsNaN(Double_t x)
Definition: TMath.h:617
void DeleteAllMethods(void)
delete methods
Definition: Factory.cxx:204
#define NULL
Definition: Rtypes.h:82
virtual Double_t GetTrainingEfficiency(const TString &)
virtual Long64_t GetEntries() const
Definition: TTree.h:382
void AddSignalTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal testing event
Definition: Factory.cxx:286
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:90
A TTree object has a header with a name and a title.
Definition: TTree.h:94
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition: Tools.cxx:1448
std::vector< IMethod * > MVector
Definition: Factory.h:80
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:325
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:109
TString GetMethodTypeName() const
Definition: MethodBase.h:297
const Bool_t kTRUE
Definition: Rtypes.h:91
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:306
UInt_t GetNumber() const
Definition: ClassInfo.h:75
void SetTree(TTree *tree, const TString &className, Double_t weight)
set background tree
Definition: Factory.cxx:496
virtual void TestClassification()
initialization
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings ...
Definition: Tools.cxx:1207
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: Factory.cxx:673
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:390
Definition: math.cpp:60
const TString & GetTestvarName() const
Definition: MethodBase.h:300
void SetBackgroundTree(TTree *background, Double_t weight=1.0)
Definition: Factory.cxx:488