ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
Factory.cxx
Go to the documentation of this file.
1 // @(#)Root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Factory *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16  * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * U. of Victoria, Canada *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 //_______________________________________________________________________
34 //
35 // This is the main MVA steering class: it creates all MVA methods,
36 // and guides them through the training, testing and evaluation
37 // phases
38 //_______________________________________________________________________
39 
40 
41 #include "TROOT.h"
42 #include "TFile.h"
43 #include "TTree.h"
44 #include "TLeaf.h"
45 #include "TEventList.h"
46 #include "TH2.h"
47 #include "TText.h"
48 #include "TStyle.h"
49 #include "TMatrixF.h"
50 #include "TMatrixDSym.h"
51 #include "TPaletteAxis.h"
52 #include "TPrincipal.h"
53 #include "TMath.h"
54 #include "TObjString.h"
55 
56 #include "TMVA/Factory.h"
57 #include "TMVA/ClassifierFactory.h"
58 #include "TMVA/Config.h"
59 #include "TMVA/Configurable.h"
60 #include "TMVA/Tools.h"
61 #include "TMVA/Ranking.h"
62 #include "TMVA/DataSet.h"
63 #include "TMVA/IMethod.h"
64 #include "TMVA/MethodBase.h"
65 #include "TMVA/DataInputHandler.h"
66 #include "TMVA/DataSetManager.h"
67 #include "TMVA/DataSetInfo.h"
68 #include "TMVA/MethodBoost.h"
69 #include "TMVA/MethodCategory.h"
70 #include "TMVA/MsgLogger.h"
71 
74 #include "TMVA/VariableInfo.h"
78 
79 #include "TMVA/Results.h"
81 #include "TMVA/ResultsRegression.h"
82 #include "TMVA/ResultsMulticlass.h"
83 
84 #include "TMVA/Types.h"
85 
87 //const Int_t MinNoTestEvents = 1;
88 TFile* TMVA::Factory::fgTargetFile = 0;
89 
91 
92 #define RECREATE_METHODS kTRUE
93 #define READXML kTRUE
94 
95 ////////////////////////////////////////////////////////////////////////////////
96 /// standard constructor
97 /// jobname : this name will appear in all weight file names produced by the MVAs
98 /// theTargetFile : output ROOT file; the test tree and all evaluation plots
99 /// will be stored here
100 /// theOption : option string; currently: "V" for verbose
101 
102 TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
103 : Configurable ( theOption ),
104  fDataSetManager ( NULL ), //DSMTEST
105  fDataInputHandler ( new DataInputHandler ),
106  fTransformations ( "I" ),
107  fVerbose ( kFALSE ),
108  fJobName ( jobName ),
109  fDataAssignType ( kAssignEvents ),
110  fATreeEvent ( NULL ),
111  fAnalysisType ( Types::kClassification )
112 {
113  fgTargetFile = theTargetFile;
114 
115  // DataSetManager::CreateInstance(*fDataInputHandler); // DSMTEST removed
116  fDataSetManager = new DataSetManager( *fDataInputHandler ); // DSMTEST
117 
118 
119  // render silent
120  if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
121 
122 
123  // init configurable
124  SetConfigDescription( "Configuration options for Factory running" );
125  SetConfigName( GetName() );
126 
127  // histograms are not automatically associated with the current
128  // directory and hence don't go out of scope when closing the file
129  // TH1::AddDirectory(kFALSE);
130  Bool_t silent = kFALSE;
131 #ifdef WIN32
132  // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
133  Bool_t color = kFALSE;
134  Bool_t drawProgressBar = kFALSE;
135 #else
136  Bool_t color = !gROOT->IsBatch();
137  Bool_t drawProgressBar = kTRUE;
138 #endif
139  DeclareOptionRef( fVerbose, "V", "Verbose flag" );
140  DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
141  DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
142  DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
143  DeclareOptionRef( drawProgressBar,
144  "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
145 
146  TString analysisType("Auto");
147  DeclareOptionRef( analysisType,
148  "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
149  AddPreDefVal(TString("Classification"));
150  AddPreDefVal(TString("Regression"));
151  AddPreDefVal(TString("Multiclass"));
152  AddPreDefVal(TString("Auto"));
153 
154  ParseOptions();
155  CheckForUnusedOptions();
156 
157  if (Verbose()) Log().SetMinType( kVERBOSE );
158 
159  // global settings
160  gConfig().SetUseColor( color );
161  gConfig().SetSilent( silent );
162  gConfig().SetDrawProgressBar( drawProgressBar );
163 
164  analysisType.ToLower();
165  if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
166  else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
167  else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
168  else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
169 
170  Greetings();
171 }
172 
173 ////////////////////////////////////////////////////////////////////////////////
174 /// print welcome message
175 /// options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
176 
178 {
180  gTools().TMVAWelcomeMessage( Log(), gTools().kLogoWelcomeMsg );
181  gTools().TMVAVersionMessage( Log() ); Log() << Endl;
182 }
183 
184 ////////////////////////////////////////////////////////////////////////////////
185 /// destructor
186 /// delete fATreeEvent;
187 
189 {
190  std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
191  for (;trfIt != fDefaultTrfs.end(); trfIt++) delete (*trfIt);
192 
193  this->DeleteAllMethods();
194  delete fDataInputHandler;
195 
196  // destroy singletons
197  // DataSetManager::DestroyInstance(); // DSMTEST replaced by following line
198  delete fDataSetManager; // DSMTEST
199 
200  // problem with call of REGISTER_METHOD macro ...
201  // ClassifierFactory::DestroyInstance();
202  // Types::DestroyInstance();
205 }
206 
207 ////////////////////////////////////////////////////////////////////////////////
208 /// delete methods
209 
211 {
212  MVector::iterator itrMethod = fMethods.begin();
213  for (; itrMethod != fMethods.end(); itrMethod++) {
214  Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
215  delete (*itrMethod);
216  }
217  fMethods.clear();
218 }
219 
220 ////////////////////////////////////////////////////////////////////////////////
221 
223 {
224  fVerbose = v;
225 }
226 
227 ////////////////////////////////////////////////////////////////////////////////
228 
230 {
231  return fDataSetManager->AddDataSetInfo(dsi); // DSMTEST
232 }
233 
234 ////////////////////////////////////////////////////////////////////////////////
235 
237 {
238  DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST
239 
240  if (dsi!=0) return *dsi;
241 
242  return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
243 }
244 
245 // ________________________________________________
246 // the next functions are to assign events directly
247 
248 ////////////////////////////////////////////////////////////////////////////////
249 /// create the data assignment tree (for event-wise data assignment by user)
250 
252 {
253  TTree * assignTree = new TTree( name, name );
254  assignTree->SetDirectory(0);
255  assignTree->Branch( "type", &fATreeType, "ATreeType/I" );
256  assignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/F" );
257 
258  std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
259  std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
260  std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
261 
262  if (!fATreeEvent) fATreeEvent = new Float_t[vars.size()+tgts.size()+spec.size()];
263  // add variables
264  for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
265  TString vname = vars[ivar].GetExpression();
266  assignTree->Branch( vname, &(fATreeEvent[ivar]), vname + "/F" );
267  }
268  // add targets
269  for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
270  TString vname = tgts[itgt].GetExpression();
271  assignTree->Branch( vname, &(fATreeEvent[vars.size()+itgt]), vname + "/F" );
272  }
273  // add spectators
274  for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
275  TString vname = spec[ispc].GetExpression();
276  assignTree->Branch( vname, &(fATreeEvent[vars.size()+tgts.size()+ispc]), vname + "/F" );
277  }
278  return assignTree;
279 }
280 
281 ////////////////////////////////////////////////////////////////////////////////
282 /// add signal training event
283 
284 void TMVA::Factory::AddSignalTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
285 {
286  AddEvent( "Signal", Types::kTraining, event, weight );
287 }
288 
289 ////////////////////////////////////////////////////////////////////////////////
290 /// add signal testing event
291 
292 void TMVA::Factory::AddSignalTestEvent( const std::vector<Double_t>& event, Double_t weight )
293 {
294  AddEvent( "Signal", Types::kTesting, event, weight );
295 }
296 
297 ////////////////////////////////////////////////////////////////////////////////
298 /// add signal training event
299 
300 void TMVA::Factory::AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
301 {
302  AddEvent( "Background", Types::kTraining, event, weight );
303 }
304 
305 ////////////////////////////////////////////////////////////////////////////////
306 /// add signal training event
307 
308 void TMVA::Factory::AddBackgroundTestEvent( const std::vector<Double_t>& event, Double_t weight )
309 {
310  AddEvent( "Background", Types::kTesting, event, weight );
311 }
312 
313 ////////////////////////////////////////////////////////////////////////////////
314 /// add signal training event
315 
316 void TMVA::Factory::AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
317 {
318  AddEvent( className, Types::kTraining, event, weight );
319 }
320 
321 ////////////////////////////////////////////////////////////////////////////////
322 /// add signal test event
323 
324 void TMVA::Factory::AddTestEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
325 {
326  AddEvent( className, Types::kTesting, event, weight );
327 }
328 
329 ////////////////////////////////////////////////////////////////////////////////
330 /// add event
331 /// vector event : the order of values is: variables + targets + spectators
332 
334  const std::vector<Double_t>& event, Double_t weight )
335 {
336  ClassInfo* theClass = DefaultDataSetInfo().AddClass(className); // returns class (creates it if necessary)
337  UInt_t clIndex = theClass->GetNumber();
338 
339 
340  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
341  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
342  fAnalysisType = Types::kMulticlass;
343 
344 
345  if (clIndex>=fTrainAssignTree.size()) {
346  fTrainAssignTree.resize(clIndex+1, 0);
347  fTestAssignTree.resize(clIndex+1, 0);
348  }
349 
350  if (fTrainAssignTree[clIndex]==0) { // does not exist yet
351  fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) );
352  fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) );
353  }
354 
355  fATreeType = clIndex;
356  fATreeWeight = weight;
357  for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
358 
359  if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
360  else fTestAssignTree[clIndex]->Fill();
361 
362 }
363 
364 ////////////////////////////////////////////////////////////////////////////////
365 ///
366 
368 {
369  return fTrainAssignTree[clIndex]!=0;
370 }
371 
372 ////////////////////////////////////////////////////////////////////////////////
373 /// assign event-wise local trees to data set
374 
376 {
377  UInt_t size = fTrainAssignTree.size();
378  for(UInt_t i=0; i<size; i++) {
379  if(!UserAssignEvents(i)) continue;
380  const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
381  SetWeightExpression( "weight", className );
382  AddTree(fTrainAssignTree[i], className, 1.0, TCut(""), Types::kTraining );
383  AddTree(fTestAssignTree[i], className, 1.0, TCut(""), Types::kTesting );
384  }
385 }
386 
387 ////////////////////////////////////////////////////////////////////////////////
388 /// number of signal events (used to compute significance)
389 
390 void TMVA::Factory::AddTree( TTree* tree, const TString& className, Double_t weight,
391  const TCut& cut, const TString& treetype )
392 {
394  TString tmpTreeType = treetype; tmpTreeType.ToLower();
395  if (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
396  else if (tmpTreeType.Contains( "train" )) tt = Types::kTraining;
397  else if (tmpTreeType.Contains( "test" )) tt = Types::kTesting;
398  else {
399  Log() << kFATAL << "<AddTree> cannot interpret tree type: \"" << treetype
400  << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
401  }
402  AddTree( tree, className, weight, cut, tt );
403 }
404 
405 ////////////////////////////////////////////////////////////////////////////////
406 
407 void TMVA::Factory::AddTree( TTree* tree, const TString& className, Double_t weight,
408  const TCut& cut, Types::ETreeType tt )
409 {
410  if(!tree)
411  Log() << kFATAL << "Tree does not exist (empty pointer)." << Endl;
412 
413  DefaultDataSetInfo().AddClass( className );
414 
415  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
416  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
417  fAnalysisType = Types::kMulticlass;
418 
419  Log() << kINFO << "Add Tree " << tree->GetName() << " of type " << className
420  << " with " << tree->GetEntries() << " events" << Endl;
421  DataInput().AddTree( tree, className, weight, cut, tt );
422 }
423 
424 ////////////////////////////////////////////////////////////////////////////////
425 /// number of signal events (used to compute significance)
426 
428 {
429  AddTree( signal, "Signal", weight, TCut(""), treetype );
430 }
431 
432 ////////////////////////////////////////////////////////////////////////////////
433 /// add signal tree from text file
434 
436 {
437  // create trees from these ascii files
438  TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
439  signalTree->ReadFile( datFileS );
440 
441  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \""
442  << datFileS << Endl;
443 
444  // number of signal events (used to compute significance)
445  AddTree( signalTree, "Signal", weight, TCut(""), treetype );
446 }
447 
448 ////////////////////////////////////////////////////////////////////////////////
449 
450 void TMVA::Factory::AddSignalTree( TTree* signal, Double_t weight, const TString& treetype )
451 {
452  AddTree( signal, "Signal", weight, TCut(""), treetype );
453 }
454 
455 ////////////////////////////////////////////////////////////////////////////////
456 /// number of signal events (used to compute significance)
457 
459 {
460  AddTree( signal, "Background", weight, TCut(""), treetype );
461 }
462 ////////////////////////////////////////////////////////////////////////////////
463 /// add background tree from text file
464 
466 {
467  // create trees from these ascii files
468  TTree* bkgTree = new TTree( "TreeB", "Tree (B)" );
469  bkgTree->ReadFile( datFileB );
470 
471  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \""
472  << datFileB << Endl;
473 
474  // number of signal events (used to compute significance)
475  AddTree( bkgTree, "Background", weight, TCut(""), treetype );
476 }
477 
478 ////////////////////////////////////////////////////////////////////////////////
479 
480 void TMVA::Factory::AddBackgroundTree( TTree* signal, Double_t weight, const TString& treetype )
481 {
482  AddTree( signal, "Background", weight, TCut(""), treetype );
483 }
484 
485 ////////////////////////////////////////////////////////////////////////////////
486 
488 {
489  AddTree( tree, "Signal", weight );
490 }
491 
492 ////////////////////////////////////////////////////////////////////////////////
493 
495 {
496  AddTree( tree, "Background", weight );
497 }
498 
499 ////////////////////////////////////////////////////////////////////////////////
500 /// set background tree
501 
502 void TMVA::Factory::SetTree( TTree* tree, const TString& className, Double_t weight )
503 {
504  AddTree( tree, className, weight, TCut(""), Types::kMaxTreeType );
505 }
506 
507 ////////////////////////////////////////////////////////////////////////////////
508 /// define the input trees for signal and background; no cuts are applied
509 
511  Double_t signalWeight, Double_t backgroundWeight )
512 {
513  AddTree( signal, "Signal", signalWeight, TCut(""), Types::kMaxTreeType );
514  AddTree( background, "Background", backgroundWeight, TCut(""), Types::kMaxTreeType );
515 }
516 
517 ////////////////////////////////////////////////////////////////////////////////
518 
519 void TMVA::Factory::SetInputTrees( const TString& datFileS, const TString& datFileB,
520  Double_t signalWeight, Double_t backgroundWeight )
521 {
522  DataInput().AddTree( datFileS, "Signal", signalWeight );
523  DataInput().AddTree( datFileB, "Background", backgroundWeight );
524 }
525 
526 ////////////////////////////////////////////////////////////////////////////////
527 /// define the input trees for signal and background from single input tree,
528 /// containing both signal and background events distinguished by the type
529 /// identifiers: SigCut and BgCut
530 
531 void TMVA::Factory::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut )
532 {
533  AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType );
534  AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType );
535 }
536 
537 ////////////////////////////////////////////////////////////////////////////////
538 /// user inserts discriminating variable in data set info
539 
540 void TMVA::Factory::AddVariable( const TString& expression, const TString& title, const TString& unit,
541  char type, Double_t min, Double_t max )
542 {
543  DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
544 }
545 
546 ////////////////////////////////////////////////////////////////////////////////
547 /// user inserts discriminating variable in data set info
548 
549 void TMVA::Factory::AddVariable( const TString& expression, char type,
551 {
552  DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type );
553 }
554 
555 ////////////////////////////////////////////////////////////////////////////////
556 /// user inserts target in data set info
557 
558 void TMVA::Factory::AddTarget( const TString& expression, const TString& title, const TString& unit,
560 {
561  if( fAnalysisType == Types::kNoAnalysisType )
562  fAnalysisType = Types::kRegression;
563 
564  DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
565 }
566 
567 ////////////////////////////////////////////////////////////////////////////////
568 /// user inserts target in data set info
569 
570 void TMVA::Factory::AddSpectator( const TString& expression, const TString& title, const TString& unit,
572 {
573  DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
574 }
575 
576 ////////////////////////////////////////////////////////////////////////////////
577 /// default creation
578 
580 {
581  return AddDataSet( "Default" );
582 }
583 
584 ////////////////////////////////////////////////////////////////////////////////
585 /// fill input variables in data set
586 
587 void TMVA::Factory::SetInputVariables( std::vector<TString>* theVariables )
588 {
589  for (std::vector<TString>::iterator it=theVariables->begin();
590  it!=theVariables->end(); it++) AddVariable(*it);
591 }
592 
593 ////////////////////////////////////////////////////////////////////////////////
594 
596 {
597  DefaultDataSetInfo().SetWeightExpression(variable, "Signal");
598 }
599 
600 ////////////////////////////////////////////////////////////////////////////////
601 
603 {
604  DefaultDataSetInfo().SetWeightExpression(variable, "Background");
605 }
606 
607 ////////////////////////////////////////////////////////////////////////////////
608 ///Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;
609 
610 void TMVA::Factory::SetWeightExpression( const TString& variable, const TString& className )
611 {
612  if (className=="") {
613  SetSignalWeightExpression(variable);
614  SetBackgroundWeightExpression(variable);
615  }
616  else DefaultDataSetInfo().SetWeightExpression( variable, className );
617 }
618 
619 ////////////////////////////////////////////////////////////////////////////////
620 
621 void TMVA::Factory::SetCut( const TString& cut, const TString& className ) {
622  SetCut( TCut(cut), className );
623 }
624 
625 ////////////////////////////////////////////////////////////////////////////////
626 
627 void TMVA::Factory::SetCut( const TCut& cut, const TString& className )
628 {
629  DefaultDataSetInfo().SetCut( cut, className );
630 }
631 
632 ////////////////////////////////////////////////////////////////////////////////
633 
634 void TMVA::Factory::AddCut( const TString& cut, const TString& className )
635 {
636  AddCut( TCut(cut), className );
637 }
638 
639 ////////////////////////////////////////////////////////////////////////////////
640 
641 void TMVA::Factory::AddCut( const TCut& cut, const TString& className )
642 {
643  DefaultDataSetInfo().AddCut( cut, className );
644 }
645 
646 ////////////////////////////////////////////////////////////////////////////////
647 /// prepare the training and test trees
648 
650  Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
651  const TString& otherOpt )
652 {
653  SetInputTreesFromEventAssignTrees();
654 
655  AddCut( cut );
656 
657  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
658  NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
659 }
660 
661 ////////////////////////////////////////////////////////////////////////////////
662 /// prepare the training and test trees
663 /// kept for backward compatibility
664 
666 {
667  SetInputTreesFromEventAssignTrees();
668 
669  AddCut( cut );
670 
671  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
672  Ntrain, Ntrain, Ntest, Ntest) );
673 }
674 
675 ////////////////////////////////////////////////////////////////////////////////
676 /// prepare the training and test trees
677 /// -> same cuts for signal and background
678 
680 {
681  SetInputTreesFromEventAssignTrees();
682 
683  DefaultDataSetInfo().PrintClasses();
684  AddCut( cut );
685  DefaultDataSetInfo().SetSplitOptions( opt );
686 }
687 
688 ////////////////////////////////////////////////////////////////////////////////
689 /// prepare the training and test trees
690 
691 void TMVA::Factory::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt )
692 {
693  // if event-wise data assignment, add local trees to dataset first
694  SetInputTreesFromEventAssignTrees();
695 
696  Log() << kINFO << "Preparing trees for training and testing..." << Endl;
697  AddCut( sigcut, "Signal" );
698  AddCut( bkgcut, "Background" );
699 
700  DefaultDataSetInfo().SetSplitOptions( splitOpt );
701 }
702 
703 ////////////////////////////////////////////////////////////////////////////////
704 /// Book a classifier or regression method
705 
706 TMVA::MethodBase* TMVA::Factory::BookMethod( TString theMethodName, TString methodTitle, TString theOption )
707 {
708  if( fAnalysisType == Types::kNoAnalysisType ){
709  if( DefaultDataSetInfo().GetNClasses()==2
710  && DefaultDataSetInfo().GetClassInfo("Signal") != NULL
711  && DefaultDataSetInfo().GetClassInfo("Background") != NULL
712  ){
713  fAnalysisType = Types::kClassification; // default is classification
714  } else if( DefaultDataSetInfo().GetNClasses() >= 2 ){
715  fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
716  } else
717  Log() << kFATAL << "No analysis type for " << DefaultDataSetInfo().GetNClasses() << " classes and "
718  << DefaultDataSetInfo().GetNTargets() << " regression targets." << Endl;
719  }
720 
721  // booking via name; the names are translated into enums and the
722  // corresponding overloaded BookMethod is called
723  if (GetMethod( methodTitle ) != 0) {
724  Log() << kFATAL << "Booking failed since method with title <"
725  << methodTitle <<"> already exists"
726  << Endl;
727  }
728 
729  Log() << kINFO << "Booking method: " << gTools().Color("bold") << methodTitle
730  << gTools().Color("reset") << Endl;
731 
732  // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
733  Int_t boostNum = 0;
734  TMVA::Configurable* conf = new TMVA::Configurable( theOption );
735  conf->DeclareOptionRef( boostNum = 0, "Boost_num",
736  "Number of times the classifier will be boosted" );
737  conf->ParseOptions();
738  delete conf;
739 
740  // initialize methods
741  IMethod* im;
742  if (!boostNum) {
743  im = ClassifierFactory::Instance().Create( std::string(theMethodName),
744  fJobName,
745  methodTitle,
746  DefaultDataSetInfo(),
747  theOption );
748  }
749  else {
750  // boosted classifier, requires a specific definition, making it transparent for the user
751  Log() << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
752  im = ClassifierFactory::Instance().Create( std::string("Boost"),
753  fJobName,
754  methodTitle,
755  DefaultDataSetInfo(),
756  theOption );
757  MethodBoost* methBoost = dynamic_cast<MethodBoost*>(im); // DSMTEST divided into two lines
758  if (!methBoost) // DSMTEST
759  Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
760  methBoost->SetBoostedMethodName( theMethodName ); // DSMTEST divided into two lines
761  methBoost->fDataSetManager = fDataSetManager; // DSMTEST
762 
763  }
764 
765  MethodBase *method = dynamic_cast<MethodBase*>(im);
766  if (method==0) return 0; // could not create method
767 
768  // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
769  if (method->GetMethodType() == Types::kCategory) { // DSMTEST
770  MethodCategory *methCat = (dynamic_cast<MethodCategory*>(im)); // DSMTEST
771  if (!methCat) // DSMTEST
772  Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
773  methCat->fDataSetManager = fDataSetManager; // DSMTEST
774  } // DSMTEST
775 
776 
777  if (!method->HasAnalysisType( fAnalysisType,
778  DefaultDataSetInfo().GetNClasses(),
779  DefaultDataSetInfo().GetNTargets() )) {
780  Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling " ;
781  if (fAnalysisType == Types::kRegression) {
782  Log() << "regression with " << DefaultDataSetInfo().GetNTargets() << " targets." << Endl;
783  }
784  else if (fAnalysisType == Types::kMulticlass ) {
785  Log() << "multiclass classification with " << DefaultDataSetInfo().GetNClasses() << " classes." << Endl;
786  }
787  else {
788  Log() << "classification with " << DefaultDataSetInfo().GetNClasses() << " classes." << Endl;
789  }
790  return 0;
791  }
792 
793 
794  method->SetAnalysisType( fAnalysisType );
795  method->SetupMethod();
796  method->ParseOptions();
797  method->ProcessSetup();
798 
799  // check-for-unused-options is performed; may be overridden by derived classes
800  method->CheckSetup();
801 
802  fMethods.push_back( method );
803 
804  return method;
805 }
806 
807 ////////////////////////////////////////////////////////////////////////////////
808 /// books MVA method; the option configuration string is custom for each MVA
809 /// the TString field "theNameAppendix" serves to define (and distinguish)
810 /// several instances of a given MVA, eg, when one wants to compare the
811 /// performance of various configurations
812 
814 {
815  return BookMethod( Types::Instance().GetMethodName( theMethod ), methodTitle, theOption );
816 }
817 
818 ////////////////////////////////////////////////////////////////////////////////
819 /// returns pointer to MVA that corresponds to given method title
820 
821 TMVA::IMethod* TMVA::Factory::GetMethod( const TString &methodTitle ) const
822 {
823  MVector::const_iterator itrMethod = fMethods.begin();
824  MVector::const_iterator itrMethodEnd = fMethods.end();
825  //
826  for (; itrMethod != itrMethodEnd; itrMethod++) {
827  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
828  if ( (mva->GetMethodName())==methodTitle ) return mva;
829  }
830  return 0;
831 }
832 
833 ////////////////////////////////////////////////////////////////////////////////
834 /// put correlations of input data and a few (default + user
835 /// selected) transformations into the root file
836 
838 {
839  RootBaseDir()->cd();
840 
841  DefaultDataSetInfo().GetDataSet(); // builds dataset (including calculation of correlation matrix)
842 
843 
844  // correlation matrix of the default DS
845  const TMatrixD* m(0);
846  const TH2* h(0);
847 
848  if(fAnalysisType == Types::kMulticlass){
849  for (UInt_t cls = 0; cls < DefaultDataSetInfo().GetNClasses() ; cls++) {
850  m = DefaultDataSetInfo().CorrelationMatrix(DefaultDataSetInfo().GetClassInfo(cls)->GetName());
851  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, TString("CorrelationMatrix")+DefaultDataSetInfo().GetClassInfo(cls)->GetName(),
852  "Correlation Matrix ("+ DefaultDataSetInfo().GetClassInfo(cls)->GetName() +TString(")"));
853  if (h!=0) {
854  h->Write();
855  delete h;
856  }
857  }
858  }
859  else{
860  m = DefaultDataSetInfo().CorrelationMatrix( "Signal" );
861  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
862  if (h!=0) {
863  h->Write();
864  delete h;
865  }
866 
867  m = DefaultDataSetInfo().CorrelationMatrix( "Background" );
868  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
869  if (h!=0) {
870  h->Write();
871  delete h;
872  }
873 
874  m = DefaultDataSetInfo().CorrelationMatrix( "Regression" );
875  h = DefaultDataSetInfo().CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
876  if (h!=0) {
877  h->Write();
878  delete h;
879  }
880  }
881 
882  // some default transformations to evaluate
883  // NOTE: all transformations are destroyed after this test
884  TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
885 
886  // plus some user defined transformations
887  processTrfs = fTransformations;
888 
889  // remove any trace of identity transform - if given (avoid to apply it twice)
890  std::vector<TMVA::TransformationHandler*> trfs;
891  TransformationHandler* identityTrHandler = 0;
892 
893  std::vector<TString> trfsDef = gTools().SplitString(processTrfs,';');
894  std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
895  for (; trfsDefIt!=trfsDef.end(); trfsDefIt++) {
896  trfs.push_back(new TMVA::TransformationHandler(DefaultDataSetInfo(), "Factory"));
897  TString trfS = (*trfsDefIt);
898 
899  Log() << kINFO << Endl;
900  Log() << kINFO << "current transformation string: '" << trfS.Data() << "'" << Endl;
902  DefaultDataSetInfo(),
903  *(trfs.back()),
904  Log() );
905 
906  if (trfS.BeginsWith('I')) identityTrHandler = trfs.back();
907  }
908 
909  const std::vector<Event*>& inputEvents = DefaultDataSetInfo().GetDataSet()->GetEventCollection();
910 
911  // apply all transformations
912  std::vector<TMVA::TransformationHandler*>::iterator trfIt = trfs.begin();
913 
914  for (;trfIt != trfs.end(); trfIt++) {
915  // setting a Root dir causes the variables distributions to be saved to the root file
916  (*trfIt)->SetRootDir(RootBaseDir());
917  (*trfIt)->CalcTransformations(inputEvents);
918  }
919  if(identityTrHandler) identityTrHandler->PrintVariableRanking();
920 
921  // clean up
922  for (trfIt = trfs.begin(); trfIt != trfs.end(); trfIt++) delete *trfIt;
923 }
924 
925 ////////////////////////////////////////////////////////////////////////////////
926 /// iterates through all booked methods and sees if they use parameter tuning and if so..
927 /// does just that i.e. calls "Method::Train()" for different parameter setttings and
928 /// keeps in mind the "optimal one"... and that's the one that will later on be used
929 /// in the main training loop.
930 
932 {
933 
934  MVector::iterator itrMethod;
935 
936  // iterate over methods and optimize
937  for( itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
939  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
940  if (!mva) {
941  Log() << kFATAL << "Dynamic cast to MethodBase failed" <<Endl;
942  return;
943  }
944 
945  if (mva->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
946  Log() << kWARNING << "Method " << mva->GetMethodName()
947  << " not trained (training tree has less entries ["
948  << mva->Data()->GetNTrainingEvents()
949  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
950  continue;
951  }
952 
953  Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
954  << (fAnalysisType == Types::kRegression ? "Regression" :
955  (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
956 
957  mva->OptimizeTuningParameters(fomType,fitType);
958  Log() << kINFO << "Optimization of tuning paremters finished for Method:"<<mva->GetName() << Endl;
959  }
960 }
961 
962 ////////////////////////////////////////////////////////////////////////////////
963 /// iterates through all booked methods and calls training
964 
966 {
967  if(fDataInputHandler->GetEntries() <=1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
968  Log() << kFATAL << "No input data for the training provided!" << Endl;
969  }
970 
971  if(fAnalysisType == Types::kRegression && DefaultDataSetInfo().GetNTargets() < 1 )
972  Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
973  else if( (fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification)
974  && DefaultDataSetInfo().GetNClasses() < 2 )
975  Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
976 
977  // iterates over all MVAs that have been booked, and calls their training methods
978 
979  // first print some information about the default dataset
980  WriteDataInformation();
981 
982  // don't do anything if no method booked
983  if (fMethods.empty()) {
984  Log() << kINFO << "...nothing found to train" << Endl;
985  return;
986  }
987 
988  // here the training starts
989  Log() << kINFO << " " << Endl;
990  Log() << kINFO << "Train all methods for "
991  << (fAnalysisType == Types::kRegression ? "Regression" :
992  (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification") ) << " ..." << Endl;
993 
994  MVector::iterator itrMethod;
995 
996  // iterate over methods and train
997  for( itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
999  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1000  if(mva==0) continue;
1001 
1002  if (mva->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
1003  Log() << kWARNING << "Method " << mva->GetMethodName()
1004  << " not trained (training tree has less entries ["
1005  << mva->Data()->GetNTrainingEvents()
1006  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1007  continue;
1008  }
1009 
1010  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
1011  << (fAnalysisType == Types::kRegression ? "Regression" :
1012  (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
1013  mva->TrainMethod();
1014  Log() << kINFO << "Training finished" << Endl;
1015  }
1016 
1017  if (fAnalysisType != Types::kRegression) {
1018 
1019  // variable ranking
1020  Log() << Endl;
1021  Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1022  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); itrMethod++) {
1023  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1024  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1025 
1026  // create and print ranking
1027  const Ranking* ranking = (*itrMethod)->CreateRanking();
1028  if (ranking != 0) ranking->Print();
1029  else Log() << kINFO << "No variable ranking supplied by classifier: "
1030  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
1031  }
1032  }
1033  }
1034 
1035  // delete all methods and recreate them from weight file - this ensures that the application
1036  // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1037  // in the testing
1038  Log() << Endl;
1039  if (RECREATE_METHODS) {
1040 
1041  Log() << kINFO << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1042 
1043  RootBaseDir()->cd();
1044 
1045  // iterate through all booked methods
1046  for (UInt_t i=0; i<fMethods.size(); i++) {
1047 
1048  MethodBase* m = dynamic_cast<MethodBase*>(fMethods[i]);
1049  if(m==0) continue;
1050 
1051  TMVA::Types::EMVA methodType = m->GetMethodType();
1052  TString weightfile = m->GetWeightFileName();
1053 
1054  // decide if .txt or .xml file should be read:
1055  if (READXML) weightfile.ReplaceAll(".txt",".xml");
1056 
1057  DataSetInfo& dataSetInfo = m->DataInfo();
1058  TString testvarName = m->GetTestvarName();
1059  delete m; //itrMethod[i];
1060 
1061  // recreate
1062  m = dynamic_cast<MethodBase*>( ClassifierFactory::Instance()
1063  .Create( std::string(Types::Instance().GetMethodName(methodType)),
1064  dataSetInfo, weightfile ) );
1065  if( m->GetMethodType() == Types::kCategory ){
1066  MethodCategory *methCat = (dynamic_cast<MethodCategory*>(m));
1067  if( !methCat ) Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1068  else methCat->fDataSetManager = fDataSetManager;
1069  }
1070  //ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1071 
1072  m->SetAnalysisType(fAnalysisType);
1073  m->SetupMethod();
1074  m->ReadStateFromFile();
1075  m->SetTestvarName(testvarName);
1076 
1077  // replace trained method by newly created one (from weight file) in methods vector
1078  fMethods[i] = m;
1079  }
1080  }
1081 }
1082 
1083 ////////////////////////////////////////////////////////////////////////////////
1084 
1086 {
1087  Log() << kINFO << "Test all methods..." << Endl;
1088 
1089  // don't do anything if no method booked
1090  if (fMethods.empty()) {
1091  Log() << kINFO << "...nothing found to test" << Endl;
1092  return;
1093  }
1094 
1095  // iterates over all MVAs that have been booked, and calls their testing methods
1096  // iterate over methods and test
1097  MVector::iterator itrMethod = fMethods.begin();
1098  MVector::iterator itrMethodEnd = fMethods.end();
1099  for (; itrMethod != itrMethodEnd; itrMethod++) {
1101  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1102  if(mva==0) continue;
1103  Types::EAnalysisType analysisType = mva->GetAnalysisType();
1104  Log() << kINFO << "Test method: " << mva->GetMethodName() << " for "
1105  << (analysisType == Types::kRegression ? "Regression" :
1106  (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << " performance" << Endl;
1107  mva->AddOutput( Types::kTesting, analysisType );
1108  }
1109 }
1110 
1111 ////////////////////////////////////////////////////////////////////////////////
1112 /// Print predefined help message of classifier
1113 /// iterate over methods and test
1114 
1115 void TMVA::Factory::MakeClass( const TString& methodTitle ) const
1116 {
1117  if (methodTitle != "") {
1118  IMethod* method = GetMethod( methodTitle );
1119  if (method) method->MakeClass();
1120  else {
1121  Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle
1122  << "\" in list" << Endl;
1123  }
1124  }
1125  else {
1126 
1127  // no classifier specified, print all hepl messages
1128  MVector::const_iterator itrMethod = fMethods.begin();
1129  MVector::const_iterator itrMethodEnd = fMethods.end();
1130  for (; itrMethod != itrMethodEnd; itrMethod++) {
1131  MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1132  if(method==0) continue;
1133  Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1134  method->MakeClass();
1135  }
1136  }
1137 }
1138 
1139 ////////////////////////////////////////////////////////////////////////////////
1140 /// Print predefined help message of classifier
1141 /// iterate over methods and test
1142 
1143 void TMVA::Factory::PrintHelpMessage( const TString& methodTitle ) const
1144 {
1145  if (methodTitle != "") {
1146  IMethod* method = GetMethod( methodTitle );
1147  if (method) method->PrintHelpMessage();
1148  else {
1149  Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle
1150  << "\" in list" << Endl;
1151  }
1152  }
1153  else {
1154 
1155  // no classifier specified, print all hepl messages
1156  MVector::const_iterator itrMethod = fMethods.begin();
1157  MVector::const_iterator itrMethodEnd = fMethods.end();
1158  for (; itrMethod != itrMethodEnd; itrMethod++) {
1159  MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1160  if(method==0) continue;
1161  Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1162  method->PrintHelpMessage();
1163  }
1164  }
1165 }
1166 
1167 ////////////////////////////////////////////////////////////////////////////////
1168 /// iterates over all MVA input varables and evaluates them
1169 
1171 {
1172  Log() << kINFO << "Evaluating all variables..." << Endl;
1174 
1175  for (UInt_t i=0; i<DefaultDataSetInfo().GetNVariables(); i++) {
1176  TString s = DefaultDataSetInfo().GetVariableInfo(i).GetLabel();
1177  if (options.Contains("V")) s += ":V";
1178  this->BookMethod( "Variable", s );
1179  }
1180 }
1181 
1182 ////////////////////////////////////////////////////////////////////////////////
1183 /// iterates over all MVAs that have been booked, and calls their evaluation methods
1184 
1186 {
1187  Log() << kINFO << "Evaluate all methods..." << Endl;
1188 
1189  // don't do anything if no method booked
1190  if (fMethods.empty()) {
1191  Log() << kINFO << "...nothing found to evaluate" << Endl;
1192  return;
1193  }
1194 
1195  // -----------------------------------------------------------------------
1196  // First part of evaluation process
1197  // --> compute efficiencies, and other separation estimators
1198  // -----------------------------------------------------------------------
1199 
1200  // although equal, we now want to seperate the outpuf for the variables
1201  // and the real methods
1202  Int_t isel; // will be 0 for a Method; 1 for a Variable
1203  Int_t nmeth_used[2] = {0,0}; // 0 Method; 1 Variable
1204 
1205  std::vector<std::vector<TString> > mname(2);
1206  std::vector<std::vector<Double_t> > sig(2), sep(2), roc(2);
1207  std::vector<std::vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
1208  std::vector<std::vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
1209  std::vector<std::vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
1210 
1211  std::vector<std::vector<Float_t> > multiclass_testEff;
1212  std::vector<std::vector<Float_t> > multiclass_trainEff;
1213  std::vector<std::vector<Float_t> > multiclass_testPur;
1214  std::vector<std::vector<Float_t> > multiclass_trainPur;
1215 
1216  std::vector<std::vector<Double_t> > biastrain(1); // "bias" of the regression on the training data
1217  std::vector<std::vector<Double_t> > biastest(1); // "bias" of the regression on test data
1218  std::vector<std::vector<Double_t> > devtrain(1); // "dev" of the regression on the training data
1219  std::vector<std::vector<Double_t> > devtest(1); // "dev" of the regression on test data
1220  std::vector<std::vector<Double_t> > rmstrain(1); // "rms" of the regression on the training data
1221  std::vector<std::vector<Double_t> > rmstest(1); // "rms" of the regression on test data
1222  std::vector<std::vector<Double_t> > minftrain(1); // "minf" of the regression on the training data
1223  std::vector<std::vector<Double_t> > minftest(1); // "minf" of the regression on test data
1224  std::vector<std::vector<Double_t> > rhotrain(1); // correlation of the regression on the training data
1225  std::vector<std::vector<Double_t> > rhotest(1); // correlation of the regression on test data
1226 
1227  // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1228  std::vector<std::vector<Double_t> > biastrainT(1);
1229  std::vector<std::vector<Double_t> > biastestT(1);
1230  std::vector<std::vector<Double_t> > devtrainT(1);
1231  std::vector<std::vector<Double_t> > devtestT(1);
1232  std::vector<std::vector<Double_t> > rmstrainT(1);
1233  std::vector<std::vector<Double_t> > rmstestT(1);
1234  std::vector<std::vector<Double_t> > minftrainT(1);
1235  std::vector<std::vector<Double_t> > minftestT(1);
1236 
1237  // following vector contains all methods - with the exception of Cuts, which are special
1238  MVector methodsNoCuts;
1239 
1240  Bool_t doRegression = kFALSE;
1241  Bool_t doMulticlass = kFALSE;
1242 
1243  // iterate over methods and evaluate
1244  MVector::iterator itrMethod = fMethods.begin();
1245  MVector::iterator itrMethodEnd = fMethods.end();
1246  for (; itrMethod != itrMethodEnd; itrMethod++) {
1248  MethodBase* theMethod = dynamic_cast<MethodBase*>(*itrMethod);
1249  if(theMethod==0) continue;
1250  if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
1251 
1252  if (theMethod->DoRegression()) {
1253  doRegression = kTRUE;
1254 
1255  Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1256  Double_t bias, dev, rms, mInf;
1257  Double_t biasT, devT, rmsT, mInfT;
1258  Double_t rho;
1259 
1260  theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting );
1261  biastest[0] .push_back( bias );
1262  devtest[0] .push_back( dev );
1263  rmstest[0] .push_back( rms );
1264  minftest[0] .push_back( mInf );
1265  rhotest[0] .push_back( rho );
1266  biastestT[0] .push_back( biasT );
1267  devtestT[0] .push_back( devT );
1268  rmstestT[0] .push_back( rmsT );
1269  minftestT[0] .push_back( mInfT );
1270 
1271  theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining );
1272  biastrain[0] .push_back( bias );
1273  devtrain[0] .push_back( dev );
1274  rmstrain[0] .push_back( rms );
1275  minftrain[0] .push_back( mInf );
1276  rhotrain[0] .push_back( rho );
1277  biastrainT[0].push_back( biasT );
1278  devtrainT[0] .push_back( devT );
1279  rmstrainT[0] .push_back( rmsT );
1280  minftrainT[0].push_back( mInfT );
1281 
1282  mname[0].push_back( theMethod->GetMethodName() );
1283  nmeth_used[0]++;
1284 
1285  Log() << kINFO << "Write evaluation histograms to file" << Endl;
1288  }
1289  else if (theMethod->DoMulticlass()) {
1290  doMulticlass = kTRUE;
1291  Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1292  Log() << kINFO << "Write evaluation histograms to file" << Endl;
1295 
1296  theMethod->TestMulticlass();
1297  multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1298 
1299  nmeth_used[0]++;
1300  mname[0].push_back( theMethod->GetMethodName() );
1301  }
1302  else {
1303 
1304  Log() << kINFO << "Evaluate classifier: " << theMethod->GetMethodName() << Endl;
1305  isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1306 
1307  // perform the evaluation
1308  theMethod->TestClassification();
1309 
1310  // evaluate the classifier
1311  mname[isel].push_back( theMethod->GetMethodName() );
1312  sig[isel].push_back ( theMethod->GetSignificance() );
1313  sep[isel].push_back ( theMethod->GetSeparation() );
1314  roc[isel].push_back ( theMethod->GetROCIntegral() );
1315 
1316  Double_t err;
1317  eff01[isel].push_back( theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err) );
1318  eff01err[isel].push_back( err );
1319  eff10[isel].push_back( theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err) );
1320  eff10err[isel].push_back( err );
1321  eff30[isel].push_back( theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err) );
1322  eff30err[isel].push_back( err );
1323  effArea[isel].push_back( theMethod->GetEfficiency("", Types::kTesting, err) ); // computes the area (average)
1324 
1325  trainEff01[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.01") ); // the first pass takes longer
1326  trainEff10[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.10") );
1327  trainEff30[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.30") );
1328 
1329  nmeth_used[isel]++;
1330 
1331  Log() << kINFO << "Write evaluation histograms to file" << Endl;
1334  }
1335  }
1336  if (doRegression) {
1337 
1338  std::vector<TString> vtemps = mname[0];
1339  std::vector< std::vector<Double_t> > vtmp;
1340  vtmp.push_back( devtest[0] ); // this is the vector that is ranked
1341  vtmp.push_back( devtrain[0] );
1342  vtmp.push_back( biastest[0] );
1343  vtmp.push_back( biastrain[0] );
1344  vtmp.push_back( rmstest[0] );
1345  vtmp.push_back( rmstrain[0] );
1346  vtmp.push_back( minftest[0] );
1347  vtmp.push_back( minftrain[0] );
1348  vtmp.push_back( rhotest[0] );
1349  vtmp.push_back( rhotrain[0] );
1350  vtmp.push_back( devtestT[0] ); // this is the vector that is ranked
1351  vtmp.push_back( devtrainT[0] );
1352  vtmp.push_back( biastestT[0] );
1353  vtmp.push_back( biastrainT[0]);
1354  vtmp.push_back( rmstestT[0] );
1355  vtmp.push_back( rmstrainT[0] );
1356  vtmp.push_back( minftestT[0] );
1357  vtmp.push_back( minftrainT[0]);
1358  gTools().UsefulSortAscending( vtmp, &vtemps );
1359  mname[0] = vtemps;
1360  devtest[0] = vtmp[0];
1361  devtrain[0] = vtmp[1];
1362  biastest[0] = vtmp[2];
1363  biastrain[0] = vtmp[3];
1364  rmstest[0] = vtmp[4];
1365  rmstrain[0] = vtmp[5];
1366  minftest[0] = vtmp[6];
1367  minftrain[0] = vtmp[7];
1368  rhotest[0] = vtmp[8];
1369  rhotrain[0] = vtmp[9];
1370  devtestT[0] = vtmp[10];
1371  devtrainT[0] = vtmp[11];
1372  biastestT[0] = vtmp[12];
1373  biastrainT[0] = vtmp[13];
1374  rmstestT[0] = vtmp[14];
1375  rmstrainT[0] = vtmp[15];
1376  minftestT[0] = vtmp[16];
1377  minftrainT[0] = vtmp[17];
1378  }
1379  else if (doMulticlass) {
1380  // TODO: fill in something meaningfull
1381 
1382  }
1383  else {
1384  // now sort the variables according to the best 'eff at Beff=0.10'
1385  for (Int_t k=0; k<2; k++) {
1386  std::vector< std::vector<Double_t> > vtemp;
1387  vtemp.push_back( effArea[k] ); // this is the vector that is ranked
1388  vtemp.push_back( eff10[k] );
1389  vtemp.push_back( eff01[k] );
1390  vtemp.push_back( eff30[k] );
1391  vtemp.push_back( eff10err[k] );
1392  vtemp.push_back( eff01err[k] );
1393  vtemp.push_back( eff30err[k] );
1394  vtemp.push_back( trainEff10[k] );
1395  vtemp.push_back( trainEff01[k] );
1396  vtemp.push_back( trainEff30[k] );
1397  vtemp.push_back( sig[k] );
1398  vtemp.push_back( sep[k] );
1399  vtemp.push_back( roc[k] );
1400  std::vector<TString> vtemps = mname[k];
1401  gTools().UsefulSortDescending( vtemp, &vtemps );
1402  effArea[k] = vtemp[0];
1403  eff10[k] = vtemp[1];
1404  eff01[k] = vtemp[2];
1405  eff30[k] = vtemp[3];
1406  eff10err[k] = vtemp[4];
1407  eff01err[k] = vtemp[5];
1408  eff30err[k] = vtemp[6];
1409  trainEff10[k] = vtemp[7];
1410  trainEff01[k] = vtemp[8];
1411  trainEff30[k] = vtemp[9];
1412  sig[k] = vtemp[10];
1413  sep[k] = vtemp[11];
1414  roc[k] = vtemp[12];
1415  mname[k] = vtemps;
1416  }
1417  }
1418 
1419  // -----------------------------------------------------------------------
1420  // Second part of evaluation process
1421  // --> compute correlations among MVAs
1422  // --> compute correlations between input variables and MVA (determines importsance)
1423  // --> count overlaps
1424  // -----------------------------------------------------------------------
1425 
1426  const Int_t nmeth = methodsNoCuts.size();
1427  const Int_t nvar = DefaultDataSetInfo().GetNVariables();
1428  if (!doRegression && !doMulticlass ) {
1429 
1430  if (nmeth > 0) {
1431 
1432  // needed for correlations
1433  Double_t *dvec = new Double_t[nmeth+nvar];
1434  std::vector<Double_t> rvec;
1435 
1436  // for correlations
1437  TPrincipal* tpSig = new TPrincipal( nmeth+nvar, "" );
1438  TPrincipal* tpBkg = new TPrincipal( nmeth+nvar, "" );
1439 
1440  // set required tree branch references
1441  Int_t ivar = 0;
1442  std::vector<TString>* theVars = new std::vector<TString>;
1443  std::vector<ResultsClassification*> mvaRes;
1444  for (itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); itrMethod++, ivar++) {
1445  MethodBase* m = dynamic_cast<MethodBase*>(*itrMethod);
1446  if(m==0) continue;
1447  theVars->push_back( m->GetTestvarName() );
1448  rvec.push_back( m->GetSignalReferenceCut() );
1449  theVars->back().ReplaceAll( "MVA_", "" );
1450  mvaRes.push_back( dynamic_cast<ResultsClassification*>( m->Data()->GetResults( m->GetMethodName(),
1451  Types::kTesting,
1453  }
1454 
1455  // for overlap study
1456  TMatrixD* overlapS = new TMatrixD( nmeth, nmeth );
1457  TMatrixD* overlapB = new TMatrixD( nmeth, nmeth );
1458  (*overlapS) *= 0; // init...
1459  (*overlapB) *= 0; // init...
1460 
1461  // loop over test tree
1462  DataSet* defDs = DefaultDataSetInfo().GetDataSet();
1464  for (Int_t ievt=0; ievt<defDs->GetNEvents(); ievt++) {
1465  const Event* ev = defDs->GetEvent(ievt);
1466 
1467  // for correlations
1468  TMatrixD* theMat = 0;
1469  for (Int_t im=0; im<nmeth; im++) {
1470  // check for NaN value
1471  Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1472  if (TMath::IsNaN(retval)) {
1473  Log() << kWARNING << "Found NaN return value in event: " << ievt
1474  << " for method \"" << methodsNoCuts[im]->GetName() << "\"" << Endl;
1475  dvec[im] = 0;
1476  }
1477  else dvec[im] = retval;
1478  }
1479  for (Int_t iv=0; iv<nvar; iv++) dvec[iv+nmeth] = (Double_t)ev->GetValue(iv);
1480  if (DefaultDataSetInfo().IsSignal(ev)) { tpSig->AddRow( dvec ); theMat = overlapS; }
1481  else { tpBkg->AddRow( dvec ); theMat = overlapB; }
1482 
1483  // count overlaps
1484  for (Int_t im=0; im<nmeth; im++) {
1485  for (Int_t jm=im; jm<nmeth; jm++) {
1486  if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) {
1487  (*theMat)(im,jm)++;
1488  if (im != jm) (*theMat)(jm,im)++;
1489  }
1490  }
1491  }
1492  }
1493 
1494  // renormalise overlap matrix
1495  (*overlapS) *= (1.0/defDs->GetNEvtSigTest()); // init...
1496  (*overlapB) *= (1.0/defDs->GetNEvtBkgdTest()); // init...
1497 
1498  tpSig->MakePrincipals();
1499  tpBkg->MakePrincipals();
1500 
1501  const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
1502  const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
1503 
1504  const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
1505  const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
1506 
1507  // print correlation matrices
1508  if (corrMatS != 0 && corrMatB != 0) {
1509 
1510  // extract MVA matrix
1511  TMatrixD mvaMatS(nmeth,nmeth);
1512  TMatrixD mvaMatB(nmeth,nmeth);
1513  for (Int_t im=0; im<nmeth; im++) {
1514  for (Int_t jm=0; jm<nmeth; jm++) {
1515  mvaMatS(im,jm) = (*corrMatS)(im,jm);
1516  mvaMatB(im,jm) = (*corrMatB)(im,jm);
1517  }
1518  }
1519 
1520  // extract variables - to MVA matrix
1521  std::vector<TString> theInputVars;
1522  TMatrixD varmvaMatS(nvar,nmeth);
1523  TMatrixD varmvaMatB(nvar,nmeth);
1524  for (Int_t iv=0; iv<nvar; iv++) {
1525  theInputVars.push_back( DefaultDataSetInfo().GetVariableInfo( iv ).GetLabel() );
1526  for (Int_t jm=0; jm<nmeth; jm++) {
1527  varmvaMatS(iv,jm) = (*corrMatS)(nmeth+iv,jm);
1528  varmvaMatB(iv,jm) = (*corrMatB)(nmeth+iv,jm);
1529  }
1530  }
1531 
1532  if (nmeth > 1) {
1533  Log() << kINFO << Endl;
1534  Log() << kINFO << "Inter-MVA correlation matrix (signal):" << Endl;
1535  gTools().FormattedOutput( mvaMatS, *theVars, Log() );
1536  Log() << kINFO << Endl;
1537 
1538  Log() << kINFO << "Inter-MVA correlation matrix (background):" << Endl;
1539  gTools().FormattedOutput( mvaMatB, *theVars, Log() );
1540  Log() << kINFO << Endl;
1541  }
1542 
1543  Log() << kINFO << "Correlations between input variables and MVA response (signal):" << Endl;
1544  gTools().FormattedOutput( varmvaMatS, theInputVars, *theVars, Log() );
1545  Log() << kINFO << Endl;
1546 
1547  Log() << kINFO << "Correlations between input variables and MVA response (background):" << Endl;
1548  gTools().FormattedOutput( varmvaMatB, theInputVars, *theVars, Log() );
1549  Log() << kINFO << Endl;
1550  }
1551  else Log() << kWARNING << "<TestAllMethods> cannot compute correlation matrices" << Endl;
1552 
1553  // print overlap matrices
1554  Log() << kINFO << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1555  Log() << kINFO << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1556  Log() << kINFO << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1557  gTools().FormattedOutput( rvec, *theVars, "Method" , "Cut value", Log() );
1558  Log() << kINFO << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1559 
1560  // give notice that cut method has been excluded from this test
1561  if (nmeth != (Int_t)fMethods.size())
1562  Log() << kINFO << "Note: no correlations and overlap with cut method are provided at present" << Endl;
1563 
1564  if (nmeth > 1) {
1565  Log() << kINFO << Endl;
1566  Log() << kINFO << "Inter-MVA overlap matrix (signal):" << Endl;
1567  gTools().FormattedOutput( *overlapS, *theVars, Log() );
1568  Log() << kINFO << Endl;
1569 
1570  Log() << kINFO << "Inter-MVA overlap matrix (background):" << Endl;
1571  gTools().FormattedOutput( *overlapB, *theVars, Log() );
1572  }
1573 
1574  // cleanup
1575  delete tpSig;
1576  delete tpBkg;
1577  delete corrMatS;
1578  delete corrMatB;
1579  delete theVars;
1580  delete overlapS;
1581  delete overlapB;
1582  delete [] dvec;
1583  }
1584  }
1585 
1586  // -----------------------------------------------------------------------
1587  // Third part of evaluation process
1588  // --> output
1589  // -----------------------------------------------------------------------
1590 
1591  if (doRegression) {
1592 
1593  Log() << kINFO << Endl;
1594  TString hLine = "-------------------------------------------------------------------------";
1595  Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1596  Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1597  Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1598  Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1599  Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1600  Log() << kINFO << hLine << Endl;
1601  Log() << kINFO << "MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1602  Log() << kINFO << hLine << Endl;
1603 
1604  for (Int_t i=0; i<nmeth_used[0]; i++) {
1605  Log() << kINFO << Form("%-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1606  (const char*)mname[0][i],
1607  biastest[0][i], biastestT[0][i],
1608  rmstest[0][i], rmstestT[0][i],
1609  minftest[0][i], minftestT[0][i] )
1610  << Endl;
1611  }
1612  Log() << kINFO << hLine << Endl;
1613  Log() << kINFO << Endl;
1614  Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1615  Log() << kINFO << "(overtraining check)" << Endl;
1616  Log() << kINFO << hLine << Endl;
1617  Log() << kINFO << "MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1618  Log() << kINFO << hLine << Endl;
1619 
1620  for (Int_t i=0; i<nmeth_used[0]; i++) {
1621  Log() << kINFO << Form("%-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1622  (const char*)mname[0][i],
1623  biastrain[0][i], biastrainT[0][i],
1624  rmstrain[0][i], rmstrainT[0][i],
1625  minftrain[0][i], minftrainT[0][i] )
1626  << Endl;
1627  }
1628  Log() << kINFO << hLine << Endl;
1629  Log() << kINFO << Endl;
1630  }
1631  else if( doMulticlass ){
1632  Log() << Endl;
1633  TString hLine = "--------------------------------------------------------------------------------";
1634  Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1635  Log() << kINFO << hLine << Endl;
1636  TString header= "MVA Method ";
1637  for(UInt_t icls = 0; icls<DefaultDataSetInfo().GetNClasses(); ++icls){
1638  header += Form("%-12s ",DefaultDataSetInfo().GetClassInfo(icls)->GetName().Data());
1639  }
1640  Log() << kINFO << header << Endl;
1641  Log() << kINFO << hLine << Endl;
1642  for (Int_t i=0; i<nmeth_used[0]; i++) {
1643  TString res = Form("%-15s",(const char*)mname[0][i]);
1644  for(UInt_t icls = 0; icls<DefaultDataSetInfo().GetNClasses(); ++icls){
1645  res += Form("%#1.3f ",(multiclass_testEff[i][icls])*(multiclass_testPur[i][icls]));
1646  }
1647  Log() << kINFO << res << Endl;
1648  }
1649  Log() << kINFO << hLine << Endl;
1650  Log() << kINFO << Endl;
1651 
1652  }
1653  else {
1654  Log() << Endl;
1655  TString hLine = "--------------------------------------------------------------------------------";
1656  Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
1657  Log() << kINFO << hLine << Endl;
1658  Log() << kINFO << "MVA Signal efficiency at bkg eff.(error): | Sepa- Signifi- " << Endl;
1659  Log() << kINFO << "Method: @B=0.01 @B=0.10 @B=0.30 ROC-integ. | ration: cance: " << Endl;
1660  Log() << kINFO << hLine << Endl;
1661  for (Int_t k=0; k<2; k++) {
1662  if (k == 1 && nmeth_used[k] > 0) {
1663  Log() << kINFO << hLine << Endl;
1664  Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
1665  }
1666  for (Int_t i=0; i<nmeth_used[k]; i++) {
1667  if (k == 1) mname[k][i].ReplaceAll( "Variable_", "" );
1668  if (sep[k][i] < 0 || sig[k][i] < 0) {
1669  // cannot compute separation/significance -> no MVA (usually for Cuts)
1670  Log() << kINFO << Form("%-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i) %#1.3f | -- --",
1671  (const char*)mname[k][i],
1672  eff01[k][i], Int_t(1000*eff01err[k][i]),
1673  eff10[k][i], Int_t(1000*eff10err[k][i]),
1674  eff30[k][i], Int_t(1000*eff30err[k][i]),
1675  effArea[k][i]) << Endl;
1676  }
1677  else {
1678  Log() << kINFO << Form("%-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i) %#1.3f | %#1.3f %#1.3f",
1679  (const char*)mname[k][i],
1680  eff01[k][i], Int_t(1000*eff01err[k][i]),
1681  eff10[k][i], Int_t(1000*eff10err[k][i]),
1682  eff30[k][i], Int_t(1000*eff30err[k][i]),
1683  effArea[k][i],
1684  sep[k][i], sig[k][i]) << Endl;
1685  }
1686  }
1687  }
1688  Log() << kINFO << hLine << Endl;
1689  Log() << kINFO << Endl;
1690  Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
1691  Log() << kINFO << hLine << Endl;
1692  Log() << kINFO << "MVA Signal efficiency: from test sample (from training sample) " << Endl;
1693  Log() << kINFO << "Method: @B=0.01 @B=0.10 @B=0.30 " << Endl;
1694  Log() << kINFO << hLine << Endl;
1695  for (Int_t k=0; k<2; k++) {
1696  if (k == 1 && nmeth_used[k] > 0) {
1697  Log() << kINFO << hLine << Endl;
1698  Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
1699  }
1700  for (Int_t i=0; i<nmeth_used[k]; i++) {
1701  if (k == 1) mname[k][i].ReplaceAll( "Variable_", "" );
1702  Log() << kINFO << Form("%-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
1703  (const char*)mname[k][i],
1704  eff01[k][i],trainEff01[k][i],
1705  eff10[k][i],trainEff10[k][i],
1706  eff30[k][i],trainEff30[k][i]) << Endl;
1707  }
1708  }
1709  Log() << kINFO << hLine << Endl;
1710  Log() << kINFO << Endl;
1711  }
1712 
1713  // write test tree
1714  RootBaseDir()->cd();
1715  DefaultDataSetInfo().GetDataSet()->GetTree(Types::kTesting) ->Write( "", TObject::kOverwrite );
1716  DefaultDataSetInfo().GetDataSet()->GetTree(Types::kTraining)->Write( "", TObject::kOverwrite );
1717 
1718  // references for citation
1720 }
1721 
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:823
Principal Components Analysis (PCA)
Definition: TPrincipal.h:28
void AddSignalTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: Factory.cxx:284
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:573
void SetInputTrees(const TString &signalFileName, const TString &backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0)
Definition: Factory.cxx:519
virtual void MakeClass(const TString &methodTitle="") const
Print predefined help message of classifier iterate over methods and test.
Definition: Factory.cxx:1115
void OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
iterates through all booked methods and sees if they use parameter tuning and if so.
Definition: Factory.cxx:931
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:386
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
static void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
create variable transformations
Definition: MethodBase.cxx:488
float Float_t
Definition: RtypesCore.h:53
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition: Tools.cxx:1333
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:635
const char * GetName() const
Returns name of object.
Definition: MethodBase.h:299
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
void SetSignalWeightExpression(const TString &variable)
Definition: Factory.cxx:595
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimzier with the set of paremeters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:633
Config & gConfig()
void SetInputVariables(std::vector< TString > *theVariables)
fill input variables in data set
Definition: Factory.cxx:587
TH1 * h
Definition: legend2.C:5
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:45
DataSet * Data() const
Definition: MethodBase.h:363
EAnalysisType
Definition: Types.h:124
virtual void MakeClass(const TString &classFileName=TString("")) const =0
void AddBackgroundTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: Factory.cxx:300
#define gROOT
Definition: TROOT.h:344
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: Factory.cxx:570
Basic string class.
Definition: TString.h:137
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1075
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: Factory.cxx:540
void TrainAllMethods()
iterates through all booked methods and calls training
Definition: Factory.cxx:965
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual void TestMulticlass()
test multiclass classification
DataSetInfo & DefaultDataSetInfo()
default creation
Definition: Factory.cxx:579
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
void AddTrainingEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal training event
Definition: Factory.cxx:316
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:558
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: Factory.cxx:558
void SetSignalTree(TTree *signal, Double_t weight=1.0)
Definition: Factory.cxx:487
void WriteDataInformation()
put correlations of input data and a few (default + user selected) transformations into the root file...
Definition: Factory.cxx:837
const TString & GetMethodName() const
Definition: MethodBase.h:296
TString GetWeightFileName() const
retrieve weight file name
void AddTestEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal test event
Definition: Factory.cxx:324
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: Factory.cxx:458
#define RECREATE_METHODS
const char * Data() const
Definition: TString.h:349
void AddEvent(const TString &className, Types::ETreeType tt, const std::vector< Double_t > &event, Double_t weight)
add event vector event : the order of values is: variables + targets + spectators ...
Definition: Factory.cxx:333
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
Definition: Types.cxx:64
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:391
Tools & gTools()
Definition: Tools.cxx:79
#define READXML
TText * tt
Definition: textangle.C:16
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
Definition: Factory.cxx:821
ClassImp(TMVA::Factory) TMVA
standard constructor jobname : this name will appear in all weight file names produced by the MVAs th...
Definition: Factory.cxx:90
void ReadStateFromFile()
Function to write options and weights to file.
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
void EvaluateAllVariables(TString options="")
iterates over all MVA input varables and evaluates them
Definition: Factory.cxx:1170
Bool_t DoMulticlass() const
Definition: MethodBase.h:393
std::vector< std::vector< double > > Data
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:302
TTree * CreateEventAssignTrees(const TString &name)
create the data assignment tree (for event-wise data assignment by user)
Definition: Factory.cxx:251
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:70
Types::EMVA GetMethodType() const
Definition: MethodBase.h:298
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
Definition: Cppyy.cxx:700
A specialized string object used for TTree selections.
Definition: TCut.h:27
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:66
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:24
static void DestroyInstance()
Definition: Tools.cxx:95
Bool_t UserAssignEvents(UInt_t clIndex)
Definition: Factory.cxx:367
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:86
virtual ~Factory()
destructor delete fATreeEvent;
Definition: Factory.cxx:188
void AddBackgroundTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: Factory.cxx:308
void SetCut(const TString &cut, const TString &className="")
Definition: Factory.cxx:621
DataSetInfo & AddDataSet(DataSetInfo &)
Definition: Factory.cxx:229
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
TString info(resultsName+"/"); switch(type) { case Types::kTraining: info += "kTraining/"; break; cas...
Definition: DataSet.cxx:263
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:706
Service class for 2-Dim histogram classes.
Definition: TH2.h:36
MethodBase * BookMethod(TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:706
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
Definition: TPrincipal.cxx:410
SVector< double, 2 > v
Definition: Dict.h:5
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition: Tools.cxx:1310
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:396
TPaveLabel title(3, 27.1, 15, 28.7,"ROOT Environment and Tools")
void EvaluateAllMethods(void)
iterates over all MVAs that have been booked, and calls their evaluation methods
Definition: Factory.cxx:1185
void TestAllMethods()
Definition: Factory.cxx:1085
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
char * Form(const char *fmt,...)
DataSetManager * fDataSetManager
void AddCut(const TString &cut, const TString &className="")
Definition: Factory.cxx:634
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
void SetBackgroundWeightExpression(const TString &variable)
Definition: Factory.cxx:602
virtual Double_t GetSignificance() const
compute significance of mean difference significance = |<S> - |/Sqrt(RMS_S2 + RMS_B2) ...
void Greetings()
print welcome message options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg ...
Definition: Factory.cxx:177
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:862
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:87
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:222
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:404
const Event * GetEvent() const
Definition: DataSet.cxx:186
void SetInputTreesFromEventAssignTrees()
assign event-wise local trees to data set
Definition: Factory.cxx:375
void PrintHelpMessage(const TString &methodTitle="") const
Print predefined help message of classifier iterate over methods and test.
Definition: Factory.cxx:1143
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:111
tuple tree
Definition: tree.py:24
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
Definition: TTree.cxx:8064
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:327
double Double_t
Definition: RtypesCore.h:55
virtual void PrintHelpMessage() const =0
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification ...
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:337
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: Factory.cxx:407
int type
Definition: TGX11.cxx:120
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:231
static void DestroyInstance()
static function: destroy TMVA instance
Definition: Config.cxx:81
MsgLogger & Log() const
Definition: Configurable.h:130
DataSetInfo & DataInfo() const
Definition: MethodBase.h:364
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:547
RooCmdArg Verbose(Bool_t flag=kTRUE)
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition: Tools.cxx:1324
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:317
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
#define name(a, b)
Definition: linkTestLib0.cpp:5
void PrintVariableRanking() const
prints ranking of input variables
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:896
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: Factory.cxx:427
virtual Long64_t ReadFile(const char *filename, const char *branchDescriptor="", char delimiter= ' ')
Create or simply read branches from filename.
Definition: TTree.cxx:6801
void SetUseColor(Bool_t uc)
Definition: Config.h:61
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1623
Bool_t DoRegression() const
Definition: MethodBase.h:392
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as <s2> = (1/2) Int_-oo..+oo { (S(x) - B(x))^2/(S(x) + B(x)) dx } ...
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample ...
Definition: MethodBase.cxx:938
DataSetManager * fDataSetManager
Definition: MethodBoost.h:194
void SetWeightExpression(const TString &variable, const TString &className="")
Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;.
Definition: Factory.cxx:610
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
void PrintHelpMessage() const
prints out method-specific help method
void SetSilent(Bool_t s)
Definition: Config.h:64
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:567
Int_t IsNaN(Double_t x)
Definition: TMath.h:617
void DeleteAllMethods(void)
delete methods
Definition: Factory.cxx:210
#define NULL
Definition: Rtypes.h:82
virtual Double_t GetTrainingEfficiency(const TString &)
virtual Long64_t GetEntries() const
Definition: TTree.h:386
void AddSignalTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal testing event
Definition: Factory.cxx:292
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:90
A TTree object has a header with a name and a title.
Definition: TTree.h:98
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition: Tools.cxx:1448
std::vector< IMethod * > MVector
Definition: Factory.h:80
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:325
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
TString GetMethodTypeName() const
Definition: MethodBase.h:297
const Bool_t kTRUE
Definition: Rtypes.h:91
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:306
UInt_t GetNumber() const
Definition: ClassInfo.h:75
void SetTree(TTree *tree, const TString &className, Double_t weight)
set background tree
Definition: Factory.cxx:502
virtual void TestClassification()
initialization
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings ...
Definition: Tools.cxx:1207
TH1F * background
Definition: fithist.C:4
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: Factory.cxx:679
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:390
Definition: math.cpp:60
const TString & GetTestvarName() const
Definition: MethodBase.h:300
void SetBackgroundTree(TTree *background, Double_t weight=1.0)
Definition: Factory.cxx:494