Logo ROOT  
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <iomanip>
49#include <vector>
50#include <iostream>
51
52#include "Riostream.h"
53#include "TRandom3.h"
54#include "TMath.h"
55#include "TObjString.h"
56#include "TH1F.h"
57#include "TGraph.h"
58#include "TSpline.h"
59#include "TDirectory.h"
60#include "TTreeFormula.h"
61
63#include "TMVA/Config.h"
64#include "TMVA/DataSet.h"
65#include "TMVA/DataSetInfo.h"
66#include "TMVA/DataSetManager.h"
67#include "TMVA/IMethod.h"
68#include "TMVA/MethodBase.h"
70#include "TMVA/MsgLogger.h"
71#include "TMVA/PDF.h"
72#include "TMVA/Ranking.h"
73#include "TMVA/Timer.h"
74#include "TMVA/Tools.h"
75#include "TMVA/Types.h"
76#include "TMVA/VariableInfo.h"
78
79REGISTER_METHOD(Category)
80
82
83////////////////////////////////////////////////////////////////////////////////
84/// standard constructor
85
87 const TString& methodTitle,
88 DataSetInfo& theData,
89 const TString& theOption )
90 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
91 fCatTree(0),
92 fDataSetManager(NULL)
93{
94}
95
96////////////////////////////////////////////////////////////////////////////////
97/// constructor from weight file
98
100 const TString& theWeightFile)
101 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
102 fCatTree(0),
103 fDataSetManager(NULL)
104{
105}
106
107////////////////////////////////////////////////////////////////////////////////
108/// destructor
109
111{
112 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
113 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
114 for(;formIt!=lastF; ++formIt) delete *formIt;
115 delete fCatTree;
116}
117
118////////////////////////////////////////////////////////////////////////////////
119/// check whether method category has analysis type
120/// the method type has to be the same for all sub-methods
121
123{
124 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
125
126 // iterate over methods and check whether they have the analysis type
127 for(; itrMethod != fMethods.end(); ++itrMethod ) {
128 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
129 return kFALSE;
130 }
131 return kTRUE;
132}
133
134////////////////////////////////////////////////////////////////////////////////
135/// options for this method
136
138{
139}
140
141////////////////////////////////////////////////////////////////////////////////
142/// adds sub-classifier for a category
143
145 const TString& theVariables,
146 Types::EMVA theMethod ,
147 const TString& theTitle,
148 const TString& theOptions )
149{
150 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
151
152 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
153
154 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
155
156 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
157
158 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
159 if(method==0) return 0;
160
161 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
162 method->SetModelPersistence(fModelPersistence);
163 method->SetAnalysisType( fAnalysisType );
164 method->SetupMethod();
165 method->ParseOptions();
166 method->ProcessSetup();
167 method->SetFile(fFile);
168 method->SetSilentFile(IsSilentFile());
169
170
171 // set or create correct method base dir for added method
172 const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
173 TDirectory * dir = BaseDir()->GetDirectory(dirName);
174 if (dir != 0) method->SetMethodBaseDir( dir );
175 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
176
177 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
178
179 // check-for-unused-options is performed; may be overridden by derived
180 // classes
181 method->CheckSetup();
182
183 // disable writing of XML files and standalone classes for sub methods
184 method->DisableWriting( kTRUE );
185
186 // store method, cut and variable names and create cut formula
187 fMethods.push_back(method);
188 fCategoryCuts.push_back(theCut);
189 fVars.push_back(theVariables);
190
191 DataSetInfo& primaryDSI = DataInfo();
192
193 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
194 fCategorySpecIdx.push_back(newSpectatorIndex);
195
196 primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
197 Form("%s:%s",GetName(),method->GetName()),
198 "pass", 0, 0, 'C' );
199
200 return method;
201}
202
203////////////////////////////////////////////////////////////////////////////////
204/// create a DataSetInfo object for a sub-classifier
205
207 const TString& theVariables,
208 const TString& theTitle)
209{
210 // create a new dsi with name: theTitle+"_dsi"
211 TString dsiName=theTitle+"_dsi";
212 DataSetInfo& oldDSI = DataInfo();
213 DataSetInfo* dsi = new DataSetInfo(dsiName);
214
215 // register the new dsi
216 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
217 fDataSetManager->AddDataSetInfo(*dsi);
218
219 // copy the targets and spectators from the old dsi to the new dsi
220 std::vector<VariableInfo>::iterator itrVarInfo;
221
222 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
223 dsi->AddTarget(*itrVarInfo);
224
225 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
226 dsi->AddSpectator(*itrVarInfo);
227
228 // split string that contains the variables into tiny little pieces
229 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
230
231 // prepare to create varMap
232 std::vector<UInt_t> varMap;
233 Int_t counter=0;
234
235 // add the variables that were specified in theVariables
236 std::vector<TString>::iterator itrVariables;
237 Bool_t found = kFALSE;
238
239 // iterate over all variables in 'variables' and add them
240 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
241 counter=0;
242
243 // check the variables of the old dsi for the variable that we want to add
244 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
245 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
246 // don't compare the expression, since the user might take two times the same expression, but with different labels
247 // and apply different transformations to the variables.
248 dsi->AddVariable(*itrVarInfo);
249 varMap.push_back(counter);
250 found = kTRUE;
251 }
252 counter++;
253 }
254
255 // check the spectators of the old dsi for the variable that we want to add
256 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
257 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
258 // don't compare the expression, since the user might take two times the same expression, but with different labels
259 // and apply different transformations to the variables.
260 dsi->AddVariable(*itrVarInfo);
261 varMap.push_back(counter);
262 found = kTRUE;
263 }
264 counter++;
265 }
266
267 // if the variable is neither in the variables nor in the spectators, we abort
268 if (!found) {
269 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
270 }
271 found = kFALSE;
272 }
273
274 // in the case that no variables are specified, add the default-variables from the original dsi
275 if (theVariables=="") {
276 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
277 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
278 varMap.push_back(i);
279 }
280 }
281
282 // add the variable map 'varMap' to the vector of varMaps
283 fVarMaps.push_back(varMap);
284
285 // set classes and cuts
286 UInt_t nClasses=oldDSI.GetNClasses();
287 TString className;
288
289 for (UInt_t i=0; i<nClasses; i++) {
290 className = oldDSI.GetClassInfo(i)->GetName();
291 dsi->AddClass(className);
292 dsi->SetCut(oldDSI.GetCut(i),className);
293 dsi->AddCut(theCut,className);
294 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
295 }
296
297 // set split options, root dir and normalization for the new dsi
298 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
299 dsi->SetRootDir(oldDSI.GetRootDir());
300 TString norm(oldDSI.GetNormalization().Data());
301 dsi->SetNormalization(norm);
302
303 DataSetInfo& dsiReference= (*dsi);
304
305 return dsiReference;
306}
307
308////////////////////////////////////////////////////////////////////////////////
309/// initialize the method
310
312{
313}
314
315////////////////////////////////////////////////////////////////////////////////
316/// initialize the circular tree
317
319{
320 delete fCatTree;
321
322 std::vector<VariableInfo>::const_iterator viIt;
323 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
324 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
325
326 Bool_t hasAllExternalLinks = kTRUE;
327 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
328 if( viIt->GetExternalLink() == 0 ) {
329 hasAllExternalLinks = kFALSE;
330 break;
331 }
332 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
333 if( viIt->GetExternalLink() == 0 ) {
334 hasAllExternalLinks = kFALSE;
335 break;
336 }
337
338 if(!hasAllExternalLinks) return;
339
340 {
341 // Rather than having TTree::TTree add to the current directory and then remove it, let
342 // make sure to not add it in the first place.
343 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
344 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
345 TDirectory::TContext ctxt(nullptr);
346 fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
347 fCatTree->SetCircular(1);
348 }
349
350 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
351 const VariableInfo& vi = *viIt;
352 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
353 }
354 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
355 const VariableInfo& vi = *viIt;
356 if(vi.GetVarType()=='C') continue;
357 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
358 }
359
360 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
361 fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
362 }
363}
364
365////////////////////////////////////////////////////////////////////////////////
366/// train all sub-classifiers
367
369{
370 // specify the minimum # of training events and set 'classification'
371 const Int_t MinNoTrainingEvents = 10;
372
373 Types::EAnalysisType analysisType = GetAnalysisType();
374
375 // start the training
376 Log() << kINFO << "Train all sub-classifiers for "
377 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
378
379 // don't do anything if no sub-classifier booked
380 if (fMethods.empty()) {
381 Log() << kINFO << "...nothing found to train" << Endl;
382 return;
383 }
384
385 std::vector<IMethod*>::iterator itrMethod;
386
387 // iterate over all booked sub-classifiers and train them
388 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
389
390 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
391 if(!mva) continue;
392 mva->SetAnalysisType( analysisType );
393 if (!mva->HasAnalysisType( analysisType,
394 mva->DataInfo().GetNClasses(),
395 mva->DataInfo().GetNTargets() ) ) {
396 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
397 if (analysisType == Types::kRegression)
398 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
399 else
400 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
401 itrMethod = fMethods.erase( itrMethod );
402 continue;
403 }
405
406 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
407 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
408 mva->TrainMethod();
409 Log() << kINFO << "Training finished" << Endl;
410
411 } else {
412
413 Log() << kWARNING << "Method " << mva->GetMethodName()
414 << " not trained (training tree has less entries ["
415 << mva->Data()->GetNTrainingEvents()
416 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
417
418 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
419 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
420
421 }
422 }
423
424 if (analysisType != Types::kRegression) {
425
426 // variable ranking
427 Log() << kINFO << "Begin ranking of input variables..." << Endl;
428 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
429 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
430 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
431 const Ranking* ranking = (*itrMethod)->CreateRanking();
432 if (ranking != 0)
433 ranking->Print();
434 else
435 Log() << kINFO << "No variable ranking supplied by classifier: "
436 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
437 }
438 }
439 }
440}
441
442////////////////////////////////////////////////////////////////////////////////
443/// create XML description of Category classifier
444
446{
447 void* wght = gTools().AddChild(parent, "Weights");
448 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
449 void* submethod(0);
450
451 // iterate over methods and write them to XML file
452 for (UInt_t i=0; i<fMethods.size(); i++) {
453 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
454 submethod = gTools().AddChild(wght, "SubMethod");
455 gTools().AddAttr(submethod, "Index", i);
456 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
457 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
458 gTools().AddAttr(submethod, "Variables", fVars[i]);
459 method->WriteStateToXML( submethod );
460 }
461}
462
463////////////////////////////////////////////////////////////////////////////////
464/// read weights of sub-classifiers of MethodCategory from xml weight file
465
467{
468 UInt_t nSubMethods;
469 TString fullMethodName;
470 TString methodType;
471 TString methodTitle;
472 TString theCutString;
473 TString theVariables;
474 Int_t titleLength;
475 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
476 void* subMethodNode = gTools().GetChild(wghtnode);
477
478 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
479
480 // recreate all sub-methods from weight file
481 for (UInt_t i=0; i<nSubMethods; i++) {
482 gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
483 gTools().ReadAttr( subMethodNode, "Cut", theCutString );
484 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
485
486 // determine sub-method type
487 methodType = fullMethodName(0,fullMethodName.Index("::"));
488 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
489
490 // determine sub-method title
491 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
492 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
493
494 // reconstruct dsi for sub-method
495 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
496
497 // recreate sub-method from weights and add to fMethods
498 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
499 dsi, "none" ) );
500 if(method==0)
501 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
502
503 method->SetupMethod();
504 method->ReadStateFromXML(subMethodNode);
505
506 fMethods.push_back(method);
507 fCategoryCuts.push_back(TCut(theCutString));
508 fVars.push_back(theVariables);
509
510 DataSetInfo& primaryDSI = DataInfo();
511
512 UInt_t spectatorIdx = 10000;
513 UInt_t counter=0;
514
515 // find the spectator index
516 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
517 std::vector<VariableInfo>::iterator itrVarInfo;
518 TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
519
520 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
521 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
522 spectatorIdx=counter;
523 fCategorySpecIdx.push_back(spectatorIdx);
524 break;
525 }
526 }
527
528 subMethodNode = gTools().GetNextChild(subMethodNode);
529 }
530
531 InitCircularTree(DataInfo());
532
533}
534
535////////////////////////////////////////////////////////////////////////////////
536/// process user options
537
539{
540}
541
542////////////////////////////////////////////////////////////////////////////////
543/// Get help message text
544///
545/// typical length of text line:
546/// "|--------------------------------------------------------------|"
547
549{
550 Log() << Endl;
551 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
552 Log() << Endl;
553 Log() << "This method allows to define different categories of events. The" <<Endl;
554 Log() << "categories are defined via cuts on the variables. For each" << Endl;
555 Log() << "category, a different classifier and set of variables can be" <<Endl;
556 Log() << "specified. The categories which are defined for this method must" << Endl;
557 Log() << "be disjoint." << Endl;
558}
559
560////////////////////////////////////////////////////////////////////////////////
561/// no ranking
562
564{
565 return 0;
566}
567
568////////////////////////////////////////////////////////////////////////////////
569
571{
572 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
573 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
574 // one will be evaluated.. (the formulae return 'true' or 'false'
575 if (fCatTree) {
576 if (methodIdx>=fCatFormulas.size()) {
577 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
578 << fCatFormulas.size() << Endl;
579 }
580 TTreeFormula* f = fCatFormulas[methodIdx];
581 return f->EvalInstance(0) > 0.5;
582 }
583 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
584 else {
585
586 // checks whether an event lies within a cut
587 if (methodIdx>=fCategorySpecIdx.size()) {
588 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
589 << fCategorySpecIdx.size() << Endl;
590 }
591 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
592 Float_t specVal = ev->GetSpectator(spectatorIdx);
593 Bool_t pass = (specVal>0.5);
594 return pass;
595 }
596}
597
598////////////////////////////////////////////////////////////////////////////////
599/// returns the mva value of the right sub-classifier
600
602{
603 if (fMethods.empty()) return 0;
604
605 UInt_t methodToUse = 0;
606 const Event* ev = GetEvent();
607
608 // determine which sub-classifier to use for this event
609 Int_t suitableCutsN = 0;
610
611 for (UInt_t i=0; i<fMethods.size(); ++i) {
612 if (PassesCut(ev, i)) {
613 ++suitableCutsN;
614 methodToUse=i;
615 }
616 }
617
618 if (suitableCutsN == 0) {
619 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
620 return 0;
621 }
622
623 if (suitableCutsN > 1) {
624 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
625 return 0;
626 }
627
628 // get mva value from the suitable sub-classifier
629 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
630 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
632
633 return mvaValue;
634}
635
636
637
638////////////////////////////////////////////////////////////////////////////////
639/// returns the mva value of the right sub-classifier
640
641const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
642{
643 if (fMethods.empty()) return MethodBase::GetRegressionValues();
644
645 UInt_t methodToUse = 0;
646 const Event* ev = GetEvent();
647
648 // determine which sub-classifier to use for this event
649 Int_t suitableCutsN = 0;
650
651 for (UInt_t i=0; i<fMethods.size(); ++i) {
652 if (PassesCut(ev, i)) {
653 ++suitableCutsN;
654 methodToUse=i;
655 }
656 }
657
658 if (suitableCutsN == 0) {
659 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
661 }
662
663 if (suitableCutsN > 1) {
664 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
666 }
667 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
668 if (!meth){
669 Log() << kFATAL << "method not found in Category Regression method" << Endl;
671 }
672 // get mva value from the suitable sub-classifier
673 return meth->GetRegressionValues(ev);
674}
675
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition: RSha256.hxx:104
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
int type
Definition: TGX11.cxx:120
char * Form(const char *fmt,...)
A specialized string object used for TTree selections.
Definition: TCut.h:25
Small helper to keep current directory context.
Definition: TDirectory.h:41
Describe directory structure in memory.
Definition: TDirectory.h:34
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:400
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual void ParseOptions()
options parser
Class that contains all the data information.
Definition: DataSetInfo.h:60
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:162
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:101
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:183
ClassInfo * AddClass(const TString &className)
const TString & GetNormalization() const
Definition: DataSetInfo.h:129
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:120
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:188
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:130
UInt_t GetNClasses() const
Definition: DataSetInfo.h:153
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:184
UInt_t GetNTargets() const
Definition: DataSetInfo.h:126
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:166
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:112
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:187
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:192
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:220
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:213
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
TString GetMethodTypeName() const
Definition: MethodBase.h:331
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:442
const char * GetName() const
Definition: MethodBase.h:333
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:330
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
void SetFile(TFile *file)
Definition: MethodBase.h:374
void ReadStateFromXML(void *parent)
friend class MethodCategory
Definition: MethodBase.h:268
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:373
DataSet * Data() const
Definition: MethodBase.h:408
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
Class for categorizing the phase space.
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void GetHelpMessage() const
Get help message text.
void Init()
initialize the method
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
void ProcessOptions()
process user options
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
void DeclareOptions()
options for this method
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
const Ranking * CreateRanking()
no ranking
virtual ~MethodCategory(void)
destructor
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
void Train(void)
train all sub-classifiers
Virtual base class for combining several TMVA method.
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
EAnalysisType
Definition: Types.h:127
@ kRegression
Definition: Types.h:129
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void * GetExternalLink() const
Definition: VariableInfo.h:83
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
const char * Data() const
Definition: TString.h:364
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:634
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
A TTree represents a columnar dataset.
Definition: TTree.h:72
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:757
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:150
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:99