Logo ROOT  
Reference Guide
Factory.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer
4
5/**********************************************************************************
6 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7 * Package: TMVA *
8 * Class : Factory *
9 * Web : http://tmva.sourceforge.net *
10 * *
11 * Description: *
12 * This is the main MVA steering class: it creates (books) all MVA methods, *
13 * and guides them through the training, testing and evaluation phases. *
14 * *
15 * Authors (alphabetical): *
16 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
18 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
22 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
23 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
24 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
25 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
26 * *
27 * Copyright (c) 2005-2011: *
28 * CERN, Switzerland *
29 * U. of Victoria, Canada *
30 * MPI-K Heidelberg, Germany *
31 * U. of Bonn, Germany *
32 * UdeA/ITM, Colombia *
33 * U. of Florida, USA *
34 * *
35 * Redistribution and use in source and binary forms, with or without *
36 * modification, are permitted according to the terms listed in LICENSE *
37 * (http://tmva.sourceforge.net/LICENSE) *
38 **********************************************************************************/
39
40#ifndef ROOT_TMVA_Factory
41#define ROOT_TMVA_Factory
42
43//////////////////////////////////////////////////////////////////////////
44// //
45// Factory //
46// //
47// This is the main MVA steering class: it creates all MVA methods, //
48// and guides them through the training, testing and evaluation //
49// phases //
50// //
51//////////////////////////////////////////////////////////////////////////
52
53#include <string>
54#include <vector>
55#include <map>
56#include "TCut.h"
57
58#include "TMVA/Configurable.h"
59#include "TMVA/Types.h"
60#include "TMVA/DataSet.h"
61
62class TCanvas;
63class TDirectory;
64class TFile;
65class TGraph;
66class TH1F;
67class TMultiGraph;
68class TTree;
69namespace TMVA {
70
71 class IMethod;
72 class MethodBase;
73 class DataInputHandler;
74 class DataSetInfo;
75 class DataSetManager;
76 class DataLoader;
77 class ROCCurve;
78 class VariableTransformBase;
79
80
81 class Factory : public Configurable {
82 friend class CrossValidation;
83 public:
84
85 typedef std::vector<IMethod*> MVector;
86 std::map<TString,MVector*> fMethodsMap;//all methods for every dataset with the same name
87
88 // no default constructor
89 Factory( TString theJobName, TFile* theTargetFile, TString theOption = "" );
90
91 // contructor to work without file
92 Factory( TString theJobName, TString theOption = "" );
93
94 // default destructor
95 virtual ~Factory();
96
97 // use TName::GetName and define correct name in constructor
98 //virtual const char* GetName() const { return "Factory"; }
99
100
101 MethodBase* BookMethod( DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption = "" );
102 MethodBase* BookMethod( DataLoader *loader, Types::EMVA theMethod, TString methodTitle, TString theOption = "" );
104 TString /*methodTitle*/,
105 TString /*methodOption*/,
106 TMVA::Types::EMVA /*theComposite*/,
107 TString /*compositeOption = ""*/ ) { return 0; }
108
109 // optimize all booked methods (well, if desired by the method)
110 std::map<TString,Double_t> OptimizeAllMethods (TString fomType="ROCIntegral", TString fitType="FitGA");
111 void OptimizeAllMethodsForClassification(TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
112 void OptimizeAllMethodsForRegression (TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
113
114 // training for all booked methods
115 void TrainAllMethods ();
118
119 // testing
120 void TestAllMethods();
121
122 // performance evaluation
123 void EvaluateAllMethods( void );
124 void EvaluateAllVariables(DataLoader *loader, TString options = "" );
125
126 TH1F* EvaluateImportance( DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
127
128 // delete all methods and reset the method vector
129 void DeleteAllMethods( void );
130
131 // accessors
132 IMethod* GetMethod( const TString& datasetname, const TString& title ) const;
133 Bool_t HasMethod( const TString& datasetname, const TString& title ) const;
134
135 Bool_t Verbose( void ) const { return fVerbose; }
136 void SetVerbose( Bool_t v=kTRUE );
137
138 // make ROOT-independent C++ class for classifier response
139 // (classifier-specific implementation)
140 // If no classifier name is given, help messages for all booked
141 // classifiers are printed
142 virtual void MakeClass(const TString& datasetname , const TString& methodTitle = "" ) const;
143
144 // prints classifier-specific hepl messages, dedicated to
145 // help with the optimisation and configuration options tuning.
146 // If no classifier name is given, help messages for all booked
147 // classifiers are printed
148 void PrintHelpMessage(const TString& datasetname , const TString& methodTitle = "" ) const;
149
151
152 Bool_t IsSilentFile() const { return fSilentFile;}
154
155 Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass = 0);
156 Double_t GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass = 0);
157
158 // Methods to get a TGraph for an indicated method in dataset.
159 // Optional title and axis added with fLegend=kTRUE.
160 // Argument iClass used in multiclass settings, otherwise ignored.
161 TGraph* GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0);
162 TGraph* GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0);
163
164 // Methods to get a TMultiGraph for a given class and all methods in dataset.
167
168 // Draw all ROC curves of a given class for all methods in the dataset.
169 TCanvas* GetROCCurve(DataLoader *loader, UInt_t iClass=0);
170 TCanvas* GetROCCurve(TString datasetname, UInt_t iClass=0);
171
172 private:
173
174 // the beautiful greeting message
175 void Greetings();
176
177 //evaluate the simple case that is removing 1 variable at time
178 TH1F* EvaluateImportanceShort( DataLoader *loader,Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
179 //evaluate all variables combinations
180 TH1F* EvaluateImportanceAll( DataLoader *loader,Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
181 //evaluate randomly given a number of seeds
182 TH1F* EvaluateImportanceRandom( DataLoader *loader,UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
183
184 TH1F* GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames);
185
186 // Helpers for public facing ROC methods
187 ROCCurve *GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass = 0,
189 ROCCurve *GetROC(TString datasetname, TString theMethodName, UInt_t iClass = 0,
191
192 void WriteDataInformation(DataSetInfo& fDataSetInfo);
193
195
196 MethodBase* BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile);
197
198 private:
199
200 // data members
201
202 TFile* fgTargetFile; //! ROOT output file
203
204
205 std::vector<TMVA::VariableTransformBase*> fDefaultTrfs; //! list of transformations on default DataSet
206
207 // cd to local directory
208 TString fOptions; //! option string given by construction (presently only "V")
209 TString fTransformations; //! list of transformations to test
210 Bool_t fVerbose; //! verbose mode
211 TString fVerboseLevel; //! verbosity level, controls granularity of logging
212 Bool_t fCorrelations; //! enable to calculate corelations
213 Bool_t fROC; //! enable to calculate ROC values
214 Bool_t fSilentFile; //! used in contructor wihtout file
215
216 TString fJobName; //! jobname, used as extension in weight file names
217
218 Types::EAnalysisType fAnalysisType; //! the training type
219 Bool_t fModelPersistence;//! option to save the trained model in xml file or using serialization
220
221
222 protected:
223
224 ClassDef(Factory,0); // The factory creates all MVA methods, and performs their training and testing
225 };
226
227} // namespace TMVA
228
229#endif
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassDef(name, id)
Definition: Rtypes.h:326
int type
Definition: TGX11.cxx:120
The Canvas class.
Definition: TCanvas.h:31
Describe directory structure in memory.
Definition: TDirectory.h:34
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
Class to perform cross validation, splitting the dataloader into folds.
Class that contains all the data information.
Definition: DataSetInfo.h:60
This is the main MVA steering class.
Definition: Factory.h:81
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition: Factory.cxx:1308
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition: Factory.cxx:843
Bool_t fSilentFile
enable to calculate ROC values
Definition: Factory.h:214
Bool_t fCorrelations
verbosity level, controls granularity of logging
Definition: Factory.h:212
Bool_t IsModelPersistence() const
Definition: Factory.h:153
TString fOptions
list of transformations on default DataSet
Definition: Factory.h:208
std::vector< IMethod * > MVector
Definition: Factory.h:85
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1094
Bool_t Verbose(void) const
Definition: Factory.h:135
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition: Factory.cxx:597
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:346
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition: Factory.cxx:119
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1245
void TrainAllMethodsForClassification(void)
Definition: Factory.h:116
Bool_t fVerbose
list of transformations to test
Definition: Factory.h:210
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1350
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2390
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition: Factory.cxx:2502
Bool_t fROC
enable to calculate corelations
Definition: Factory.h:213
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition: Factory.cxx:1335
TDirectory * RootBaseDir()
Definition: Factory.h:150
TString fVerboseLevel
verbose mode
Definition: Factory.h:211
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition: Factory.cxx:2166
void OptimizeAllMethodsForRegression(TString fomType="ROCIntegral", TString fitType="FitGA")
Definition: Factory.h:112
std::map< TString, MVector * > fMethodsMap
Definition: Factory.h:86
void SetInputTreesFromEventAssignTrees()
virtual ~Factory()
Destructor.
Definition: Factory.cxx:300
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition: Factory.cxx:904
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition: Factory.cxx:1280
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition: Factory.cxx:498
Bool_t fModelPersistence
the training type
Definition: Factory.h:219
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so.
Definition: Factory.cxx:695
void OptimizeAllMethodsForClassification(TString fomType="ROCIntegral", TString fitType="FitGA")
Definition: Factory.h:111
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition: Factory.cxx:744
Bool_t IsSilentFile() const
Definition: Factory.h:152
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2285
Types::EAnalysisType fAnalysisType
jobname, used as extension in weight file names
Definition: Factory.h:218
TString fJobName
used in contructor wihtout file
Definition: Factory.h:216
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition: Factory.cxx:580
MethodBase * BookMethod(DataLoader *, TMVA::Types::EMVA, TString, TString, TMVA::Types::EMVA, TString)
Definition: Factory.h:103
void TrainAllMethodsForRegression(void)
Definition: Factory.h:117
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2189
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:338
TFile * fgTargetFile
Definition: Factory.h:202
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
ROOT output file.
Definition: Factory.h:205
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition: Factory.cxx:562
void DeleteAllMethods(void)
Delete methods.
Definition: Factory.cxx:318
TString fTransformations
option string given by construction (presently only "V")
Definition: Factory.h:209
void Greetings()
Print welcome message.
Definition: Factory.cxx:290
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass)
Generate a collection of graphs, for all methods for a given class.
Definition: Factory.cxx:973
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
EAnalysisType
Definition: Types.h:127
@ kTesting
Definition: Types.h:145
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:131
A TTree represents a columnar dataset.
Definition: TTree.h:72
create variable transformations