Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (see tmva/doc/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <vector>
49
50#include "TRandom3.h"
51#include "TH1F.h"
52#include "TSpline.h"
53#include "TDirectory.h"
54#include "TTreeFormula.h"
55
57#include "TMVA/Config.h"
58#include "TMVA/DataSet.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataSetManager.h"
61#include "TMVA/IMethod.h"
62#include "TMVA/MethodBase.h"
64#include "TMVA/MsgLogger.h"
65#include "TMVA/PDF.h"
66#include "TMVA/Ranking.h"
67#include "TMVA/Timer.h"
68#include "TMVA/Tools.h"
69#include "TMVA/Types.h"
70#include "TMVA/VariableInfo.h"
72
74
75
76////////////////////////////////////////////////////////////////////////////////
77/// standard constructor
78
80 const TString& methodTitle,
83 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
84 fCatTree(0),
85 fDataSetManager(NULL)
86{
87}
88
89////////////////////////////////////////////////////////////////////////////////
90/// constructor from weight file
91
95 fCatTree(0),
96 fDataSetManager(NULL)
97{
98}
99
100////////////////////////////////////////////////////////////////////////////////
101/// destructor
102
104{
105 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
106 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
107 for(;formIt!=lastF; ++formIt) delete *formIt;
108 delete fCatTree;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// check whether method category has analysis type
113/// the method type has to be the same for all sub-methods
114
116{
117 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
118
119 // iterate over methods and check whether they have the analysis type
120 for(; itrMethod != fMethods.end(); ++itrMethod ) {
121 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
122 return kFALSE;
123 }
124 return kTRUE;
125}
126
127////////////////////////////////////////////////////////////////////////////////
128/// options for this method
129
133
134////////////////////////////////////////////////////////////////////////////////
135/// adds sub-classifier for a category
136
138 const TString& theVariables,
140 const TString& theTitle,
141 const TString& theOptions )
142{
143 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
144
145 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
146
147 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
148
150
151 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
152 if(method==0) return 0;
153
154 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
155 method->SetModelPersistence(fModelPersistence);
156 method->SetAnalysisType( fAnalysisType );
157 method->SetupMethod();
158 method->ParseOptions();
159 method->ProcessSetup();
160 method->SetFile(fFile);
161 method->SetSilentFile(IsSilentFile());
162
163
164 // set or create correct method base dir for added method
165 const TString dirName = TString::Format("Method_%s",method->GetMethodTypeName().Data());
166 TDirectory * dir = BaseDir()->GetDirectory(dirName);
167 if (dir != 0) method->SetMethodBaseDir( dir );
168 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName, TString::Format("Directory for all %s methods", method->GetMethodTypeName().Data())) );
169
170 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
171
172 // check-for-unused-options is performed; may be overridden by derived
173 // classes
174 method->CheckSetup();
175
176 // disable writing of XML files and standalone classes for sub methods
177 method->DisableWriting( kTRUE );
178
179 // store method, cut and variable names and create cut formula
180 fMethods.push_back(method);
181 fCategoryCuts.push_back(theCut);
182 fVars.push_back(theVariables);
183
184 DataSetInfo& primaryDSI = DataInfo();
185
186 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
187 fCategorySpecIdx.push_back(newSpectatorIndex);
188
189 primaryDSI.AddSpectator( TString::Format("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()).Data(),
190 TString::Format("%s:%s",GetName(),method->GetName()).Data(),
191 "pass", 0, 0, 'C' );
192
193 return method;
194}
195
196////////////////////////////////////////////////////////////////////////////////
197/// create a DataSetInfo object for a sub-classifier
198
200 const TString& theVariables,
201 const TString& theTitle)
202{
203 // create a new dsi with name: theTitle+"_dsi"
204 TString dsiName=theTitle+"_dsi";
205 DataSetInfo& oldDSI = DataInfo();
207
208 // register the new dsi
209 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
210 fDataSetManager->AddDataSetInfo(*dsi);
211
212 // copy the targets and spectators from the old dsi to the new dsi
213 std::vector<VariableInfo>::iterator itrVarInfo;
214
215 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
216 dsi->AddTarget(*itrVarInfo);
217
218 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
219 dsi->AddSpectator(*itrVarInfo);
220
221 // split string that contains the variables into tiny little pieces
222 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
223
224 // prepare to create varMap
225 std::vector<UInt_t> varMap;
226 Int_t counter=0;
227
228 // add the variables that were specified in theVariables
229 std::vector<TString>::iterator itrVariables;
230 Bool_t found = kFALSE;
231
232 // iterate over all variables in 'variables' and add them
233 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
234 counter=0;
235
236 // check the variables of the old dsi for the variable that we want to add
237 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
238 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
239 // don't compare the expression, since the user might take two times the same expression, but with different labels
240 // and apply different transformations to the variables.
241 dsi->AddVariable(*itrVarInfo);
242 varMap.push_back(counter);
243 found = kTRUE;
244 }
245 counter++;
246 }
247
248 // check the spectators of the old dsi for the variable that we want to add
249 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
250 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
251 // don't compare the expression, since the user might take two times the same expression, but with different labels
252 // and apply different transformations to the variables.
253 dsi->AddVariable(*itrVarInfo);
254 varMap.push_back(counter);
255 found = kTRUE;
256 }
257 counter++;
258 }
259
260 // if the variable is neither in the variables nor in the spectators, we abort
261 if (!found) {
262 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
263 }
264 found = kFALSE;
265 }
266
267 // in the case that no variables are specified, add the default-variables from the original dsi
268 if (theVariables=="") {
269 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
270 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
271 varMap.push_back(i);
272 }
273 }
274
275 // add the variable map 'varMap' to the vector of varMaps
276 fVarMaps.push_back(varMap);
277
278 // set classes and cuts
279 UInt_t nClasses=oldDSI.GetNClasses();
280 TString className;
281
282 for (UInt_t i=0; i<nClasses; i++) {
283 className = oldDSI.GetClassInfo(i)->GetName();
284 dsi->AddClass(className);
285 dsi->SetCut(oldDSI.GetCut(i),className);
286 dsi->AddCut(theCut,className);
287 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
288 }
289
290 // set split options, root dir and normalization for the new dsi
291 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
292 dsi->SetRootDir(oldDSI.GetRootDir());
293 TString norm(oldDSI.GetNormalization().Data());
294 dsi->SetNormalization(norm);
295 // need to add split options to normalize with cut efficiency
296 TString splitOpt = dsi->GetSplitOptions();
297 splitOpt += ":ScaleWithPreselEff";
298 dsi->SetSplitOptions(splitOpt);
299
300 DataSetInfo& dsiReference= (*dsi);
301
302 return dsiReference;
303}
304
305////////////////////////////////////////////////////////////////////////////////
306/// initialize the method
307
311
312////////////////////////////////////////////////////////////////////////////////
313/// initialize the circular tree
314
316{
317 delete fCatTree;
318 fCatTree = nullptr;
319
320 std::vector<VariableInfo>::const_iterator viIt;
321 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
322 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
323
325 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
326 if( viIt->GetExternalLink() == 0 ) {
328 break;
329 }
330 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
331 if( viIt->GetExternalLink() == 0 ) {
333 break;
334 }
335
336 if(!hasAllExternalLinks) return;
337
338 {
339 // Rather than having TTree::TTree add to the current directory and then remove it, let
340 // make sure to not add it in the first place.
341 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
342 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
343 TDirectory::TContext ctxt(nullptr);
344 fCatTree = new TTree(TString::Format("Circ%s",GetMethodName().Data()).Data(),"Circular Tree for categorization");
345 fCatTree->SetCircular(1);
346 }
347
348 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
349 const VariableInfo& vi = *viIt;
350 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
351 }
352 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
353 const VariableInfo& vi = *viIt;
354 if(vi.GetVarType()=='C') continue;
355 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
356 }
357
358 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
359 fCatFormulas.push_back(new TTreeFormula(TString::Format("Category_%i",cat).Data(), fCategoryCuts[cat].GetTitle(), fCatTree));
360 }
361}
362
363////////////////////////////////////////////////////////////////////////////////
364/// train all sub-classifiers
365
367{
368 // specify the minimum # of training events and set 'classification'
369 const Int_t MinNoTrainingEvents = 10;
370
371 Types::EAnalysisType analysisType = GetAnalysisType();
372
373 // start the training
374 Log() << kINFO << "Train all sub-classifiers for "
375 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
376
377 // don't do anything if no sub-classifier booked
378 if (fMethods.empty()) {
379 Log() << kINFO << "...nothing found to train" << Endl;
380 return;
381 }
382
383 std::vector<IMethod*>::iterator itrMethod;
384
385 // iterate over all booked sub-classifiers and train them
386 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
387
388 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
389 if(!mva) continue;
390 mva->SetAnalysisType( analysisType );
391 if (!mva->HasAnalysisType( analysisType,
392 mva->DataInfo().GetNClasses(),
393 mva->DataInfo().GetNTargets() ) ) {
394 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
395 if (analysisType == Types::kRegression)
396 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
397 else
398 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
399 itrMethod = fMethods.erase( itrMethod );
400 continue;
401 }
402 if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
403
404 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
405 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
406 mva->TrainMethod();
407 Log() << kINFO << "Training finished" << Endl;
408
409 } else {
410
411 Log() << kWARNING << "Method " << mva->GetMethodName()
412 << " not trained (training tree has less entries ["
413 << mva->Data()->GetNTrainingEvents()
414 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
415
416 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
417 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
418
419 }
420 }
421
422 if (analysisType != Types::kRegression) {
423
424 // variable ranking
425 Log() << kINFO << "Begin ranking of input variables..." << Endl;
426 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
427 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
428 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
429 const Ranking* ranking = (*itrMethod)->CreateRanking();
430 if (ranking != 0)
431 ranking->Print();
432 else
433 Log() << kINFO << "No variable ranking supplied by classifier: "
434 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
435 }
436 }
437 }
438}
439
440////////////////////////////////////////////////////////////////////////////////
441/// create XML description of Category classifier
442
444{
445 void* wght = gTools().AddChild(parent, "Weights");
446 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
447 void* submethod(0);
448
449 // iterate over methods and write them to XML file
450 for (UInt_t i=0; i<fMethods.size(); i++) {
451 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
452 submethod = gTools().AddChild(wght, "SubMethod");
453 gTools().AddAttr(submethod, "Index", i);
454 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
455 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
456 gTools().AddAttr(submethod, "Variables", fVars[i]);
457 method->WriteStateToXML( submethod );
458 }
459}
460
461////////////////////////////////////////////////////////////////////////////////
462/// read weights of sub-classifiers of MethodCategory from xml weight file
463
465{
469 TString methodTitle;
473 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
475
476 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
477
478 // recreate all sub-methods from weight file
479 for (UInt_t i=0; i<nSubMethods; i++) {
482 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
483
484 // determine sub-method type
486 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
487
488 // determine sub-method title
489 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
490 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
491
492 // reconstruct dsi for sub-method
493 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
494
495 // recreate sub-method from weights and add to fMethods
496 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
497 dsi, "none" ) );
498 if(method==0)
499 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
500
501 method->SetupMethod();
502 method->ReadStateFromXML(subMethodNode);
503
504 fMethods.push_back(method);
505 fCategoryCuts.push_back(TCut(theCutString));
506 fVars.push_back(theVariables);
507
508 DataSetInfo& primaryDSI = DataInfo();
509
510 UInt_t spectatorIdx = 10000;
511 UInt_t counter=0;
512
513 // find the spectator index
514 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
515 std::vector<VariableInfo>::iterator itrVarInfo;
516 TString specName= TString::Format("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
517
518 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
519 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
520 spectatorIdx=counter;
521 fCategorySpecIdx.push_back(spectatorIdx);
522 break;
523 }
524 }
525
527 }
528
529 InitCircularTree(DataInfo());
530
531}
532
533////////////////////////////////////////////////////////////////////////////////
534/// process user options
535
539
540////////////////////////////////////////////////////////////////////////////////
541/// Get help message text
542///
543/// typical length of text line:
544/// "|--------------------------------------------------------------|"
545
547{
548 Log() << Endl;
549 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
550 Log() << Endl;
551 Log() << "This method allows to define different categories of events. The" <<Endl;
552 Log() << "categories are defined via cuts on the variables. For each" << Endl;
553 Log() << "category, a different classifier and set of variables can be" <<Endl;
554 Log() << "specified. The categories which are defined for this method must" << Endl;
555 Log() << "be disjoint." << Endl;
556}
557
558////////////////////////////////////////////////////////////////////////////////
559/// no ranking
560
562{
563 return 0;
564}
565
566////////////////////////////////////////////////////////////////////////////////
567
569{
570 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
571 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
572 // one will be evaluated.. (the formulae return 'true' or 'false'
573 if (fCatTree) {
574 if (methodIdx>=fCatFormulas.size()) {
575 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
576 << fCatFormulas.size() << Endl;
577 }
578 TTreeFormula* f = fCatFormulas[methodIdx];
579 return f->EvalInstance(0) > 0.5;
580 }
581 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
582 else {
583
584 // checks whether an event lies within a cut
585 if (methodIdx>=fCategorySpecIdx.size()) {
586 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
587 << fCategorySpecIdx.size() << Endl;
588 }
589 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
590 Float_t specVal = ev->GetSpectator(spectatorIdx);
591 Bool_t pass = (specVal>0.5);
592 return pass;
593 }
594}
595
596////////////////////////////////////////////////////////////////////////////////
597/// returns the mva value of the right sub-classifier
598
600{
601 if (fMethods.empty()) return 0;
602
604 const Event* ev = GetEvent();
605
606 // determine which sub-classifier to use for this event
608
609 for (UInt_t i=0; i<fMethods.size(); ++i) {
610 if (PassesCut(ev, i)) {
612 methodToUse=i;
613 }
614 }
615
616 if (suitableCutsN == 0) {
617 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
618 return 0;
619 }
620
621 if (suitableCutsN > 1) {
622 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
623 return 0;
624 }
625
626 // get mva value from the suitable sub-classifier
627 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
628 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
629 ev->SetVariableArrangement(0);
630
631 Log() << kDEBUG << "Event is for method " << methodToUse << " spectator is " << ev->GetSpectator(0) << " "
632 << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value " << mvaValue
633 << " type " << Data()->GetCurrentType() << Endl;
634
635 return mvaValue;
636}
637
638///////////////////////////////////////////////////////////////
639/// returns the mva values of the right sub-classifier
640///
641std::vector<Double_t>
643{
644
645 std::vector<Double_t> result;
646
647 Info("GetMVaValues", "Evaluate MethodCategory for %d events type %d on the dataset %s", int(lastEvt - firstEvt),
648 (int)Data()->GetCurrentType(), DataInfo().GetName());
649
650 if (fMethods.empty())
651 return result;
652
653 auto data = Data();
654
655 // it is faster to evaluate all categories
656 std::vector<std::vector<Double_t>> mvaValues(fMethods.size());
657 for (UInt_t i = 0; i < fMethods.size(); ++i) {
658 // need to set variable map
659 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev) {
660 data->SetCurrentEvent(iev);
661 const Event *ev = GetEvent(data->GetEvent());
662 ev->SetVariableArrangement(&fVarMaps[i]);
663 }
664 // need to set correct data in the different method
665 mvaValues[i] = dynamic_cast<MethodBase *>(fMethods[i])->GetDataMvaValues(data,firstEvt, lastEvt, logProgress);
666 }
667
668 // now loop on all events
669 result.resize(lastEvt - firstEvt);
670
671 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev)
672 {
673 data->SetCurrentEvent(iev);
675 const Event *ev = GetEvent(data->GetEvent());
676
677 // determine which sub-classifier to use for this event
679
680 for (UInt_t i = 0; i < fMethods.size(); ++i) {
681 if (PassesCut(ev, i)) {
683 methodToUse = i;
684 }
685 }
686
687 if (suitableCutsN == 0) {
688 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
689 result[iev] = 0;
690 }
691
692 if (suitableCutsN > 1) {
693 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
694 return result;
695 }
696
697
698 result[iev - firstEvt] = mvaValues[methodToUse][iev - firstEvt];
699
700 // reset variable map which was set it before
701 ev->SetVariableArrangement(nullptr);
702 }
703 return result;
704}
705
706////////////////////////////////////////////////////////////////////////////////
707/// returns the mva values of the multi-class right sub-classifier
708///
709const std::vector<Float_t> &TMVA::MethodCategory::GetMulticlassValues()
710{
711 if (fMethods.empty())
713
715 const Event *ev = GetEvent();
716
717 // determine which sub-classifier to use for this event
719
720 for (UInt_t i = 0; i < fMethods.size(); ++i) {
721 if (PassesCut(ev, i)) {
723 methodToUse = i;
724 }
725 }
726
727 if (suitableCutsN == 0) {
728 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
730 }
731
732 if (suitableCutsN > 1) {
733 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
735 }
736 MethodBase *meth = dynamic_cast<MethodBase *>(fMethods[methodToUse]);
737 if (!meth) {
738 Log() << kFATAL << "method not found in Category Regression method" << Endl;
740 }
741 // get mva value from the suitable sub-classifier
742 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
743 auto &result = meth->GetMulticlassValues();
744 ev->SetVariableArrangement(nullptr);
745 return result;
746}
747
748////////////////////////////////////////////////////////////////////////////////
749/// returns the mva value of the right sub-classifier
750
751const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
752{
753 if (fMethods.empty()) return MethodBase::GetRegressionValues();
754
756 const Event* ev = GetEvent();
757
758 // determine which sub-classifier to use for this event
760
761 for (UInt_t i=0; i<fMethods.size(); ++i) {
762 if (PassesCut(ev, i)) {
764 methodToUse=i;
765 }
766 }
767
768 if (suitableCutsN == 0) {
769 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
771 }
772
773 if (suitableCutsN > 1) {
774 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
776 }
777 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
778 if (!meth){
779 Log() << kFATAL << "method not found in Category Regression method" << Endl;
781 }
782 // get mva value from the suitable sub-classifier
783 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
784 auto & result = meth->GetRegressionValues(ev);
785 return result;
786}
#define MinNoTrainingEvents
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition RSha256.hxx:104
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
float Float_t
Float 4 bytes (float)
Definition RtypesCore.h:71
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:241
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
A specialized string object used for TTree selections.
Definition TCut.h:25
TDirectory::TContext keeps track and restore the current directory.
Definition TDirectory.h:89
Describe directory structure in memory.
Definition TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Class that contains all the data information.
Definition DataSetInfo.h:62
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition MethodBase.h:221
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition MethodBase.h:227
const TString & GetMethodName() const
Definition MethodBase.h:331
friend class MethodCategory
Definition MethodBase.h:269
Class for categorizing the phase space.
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void DeclareOptions() override
options for this method
void ReadWeightsFromXML(void *wghtnode) override
read weights of sub-classifiers of MethodCategory from xml weight file
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
void Train(void) override
train all sub-classifiers
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
check whether method category has analysis type the method type has to be the same for all sub-method...
const Ranking * CreateRanking() override
no ranking
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
const std::vector< Float_t > & GetRegressionValues() override
returns the mva value of the right sub-classifier
void Init() override
initialize the method
virtual ~MethodCategory(void)
destructor
void AddWeightsXMLTo(void *parent) const override
create XML description of Category classifier
void ProcessOptions() override
process user options
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
void GetHelpMessage() const override
Get help message text.
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
returns the mva value of the right sub-classifier
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false) override
returns the mva values of the right sub-classifier
const std::vector< Float_t > & GetMulticlassValues() override
returns the mva values of the multi-class right sub-classifier
Virtual base class for combining several TMVA method.
Ranking for variables in method (implementation)
Definition Ranking.h:48
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1199
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:70
@ kRegression
Definition Types.h:128
Class for type info of MVA input variable.
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
Used to pass a selection expression to the Tree drawing routine.
T EvalInstance(Int_t i=0, const char *stringStack[]=nullptr)
Evaluate this treeformula.
A TTree represents a columnar dataset.
Definition TTree.h:89
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148