ROOT  6.06/09
Reference Guide
MethodCompositeBase.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3 
4 /*****************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * LAPP, Annecy, France *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://tmva.sourceforge.net/LICENSE) *
29  *****************************************************************************/
30 
31 //_______________________________________________________________________
32 //
33 // This class is virtual class meant to combine more than one classifier//
34 // together. The training of the classifiers is done by classes that are//
35 // derived from this one, while the saving and loading of weights file //
36 // and the evaluation is done here. //
37 //_______________________________________________________________________
38 
39 #include <algorithm>
40 #include <iomanip>
41 #include <vector>
42 
43 #include "Riostream.h"
44 #include "TRandom3.h"
45 #include "TMath.h"
46 #include "TObjString.h"
47 
49 #include "TMVA/MethodBoost.h"
50 #include "TMVA/MethodBase.h"
51 #include "TMVA/Tools.h"
52 #include "TMVA/Types.h"
53 #include "TMVA/Factory.h"
54 #include "TMVA/ClassifierFactory.h"
55 
56 using std::vector;
57 
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 
62 TMVA::MethodCompositeBase::MethodCompositeBase( const TString& jobName,
63  Types::EMVA methodType,
64  const TString& methodTitle,
65  DataSetInfo& theData,
66  const TString& theOption,
67  TDirectory* theTargetDir )
68  : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption, theTargetDir ),
69  fCurrentMethodIdx(0), fCurrentMethod(0)
70 {}
71 
72 ////////////////////////////////////////////////////////////////////////////////
73 
75  DataSetInfo& dsi,
76  const TString& weightFile,
77  TDirectory* theTargetDir )
78  : TMVA::MethodBase( methodType, dsi, weightFile, theTargetDir ),
79  fCurrentMethodIdx(0), fCurrentMethod(0)
80 {}
81 
82 ////////////////////////////////////////////////////////////////////////////////
83 /// returns pointer to MVA that corresponds to given method title
84 
86 {
87  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
88  std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
89 
90  for (; itrMethod != itrMethodEnd; itrMethod++) {
91  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
92  if ( (mva->GetMethodName())==methodTitle ) return mva;
93  }
94  return 0;
95 }
96 
97 ////////////////////////////////////////////////////////////////////////////////
98 /// returns pointer to MVA that corresponds to given method index
99 
101 {
102  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
103  if (itrMethod<fMethods.end()) return *itrMethod;
104  else return 0;
105 }
106 
107 
108 ////////////////////////////////////////////////////////////////////////////////
109 
111 {
112  void* wght = gTools().AddChild(parent, "Weights");
113  gTools().AddAttr( wght, "NMethods", fMethods.size() );
114  for (UInt_t i=0; i< fMethods.size(); i++)
115  {
116  void* methxml = gTools().AddChild( wght, "Method" );
117  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
118  gTools().AddAttr(methxml,"Index", i );
119  gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
120  gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
121  gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
122  gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
123  gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
124  gTools().AddAttr(methxml,"JobName", method->GetJobName());
125  gTools().AddAttr(methxml,"Options", method->GetOptions());
126  if (method->fTransformationPointer)
127  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
128  else
129  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
130  method->AddWeightsXMLTo(methxml);
131  }
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 /// delete methods
136 
138 {
139  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
140  for (; itrMethod != fMethods.end(); itrMethod++) {
141  Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
142  delete (*itrMethod);
143  }
144  fMethods.clear();
145 }
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 /// XML streamer
149 
151 {
152  UInt_t nMethods;
153  TString methodName, methodTypeName, jobName, optionString;
154 
155  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
156  fMethods.clear();
157  fMethodWeight.clear();
158  gTools().ReadAttr( wghtnode, "NMethods", nMethods );
159  void* ch = gTools().GetChild(wghtnode);
160  for (UInt_t i=0; i< nMethods; i++) {
161  Double_t methodWeight, methodSigCut, methodSigCutOrientation;
162  gTools().ReadAttr( ch, "Weight", methodWeight );
163  gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
164  gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
165  gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
166  gTools().ReadAttr( ch, "MethodName", methodName );
167  gTools().ReadAttr( ch, "JobName", jobName );
168  gTools().ReadAttr( ch, "Options", optionString );
169 
170  // Bool_t rerouteTransformation = kFALSE;
171  if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
172  TString rerouteString("");
173  gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
174  rerouteString.ToLower();
175  // if (rerouteString=="true")
176  // rerouteTransformation=kTRUE;
177  }
178 
179  //remove trailing "~" to signal that options have to be reused
180  optionString.ReplaceAll("~","");
181  //ignore meta-options for method Boost
182  optionString.ReplaceAll("Boost_","~Boost_");
183  optionString.ReplaceAll("!~","~!");
184 
185  if (i==0){
186  // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
187  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
188  }
189  fMethods.push_back(ClassifierFactory::Instance().Create(
190  std::string(methodTypeName),jobName, methodName,DataInfo(),optionString));
191 
192  fMethodWeight.push_back(methodWeight);
193  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
194 
195  if(meth==0)
196  Log() << kFATAL << "Could not read method from XML" << Endl;
197 
198  void* methXML = gTools().GetChild(ch);
199  meth->SetupMethod();
200  meth->SetMsgType(kWARNING);
201  meth->ParseOptions();
202  meth->ProcessSetup();
203  meth->CheckSetup();
204  meth->ReadWeightsFromXML(methXML);
205  meth->SetSignalReferenceCut(methodSigCut);
206  meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
207 
208  meth->RerouteTransformationHandler (&(this->GetTransformationHandler()));
209 
210  ch = gTools().GetNextChild(ch);
211  }
212  //Log() << kINFO << "Reading methods from XML done " << Endl;
213 }
214 
215 ////////////////////////////////////////////////////////////////////////////////
216 /// text streamer
217 
219 {
220  TString var, dummy;
221  TString methodName, methodTitle=GetMethodName(),
222  jobName=GetJobName(),optionString=GetOptions();
223  UInt_t methodNum; Double_t methodWeight;
224  // and read the Weights (BDT coefficients)
225  // coverity[tainted_data_argument]
226  istr >> dummy >> methodNum;
227  Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
228  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
229  fMethods.clear();
230  fMethodWeight.clear();
231  for (UInt_t i=0; i<methodNum; i++) {
232  istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
233  if ((UInt_t)fCurrentMethodIdx != i) {
234  Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
235  << fCurrentMethodIdx << " i=" << i
236  << " MethodName " << methodName
237  << " dummy " << dummy
238  << " MethodWeight= " << methodWeight
239  << Endl;
240  }
241  if (GetMethodType() != Types::kBoost || i==0) {
242  istr >> dummy >> jobName;
243  istr >> dummy >> methodTitle;
244  istr >> dummy >> optionString;
245  if (GetMethodType() == Types::kBoost)
246  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
247  }
248  else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
249  fMethods.push_back(ClassifierFactory::Instance().Create( std::string(methodName), jobName,
250  methodTitle,DataInfo(), optionString) );
251  fMethodWeight.push_back( methodWeight );
252  if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
253  m->ReadWeightsFromStream(istr);
254  }
255 }
256 
257 ////////////////////////////////////////////////////////////////////////////////
258 /// return composite MVA response
259 
261 {
262  Double_t mvaValue = 0;
263  for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
264 
265  // cannot determine error
266  NoErrorCalc(err, errUpper);
267 
268  return mvaValue;
269 }
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
void SetMsgType(EMsgType t)
Definition: Configurable.h:135
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:635
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:587
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:330
Basic string class.
Definition: TString.h:137
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1088
int Int_t
Definition: RtypesCore.h:41
void ReadWeightsFromStream(std::istream &istr)
text streamer
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void AddWeightsXMLTo(void *parent) const
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
const TString & GetMethodName() const
Definition: MethodBase.h:296
MethodCompositeBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
Definition: Types.cxx:61
ClassImp(TMVA::MethodCompositeBase) TMVA
Tools & gTools()
Definition: Tools.cxx:79
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
virtual ~MethodCompositeBase(void)
delete methods
std::vector< std::vector< double > > Data
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:295
virtual void AddWeightsXMLTo(void *parent) const =0
virtual void ReadWeightsFromXML(void *wghtnode)=0
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:707
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
char * Form(const char *fmt,...)
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:320
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:357
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:41
static RooMathCoreReg dummy
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:326
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:310
Abstract ClassifierFactory template that handles arbitrary types.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
return composite MVA response
const TString & GetJobName() const
Definition: MethodBase.h:295
const TString & GetOptions() const
Definition: Configurable.h:91
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:325
TString GetMethodTypeName() const
Definition: MethodBase.h:297
Definition: math.cpp:60
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:329