Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <vector>
49
50#include "TRandom3.h"
51#include "TH1F.h"
52#include "TSpline.h"
53#include "TDirectory.h"
54#include "TTreeFormula.h"
55
57#include "TMVA/Config.h"
58#include "TMVA/DataSet.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataSetManager.h"
61#include "TMVA/IMethod.h"
62#include "TMVA/MethodBase.h"
64#include "TMVA/MsgLogger.h"
65#include "TMVA/PDF.h"
66#include "TMVA/Ranking.h"
67#include "TMVA/Timer.h"
68#include "TMVA/Tools.h"
69#include "TMVA/Types.h"
70#include "TMVA/VariableInfo.h"
72
73REGISTER_METHOD(Category)
74
76
77////////////////////////////////////////////////////////////////////////////////
78/// standard constructor
79
81 const TString& methodTitle,
82 DataSetInfo& theData,
83 const TString& theOption )
84 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
85 fCatTree(0),
86 fDataSetManager(NULL)
87{
88}
89
90////////////////////////////////////////////////////////////////////////////////
91/// constructor from weight file
92
94 const TString& theWeightFile)
95 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
96 fCatTree(0),
97 fDataSetManager(NULL)
98{
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// destructor
103
105{
106 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
107 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
108 for(;formIt!=lastF; ++formIt) delete *formIt;
109 delete fCatTree;
110}
111
112////////////////////////////////////////////////////////////////////////////////
113/// check whether method category has analysis type
114/// the method type has to be the same for all sub-methods
115
117{
118 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
119
120 // iterate over methods and check whether they have the analysis type
121 for(; itrMethod != fMethods.end(); ++itrMethod ) {
122 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
123 return kFALSE;
124 }
125 return kTRUE;
126}
127
128////////////////////////////////////////////////////////////////////////////////
129/// options for this method
130
132{
133}
134
135////////////////////////////////////////////////////////////////////////////////
136/// adds sub-classifier for a category
137
139 const TString& theVariables,
140 Types::EMVA theMethod ,
141 const TString& theTitle,
142 const TString& theOptions )
143{
144 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
145
146 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
147
148 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
149
150 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
151
152 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
153 if(method==0) return 0;
154
155 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
156 method->SetModelPersistence(fModelPersistence);
157 method->SetAnalysisType( fAnalysisType );
158 method->SetupMethod();
159 method->ParseOptions();
160 method->ProcessSetup();
161 method->SetFile(fFile);
162 method->SetSilentFile(IsSilentFile());
163
164
165 // set or create correct method base dir for added method
166 const TString dirName = TString::Format("Method_%s",method->GetMethodTypeName().Data());
167 TDirectory * dir = BaseDir()->GetDirectory(dirName);
168 if (dir != 0) method->SetMethodBaseDir( dir );
169 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName, TString::Format("Directory for all %s methods", method->GetMethodTypeName().Data())) );
170
171 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
172
173 // check-for-unused-options is performed; may be overridden by derived
174 // classes
175 method->CheckSetup();
176
177 // disable writing of XML files and standalone classes for sub methods
178 method->DisableWriting( kTRUE );
179
180 // store method, cut and variable names and create cut formula
181 fMethods.push_back(method);
182 fCategoryCuts.push_back(theCut);
183 fVars.push_back(theVariables);
184
185 DataSetInfo& primaryDSI = DataInfo();
186
187 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
188 fCategorySpecIdx.push_back(newSpectatorIndex);
189
190 primaryDSI.AddSpectator( TString::Format("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()).Data(),
191 TString::Format("%s:%s",GetName(),method->GetName()).Data(),
192 "pass", 0, 0, 'C' );
193
194 return method;
195}
196
197////////////////////////////////////////////////////////////////////////////////
198/// create a DataSetInfo object for a sub-classifier
199
201 const TString& theVariables,
202 const TString& theTitle)
203{
204 // create a new dsi with name: theTitle+"_dsi"
205 TString dsiName=theTitle+"_dsi";
206 DataSetInfo& oldDSI = DataInfo();
207 DataSetInfo* dsi = new DataSetInfo(dsiName);
208
209 // register the new dsi
210 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
211 fDataSetManager->AddDataSetInfo(*dsi);
212
213 // copy the targets and spectators from the old dsi to the new dsi
214 std::vector<VariableInfo>::iterator itrVarInfo;
215
216 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
217 dsi->AddTarget(*itrVarInfo);
218
219 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
220 dsi->AddSpectator(*itrVarInfo);
221
222 // split string that contains the variables into tiny little pieces
223 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
224
225 // prepare to create varMap
226 std::vector<UInt_t> varMap;
227 Int_t counter=0;
228
229 // add the variables that were specified in theVariables
230 std::vector<TString>::iterator itrVariables;
231 Bool_t found = kFALSE;
232
233 // iterate over all variables in 'variables' and add them
234 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
235 counter=0;
236
237 // check the variables of the old dsi for the variable that we want to add
238 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
239 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
240 // don't compare the expression, since the user might take two times the same expression, but with different labels
241 // and apply different transformations to the variables.
242 dsi->AddVariable(*itrVarInfo);
243 varMap.push_back(counter);
244 found = kTRUE;
245 }
246 counter++;
247 }
248
249 // check the spectators of the old dsi for the variable that we want to add
250 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
251 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
252 // don't compare the expression, since the user might take two times the same expression, but with different labels
253 // and apply different transformations to the variables.
254 dsi->AddVariable(*itrVarInfo);
255 varMap.push_back(counter);
256 found = kTRUE;
257 }
258 counter++;
259 }
260
261 // if the variable is neither in the variables nor in the spectators, we abort
262 if (!found) {
263 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
264 }
265 found = kFALSE;
266 }
267
268 // in the case that no variables are specified, add the default-variables from the original dsi
269 if (theVariables=="") {
270 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
271 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
272 varMap.push_back(i);
273 }
274 }
275
276 // add the variable map 'varMap' to the vector of varMaps
277 fVarMaps.push_back(varMap);
278
279 // set classes and cuts
280 UInt_t nClasses=oldDSI.GetNClasses();
281 TString className;
282
283 for (UInt_t i=0; i<nClasses; i++) {
284 className = oldDSI.GetClassInfo(i)->GetName();
285 dsi->AddClass(className);
286 dsi->SetCut(oldDSI.GetCut(i),className);
287 dsi->AddCut(theCut,className);
288 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
289 }
290
291 // set split options, root dir and normalization for the new dsi
292 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
293 dsi->SetRootDir(oldDSI.GetRootDir());
294 TString norm(oldDSI.GetNormalization().Data());
295 dsi->SetNormalization(norm);
296 // need to add split options to normalize with cut efficiency
297 TString splitOpt = dsi->GetSplitOptions();
298 splitOpt += ":ScaleWithPreselEff";
299 dsi->SetSplitOptions(splitOpt);
300
301 DataSetInfo& dsiReference= (*dsi);
302
303 return dsiReference;
304}
305
306////////////////////////////////////////////////////////////////////////////////
307/// initialize the method
308
310{
311}
312
313////////////////////////////////////////////////////////////////////////////////
314/// initialize the circular tree
315
317{
318 delete fCatTree;
319 fCatTree = nullptr;
320
321 std::vector<VariableInfo>::const_iterator viIt;
322 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
323 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
324
325 Bool_t hasAllExternalLinks = kTRUE;
326 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
327 if( viIt->GetExternalLink() == 0 ) {
328 hasAllExternalLinks = kFALSE;
329 break;
330 }
331 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
332 if( viIt->GetExternalLink() == 0 ) {
333 hasAllExternalLinks = kFALSE;
334 break;
335 }
336
337 if(!hasAllExternalLinks) return;
338
339 {
340 // Rather than having TTree::TTree add to the current directory and then remove it, let
341 // make sure to not add it in the first place.
342 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
343 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
344 TDirectory::TContext ctxt(nullptr);
345 fCatTree = new TTree(TString::Format("Circ%s",GetMethodName().Data()).Data(),"Circular Tree for categorization");
346 fCatTree->SetCircular(1);
347 }
348
349 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
350 const VariableInfo& vi = *viIt;
351 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
352 }
353 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
354 const VariableInfo& vi = *viIt;
355 if(vi.GetVarType()=='C') continue;
356 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
357 }
358
359 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
360 fCatFormulas.push_back(new TTreeFormula(TString::Format("Category_%i",cat).Data(), fCategoryCuts[cat].GetTitle(), fCatTree));
361 }
362}
363
364////////////////////////////////////////////////////////////////////////////////
365/// train all sub-classifiers
366
368{
369 // specify the minimum # of training events and set 'classification'
370 const Int_t MinNoTrainingEvents = 10;
371
372 Types::EAnalysisType analysisType = GetAnalysisType();
373
374 // start the training
375 Log() << kINFO << "Train all sub-classifiers for "
376 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
377
378 // don't do anything if no sub-classifier booked
379 if (fMethods.empty()) {
380 Log() << kINFO << "...nothing found to train" << Endl;
381 return;
382 }
383
384 std::vector<IMethod*>::iterator itrMethod;
385
386 // iterate over all booked sub-classifiers and train them
387 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
388
389 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
390 if(!mva) continue;
391 mva->SetAnalysisType( analysisType );
392 if (!mva->HasAnalysisType( analysisType,
393 mva->DataInfo().GetNClasses(),
394 mva->DataInfo().GetNTargets() ) ) {
395 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
396 if (analysisType == Types::kRegression)
397 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
398 else
399 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
400 itrMethod = fMethods.erase( itrMethod );
401 continue;
402 }
404
405 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
406 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
407 mva->TrainMethod();
408 Log() << kINFO << "Training finished" << Endl;
409
410 } else {
411
412 Log() << kWARNING << "Method " << mva->GetMethodName()
413 << " not trained (training tree has less entries ["
414 << mva->Data()->GetNTrainingEvents()
415 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
416
417 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
418 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
419
420 }
421 }
422
423 if (analysisType != Types::kRegression) {
424
425 // variable ranking
426 Log() << kINFO << "Begin ranking of input variables..." << Endl;
427 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
428 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
429 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
430 const Ranking* ranking = (*itrMethod)->CreateRanking();
431 if (ranking != 0)
432 ranking->Print();
433 else
434 Log() << kINFO << "No variable ranking supplied by classifier: "
435 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
436 }
437 }
438 }
439}
440
441////////////////////////////////////////////////////////////////////////////////
442/// create XML description of Category classifier
443
445{
446 void* wght = gTools().AddChild(parent, "Weights");
447 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
448 void* submethod(0);
449
450 // iterate over methods and write them to XML file
451 for (UInt_t i=0; i<fMethods.size(); i++) {
452 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
453 submethod = gTools().AddChild(wght, "SubMethod");
454 gTools().AddAttr(submethod, "Index", i);
455 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
456 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
457 gTools().AddAttr(submethod, "Variables", fVars[i]);
458 method->WriteStateToXML( submethod );
459 }
460}
461
462////////////////////////////////////////////////////////////////////////////////
463/// read weights of sub-classifiers of MethodCategory from xml weight file
464
466{
467 UInt_t nSubMethods;
468 TString fullMethodName;
469 TString methodType;
470 TString methodTitle;
471 TString theCutString;
472 TString theVariables;
473 Int_t titleLength;
474 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
475 void* subMethodNode = gTools().GetChild(wghtnode);
476
477 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
478
479 // recreate all sub-methods from weight file
480 for (UInt_t i=0; i<nSubMethods; i++) {
481 gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
482 gTools().ReadAttr( subMethodNode, "Cut", theCutString );
483 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
484
485 // determine sub-method type
486 methodType = fullMethodName(0,fullMethodName.Index("::"));
487 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
488
489 // determine sub-method title
490 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
491 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
492
493 // reconstruct dsi for sub-method
494 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
495
496 // recreate sub-method from weights and add to fMethods
497 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
498 dsi, "none" ) );
499 if(method==0)
500 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
501
502 method->SetupMethod();
503 method->ReadStateFromXML(subMethodNode);
504
505 fMethods.push_back(method);
506 fCategoryCuts.push_back(TCut(theCutString));
507 fVars.push_back(theVariables);
508
509 DataSetInfo& primaryDSI = DataInfo();
510
511 UInt_t spectatorIdx = 10000;
512 UInt_t counter=0;
513
514 // find the spectator index
515 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
516 std::vector<VariableInfo>::iterator itrVarInfo;
517 TString specName= TString::Format("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
518
519 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
520 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
521 spectatorIdx=counter;
522 fCategorySpecIdx.push_back(spectatorIdx);
523 break;
524 }
525 }
526
527 subMethodNode = gTools().GetNextChild(subMethodNode);
528 }
529
530 InitCircularTree(DataInfo());
531
532}
533
534////////////////////////////////////////////////////////////////////////////////
535/// process user options
536
538{
539}
540
541////////////////////////////////////////////////////////////////////////////////
542/// Get help message text
543///
544/// typical length of text line:
545/// "|--------------------------------------------------------------|"
546
548{
549 Log() << Endl;
550 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
551 Log() << Endl;
552 Log() << "This method allows to define different categories of events. The" <<Endl;
553 Log() << "categories are defined via cuts on the variables. For each" << Endl;
554 Log() << "category, a different classifier and set of variables can be" <<Endl;
555 Log() << "specified. The categories which are defined for this method must" << Endl;
556 Log() << "be disjoint." << Endl;
557}
558
559////////////////////////////////////////////////////////////////////////////////
560/// no ranking
561
563{
564 return 0;
565}
566
567////////////////////////////////////////////////////////////////////////////////
568
570{
571 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
572 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
573 // one will be evaluated.. (the formulae return 'true' or 'false'
574 if (fCatTree) {
575 if (methodIdx>=fCatFormulas.size()) {
576 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
577 << fCatFormulas.size() << Endl;
578 }
579 TTreeFormula* f = fCatFormulas[methodIdx];
580 return f->EvalInstance(0) > 0.5;
581 }
582 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
583 else {
584
585 // checks whether an event lies within a cut
586 if (methodIdx>=fCategorySpecIdx.size()) {
587 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
588 << fCategorySpecIdx.size() << Endl;
589 }
590 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
591 Float_t specVal = ev->GetSpectator(spectatorIdx);
592 Bool_t pass = (specVal>0.5);
593 return pass;
594 }
595}
596
597////////////////////////////////////////////////////////////////////////////////
598/// returns the mva value of the right sub-classifier
599
601{
602 if (fMethods.empty()) return 0;
603
604 UInt_t methodToUse = 0;
605 const Event* ev = GetEvent();
606
607 // determine which sub-classifier to use for this event
608 Int_t suitableCutsN = 0;
609
610 for (UInt_t i=0; i<fMethods.size(); ++i) {
611 if (PassesCut(ev, i)) {
612 ++suitableCutsN;
613 methodToUse=i;
614 }
615 }
616
617 if (suitableCutsN == 0) {
618 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
619 return 0;
620 }
621
622 if (suitableCutsN > 1) {
623 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
624 return 0;
625 }
626
627 // get mva value from the suitable sub-classifier
628 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
629 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
631
632 Log() << kDEBUG << "Event is for method " << methodToUse << " spectator is " << ev->GetSpectator(0) << " "
633 << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value " << mvaValue
634 << " type " << Data()->GetCurrentType() << Endl;
635
636 return mvaValue;
637}
638
639///////////////////////////////////////////////////////////////
640/// returns the mva values of the right sub-classifier
641///
642std::vector<Double_t>
644{
645
646 std::vector<Double_t> result;
647
648 Info("GetMVaValues", "Evaluate MethodCategory for %d events type %d on the dataset %s", int(lastEvt - firstEvt),
649 (int)Data()->GetCurrentType(), DataInfo().GetName());
650
651 if (fMethods.empty())
652 return result;
653
654 auto data = Data();
655
656 // it is faster to evaluate all categories
657 std::vector<std::vector<Double_t>> mvaValues(fMethods.size());
658 for (UInt_t i = 0; i < fMethods.size(); ++i) {
659 // need to set variable map
660 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev) {
661 data->SetCurrentEvent(iev);
662 const Event *ev = GetEvent(data->GetEvent());
663 ev->SetVariableArrangement(&fVarMaps[i]);
664 }
665 // need to set correct data in the different method
666 mvaValues[i] = dynamic_cast<MethodBase *>(fMethods[i])->GetDataMvaValues(data,firstEvt, lastEvt, logProgress);
667 }
668
669 // now loop on all events
670 result.resize(lastEvt - firstEvt);
671
672 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev)
673 {
674 data->SetCurrentEvent(iev);
675 UInt_t methodToUse = 0;
676 const Event *ev = GetEvent(data->GetEvent());
677
678 // determine which sub-classifier to use for this event
679 Int_t suitableCutsN = 0;
680
681 for (UInt_t i = 0; i < fMethods.size(); ++i) {
682 if (PassesCut(ev, i)) {
683 ++suitableCutsN;
684 methodToUse = i;
685 }
686 }
687
688 if (suitableCutsN == 0) {
689 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
690 result[iev] = 0;
691 }
692
693 if (suitableCutsN > 1) {
694 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
695 return result;
696 }
697
698
699 result[iev - firstEvt] = mvaValues[methodToUse][iev - firstEvt];
700
701 // reset variable map which was set it before
702 ev->SetVariableArrangement(nullptr);
703 }
704 return result;
705}
706
707////////////////////////////////////////////////////////////////////////////////
708/// returns the mva values of the multi-class right sub-classifier
709///
710const std::vector<Float_t> &TMVA::MethodCategory::GetMulticlassValues()
711{
712 if (fMethods.empty())
714
715 UInt_t methodToUse = 0;
716 const Event *ev = GetEvent();
717
718 // determine which sub-classifier to use for this event
719 Int_t suitableCutsN = 0;
720
721 for (UInt_t i = 0; i < fMethods.size(); ++i) {
722 if (PassesCut(ev, i)) {
723 ++suitableCutsN;
724 methodToUse = i;
725 }
726 }
727
728 if (suitableCutsN == 0) {
729 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
731 }
732
733 if (suitableCutsN > 1) {
734 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
736 }
737 MethodBase *meth = dynamic_cast<MethodBase *>(fMethods[methodToUse]);
738 if (!meth) {
739 Log() << kFATAL << "method not found in Category Regression method" << Endl;
741 }
742 // get mva value from the suitable sub-classifier
743 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
744 auto &result = meth->GetMulticlassValues();
745 ev->SetVariableArrangement(nullptr);
746 return result;
747}
748
749////////////////////////////////////////////////////////////////////////////////
750/// returns the mva value of the right sub-classifier
751
752const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
753{
754 if (fMethods.empty()) return MethodBase::GetRegressionValues();
755
756 UInt_t methodToUse = 0;
757 const Event* ev = GetEvent();
758
759 // determine which sub-classifier to use for this event
760 Int_t suitableCutsN = 0;
761
762 for (UInt_t i=0; i<fMethods.size(); ++i) {
763 if (PassesCut(ev, i)) {
764 ++suitableCutsN;
765 methodToUse=i;
766 }
767 }
768
769 if (suitableCutsN == 0) {
770 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
772 }
773
774 if (suitableCutsN > 1) {
775 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
777 }
778 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
779 if (!meth){
780 Log() << kFATAL << "method not found in Category Regression method" << Endl;
782 }
783 // get mva value from the suitable sub-classifier
784 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
785 auto & result = meth->GetRegressionValues(ev);
786 return result;
787}
#define MinNoTrainingEvents
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition RSha256.hxx:104
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
long long Long64_t
Definition RtypesCore.h:80
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:377
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:218
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
A specialized string object used for TTree selections.
Definition TCut.h:25
TDirectory::TContext keeps track and restore the current directory.
Definition TDirectory.h:89
Describe directory structure in memory.
Definition TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual void ParseOptions()
options parser
Class that contains all the data information.
Definition DataSetInfo.h:62
const TString GetWeightExpression(Int_t i) const
std::vector< VariableInfo > & GetVariableInfos()
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=nullptr)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
void SetSplitOptions(const TString &so)
ClassInfo * AddClass(const TString &className)
const TString & GetNormalization() const
std::vector< VariableInfo > & GetSpectatorInfos()
TDirectory * GetRootDir() const
void SetNormalization(const TString &norm)
UInt_t GetNClasses() const
const TString & GetSplitOptions() const
UInt_t GetNTargets() const
ClassInfo * GetClassInfo(Int_t clNum) const
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=nullptr)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
const TCut & GetCut(Int_t i) const
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
std::vector< VariableInfo > & GetTargetInfos()
void SetRootDir(TDirectory *d)
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=nullptr)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition Event.cxx:191
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition Event.cxx:261
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition MethodBase.h:221
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition MethodBase.h:214
void SetSilentFile(Bool_t status)
Definition MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
TString GetMethodTypeName() const
Definition MethodBase.h:332
void DisableWriting(Bool_t setter)
Definition MethodBase.h:442
const char * GetName() const
Definition MethodBase.h:334
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition MethodBase.h:227
void SetupMethod()
setup of methods
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:436
const TString & GetMethodName() const
Definition MethodBase.h:331
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
void SetFile(TFile *file)
Definition MethodBase.h:375
void ReadStateFromXML(void *parent)
friend class MethodCategory
Definition MethodBase.h:269
void SetMethodBaseDir(TDirectory *methodDir)
Definition MethodBase.h:374
DataSet * Data() const
Definition MethodBase.h:409
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:382
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for categorizing the phase space.
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void GetHelpMessage() const
Get help message text.
void Init()
initialize the method
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
void ProcessOptions()
process user options
virtual const std::vector< Float_t > & GetMulticlassValues()
returns the mva values of the multi-class right sub-classifier
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
returns the mva value of the right sub-classifier
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
void DeclareOptions()
options for this method
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
const Ranking * CreateRanking()
no ranking
virtual ~MethodCategory(void)
destructor
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
returns the mva values of the right sub-classifier
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
void Train(void)
train all sub-classifiers
Virtual base class for combining several TMVA method.
Ranking for variables in method (implementation)
Definition Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition Ranking.cxx:111
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1199
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:70
@ kRegression
Definition Types.h:128
Class for type info of MVA input variable.
const TString & GetExpression() const
char GetVarType() const
void * GetExternalLink() const
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
const char * GetTitle() const override
Returns title of object.
Definition TNamed.h:48
Basic string class.
Definition TString.h:139
Ssiz_t Length() const
Definition TString.h:421
const char * Data() const
Definition TString.h:380
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition TString.cxx:924
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2356
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:636
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition TString.h:651
Used to pass a selection expression to the Tree drawing routine.
T EvalInstance(Int_t i=0, const char *stringStack[]=nullptr)
Evaluate this treeformula.
A TTree represents a columnar dataset.
Definition TTree.h:79
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148