Logo ROOT   6.14/05
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16  * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * MSU East Lansing, USA *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 /*! \class TMVA::MethodCategory
34 \ingroup TMVA
35 
36 Class for categorizing the phase space
37 
38 This class is meant to allow categorisation of the data. For different
39 categories, different classifiers may be booked and different variables
40 may be considered. The aim is to account for the difference that
41 is due to different locations/angles.
42 */
43 
44 
45 #include "TMVA/MethodCategory.h"
46 
47 #include <algorithm>
48 #include <iomanip>
49 #include <vector>
50 #include <iostream>
51 
52 #include "Riostream.h"
53 #include "TRandom3.h"
54 #include "TMath.h"
55 #include "TObjString.h"
56 #include "TH1F.h"
57 #include "TGraph.h"
58 #include "TSpline.h"
59 #include "TDirectory.h"
60 #include "TTreeFormula.h"
61 
62 #include "TMVA/ClassifierFactory.h"
63 #include "TMVA/Config.h"
64 #include "TMVA/DataSet.h"
65 #include "TMVA/DataSetInfo.h"
66 #include "TMVA/DataSetManager.h"
67 #include "TMVA/IMethod.h"
68 #include "TMVA/MethodBase.h"
70 #include "TMVA/MsgLogger.h"
71 #include "TMVA/PDF.h"
72 #include "TMVA/Ranking.h"
73 #include "TMVA/Timer.h"
74 #include "TMVA/Tools.h"
75 #include "TMVA/Types.h"
76 #include "TMVA/VariableInfo.h"
78 
79 REGISTER_METHOD(Category)
80 
82 
83 ////////////////////////////////////////////////////////////////////////////////
84 /// standard constructor
85 
87  const TString& methodTitle,
88  DataSetInfo& theData,
89  const TString& theOption )
90  : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
91  fCatTree(0),
92  fDataSetManager(NULL)
93 {
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 /// constructor from weight file
98 
100  const TString& theWeightFile)
101  : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
102  fCatTree(0),
103  fDataSetManager(NULL)
104 {
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// destructor
109 
111 {
112  std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
113  std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
114  for(;formIt!=lastF; ++formIt) delete *formIt;
115  delete fCatTree;
116 }
117 
118 ////////////////////////////////////////////////////////////////////////////////
119 /// check whether method category has analysis type
120 /// the method type has to be the same for all sub-methods
121 
123 {
124  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
125 
126  // iterate over methods and check whether they have the analysis type
127  for(; itrMethod != fMethods.end(); ++itrMethod ) {
128  if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
129  return kFALSE;
130  }
131  return kTRUE;
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 /// options for this method
136 
138 {
139 }
140 
141 ////////////////////////////////////////////////////////////////////////////////
142 /// adds sub-classifier for a category
143 
145  const TString& theVariables,
146  Types::EMVA theMethod ,
147  const TString& theTitle,
148  const TString& theOptions )
149 {
150  std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
151 
152  Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
153 
154  DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
155 
156  IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
157 
158  MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
159  if(method==0) return 0;
160 
163  method->SetAnalysisType( fAnalysisType );
164  method->SetupMethod();
165  method->ParseOptions();
166  method->ProcessSetup();
167  method->SetFile(fFile);
168  method->SetSilentFile(IsSilentFile());
169 
170 
171  // set or create correct method base dir for added method
172  const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
173  TDirectory * dir = BaseDir()->GetDirectory(dirName);
174  if (dir != 0) method->SetMethodBaseDir( dir );
175  else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
176 
177  // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
178 
179  // check-for-unused-options is performed; may be overridden by derived
180  // classes
181  method->CheckSetup();
182 
183  // disable writing of XML files and standalone classes for sub methods
184  method->DisableWriting( kTRUE );
185 
186  // store method, cut and variable names and create cut formula
187  fMethods.push_back(method);
188  fCategoryCuts.push_back(theCut);
189  fVars.push_back(theVariables);
190 
191  DataSetInfo& primaryDSI = DataInfo();
192 
193  UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
194  fCategorySpecIdx.push_back(newSpectatorIndex);
195 
196  primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
197  Form("%s:%s",GetName(),method->GetName()),
198  "pass", 0, 0, 'C' );
199 
200  return method;
201 }
202 
203 ////////////////////////////////////////////////////////////////////////////////
204 /// create a DataSetInfo object for a sub-classifier
205 
207  const TString& theVariables,
208  const TString& theTitle)
209 {
210  // create a new dsi with name: theTitle+"_dsi"
211  TString dsiName=theTitle+"_dsi";
212  DataSetInfo& oldDSI = DataInfo();
213  DataSetInfo* dsi = new DataSetInfo(dsiName);
214 
215  // register the new dsi
216  // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
218 
219  // copy the targets and spectators from the old dsi to the new dsi
220  std::vector<VariableInfo>::iterator itrVarInfo;
221 
222  for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
223  dsi->AddTarget(*itrVarInfo);
224 
225  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
226  dsi->AddSpectator(*itrVarInfo);
227 
228  // split string that contains the variables into tiny little pieces
229  std::vector<TString> variables = gTools().SplitString(theVariables,':' );
230 
231  // prepare to create varMap
232  std::vector<UInt_t> varMap;
233  Int_t counter=0;
234 
235  // add the variables that were specified in theVariables
236  std::vector<TString>::iterator itrVariables;
237  Bool_t found = kFALSE;
238 
239  // iterate over all variables in 'variables' and add them
240  for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
241  counter=0;
242 
243  // check the variables of the old dsi for the variable that we want to add
244  for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
245  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
246  // don't compare the expression, since the user might take two times the same expression, but with different labels
247  // and apply different transformations to the variables.
248  dsi->AddVariable(*itrVarInfo);
249  varMap.push_back(counter);
250  found = kTRUE;
251  }
252  counter++;
253  }
254 
255  // check the spectators of the old dsi for the variable that we want to add
256  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
257  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
258  // don't compare the expression, since the user might take two times the same expression, but with different labels
259  // and apply different transformations to the variables.
260  dsi->AddVariable(*itrVarInfo);
261  varMap.push_back(counter);
262  found = kTRUE;
263  }
264  counter++;
265  }
266 
267  // if the variable is neither in the variables nor in the spectators, we abort
268  if (!found) {
269  Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
270  }
271  found = kFALSE;
272  }
273 
274  // in the case that no variables are specified, add the default-variables from the original dsi
275  if (theVariables=="") {
276  for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
277  dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
278  varMap.push_back(i);
279  }
280  }
281 
282  // add the variable map 'varMap' to the vector of varMaps
283  fVarMaps.push_back(varMap);
284 
285  // set classes and cuts
286  UInt_t nClasses=oldDSI.GetNClasses();
287  TString className;
288 
289  for (UInt_t i=0; i<nClasses; i++) {
290  className = oldDSI.GetClassInfo(i)->GetName();
291  dsi->AddClass(className);
292  dsi->SetCut(oldDSI.GetCut(i),className);
293  dsi->AddCut(theCut,className);
294  dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
295  }
296 
297  // set split options, root dir and normalization for the new dsi
298  dsi->SetSplitOptions(oldDSI.GetSplitOptions());
299  dsi->SetRootDir(oldDSI.GetRootDir());
300  TString norm(oldDSI.GetNormalization().Data());
301  dsi->SetNormalization(norm);
302 
303  DataSetInfo& dsiReference= (*dsi);
304 
305  return dsiReference;
306 }
307 
308 ////////////////////////////////////////////////////////////////////////////////
309 /// initialize the method
310 
312 {
313 }
314 
315 ////////////////////////////////////////////////////////////////////////////////
316 /// initialize the circular tree
317 
319 {
320  delete fCatTree;
321 
322  std::vector<VariableInfo>::const_iterator viIt;
323  const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
324  const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
325 
326  Bool_t hasAllExternalLinks = kTRUE;
327  for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
328  if( viIt->GetExternalLink() == 0 ) {
329  hasAllExternalLinks = kFALSE;
330  break;
331  }
332  for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
333  if( viIt->GetExternalLink() == 0 ) {
334  hasAllExternalLinks = kFALSE;
335  break;
336  }
337 
338  if(!hasAllExternalLinks) return;
339 
340  {
341  // Rather than having TTree::TTree add to the current directory and then remove it, let
342  // make sure to not add it in the first place.
343  // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
344  // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
345  TDirectory::TContext ctxt(nullptr);
346  fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
347  fCatTree->SetCircular(1);
348  }
349 
350  for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
351  const VariableInfo& vi = *viIt;
353  }
354  for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
355  const VariableInfo& vi = *viIt;
356  if(vi.GetVarType()=='C') continue;
358  }
359 
360  for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
361  fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
362  }
363 }
364 
365 ////////////////////////////////////////////////////////////////////////////////
366 /// train all sub-classifiers
367 
369 {
370  // specify the minimum # of training events and set 'classification'
371  const Int_t MinNoTrainingEvents = 10;
372 
373  Types::EAnalysisType analysisType = GetAnalysisType();
374 
375  // start the training
376  Log() << kINFO << "Train all sub-classifiers for "
377  << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
378 
379  // don't do anything if no sub-classifier booked
380  if (fMethods.empty()) {
381  Log() << kINFO << "...nothing found to train" << Endl;
382  return;
383  }
384 
385  std::vector<IMethod*>::iterator itrMethod;
386 
387  // iterate over all booked sub-classifiers and train them
388  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
389 
390  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
391  if(!mva) continue;
392  mva->SetAnalysisType( analysisType );
393  if (!mva->HasAnalysisType( analysisType,
394  mva->DataInfo().GetNClasses(),
395  mva->DataInfo().GetNTargets() ) ) {
396  Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
397  if (analysisType == Types::kRegression)
398  Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
399  else
400  Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
401  itrMethod = fMethods.erase( itrMethod );
402  continue;
403  }
404  if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
405 
406  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
407  << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
408  mva->TrainMethod();
409  Log() << kINFO << "Training finished" << Endl;
410 
411  } else {
412 
413  Log() << kWARNING << "Method " << mva->GetMethodName()
414  << " not trained (training tree has less entries ["
415  << mva->Data()->GetNTrainingEvents()
416  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
417 
418  Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
419  Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
420 
421  }
422  }
423 
424  if (analysisType != Types::kRegression) {
425 
426  // variable ranking
427  Log() << kINFO << "Begin ranking of input variables..." << Endl;
428  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
429  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
430  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
431  const Ranking* ranking = (*itrMethod)->CreateRanking();
432  if (ranking != 0)
433  ranking->Print();
434  else
435  Log() << kINFO << "No variable ranking supplied by classifier: "
436  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
437  }
438  }
439  }
440 }
441 
442 ////////////////////////////////////////////////////////////////////////////////
443 /// create XML description of Category classifier
444 
445 void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
446 {
447  void* wght = gTools().AddChild(parent, "Weights");
448  gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
449  void* submethod(0);
450 
451  // iterate over methods and write them to XML file
452  for (UInt_t i=0; i<fMethods.size(); i++) {
453  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
454  submethod = gTools().AddChild(wght, "SubMethod");
455  gTools().AddAttr(submethod, "Index", i);
456  gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
457  gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
458  gTools().AddAttr(submethod, "Variables", fVars[i]);
459  method->WriteStateToXML( submethod );
460  }
461 }
462 
463 ////////////////////////////////////////////////////////////////////////////////
464 /// read weights of sub-classifiers of MethodCategory from xml weight file
465 
467 {
468  UInt_t nSubMethods;
469  TString fullMethodName;
470  TString methodType;
471  TString methodTitle;
472  TString theCutString;
473  TString theVariables;
474  Int_t titleLength;
475  gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
476  void* subMethodNode = gTools().GetChild(wghtnode);
477 
478  Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
479 
480  // recreate all sub-methods from weight file
481  for (UInt_t i=0; i<nSubMethods; i++) {
482  gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
483  gTools().ReadAttr( subMethodNode, "Cut", theCutString );
484  gTools().ReadAttr( subMethodNode, "Variables", theVariables );
485 
486  // determine sub-method type
487  methodType = fullMethodName(0,fullMethodName.Index("::"));
488  if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
489 
490  // determine sub-method title
491  titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
492  methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
493 
494  // reconstruct dsi for sub-method
495  DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
496 
497  // recreate sub-method from weights and add to fMethods
498  MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
499  dsi, "none" ) );
500  if(method==0)
501  Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
502 
503  method->SetupMethod();
504  method->ReadStateFromXML(subMethodNode);
505 
506  fMethods.push_back(method);
507  fCategoryCuts.push_back(TCut(theCutString));
508  fVars.push_back(theVariables);
509 
510  DataSetInfo& primaryDSI = DataInfo();
511 
512  UInt_t spectatorIdx = 10000;
513  UInt_t counter=0;
514 
515  // find the spectator index
516  std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
517  std::vector<VariableInfo>::iterator itrVarInfo;
518  TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
519 
520  for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
521  if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
522  spectatorIdx=counter;
523  fCategorySpecIdx.push_back(spectatorIdx);
524  break;
525  }
526  }
527 
528  subMethodNode = gTools().GetNextChild(subMethodNode);
529  }
530 
532 
533 }
534 
535 ////////////////////////////////////////////////////////////////////////////////
536 /// process user options
537 
539 {
540 }
541 
542 ////////////////////////////////////////////////////////////////////////////////
543 /// Get help message text
544 ///
545 /// typical length of text line:
546 /// "|--------------------------------------------------------------|"
547 
549 {
550  Log() << Endl;
551  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
552  Log() << Endl;
553  Log() << "This method allows to define different categories of events. The" <<Endl;
554  Log() << "categories are defined via cuts on the variables. For each" << Endl;
555  Log() << "category, a different classifier and set of variables can be" <<Endl;
556  Log() << "specified. The categories which are defined for this method must" << Endl;
557  Log() << "be disjoint." << Endl;
558 }
559 
560 ////////////////////////////////////////////////////////////////////////////////
561 /// no ranking
562 
564 {
565  return 0;
566 }
567 
568 ////////////////////////////////////////////////////////////////////////////////
569 
571 {
572  // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
573  // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
574  // one will be evaluated.. (the formulae return 'true' or 'false'
575  if (fCatTree) {
576  if (methodIdx>=fCatFormulas.size()) {
577  Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
578  << fCatFormulas.size() << Endl;
579  }
580  TTreeFormula* f = fCatFormulas[methodIdx];
581  return f->EvalInstance(0) > 0.5;
582  }
583  // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
584  else {
585 
586  // checks whether an event lies within a cut
587  if (methodIdx>=fCategorySpecIdx.size()) {
588  Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
589  << fCategorySpecIdx.size() << Endl;
590  }
591  UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
592  Float_t specVal = ev->GetSpectator(spectatorIdx);
593  Bool_t pass = (specVal>0.5);
594  return pass;
595  }
596 }
597 
598 ////////////////////////////////////////////////////////////////////////////////
599 /// returns the mva value of the right sub-classifier
600 
602 {
603  if (fMethods.empty()) return 0;
604 
605  UInt_t methodToUse = 0;
606  const Event* ev = GetEvent();
607 
608  // determine which sub-classifier to use for this event
609  Int_t suitableCutsN = 0;
610 
611  for (UInt_t i=0; i<fMethods.size(); ++i) {
612  if (PassesCut(ev, i)) {
613  ++suitableCutsN;
614  methodToUse=i;
615  }
616  }
617 
618  if (suitableCutsN == 0) {
619  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
620  return 0;
621  }
622 
623  if (suitableCutsN > 1) {
624  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
625  return 0;
626  }
627 
628  // get mva value from the suitable sub-classifier
629  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
630  Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
631  ev->SetVariableArrangement(0);
632 
633  return mvaValue;
634 }
635 
636 
637 
638 ////////////////////////////////////////////////////////////////////////////////
639 /// returns the mva value of the right sub-classifier
640 
641 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
642 {
643  if (fMethods.empty()) return MethodBase::GetRegressionValues();
644 
645  UInt_t methodToUse = 0;
646  const Event* ev = GetEvent();
647 
648  // determine which sub-classifier to use for this event
649  Int_t suitableCutsN = 0;
650 
651  for (UInt_t i=0; i<fMethods.size(); ++i) {
652  if (PassesCut(ev, i)) {
653  ++suitableCutsN;
654  methodToUse=i;
655  }
656  }
657 
658  if (suitableCutsN == 0) {
659  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
661  }
662 
663  if (suitableCutsN > 1) {
664  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
666  }
667  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
668  if (!meth){
669  Log() << kFATAL << "method not found in Category Regression method" << Endl;
671  }
672  // get mva value from the suitable sub-classifier
673  return meth->GetRegressionValues(ev);
674 }
675 
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Types::EAnalysisType fAnalysisType
Definition: MethodBase.h:584
std::vector< IMethod * > fMethods
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:373
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
void Init()
initialize the method
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables, variable transformation type etc.
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Singleton class for Global types used by TMVA.
Definition: Types.h:73
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void ReadStateFromXML(void *parent)
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:104
float Float_t
Definition: RtypesCore.h:53
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:365
MsgLogger & Log() const
Definition: Configurable.h:122
EAnalysisType
Definition: Types.h:127
std::vector< TCut > fCategoryCuts
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:211
std::vector< UInt_t > fCategorySpecIdx
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:634
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:218
Basic string class.
Definition: TString.h:131
Ranking for variables in method (implementation)
Definition: Ranking.h:48
#define f(i)
Definition: RSha256.hxx:104
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
UInt_t GetNClasses() const
Definition: DataSetInfo.h:136
virtual ~MethodCategory(void)
destructor
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:369
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
const TString & GetExpression() const
Definition: VariableInfo.h:57
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
const TString & GetNormalization() const
Definition: DataSetInfo.h:114
char GetVarType() const
Definition: VariableInfo.h:61
std::vector< std::vector< UInt_t > > fVarMaps
void DeclareOptions()
options for this method
const Event * GetEvent() const
Definition: MethodBase.h:740
DataSet * Data() const
Definition: MethodBase.h:400
Virtual base class for combining several TMVA method.
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:192
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
Class that contains all the data information.
Definition: DataSetInfo.h:60
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:99
Bool_t fModelPersistence
Definition: MethodBase.h:622
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:171
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
A specialized string object used for TTree selections.
Definition: TCut.h:25
std::vector< TTreeFormula * > fCatFormulas
needed in conjunction with TTreeFormulas for evaluation category expressions
MethodCategory(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:99
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:166
const Ranking * CreateRanking()
no ranking
void * GetExternalLink() const
Definition: VariableInfo.h:81
UInt_t GetNTargets() const
Definition: DataSetInfo.h:111
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
const char * GetName() const
Definition: MethodBase.h:325
ClassInfo * GetClassInfo(Int_t clNum) const
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
DataSetInfo & AddDataSetInfo(DataSetInfo &dsi)
stores a copy of the dataset info object
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
DataSetManager * fDataSetManager
Ssiz_t Length() const
Definition: TString.h:405
const TString & GetJobName() const
Definition: MethodBase.h:321
const TString & GetMethodName() const
Definition: MethodBase.h:322
void Train(void)
train all sub-classifiers
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
Tools & gTools()
Bool_t IsSilentFile()
Definition: MethodBase.h:370
const Bool_t kFALSE
Definition: RtypesCore.h:88
Class for categorizing the phase space.
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
#define ClassImp(name)
Definition: Rtypes.h:359
double Double_t
Definition: RtypesCore.h:55
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:876
Describe directory structure in memory.
Definition: TDirectory.h:34
int type
Definition: TGX11.cxx:120
void SetFile(TFile *file)
Definition: MethodBase.h:366
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
T EvalInstance(Int_t i=0, const char *stringStack[]=0)
Evaluate this treeformula.
ClassInfo * AddClass(const TString &className)
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
std::vector< TString > fVars
void GetHelpMessage() const
Get help message text.
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1711
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:167
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
TString GetMethodTypeName() const
Definition: MethodBase.h:323
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:149
void SetWeightFileDir(TString fileDir)
set directory of weight file
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:400
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at &#39;separator&#39; and fills the list &#39;splitV&#39; with the primitive strings ...
Definition: Tools.cxx:1211
virtual void SetCircular(Long64_t maxEntries)
Enable/Disable circularity for this tree.
Definition: TTree.cxx:8476
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:433
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:428
A TTree object has a header with a name and a title.
Definition: TTree.h:70
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
TString fFileDir
Definition: MethodBase.h:626
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:145
const Bool_t kTRUE
Definition: RtypesCore.h:87
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:115
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:170
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:94
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:427
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
void ProcessOptions()
process user options
const char * Data() const
Definition: TString.h:364