Logo ROOT  
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16  * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * MSU East Lansing, USA *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 /*! \class TMVA::MethodCategory
34 \ingroup TMVA
35 
36 Class for categorizing the phase space
37 
38 This class is meant to allow categorisation of the data. For different
39 categories, different classifiers may be booked and different variables
40 may be considered. The aim is to account for the difference that
41 is due to different locations/angles.
42 */
43 
44 
45 #include "TMVA/MethodCategory.h"
46 
47 #include <algorithm>
48 #include <vector>
49 
50 #include "TRandom3.h"
51 #include "TH1F.h"
52 #include "TSpline.h"
53 #include "TDirectory.h"
54 #include "TTreeFormula.h"
55 
56 #include "TMVA/ClassifierFactory.h"
57 #include "TMVA/Config.h"
58 #include "TMVA/DataSet.h"
59 #include "TMVA/DataSetInfo.h"
60 #include "TMVA/DataSetManager.h"
61 #include "TMVA/IMethod.h"
62 #include "TMVA/MethodBase.h"
64 #include "TMVA/MsgLogger.h"
65 #include "TMVA/PDF.h"
66 #include "TMVA/Ranking.h"
67 #include "TMVA/Timer.h"
68 #include "TMVA/Tools.h"
69 #include "TMVA/Types.h"
70 #include "TMVA/VariableInfo.h"
72 
73 REGISTER_METHOD(Category)
74 
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 /// standard constructor
79 
81  const TString& methodTitle,
82  DataSetInfo& theData,
83  const TString& theOption )
84  : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
85  fCatTree(0),
86  fDataSetManager(NULL)
87 {
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 /// constructor from weight file
92 
94  const TString& theWeightFile)
95  : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
96  fCatTree(0),
97  fDataSetManager(NULL)
98 {
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// destructor
103 
105 {
106  std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
107  std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
108  for(;formIt!=lastF; ++formIt) delete *formIt;
109  delete fCatTree;
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// check whether method category has analysis type
114 /// the method type has to be the same for all sub-methods
115 
117 {
118  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
119 
120  // iterate over methods and check whether they have the analysis type
121  for(; itrMethod != fMethods.end(); ++itrMethod ) {
122  if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
123  return kFALSE;
124  }
125  return kTRUE;
126 }
127 
128 ////////////////////////////////////////////////////////////////////////////////
129 /// options for this method
130 
132 {
133 }
134 
135 ////////////////////////////////////////////////////////////////////////////////
136 /// adds sub-classifier for a category
137 
139  const TString& theVariables,
140  Types::EMVA theMethod ,
141  const TString& theTitle,
142  const TString& theOptions )
143 {
144  std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
145 
146  Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
147 
148  DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
149 
150  IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
151 
152  MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
153  if(method==0) return 0;
154 
155  if(fModelPersistence) method->SetWeightFileDir(fFileDir);
156  method->SetModelPersistence(fModelPersistence);
157  method->SetAnalysisType( fAnalysisType );
158  method->SetupMethod();
159  method->ParseOptions();
160  method->ProcessSetup();
161  method->SetFile(fFile);
162  method->SetSilentFile(IsSilentFile());
163 
164 
165  // set or create correct method base dir for added method
166  const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
167  TDirectory * dir = BaseDir()->GetDirectory(dirName);
168  if (dir != 0) method->SetMethodBaseDir( dir );
169  else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
170 
171  // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
172 
173  // check-for-unused-options is performed; may be overridden by derived
174  // classes
175  method->CheckSetup();
176 
177  // disable writing of XML files and standalone classes for sub methods
178  method->DisableWriting( kTRUE );
179 
180  // store method, cut and variable names and create cut formula
181  fMethods.push_back(method);
182  fCategoryCuts.push_back(theCut);
183  fVars.push_back(theVariables);
184 
185  DataSetInfo& primaryDSI = DataInfo();
186 
187  UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
188  fCategorySpecIdx.push_back(newSpectatorIndex);
189 
190  primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
191  Form("%s:%s",GetName(),method->GetName()),
192  "pass", 0, 0, 'C' );
193 
194  return method;
195 }
196 
197 ////////////////////////////////////////////////////////////////////////////////
198 /// create a DataSetInfo object for a sub-classifier
199 
201  const TString& theVariables,
202  const TString& theTitle)
203 {
204  // create a new dsi with name: theTitle+"_dsi"
205  TString dsiName=theTitle+"_dsi";
206  DataSetInfo& oldDSI = DataInfo();
207  DataSetInfo* dsi = new DataSetInfo(dsiName);
208 
209  // register the new dsi
210  // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
211  fDataSetManager->AddDataSetInfo(*dsi);
212 
213  // copy the targets and spectators from the old dsi to the new dsi
214  std::vector<VariableInfo>::iterator itrVarInfo;
215 
216  for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
217  dsi->AddTarget(*itrVarInfo);
218 
219  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
220  dsi->AddSpectator(*itrVarInfo);
221 
222  // split string that contains the variables into tiny little pieces
223  std::vector<TString> variables = gTools().SplitString(theVariables,':' );
224 
225  // prepare to create varMap
226  std::vector<UInt_t> varMap;
227  Int_t counter=0;
228 
229  // add the variables that were specified in theVariables
230  std::vector<TString>::iterator itrVariables;
231  Bool_t found = kFALSE;
232 
233  // iterate over all variables in 'variables' and add them
234  for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
235  counter=0;
236 
237  // check the variables of the old dsi for the variable that we want to add
238  for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
239  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
240  // don't compare the expression, since the user might take two times the same expression, but with different labels
241  // and apply different transformations to the variables.
242  dsi->AddVariable(*itrVarInfo);
243  varMap.push_back(counter);
244  found = kTRUE;
245  }
246  counter++;
247  }
248 
249  // check the spectators of the old dsi for the variable that we want to add
250  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
251  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
252  // don't compare the expression, since the user might take two times the same expression, but with different labels
253  // and apply different transformations to the variables.
254  dsi->AddVariable(*itrVarInfo);
255  varMap.push_back(counter);
256  found = kTRUE;
257  }
258  counter++;
259  }
260 
261  // if the variable is neither in the variables nor in the spectators, we abort
262  if (!found) {
263  Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
264  }
265  found = kFALSE;
266  }
267 
268  // in the case that no variables are specified, add the default-variables from the original dsi
269  if (theVariables=="") {
270  for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
271  dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
272  varMap.push_back(i);
273  }
274  }
275 
276  // add the variable map 'varMap' to the vector of varMaps
277  fVarMaps.push_back(varMap);
278 
279  // set classes and cuts
280  UInt_t nClasses=oldDSI.GetNClasses();
281  TString className;
282 
283  for (UInt_t i=0; i<nClasses; i++) {
284  className = oldDSI.GetClassInfo(i)->GetName();
285  dsi->AddClass(className);
286  dsi->SetCut(oldDSI.GetCut(i),className);
287  dsi->AddCut(theCut,className);
288  dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
289  }
290 
291  // set split options, root dir and normalization for the new dsi
292  dsi->SetSplitOptions(oldDSI.GetSplitOptions());
293  dsi->SetRootDir(oldDSI.GetRootDir());
294  TString norm(oldDSI.GetNormalization().Data());
295  dsi->SetNormalization(norm);
296 
297  DataSetInfo& dsiReference= (*dsi);
298 
299  return dsiReference;
300 }
301 
302 ////////////////////////////////////////////////////////////////////////////////
303 /// initialize the method
304 
306 {
307 }
308 
309 ////////////////////////////////////////////////////////////////////////////////
310 /// initialize the circular tree
311 
313 {
314  delete fCatTree;
315 
316  std::vector<VariableInfo>::const_iterator viIt;
317  const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
318  const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
319 
320  Bool_t hasAllExternalLinks = kTRUE;
321  for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
322  if( viIt->GetExternalLink() == 0 ) {
323  hasAllExternalLinks = kFALSE;
324  break;
325  }
326  for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
327  if( viIt->GetExternalLink() == 0 ) {
328  hasAllExternalLinks = kFALSE;
329  break;
330  }
331 
332  if(!hasAllExternalLinks) return;
333 
334  {
335  // Rather than having TTree::TTree add to the current directory and then remove it, let
336  // make sure to not add it in the first place.
337  // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
338  // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
339  TDirectory::TContext ctxt(nullptr);
340  fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
341  fCatTree->SetCircular(1);
342  }
343 
344  for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
345  const VariableInfo& vi = *viIt;
346  fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
347  }
348  for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
349  const VariableInfo& vi = *viIt;
350  if(vi.GetVarType()=='C') continue;
351  fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
352  }
353 
354  for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
355  fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
356  }
357 }
358 
359 ////////////////////////////////////////////////////////////////////////////////
360 /// train all sub-classifiers
361 
363 {
364  // specify the minimum # of training events and set 'classification'
365  const Int_t MinNoTrainingEvents = 10;
366 
367  Types::EAnalysisType analysisType = GetAnalysisType();
368 
369  // start the training
370  Log() << kINFO << "Train all sub-classifiers for "
371  << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
372 
373  // don't do anything if no sub-classifier booked
374  if (fMethods.empty()) {
375  Log() << kINFO << "...nothing found to train" << Endl;
376  return;
377  }
378 
379  std::vector<IMethod*>::iterator itrMethod;
380 
381  // iterate over all booked sub-classifiers and train them
382  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
383 
384  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
385  if(!mva) continue;
386  mva->SetAnalysisType( analysisType );
387  if (!mva->HasAnalysisType( analysisType,
388  mva->DataInfo().GetNClasses(),
389  mva->DataInfo().GetNTargets() ) ) {
390  Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
391  if (analysisType == Types::kRegression)
392  Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
393  else
394  Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
395  itrMethod = fMethods.erase( itrMethod );
396  continue;
397  }
398  if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
399 
400  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
401  << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
402  mva->TrainMethod();
403  Log() << kINFO << "Training finished" << Endl;
404 
405  } else {
406 
407  Log() << kWARNING << "Method " << mva->GetMethodName()
408  << " not trained (training tree has less entries ["
409  << mva->Data()->GetNTrainingEvents()
410  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
411 
412  Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
413  Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
414 
415  }
416  }
417 
418  if (analysisType != Types::kRegression) {
419 
420  // variable ranking
421  Log() << kINFO << "Begin ranking of input variables..." << Endl;
422  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
423  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
424  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
425  const Ranking* ranking = (*itrMethod)->CreateRanking();
426  if (ranking != 0)
427  ranking->Print();
428  else
429  Log() << kINFO << "No variable ranking supplied by classifier: "
430  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
431  }
432  }
433  }
434 }
435 
436 ////////////////////////////////////////////////////////////////////////////////
437 /// create XML description of Category classifier
438 
439 void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
440 {
441  void* wght = gTools().AddChild(parent, "Weights");
442  gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
443  void* submethod(0);
444 
445  // iterate over methods and write them to XML file
446  for (UInt_t i=0; i<fMethods.size(); i++) {
447  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
448  submethod = gTools().AddChild(wght, "SubMethod");
449  gTools().AddAttr(submethod, "Index", i);
450  gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
451  gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
452  gTools().AddAttr(submethod, "Variables", fVars[i]);
453  method->WriteStateToXML( submethod );
454  }
455 }
456 
457 ////////////////////////////////////////////////////////////////////////////////
458 /// read weights of sub-classifiers of MethodCategory from xml weight file
459 
461 {
462  UInt_t nSubMethods;
463  TString fullMethodName;
464  TString methodType;
465  TString methodTitle;
466  TString theCutString;
467  TString theVariables;
468  Int_t titleLength;
469  gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
470  void* subMethodNode = gTools().GetChild(wghtnode);
471 
472  Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
473 
474  // recreate all sub-methods from weight file
475  for (UInt_t i=0; i<nSubMethods; i++) {
476  gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
477  gTools().ReadAttr( subMethodNode, "Cut", theCutString );
478  gTools().ReadAttr( subMethodNode, "Variables", theVariables );
479 
480  // determine sub-method type
481  methodType = fullMethodName(0,fullMethodName.Index("::"));
482  if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
483 
484  // determine sub-method title
485  titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
486  methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
487 
488  // reconstruct dsi for sub-method
489  DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
490 
491  // recreate sub-method from weights and add to fMethods
492  MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
493  dsi, "none" ) );
494  if(method==0)
495  Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
496 
497  method->SetupMethod();
498  method->ReadStateFromXML(subMethodNode);
499 
500  fMethods.push_back(method);
501  fCategoryCuts.push_back(TCut(theCutString));
502  fVars.push_back(theVariables);
503 
504  DataSetInfo& primaryDSI = DataInfo();
505 
506  UInt_t spectatorIdx = 10000;
507  UInt_t counter=0;
508 
509  // find the spectator index
510  std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
511  std::vector<VariableInfo>::iterator itrVarInfo;
512  TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
513 
514  for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
515  if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
516  spectatorIdx=counter;
517  fCategorySpecIdx.push_back(spectatorIdx);
518  break;
519  }
520  }
521 
522  subMethodNode = gTools().GetNextChild(subMethodNode);
523  }
524 
525  InitCircularTree(DataInfo());
526 
527 }
528 
529 ////////////////////////////////////////////////////////////////////////////////
530 /// process user options
531 
533 {
534 }
535 
536 ////////////////////////////////////////////////////////////////////////////////
537 /// Get help message text
538 ///
539 /// typical length of text line:
540 /// "|--------------------------------------------------------------|"
541 
543 {
544  Log() << Endl;
545  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
546  Log() << Endl;
547  Log() << "This method allows to define different categories of events. The" <<Endl;
548  Log() << "categories are defined via cuts on the variables. For each" << Endl;
549  Log() << "category, a different classifier and set of variables can be" <<Endl;
550  Log() << "specified. The categories which are defined for this method must" << Endl;
551  Log() << "be disjoint." << Endl;
552 }
553 
554 ////////////////////////////////////////////////////////////////////////////////
555 /// no ranking
556 
558 {
559  return 0;
560 }
561 
562 ////////////////////////////////////////////////////////////////////////////////
563 
565 {
566  // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
567  // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
568  // one will be evaluated.. (the formulae return 'true' or 'false'
569  if (fCatTree) {
570  if (methodIdx>=fCatFormulas.size()) {
571  Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
572  << fCatFormulas.size() << Endl;
573  }
574  TTreeFormula* f = fCatFormulas[methodIdx];
575  return f->EvalInstance(0) > 0.5;
576  }
577  // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
578  else {
579 
580  // checks whether an event lies within a cut
581  if (methodIdx>=fCategorySpecIdx.size()) {
582  Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
583  << fCategorySpecIdx.size() << Endl;
584  }
585  UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
586  Float_t specVal = ev->GetSpectator(spectatorIdx);
587  Bool_t pass = (specVal>0.5);
588  return pass;
589  }
590 }
591 
592 ////////////////////////////////////////////////////////////////////////////////
593 /// returns the mva value of the right sub-classifier
594 
596 {
597  if (fMethods.empty()) return 0;
598 
599  UInt_t methodToUse = 0;
600  const Event* ev = GetEvent();
601 
602  // determine which sub-classifier to use for this event
603  Int_t suitableCutsN = 0;
604 
605  for (UInt_t i=0; i<fMethods.size(); ++i) {
606  if (PassesCut(ev, i)) {
607  ++suitableCutsN;
608  methodToUse=i;
609  }
610  }
611 
612  if (suitableCutsN == 0) {
613  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
614  return 0;
615  }
616 
617  if (suitableCutsN > 1) {
618  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
619  return 0;
620  }
621 
622  // get mva value from the suitable sub-classifier
623  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
624  Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
625  ev->SetVariableArrangement(0);
626 
627  return mvaValue;
628 }
629 
630 
631 
632 ////////////////////////////////////////////////////////////////////////////////
633 /// returns the mva value of the right sub-classifier
634 
635 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
636 {
637  if (fMethods.empty()) return MethodBase::GetRegressionValues();
638 
639  UInt_t methodToUse = 0;
640  const Event* ev = GetEvent();
641 
642  // determine which sub-classifier to use for this event
643  Int_t suitableCutsN = 0;
644 
645  for (UInt_t i=0; i<fMethods.size(); ++i) {
646  if (PassesCut(ev, i)) {
647  ++suitableCutsN;
648  methodToUse=i;
649  }
650  }
651 
652  if (suitableCutsN == 0) {
653  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
655  }
656 
657  if (suitableCutsN > 1) {
658  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
660  }
661  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
662  if (!meth){
663  Log() << kFATAL << "method not found in Category Regression method" << Endl;
665  }
666  // get mva value from the suitable sub-classifier
667  return meth->GetRegressionValues(ev);
668 }
669 
TMVA::MethodCategory::HasAnalysisType
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
Definition: MethodCategory.cxx:116
TMVA::MethodCategory::AddMethod
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
Definition: MethodCategory.cxx:138
TCut
A specialized string object used for TTree selections.
Definition: TCut.h:25
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
TH1F.h
DataSetManager.h
TMVA::Tools::GetChild
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
TMVA::MethodBase::Data
DataSet * Data() const
Definition: MethodBase.h:408
TMVA::MethodBase::SetModelPersistence
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
TMVA::Ranking::Print
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
f
#define f(i)
Definition: RSha256.hxx:104
TDirectory.h
TMVA::DataSetInfo::AddSpectator
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
Definition: DataSetInfo.cxx:289
TMVA::DataSetInfo::SetCut
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
Definition: DataSetInfo.cxx:360
TMVA::Tools::SplitString
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
TMVA::Types::kRegression
@ kRegression
Definition: Types.h:130
TMVA::MethodBase::SetupMethod
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
TMVA::DataSetInfo::SetNormalization
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:132
TString::Data
const char * Data() const
Definition: TString.h:369
DataSetInfo.h
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
Form
char * Form(const char *fmt,...)
TNamed::GetTitle
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
TMVA::Ranking
Ranking for variables in method (implementation)
Definition: Ranking.h:48
TMVA::MethodBase::ReadStateFromXML
void ReadStateFromXML(void *parent)
Definition: MethodBase.cxx:1469
TMVA::MethodCategory::GetRegressionValues
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
Definition: MethodCategory.cxx:635
IMethod.h
TMVA::MethodBase::SetSilentFile
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
MinNoTrainingEvents
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:95
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TTree
A TTree represents a columnar dataset.
Definition: TTree.h:79
Ranking.h
TMVA::Tools::AddChild
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
TMVA::MethodCategory::ProcessOptions
void ProcessOptions()
process user options
Definition: MethodCategory.cxx:532
TTreeFormula.h
TMVA::MethodBase::TrainMethod
void TrainMethod()
Definition: MethodBase.cxx:650
Float_t
float Float_t
Definition: RtypesCore.h:57
VariableInfo.h
TString::Contains
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:624
TString::Length
Ssiz_t Length() const
Definition: TString.h:410
TMVA::DataSetInfo::GetSpectatorInfos
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
TMVA::MethodBase::GetRegressionValues
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:213
TDirectory::TContext
Small helper to keep current directory context.
Definition: TDirectory.h:47
MethodBase.h
TMVA::DataSetInfo::SetSplitOptions
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:185
TString
Basic string class.
Definition: TString.h:136
TMVA::VariableInfo::GetVarType
char GetVarType() const
Definition: VariableInfo.h:61
TMVA::MethodCategory
Class for categorizing the phase space.
Definition: MethodCategory.h:58
TMVA::MethodCategory::PassesCut
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
Definition: MethodCategory.cxx:564
REGISTER_METHOD
#define REGISTER_METHOD(CLASS)
for example
Definition: ClassifierFactory.h:124
TMVA::MethodCompositeBase
Virtual base class for combining several TMVA method.
Definition: MethodCompositeBase.h:50
bool
TMVA::MethodCategory::DeclareOptions
void DeclareOptions()
options for this method
Definition: MethodCategory.cxx:131
TString::Last
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
PDF.h
TMVA::MethodCategory::AddWeightsXMLTo
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
Definition: MethodCategory.cxx:439
TMVA::MethodBase::DataInfo
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
TMVA::DataSetInfo::SetWeightExpression
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
Definition: DataSetInfo.cxx:333
TMVA::ClassifierFactory::Instance
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Definition: ClassifierFactory.cxx:48
TMVA::DataSetInfo::GetNClasses
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
TMVA::MethodCategory::~MethodCategory
virtual ~MethodCategory(void)
destructor
Definition: MethodCategory.cxx:104
TMVA::Tools::AddAttr
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
TSpline.h
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::DataSetInfo::GetCut
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:168
TMVA::MethodCategory::Train
void Train(void)
train all sub-classifiers
Definition: MethodCategory.cxx:362
TMVA::DataSetInfo::GetTargetInfos
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:114
TMVA::DataSetInfo::AddCut
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Definition: DataSetInfo.cxx:376
VariableRearrangeTransform.h
TMVA::MethodBase::GetMethodName
const TString & GetMethodName() const
Definition: MethodBase.h:330
MsgLogger.h
Timer.h
TMVA::MethodBase::MethodCategory
friend class MethodCategory
Definition: MethodBase.h:268
TMVA::DataSetInfo::GetNormalization
const TString & GetNormalization() const
Definition: DataSetInfo.h:131
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::MethodCategory::CreateRanking
const Ranking * CreateRanking()
no ranking
Definition: MethodCategory.cxx:557
TMVA::variables
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
TMVA::MethodBase::CheckSetup
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
TMVA::DataSetInfo::GetWeightExpression
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:164
TMVA::DataSetInfo::GetClassInfo
ClassInfo * GetClassInfo(Int_t clNum) const
Definition: DataSetInfo.cxx:146
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TMVA::VariableInfo
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
TMVA::MethodBase::SetMethodBaseDir
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:373
TMVA::MethodCategory::InitCircularTree
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
Definition: MethodCategory.cxx:312
TMVA::Tools::ReadAttr
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
TMVA::MethodCategory::Init
void Init()
initialize the method
Definition: MethodCategory.cxx:305
TMVA::MethodCompositeBase::GetMvaValue
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
TTreeFormula
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
TRandom3.h
TDirectory::GetDirectory
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:401
TMVA::MethodBase::SetWeightFileDir
void SetWeightFileDir(TString fileDir)
set directory of weight file
Definition: MethodBase.cxx:2048
TMVA::MethodBase::GetMethodTypeName
TString GetMethodTypeName() const
Definition: MethodBase.h:331
TMVA::MethodBase
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TMVA::Types
Singleton class for Global types used by TMVA.
Definition: Types.h:73
TMVA::MethodBase::DisableWriting
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:442
TMVA::DataSetInfo::GetRootDir
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:190
Types.h
TMVA::DataSetInfo::GetSplitOptions
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:186
TMVA::MethodBase::WriteStateToXML
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
Definition: MethodBase.cxx:1320
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Config.h
unsigned int
TMVA::TMVAGlob::GetMethodName
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
TMVA::IMethod
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
TMVA::Tools::Color
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
TString::Index
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:639
TMVA::DataSetInfo::AddVariable
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
Definition: DataSetInfo.cxx:207
TMVA::DataSetInfo::GetNTargets
UInt_t GetNTargets() const
Definition: DataSetInfo.h:128
TMVA::Event::GetSpectator
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:261
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::DataSetInfo::SetRootDir
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:189
TMVA::DataSetInfo::GetVariableInfos
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:103
TMVA::MethodBase::SetAnalysisType
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
TMVA::MethodBase::GetRegressionValues
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:220
TMVA::Types::Instance
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:69
TMVA::DataSetInfo::AddClass
ClassInfo * AddClass(const TString &className)
Definition: DataSetInfo.cxx:113
TMVA::IMethod::HasAnalysisType
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
TMVA::Tools::GetNextChild
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
TMVA::Types::EMVA
EMVA
Definition: Types.h:78
TMVA::MethodBase::GetName
const char * GetName() const
Definition: MethodBase.h:333
TMVA::MethodCategory::CreateCategoryDSI
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
Definition: MethodCategory.cxx:200
TMVA::Event
Definition: Event.h:51
TMVA::Event::SetVariableArrangement
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:191
TMVA::MethodCategory::ReadWeightsFromXML
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
Definition: MethodCategory.cxx:460
TDirectory
Describe directory structure in memory.
Definition: TDirectory.h:40
TMVA::VariableInfo::GetExpression
const TString & GetExpression() const
Definition: VariableInfo.h:57
MethodCompositeBase.h
TMVA::MethodBase::SetFile
void SetFile(TFile *file)
Definition: MethodBase.h:374
TMVA::VariableInfo::GetExternalLink
void * GetExternalLink() const
Definition: VariableInfo.h:83
Tools.h
TMVA::DataSet::GetNTrainingEvents
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:68
TNamed::GetName
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
ClassifierFactory.h
type
int type
Definition: TGX11.cxx:121
TMVA::DataSetInfo::AddTarget
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
Definition: DataSetInfo.cxx:259
TMVA::gTools
Tools & gTools()
TMVA::MethodCategory::GetHelpMessage
void GetHelpMessage() const
Get help message text.
Definition: MethodCategory.cxx:542
MethodCategory.h
DataSet.h
TMVA::Configurable::ParseOptions
virtual void ParseOptions()
options parser
Definition: Configurable.cxx:124
TMVA::ClassifierFactory::Create
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
Definition: ClassifierFactory.cxx:89
TMVA::MethodBase::ProcessSetup
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int