Logo ROOT   6.16/01
Reference Guide
MethodCompositeBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3
4/*****************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * LAPP, Annecy, France *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://tmva.sourceforge.net/LICENSE) *
29 *****************************************************************************/
30
31/*! \class TMVA::MethodCompositeBase
32\ingroup TMVA
33
34Virtual base class for combining several TMVA method.
35
36This class is virtual class meant to combine more than one classifier
37together. The training of the classifiers is done by classes that are
38derived from this one, while the saving and loading of weights file
39and the evaluation is done here.
40*/
41
43
45#include "TMVA/DataSetInfo.h"
46#include "TMVA/Factory.h"
47#include "TMVA/IMethod.h"
48#include "TMVA/MethodBase.h"
49#include "TMVA/MethodBoost.h"
50#include "TMVA/MsgLogger.h"
51#include "TMVA/Tools.h"
52#include "TMVA/Types.h"
53#include "TMVA/Config.h"
54
55#include "Riostream.h"
56#include "TRandom3.h"
57#include "TMath.h"
58#include "TObjString.h"
59
60#include <algorithm>
61#include <iomanip>
62#include <vector>
63
64
65using std::vector;
66
68
69////////////////////////////////////////////////////////////////////////////////
70
72 Types::EMVA methodType,
73 const TString& methodTitle,
74 DataSetInfo& theData,
75 const TString& theOption )
76: TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
77 fCurrentMethodIdx(0), fCurrentMethod(0)
78{}
79
80////////////////////////////////////////////////////////////////////////////////
81
83 DataSetInfo& dsi,
84 const TString& weightFile)
85 : TMVA::MethodBase( methodType, dsi, weightFile),
86 fCurrentMethodIdx(0), fCurrentMethod(0)
87{}
88
89////////////////////////////////////////////////////////////////////////////////
90/// returns pointer to MVA that corresponds to given method title
91
93{
94 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
95 std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
96
97 for (; itrMethod != itrMethodEnd; ++itrMethod) {
98 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
99 if ( (mva->GetMethodName())==methodTitle ) return mva;
100 }
101 return 0;
102}
103
104////////////////////////////////////////////////////////////////////////////////
105/// returns pointer to MVA that corresponds to given method index
106
108{
109 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
110 if (itrMethod<fMethods.end()) return *itrMethod;
111 else return 0;
112}
113
114
115////////////////////////////////////////////////////////////////////////////////
116
118{
119 void* wght = gTools().AddChild(parent, "Weights");
120 gTools().AddAttr( wght, "NMethods", fMethods.size() );
121 for (UInt_t i=0; i< fMethods.size(); i++)
122 {
123 void* methxml = gTools().AddChild( wght, "Method" );
124 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
125 gTools().AddAttr(methxml,"Index", i );
126 gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
127 gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
128 gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
129 gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
130 gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
131 gTools().AddAttr(methxml,"JobName", method->GetJobName());
132 gTools().AddAttr(methxml,"Options", method->GetOptions());
133 if (method->fTransformationPointer)
134 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
135 else
136 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
137 method->AddWeightsXMLTo(methxml);
138 }
139}
140
141////////////////////////////////////////////////////////////////////////////////
142/// delete methods
143
145{
146 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
147 for (; itrMethod != fMethods.end(); ++itrMethod) {
148 Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
149 delete (*itrMethod);
150 }
151 fMethods.clear();
152}
153
154////////////////////////////////////////////////////////////////////////////////
155/// XML streamer
156
158{
159 UInt_t nMethods;
160 TString methodName, methodTypeName, jobName, optionString;
161
162 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
163 fMethods.clear();
164 fMethodWeight.clear();
165 gTools().ReadAttr( wghtnode, "NMethods", nMethods );
166 void* ch = gTools().GetChild(wghtnode);
167 for (UInt_t i=0; i< nMethods; i++) {
168 Double_t methodWeight, methodSigCut, methodSigCutOrientation;
169 gTools().ReadAttr( ch, "Weight", methodWeight );
170 gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
171 gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
172 gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
173 gTools().ReadAttr( ch, "MethodName", methodName );
174 gTools().ReadAttr( ch, "JobName", jobName );
175 gTools().ReadAttr( ch, "Options", optionString );
176
177 // Bool_t rerouteTransformation = kFALSE;
178 if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
179 TString rerouteString("");
180 gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
181 rerouteString.ToLower();
182 // if (rerouteString=="true")
183 // rerouteTransformation=kTRUE;
184 }
185
186 //remove trailing "~" to signal that options have to be reused
187 optionString.ReplaceAll("~","");
188 //ignore meta-options for method Boost
189 optionString.ReplaceAll("Boost_","~Boost_");
190 optionString.ReplaceAll("!~","~!");
191
192 if (i==0){
193 // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
194 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
195 }
196 fMethods.push_back(
197 ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
198
199 fMethodWeight.push_back(methodWeight);
200 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
201
202 if(meth==0)
203 Log() << kFATAL << "Could not read method from XML" << Endl;
204
205 void* methXML = gTools().GetChild(ch);
206
207 TString _fFileDir= meth->DataInfo().GetName();
208 _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
209 meth->SetWeightFileDir(_fFileDir);
210 meth->SetModelPersistence(IsModelPersistence());
211 meth->SetSilentFile(IsSilentFile());
212 meth->SetupMethod();
213 meth->SetMsgType(kWARNING);
214 meth->ParseOptions();
215 meth->ProcessSetup();
216 meth->CheckSetup();
217 meth->ReadWeightsFromXML(methXML);
218 meth->SetSignalReferenceCut(methodSigCut);
219 meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
220
221 meth->RerouteTransformationHandler (&(this->GetTransformationHandler()));
222
223 ch = gTools().GetNextChild(ch);
224 }
225 //Log() << kINFO << "Reading methods from XML done " << Endl;
226}
227
228////////////////////////////////////////////////////////////////////////////////
229/// text streamer
230
232{
233 TString var, dummy;
234 TString methodName, methodTitle=GetMethodName(),
235 jobName=GetJobName(),optionString=GetOptions();
236 UInt_t methodNum; Double_t methodWeight;
237 // and read the Weights (BDT coefficients)
238 // coverity[tainted_data_argument]
239 istr >> dummy >> methodNum;
240 Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
241 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
242 fMethods.clear();
243 fMethodWeight.clear();
244 for (UInt_t i=0; i<methodNum; i++) {
245 istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
246 if ((UInt_t)fCurrentMethodIdx != i) {
247 Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
248 << fCurrentMethodIdx << " i=" << i
249 << " MethodName " << methodName
250 << " dummy " << dummy
251 << " MethodWeight= " << methodWeight
252 << Endl;
253 }
254 if (GetMethodType() != Types::kBoost || i==0) {
255 istr >> dummy >> jobName;
256 istr >> dummy >> methodTitle;
257 istr >> dummy >> optionString;
258 if (GetMethodType() == Types::kBoost)
259 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
260 }
261 else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
262 fMethods.push_back(
263 ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
264 fMethodWeight.push_back( methodWeight );
265 if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
266 m->ReadWeightsFromStream(istr);
267 }
268}
269
270////////////////////////////////////////////////////////////////////////////////
271/// return composite MVA response
272
274{
275 Double_t mvaValue = 0;
276 for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
277
278 // cannot determine error
279 NoErrorCalc(err, errUpper);
280
281 return mvaValue;
282}
static RooMathCoreReg dummy
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
double Double_t
Definition: RtypesCore.h:55
#define ClassImp(name)
Definition: Rtypes.h:363
char * Form(const char *fmt,...)
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:112
IONames & GetIONames()
Definition: Config.h:90
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
void SetMsgType(EMsgType t)
Definition: Configurable.h:125
Class that contains all the data information.
Definition: DataSetInfo.h:60
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:67
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:660
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:369
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition: MethodBase.h:323
virtual void ReadWeightsFromXML(void *wghtnode)=0
const TString & GetJobName() const
Definition: MethodBase.h:321
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
friend class MethodCompositeBase
Definition: MethodBase.h:261
const TString & GetMethodName() const
Definition: MethodBase.h:322
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:394
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
virtual void AddWeightsXMLTo(void *parent) const =0
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:352
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:355
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:356
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:373
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:351
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
Virtual base class for combining several TMVA method.
void ReadWeightsFromStream(std::istream &istr)
text streamer
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
return composite MVA response
virtual ~MethodCompositeBase(void)
delete methods
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
void AddWeightsXMLTo(void *parent) const
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kBoost
Definition: Types.h:95
Types::EMVA GetMethodType(const TString &method) const
returns the method type (enum) for a given method (string)
Definition: Types.cxx:121
Basic string class.
Definition: TString.h:131
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1100
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:750
Abstract ClassifierFactory template that handles arbitrary types.
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748
auto * m
Definition: textangle.C:8