Logo ROOT   6.14/05
Reference Guide
MethodCompositeBase.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3 
4 /*****************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * LAPP, Annecy, France *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://tmva.sourceforge.net/LICENSE) *
29  *****************************************************************************/
30 
31 /*! \class TMVA::MethodCompositeBase
32 \ingroup TMVA
33 
34 Virtual base class for combining several TMVA method.
35 
36 This class is virtual class meant to combine more than one classifier
37 together. The training of the classifiers is done by classes that are
38 derived from this one, while the saving and loading of weights file
39 and the evaluation is done here.
40 */
41 
43 
44 #include "TMVA/ClassifierFactory.h"
45 #include "TMVA/DataSetInfo.h"
46 #include "TMVA/Factory.h"
47 #include "TMVA/IMethod.h"
48 #include "TMVA/MethodBase.h"
49 #include "TMVA/MethodBoost.h"
50 #include "TMVA/MsgLogger.h"
51 #include "TMVA/Tools.h"
52 #include "TMVA/Types.h"
53 #include "TMVA/Config.h"
54 
55 #include "Riostream.h"
56 #include "TRandom3.h"
57 #include "TMath.h"
58 #include "TObjString.h"
59 
60 #include <algorithm>
61 #include <iomanip>
62 #include <vector>
63 
64 
65 using std::vector;
66 
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 
72  Types::EMVA methodType,
73  const TString& methodTitle,
74  DataSetInfo& theData,
75  const TString& theOption )
76 : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
77  fCurrentMethodIdx(0), fCurrentMethod(0)
78 {}
79 
80 ////////////////////////////////////////////////////////////////////////////////
81 
83  DataSetInfo& dsi,
84  const TString& weightFile)
85  : TMVA::MethodBase( methodType, dsi, weightFile),
87 {}
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 /// returns pointer to MVA that corresponds to given method title
91 
93 {
94  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
95  std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
96 
97  for (; itrMethod != itrMethodEnd; ++itrMethod) {
98  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
99  if ( (mva->GetMethodName())==methodTitle ) return mva;
100  }
101  return 0;
102 }
103 
104 ////////////////////////////////////////////////////////////////////////////////
105 /// returns pointer to MVA that corresponds to given method index
106 
108 {
109  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
110  if (itrMethod<fMethods.end()) return *itrMethod;
111  else return 0;
112 }
113 
114 
115 ////////////////////////////////////////////////////////////////////////////////
116 
118 {
119  void* wght = gTools().AddChild(parent, "Weights");
120  gTools().AddAttr( wght, "NMethods", fMethods.size() );
121  for (UInt_t i=0; i< fMethods.size(); i++)
122  {
123  void* methxml = gTools().AddChild( wght, "Method" );
124  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
125  gTools().AddAttr(methxml,"Index", i );
126  gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
127  gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
128  gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
129  gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
130  gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
131  gTools().AddAttr(methxml,"JobName", method->GetJobName());
132  gTools().AddAttr(methxml,"Options", method->GetOptions());
133  if (method->fTransformationPointer)
134  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
135  else
136  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
137  method->AddWeightsXMLTo(methxml);
138  }
139 }
140 
141 ////////////////////////////////////////////////////////////////////////////////
142 /// delete methods
143 
145 {
146  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
147  for (; itrMethod != fMethods.end(); ++itrMethod) {
148  Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
149  delete (*itrMethod);
150  }
151  fMethods.clear();
152 }
153 
154 ////////////////////////////////////////////////////////////////////////////////
155 /// XML streamer
156 
158 {
159  UInt_t nMethods;
160  TString methodName, methodTypeName, jobName, optionString;
161 
162  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
163  fMethods.clear();
164  fMethodWeight.clear();
165  gTools().ReadAttr( wghtnode, "NMethods", nMethods );
166  void* ch = gTools().GetChild(wghtnode);
167  for (UInt_t i=0; i< nMethods; i++) {
168  Double_t methodWeight, methodSigCut, methodSigCutOrientation;
169  gTools().ReadAttr( ch, "Weight", methodWeight );
170  gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
171  gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
172  gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
173  gTools().ReadAttr( ch, "MethodName", methodName );
174  gTools().ReadAttr( ch, "JobName", jobName );
175  gTools().ReadAttr( ch, "Options", optionString );
176 
177  // Bool_t rerouteTransformation = kFALSE;
178  if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
179  TString rerouteString("");
180  gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
181  rerouteString.ToLower();
182  // if (rerouteString=="true")
183  // rerouteTransformation=kTRUE;
184  }
185 
186  //remove trailing "~" to signal that options have to be reused
187  optionString.ReplaceAll("~","");
188  //ignore meta-options for method Boost
189  optionString.ReplaceAll("Boost_","~Boost_");
190  optionString.ReplaceAll("!~","~!");
191 
192  if (i==0){
193  // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
194  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
195  }
196  fMethods.push_back(
197  ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
198 
199  fMethodWeight.push_back(methodWeight);
200  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
201 
202  if(meth==0)
203  Log() << kFATAL << "Could not read method from XML" << Endl;
204 
205  void* methXML = gTools().GetChild(ch);
206 
207  TString _fFileDir= meth->DataInfo().GetName();
208  _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
209  meth->SetWeightFileDir(_fFileDir);
211  meth->SetSilentFile(IsSilentFile());
212  meth->SetupMethod();
213  meth->SetMsgType(kWARNING);
214  meth->ParseOptions();
215  meth->ProcessSetup();
216  meth->CheckSetup();
217  meth->ReadWeightsFromXML(methXML);
218  meth->SetSignalReferenceCut(methodSigCut);
219  meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
220 
222 
223  ch = gTools().GetNextChild(ch);
224  }
225  //Log() << kINFO << "Reading methods from XML done " << Endl;
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 /// text streamer
230 
232 {
233  TString var, dummy;
234  TString methodName, methodTitle=GetMethodName(),
235  jobName=GetJobName(),optionString=GetOptions();
236  UInt_t methodNum; Double_t methodWeight;
237  // and read the Weights (BDT coefficients)
238  // coverity[tainted_data_argument]
239  istr >> dummy >> methodNum;
240  Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
241  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
242  fMethods.clear();
243  fMethodWeight.clear();
244  for (UInt_t i=0; i<methodNum; i++) {
245  istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
246  if ((UInt_t)fCurrentMethodIdx != i) {
247  Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
248  << fCurrentMethodIdx << " i=" << i
249  << " MethodName " << methodName
250  << " dummy " << dummy
251  << " MethodWeight= " << methodWeight
252  << Endl;
253  }
254  if (GetMethodType() != Types::kBoost || i==0) {
255  istr >> dummy >> jobName;
256  istr >> dummy >> methodTitle;
257  istr >> dummy >> optionString;
258  if (GetMethodType() == Types::kBoost)
259  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
260  }
261  else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
262  fMethods.push_back(
263  ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
264  fMethodWeight.push_back( methodWeight );
265  if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
266  m->ReadWeightsFromStream(istr);
267  }
268 }
269 
270 ////////////////////////////////////////////////////////////////////////////////
271 /// return composite MVA response
272 
274 {
275  Double_t mvaValue = 0;
276  for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
277 
278  // cannot determine error
279  NoErrorCalc(err, errUpper);
280 
281  return mvaValue;
282 }
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
MethodCompositeBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:373
void SetMsgType(EMsgType t)
Definition: Configurable.h:125
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
auto * m
Definition: textangle.C:8
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
Config & gConfig()
MsgLogger & Log() const
Definition: Configurable.h:122
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:660
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:356
Basic string class.
Definition: TString.h:131
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition: MethodBase.h:385
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1100
int Int_t
Definition: RtypesCore.h:41
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:369
void ReadWeightsFromStream(std::istream &istr)
text streamer
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void AddWeightsXMLTo(void *parent) const
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
std::vector< Double_t > fMethodWeight
DataSet * Data() const
Definition: MethodBase.h:400
Virtual base class for combining several TMVA method.
TString fWeightFileDir
Definition: Config.h:112
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
virtual ~MethodCompositeBase(void)
delete methods
IONames & GetIONames()
Definition: Config.h:90
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
Class that contains all the data information.
Definition: DataSetInfo.h:60
virtual void AddWeightsXMLTo(void *parent) const =0
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
virtual void ReadWeightsFromXML(void *wghtnode)=0
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
const TString & GetJobName() const
Definition: MethodBase.h:321
const TString & GetMethodName() const
Definition: MethodBase.h:322
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
Tools & gTools()
Bool_t IsSilentFile()
Definition: MethodBase.h:370
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:352
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
#define ClassImp(name)
Definition: Rtypes.h:359
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:394
double Double_t
Definition: RtypesCore.h:55
static RooMathCoreReg dummy
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:67
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
const TString & GetOptions() const
Definition: Configurable.h:84
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
std::vector< IMethod * > fMethods
Abstract ClassifierFactory template that handles arbitrary types.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
return composite MVA response
TString GetMethodTypeName() const
Definition: MethodBase.h:323
void SetWeightFileDir(TString fileDir)
set directory of weight file
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:351
virtual const char * GetName() const
Returns name of object.
Definition: TObject.cxx:357
Types::EMVA GetMethodType() const
Definition: MethodBase.h:324
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:841
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:355
const char * Data() const
Definition: TString.h:364
Bool_t IsModelPersistence()
Definition: MethodBase.h:374