Logo ROOT  
Reference Guide
Envelope.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#include <TMVA/Envelope.h>
13
14#include <TMVA/Configurable.h>
15#include <TMVA/DataLoader.h>
16#include <TMVA/MethodBase.h>
17#include <TMVA/OptionMap.h>
19#include <TMVA/Types.h>
20
21#include <TMVA/VariableInfo.h>
23
24#include <TAxis.h>
25#include <TCanvas.h>
26#include <TFile.h>
27#include <TGraph.h>
28#include <TSystem.h>
29#include <TH2.h>
30
31#include <iostream>
32
33using namespace TMVA;
34
35//_______________________________________________________________________
36/**
37Constructor for the initialization of Envelopes,
38differents Envelopes may needs differents constructors then
39this is a generic one protected.
40\param name the name algorithm.
41\param dataloader TMVA::DataLoader object with the data.
42\param file optional file to save the results.
43\param options extra options for the algorithm.
44*/
45Envelope::Envelope(const TString &name, DataLoader *dalaloader, TFile *file, const TString options)
46 : Configurable(options), fDataLoader(dalaloader), fFile(file), fModelPersistence(kTRUE), fVerbose(kFALSE),
47 fTransformations("I"), fSilentFile(kFALSE), fJobs(1)
48{
49 SetName(name.Data());
50 // render silent
51 if (gTools().CheckForSilentOption(GetOptions()))
52 Log().InhibitOutput(); // make sure is silent if wanted to
53
55 DeclareOptionRef(fVerbose, "V", "Verbose flag");
56
57 DeclareOptionRef(fModelPersistence, "ModelPersistence",
58 "Option to save the trained model in xml file or using serialization");
59 DeclareOptionRef(fTransformations, "Transformations", "List of transformations to test; formatting example: "
60 "\"Transformations=I;D;P;U;G,D\", for identity, "
61 "decorrelation, PCA, Uniform and Gaussianisation followed by "
62 "decorrelation transformations");
63 DeclareOptionRef(fJobs, "Jobs", "Option to run hign level algorithms in parallel with multi-thread");
64}
65
66//_______________________________________________________________________
68{
69}
70
71//_______________________________________________________________________
72/**
73Method to see if a file is available to save results
74\return Boolean with the status.
75*/
77
78//_______________________________________________________________________
79/**
80Method to get the pointer to TFile object.
81\return pointer to TFile object.
82*/
83TFile* Envelope::GetFile(){return fFile.get();}
84
85//_______________________________________________________________________
86/**
87Method to set the pointer to TFile object,
88with a writable file.
89\param file pointer to TFile object.
90*/
91void Envelope::SetFile(TFile *file){fFile=std::shared_ptr<TFile>(file);}
92
93//_______________________________________________________________________
94/**
95Method to see if the algorithm should print extra information.
96\return Boolean with the status.
97*/
99
100//_______________________________________________________________________
101/**
102Method enable print extra information in the algorithms.
103\param status Boolean with the status.
104*/
106
107//_______________________________________________________________________
108/**
109Method get the Booked methods in a option map object.
110\return vector of TMVA::OptionMap objects with the information of the Booked method
111*/
112std::vector<OptionMap> &Envelope::GetMethods()
113{
114 return fMethods;
115}
116
117//_______________________________________________________________________
118/**
119Method to get the pointer to TMVA::DataLoader object.
120\return pointer to TMVA::DataLoader object.
121*/
122
124
125//_______________________________________________________________________
126/**
127Method to set the pointer to TMVA::DataLoader object.
128\param dalaloader pointer to TMVA::DataLoader object.
129*/
130
132{
133 fDataLoader = std::shared_ptr<DataLoader>(dataloader);
134}
135
136//_______________________________________________________________________
137/**
138Method to see if the algorithm model is saved in xml or serialized files.
139\return Boolean with the status.
140*/
141Bool_t TMVA::Envelope::IsModelPersistence(){return fModelPersistence; }
142
143//_______________________________________________________________________
144/**
145Method enable model persistence, then algorithms model is saved in xml or serialized files.
146\param status Boolean with the status.
147*/
148void TMVA::Envelope::SetModelPersistence(Bool_t status){fModelPersistence=status;}
149
150//_______________________________________________________________________
151/**
152Method to book the machine learning method to perform the algorithm.
153\param method enum TMVA::Types::EMVA with the type of the mva method
154\param methodtitle String with the method title.
155\param options String with the options for the method.
156*/
157void TMVA::Envelope::BookMethod(Types::EMVA method, TString methodTitle, TString options){
158 BookMethod(Types::Instance().GetMethodName(method), methodTitle, options);
159}
160
161//_______________________________________________________________________
162/**
163Method to book the machine learning method to perform the algorithm.
164\param methodname String with the name of the mva method
165\param methodtitle String with the method title.
166\param options String with the options for the method.
167*/
168void TMVA::Envelope::BookMethod(TString methodName, TString methodTitle, TString options){
169 for (auto &meth : fMethods) {
170 if (meth.GetValue<TString>("MethodName") == methodName && meth.GetValue<TString>("MethodTitle") == methodTitle) {
171 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
172 << "in with DataSet Name <" << fDataLoader->GetName() << "> " << Endl;
173 }
174 }
175 OptionMap fMethod;
176 fMethod["MethodName"] = methodName;
177 fMethod["MethodTitle"] = methodTitle;
178 fMethod["MethodOptions"] = options;
179
180 fMethods.push_back(fMethod);
181}
182
183//_______________________________________________________________________
184/**
185Method to parse the internal option string.
186*/
188{
189
190 Bool_t silent = kFALSE;
191#ifdef WIN32
192 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
193 // (would need different sequences..)
194 Bool_t color = kFALSE;
195 Bool_t drawProgressBar = kFALSE;
196#else
197 Bool_t color = !gROOT->IsBatch();
198 Bool_t drawProgressBar = kTRUE;
199#endif
200 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
201 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
202 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
203 DeclareOptionRef(silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the "
204 "creation of the factory class object (default: False)");
205
207 CheckForUnusedOptions();
208
209 if (IsVerbose())
210 Log().SetMinType(kVERBOSE);
211
212 // global settings
213 gConfig().SetUseColor(color);
214 gConfig().SetSilent(silent);
215 gConfig().SetDrawProgressBar(drawProgressBar);
216}
217
218//_______________________________________________________________________
219/**
220 * function to check methods booked
221 * \param methodname Method's name.
222 * \param methodtitle title associated to the method.
223 * \return true if the method was booked.
224 */
226{
227 for (auto &meth : fMethods) {
228 if (meth.GetValue<TString>("MethodName") == methodname && meth.GetValue<TString>("MethodTitle") == methodtitle)
229 return kTRUE;
230 }
231 return kFALSE;
232}
233
234//_______________________________________________________________________
235/**
236 * method to save Train/Test information into the output file.
237 * \param fDataSetInfo TMVA::DataSetInfo object reference
238 * \param fAnalysisType Types::kMulticlass and Types::kRegression
239 */
241{
242 RootBaseDir()->cd();
243
244 if (!RootBaseDir()->GetDirectory(fDataSetInfo.GetName()))
245 RootBaseDir()->mkdir(fDataSetInfo.GetName());
246 else
247 return; // loader is now in the output file, we dont need to save again
248
249 RootBaseDir()->cd(fDataSetInfo.GetName());
250 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
251
252 // correlation matrix of the default DS
253 const TMatrixD *m(0);
254 const TH2 *h(0);
255
256 if (fAnalysisType == Types::kMulticlass) {
257 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses(); cls++) {
258 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
259 h = fDataSetInfo.CreateCorrelationMatrixHist(
260 m, TString("CorrelationMatrix") + fDataSetInfo.GetClassInfo(cls)->GetName(),
261 TString("Correlation Matrix (") + fDataSetInfo.GetClassInfo(cls)->GetName() + TString(")"));
262 if (h != 0) {
263 h->Write();
264 delete h;
265 }
266 }
267 } else {
268 m = fDataSetInfo.CorrelationMatrix("Signal");
269 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
270 if (h != 0) {
271 h->Write();
272 delete h;
273 }
274
275 m = fDataSetInfo.CorrelationMatrix("Background");
276 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
277 if (h != 0) {
278 h->Write();
279 delete h;
280 }
281
282 m = fDataSetInfo.CorrelationMatrix("Regression");
283 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
284 if (h != 0) {
285 h->Write();
286 delete h;
287 }
288 }
289
290 // some default transformations to evaluate
291 // NOTE: all transformations are destroyed after this test
292 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
293
294 // plus some user defined transformations
295 processTrfs = fTransformations;
296
297 // remove any trace of identity transform - if given (avoid to apply it twice)
298 std::vector<TMVA::TransformationHandler *> trfs;
299 TransformationHandler *identityTrHandler = 0;
300
301 std::vector<TString> trfsDef = gTools().SplitString(processTrfs, ';');
302 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
303 for (; trfsDefIt != trfsDef.end(); ++trfsDefIt) {
304 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Envelope"));
305 TString trfS = (*trfsDefIt);
306
307 // Log() << kINFO << Endl;
308 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
309 TMVA::CreateVariableTransforms(trfS, fDataSetInfo, *(trfs.back()), Log());
310
311 if (trfS.BeginsWith('I'))
312 identityTrHandler = trfs.back();
313 }
314
315 const std::vector<Event *> &inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
316
317 // apply all transformations
318 std::vector<TMVA::TransformationHandler *>::iterator trfIt = trfs.begin();
319
320 for (; trfIt != trfs.end(); ++trfIt) {
321 // setting a Root dir causes the variables distributions to be saved to the root file
322 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName())); // every dataloader have its own dir
323 (*trfIt)->CalcTransformations(inputEvents);
324 }
325 if (identityTrHandler)
326 identityTrHandler->PrintVariableRanking();
327
328 // clean up
329 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt)
330 delete *trfIt;
331}
#define h(i)
Definition: RSha256.hxx:106
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kTRUE
Definition: RtypesCore.h:87
char name[80]
Definition: TGX11.cxx:109
#define gROOT
Definition: TROOT.h:415
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
Service class for 2-Dim histogram classes.
Definition: TH2.h:30
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:71
void SetUseColor(Bool_t uc)
Definition: Config.h:62
void SetSilent(Bool_t s)
Definition: Config.h:65
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
MsgLogger & Log() const
Definition: Configurable.h:122
Class that contains all the data information.
Definition: DataSetInfo.h:60
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:69
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
Definition: DataSetInfo.h:153
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
ClassInfo * GetClassInfo(Int_t clNum) const
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:227
Bool_t HasMethod(TString methodname, TString methodtitle)
function to check methods booked
Definition: Envelope.cxx:225
~Envelope()
Default destructor.
Definition: Envelope.cxx:67
Bool_t IsModelPersistence()
Method to see if the algorithm model is saved in xml or serialized files.
Definition: Envelope.cxx:141
std::shared_ptr< TFile > fFile
data
Definition: Envelope.h:51
DataLoader * GetDataLoader()
Method to get the pointer to TMVA::DataLoader object.
Definition: Envelope.cxx:123
Bool_t fModelPersistence
file to save the results
Definition: Envelope.h:52
Bool_t IsSilentFile()
Method to see if a file is available to save results.
Definition: Envelope.cxx:76
void SetDataLoader(DataLoader *dalaloader)
Method to set the pointer to TMVA::DataLoader object.
Definition: Envelope.cxx:131
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition: Envelope.cxx:168
std::vector< OptionMap > fMethods
Definition: Envelope.h:49
void SetVerbose(Bool_t status)
Method enable print extra information in the algorithms.
Definition: Envelope.cxx:105
void SetFile(TFile *file)
Method to set the pointer to TFile object, with a writable file.
Definition: Envelope.cxx:91
Bool_t IsVerbose()
Method to see if the algorithm should print extra information.
Definition: Envelope.cxx:98
Bool_t fVerbose
flag to save the trained model
Definition: Envelope.h:53
void SetModelPersistence(Bool_t status=kTRUE)
Method enable model persistence, then algorithms model is saved in xml or serialized files.
Definition: Envelope.cxx:148
std::shared_ptr< DataLoader > fDataLoader
Booked method information.
Definition: Envelope.h:50
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:187
TFile * GetFile()
Method to get the pointer to TFile object.
Definition: Envelope.cxx:83
std::vector< OptionMap > & GetMethods()
Method get the Booked methods in a option map object.
Definition: Envelope.cxx:112
TString fTransformations
flag for extra information
Definition: Envelope.h:54
UInt_t fJobs
procpool object
Definition: Envelope.h:59
void WriteDataInformation(TMVA::DataSetInfo &fDataSetInfo, TMVA::Types::EAnalysisType fAnalysisType)
method to save Train/Test information into the output file.
Definition: Envelope.cxx:240
static void InhibitOutput()
Definition: MsgLogger.cxx:74
class to storage options for the differents methods
Definition: OptionMap.h:36
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition: TNamed.cxx:140
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:757
create variable transformations
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Definition: file.py:1
auto * m
Definition: textangle.C:8