Logo ROOT  
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <iomanip>
49#include <vector>
50#include <iostream>
51
52#include "Riostream.h"
53#include "TRandom3.h"
54#include "TMath.h"
55#include "TH1F.h"
56#include "TGraph.h"
57#include "TSpline.h"
58#include "TDirectory.h"
59#include "TTreeFormula.h"
60
62#include "TMVA/Config.h"
63#include "TMVA/DataSet.h"
64#include "TMVA/DataSetInfo.h"
65#include "TMVA/DataSetManager.h"
66#include "TMVA/IMethod.h"
67#include "TMVA/MethodBase.h"
69#include "TMVA/MsgLogger.h"
70#include "TMVA/PDF.h"
71#include "TMVA/Ranking.h"
72#include "TMVA/Timer.h"
73#include "TMVA/Tools.h"
74#include "TMVA/Types.h"
75#include "TMVA/VariableInfo.h"
77
78REGISTER_METHOD(Category)
79
81
82////////////////////////////////////////////////////////////////////////////////
83/// standard constructor
84
86 const TString& methodTitle,
87 DataSetInfo& theData,
88 const TString& theOption )
89 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
90 fCatTree(0),
91 fDataSetManager(NULL)
92{
93}
94
95////////////////////////////////////////////////////////////////////////////////
96/// constructor from weight file
97
99 const TString& theWeightFile)
100 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
101 fCatTree(0),
102 fDataSetManager(NULL)
103{
104}
105
106////////////////////////////////////////////////////////////////////////////////
107/// destructor
108
110{
111 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
112 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
113 for(;formIt!=lastF; ++formIt) delete *formIt;
114 delete fCatTree;
115}
116
117////////////////////////////////////////////////////////////////////////////////
118/// check whether method category has analysis type
119/// the method type has to be the same for all sub-methods
120
122{
123 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
124
125 // iterate over methods and check whether they have the analysis type
126 for(; itrMethod != fMethods.end(); ++itrMethod ) {
127 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
128 return kFALSE;
129 }
130 return kTRUE;
131}
132
133////////////////////////////////////////////////////////////////////////////////
134/// options for this method
135
137{
138}
139
140////////////////////////////////////////////////////////////////////////////////
141/// adds sub-classifier for a category
142
144 const TString& theVariables,
145 Types::EMVA theMethod ,
146 const TString& theTitle,
147 const TString& theOptions )
148{
149 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
150
151 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
152
153 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
154
155 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
156
157 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
158 if(method==0) return 0;
159
160 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
161 method->SetModelPersistence(fModelPersistence);
162 method->SetAnalysisType( fAnalysisType );
163 method->SetupMethod();
164 method->ParseOptions();
165 method->ProcessSetup();
166 method->SetFile(fFile);
167 method->SetSilentFile(IsSilentFile());
168
169
170 // set or create correct method base dir for added method
171 const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
172 TDirectory * dir = BaseDir()->GetDirectory(dirName);
173 if (dir != 0) method->SetMethodBaseDir( dir );
174 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
175
176 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
177
178 // check-for-unused-options is performed; may be overridden by derived
179 // classes
180 method->CheckSetup();
181
182 // disable writing of XML files and standalone classes for sub methods
183 method->DisableWriting( kTRUE );
184
185 // store method, cut and variable names and create cut formula
186 fMethods.push_back(method);
187 fCategoryCuts.push_back(theCut);
188 fVars.push_back(theVariables);
189
190 DataSetInfo& primaryDSI = DataInfo();
191
192 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
193 fCategorySpecIdx.push_back(newSpectatorIndex);
194
195 primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
196 Form("%s:%s",GetName(),method->GetName()),
197 "pass", 0, 0, 'C' );
198
199 return method;
200}
201
202////////////////////////////////////////////////////////////////////////////////
203/// create a DataSetInfo object for a sub-classifier
204
206 const TString& theVariables,
207 const TString& theTitle)
208{
209 // create a new dsi with name: theTitle+"_dsi"
210 TString dsiName=theTitle+"_dsi";
211 DataSetInfo& oldDSI = DataInfo();
212 DataSetInfo* dsi = new DataSetInfo(dsiName);
213
214 // register the new dsi
215 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
216 fDataSetManager->AddDataSetInfo(*dsi);
217
218 // copy the targets and spectators from the old dsi to the new dsi
219 std::vector<VariableInfo>::iterator itrVarInfo;
220
221 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
222 dsi->AddTarget(*itrVarInfo);
223
224 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
225 dsi->AddSpectator(*itrVarInfo);
226
227 // split string that contains the variables into tiny little pieces
228 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
229
230 // prepare to create varMap
231 std::vector<UInt_t> varMap;
232 Int_t counter=0;
233
234 // add the variables that were specified in theVariables
235 std::vector<TString>::iterator itrVariables;
236 Bool_t found = kFALSE;
237
238 // iterate over all variables in 'variables' and add them
239 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
240 counter=0;
241
242 // check the variables of the old dsi for the variable that we want to add
243 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
244 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
245 // don't compare the expression, since the user might take two times the same expression, but with different labels
246 // and apply different transformations to the variables.
247 dsi->AddVariable(*itrVarInfo);
248 varMap.push_back(counter);
249 found = kTRUE;
250 }
251 counter++;
252 }
253
254 // check the spectators of the old dsi for the variable that we want to add
255 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
256 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
257 // don't compare the expression, since the user might take two times the same expression, but with different labels
258 // and apply different transformations to the variables.
259 dsi->AddVariable(*itrVarInfo);
260 varMap.push_back(counter);
261 found = kTRUE;
262 }
263 counter++;
264 }
265
266 // if the variable is neither in the variables nor in the spectators, we abort
267 if (!found) {
268 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
269 }
270 found = kFALSE;
271 }
272
273 // in the case that no variables are specified, add the default-variables from the original dsi
274 if (theVariables=="") {
275 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
276 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
277 varMap.push_back(i);
278 }
279 }
280
281 // add the variable map 'varMap' to the vector of varMaps
282 fVarMaps.push_back(varMap);
283
284 // set classes and cuts
285 UInt_t nClasses=oldDSI.GetNClasses();
286 TString className;
287
288 for (UInt_t i=0; i<nClasses; i++) {
289 className = oldDSI.GetClassInfo(i)->GetName();
290 dsi->AddClass(className);
291 dsi->SetCut(oldDSI.GetCut(i),className);
292 dsi->AddCut(theCut,className);
293 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
294 }
295
296 // set split options, root dir and normalization for the new dsi
297 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
298 dsi->SetRootDir(oldDSI.GetRootDir());
299 TString norm(oldDSI.GetNormalization().Data());
300 dsi->SetNormalization(norm);
301
302 DataSetInfo& dsiReference= (*dsi);
303
304 return dsiReference;
305}
306
307////////////////////////////////////////////////////////////////////////////////
308/// initialize the method
309
311{
312}
313
314////////////////////////////////////////////////////////////////////////////////
315/// initialize the circular tree
316
318{
319 delete fCatTree;
320
321 std::vector<VariableInfo>::const_iterator viIt;
322 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
323 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
324
325 Bool_t hasAllExternalLinks = kTRUE;
326 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
327 if( viIt->GetExternalLink() == 0 ) {
328 hasAllExternalLinks = kFALSE;
329 break;
330 }
331 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
332 if( viIt->GetExternalLink() == 0 ) {
333 hasAllExternalLinks = kFALSE;
334 break;
335 }
336
337 if(!hasAllExternalLinks) return;
338
339 {
340 // Rather than having TTree::TTree add to the current directory and then remove it, let
341 // make sure to not add it in the first place.
342 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
343 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
344 TDirectory::TContext ctxt(nullptr);
345 fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
346 fCatTree->SetCircular(1);
347 }
348
349 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
350 const VariableInfo& vi = *viIt;
351 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
352 }
353 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
354 const VariableInfo& vi = *viIt;
355 if(vi.GetVarType()=='C') continue;
356 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
357 }
358
359 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
360 fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
361 }
362}
363
364////////////////////////////////////////////////////////////////////////////////
365/// train all sub-classifiers
366
368{
369 // specify the minimum # of training events and set 'classification'
370 const Int_t MinNoTrainingEvents = 10;
371
372 Types::EAnalysisType analysisType = GetAnalysisType();
373
374 // start the training
375 Log() << kINFO << "Train all sub-classifiers for "
376 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
377
378 // don't do anything if no sub-classifier booked
379 if (fMethods.empty()) {
380 Log() << kINFO << "...nothing found to train" << Endl;
381 return;
382 }
383
384 std::vector<IMethod*>::iterator itrMethod;
385
386 // iterate over all booked sub-classifiers and train them
387 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
388
389 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
390 if(!mva) continue;
391 mva->SetAnalysisType( analysisType );
392 if (!mva->HasAnalysisType( analysisType,
393 mva->DataInfo().GetNClasses(),
394 mva->DataInfo().GetNTargets() ) ) {
395 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
396 if (analysisType == Types::kRegression)
397 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
398 else
399 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
400 itrMethod = fMethods.erase( itrMethod );
401 continue;
402 }
404
405 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
406 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
407 mva->TrainMethod();
408 Log() << kINFO << "Training finished" << Endl;
409
410 } else {
411
412 Log() << kWARNING << "Method " << mva->GetMethodName()
413 << " not trained (training tree has less entries ["
414 << mva->Data()->GetNTrainingEvents()
415 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
416
417 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
418 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
419
420 }
421 }
422
423 if (analysisType != Types::kRegression) {
424
425 // variable ranking
426 Log() << kINFO << "Begin ranking of input variables..." << Endl;
427 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
428 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
429 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
430 const Ranking* ranking = (*itrMethod)->CreateRanking();
431 if (ranking != 0)
432 ranking->Print();
433 else
434 Log() << kINFO << "No variable ranking supplied by classifier: "
435 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
436 }
437 }
438 }
439}
440
441////////////////////////////////////////////////////////////////////////////////
442/// create XML description of Category classifier
443
445{
446 void* wght = gTools().AddChild(parent, "Weights");
447 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
448 void* submethod(0);
449
450 // iterate over methods and write them to XML file
451 for (UInt_t i=0; i<fMethods.size(); i++) {
452 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
453 submethod = gTools().AddChild(wght, "SubMethod");
454 gTools().AddAttr(submethod, "Index", i);
455 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
456 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
457 gTools().AddAttr(submethod, "Variables", fVars[i]);
458 method->WriteStateToXML( submethod );
459 }
460}
461
462////////////////////////////////////////////////////////////////////////////////
463/// read weights of sub-classifiers of MethodCategory from xml weight file
464
466{
467 UInt_t nSubMethods;
468 TString fullMethodName;
469 TString methodType;
470 TString methodTitle;
471 TString theCutString;
472 TString theVariables;
473 Int_t titleLength;
474 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
475 void* subMethodNode = gTools().GetChild(wghtnode);
476
477 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
478
479 // recreate all sub-methods from weight file
480 for (UInt_t i=0; i<nSubMethods; i++) {
481 gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
482 gTools().ReadAttr( subMethodNode, "Cut", theCutString );
483 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
484
485 // determine sub-method type
486 methodType = fullMethodName(0,fullMethodName.Index("::"));
487 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
488
489 // determine sub-method title
490 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
491 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
492
493 // reconstruct dsi for sub-method
494 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
495
496 // recreate sub-method from weights and add to fMethods
497 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
498 dsi, "none" ) );
499 if(method==0)
500 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
501
502 method->SetupMethod();
503 method->ReadStateFromXML(subMethodNode);
504
505 fMethods.push_back(method);
506 fCategoryCuts.push_back(TCut(theCutString));
507 fVars.push_back(theVariables);
508
509 DataSetInfo& primaryDSI = DataInfo();
510
511 UInt_t spectatorIdx = 10000;
512 UInt_t counter=0;
513
514 // find the spectator index
515 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
516 std::vector<VariableInfo>::iterator itrVarInfo;
517 TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
518
519 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
520 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
521 spectatorIdx=counter;
522 fCategorySpecIdx.push_back(spectatorIdx);
523 break;
524 }
525 }
526
527 subMethodNode = gTools().GetNextChild(subMethodNode);
528 }
529
530 InitCircularTree(DataInfo());
531
532}
533
534////////////////////////////////////////////////////////////////////////////////
535/// process user options
536
538{
539}
540
541////////////////////////////////////////////////////////////////////////////////
542/// Get help message text
543///
544/// typical length of text line:
545/// "|--------------------------------------------------------------|"
546
548{
549 Log() << Endl;
550 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
551 Log() << Endl;
552 Log() << "This method allows to define different categories of events. The" <<Endl;
553 Log() << "categories are defined via cuts on the variables. For each" << Endl;
554 Log() << "category, a different classifier and set of variables can be" <<Endl;
555 Log() << "specified. The categories which are defined for this method must" << Endl;
556 Log() << "be disjoint." << Endl;
557}
558
559////////////////////////////////////////////////////////////////////////////////
560/// no ranking
561
563{
564 return 0;
565}
566
567////////////////////////////////////////////////////////////////////////////////
568
570{
571 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
572 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
573 // one will be evaluated.. (the formulae return 'true' or 'false'
574 if (fCatTree) {
575 if (methodIdx>=fCatFormulas.size()) {
576 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
577 << fCatFormulas.size() << Endl;
578 }
579 TTreeFormula* f = fCatFormulas[methodIdx];
580 return f->EvalInstance(0) > 0.5;
581 }
582 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
583 else {
584
585 // checks whether an event lies within a cut
586 if (methodIdx>=fCategorySpecIdx.size()) {
587 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
588 << fCategorySpecIdx.size() << Endl;
589 }
590 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
591 Float_t specVal = ev->GetSpectator(spectatorIdx);
592 Bool_t pass = (specVal>0.5);
593 return pass;
594 }
595}
596
597////////////////////////////////////////////////////////////////////////////////
598/// returns the mva value of the right sub-classifier
599
601{
602 if (fMethods.empty()) return 0;
603
604 UInt_t methodToUse = 0;
605 const Event* ev = GetEvent();
606
607 // determine which sub-classifier to use for this event
608 Int_t suitableCutsN = 0;
609
610 for (UInt_t i=0; i<fMethods.size(); ++i) {
611 if (PassesCut(ev, i)) {
612 ++suitableCutsN;
613 methodToUse=i;
614 }
615 }
616
617 if (suitableCutsN == 0) {
618 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
619 return 0;
620 }
621
622 if (suitableCutsN > 1) {
623 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
624 return 0;
625 }
626
627 // get mva value from the suitable sub-classifier
628 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
629 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
631
632 return mvaValue;
633}
634
635
636
637////////////////////////////////////////////////////////////////////////////////
638/// returns the mva value of the right sub-classifier
639
640const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
641{
642 if (fMethods.empty()) return MethodBase::GetRegressionValues();
643
644 UInt_t methodToUse = 0;
645 const Event* ev = GetEvent();
646
647 // determine which sub-classifier to use for this event
648 Int_t suitableCutsN = 0;
649
650 for (UInt_t i=0; i<fMethods.size(); ++i) {
651 if (PassesCut(ev, i)) {
652 ++suitableCutsN;
653 methodToUse=i;
654 }
655 }
656
657 if (suitableCutsN == 0) {
658 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
660 }
661
662 if (suitableCutsN > 1) {
663 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
665 }
666 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
667 if (!meth){
668 Log() << kFATAL << "method not found in Category Regression method" << Endl;
670 }
671 // get mva value from the suitable sub-classifier
672 return meth->GetRegressionValues(ev);
673}
674
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition: RSha256.hxx:104
int Int_t
Definition: RtypesCore.h:43
unsigned int UInt_t
Definition: RtypesCore.h:44
const Bool_t kFALSE
Definition: RtypesCore.h:90
bool Bool_t
Definition: RtypesCore.h:61
double Double_t
Definition: RtypesCore.h:57
float Float_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
int type
Definition: TGX11.cxx:120
char * Form(const char *fmt,...)
A specialized string object used for TTree selections.
Definition: TCut.h:25
Small helper to keep current directory context.
Definition: TDirectory.h:47
Describe directory structure in memory.
Definition: TDirectory.h:40
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:401
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual void ParseOptions()
options parser
Class that contains all the data information.
Definition: DataSetInfo.h:60
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:162
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:101
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:183
ClassInfo * AddClass(const TString &className)
const TString & GetNormalization() const
Definition: DataSetInfo.h:129
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:120
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:188
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:130
UInt_t GetNClasses() const
Definition: DataSetInfo.h:153
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:184
UInt_t GetNTargets() const
Definition: DataSetInfo.h:126
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:166
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:112
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:187
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:191
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:261
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:220
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:213
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
TString GetMethodTypeName() const
Definition: MethodBase.h:331
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:442
const char * GetName() const
Definition: MethodBase.h:333
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:408
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:330
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:425
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
void SetFile(TFile *file)
Definition: MethodBase.h:374
void ReadStateFromXML(void *parent)
friend class MethodCategory
Definition: MethodBase.h:268
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:373
DataSet * Data() const
Definition: MethodBase.h:408
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:435
Class for categorizing the phase space.
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void GetHelpMessage() const
Get help message text.
void Init()
initialize the method
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
void ProcessOptions()
process user options
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
void DeclareOptions()
options for this method
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
const Ranking * CreateRanking()
no ranking
virtual ~MethodCategory(void)
destructor
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
void Train(void)
train all sub-classifiers
Virtual base class for combining several TMVA method.
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1210
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
EAnalysisType
Definition: Types.h:127
@ kRegression
Definition: Types.h:129
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void * GetExternalLink() const
Definition: VariableInfo.h:83
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
const char * Data() const
Definition: TString.h:364
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:634
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
A TTree represents a columnar dataset.
Definition: TTree.h:78
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:335
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:98