Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodFDA.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodFDA *
8 * *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
17 * Maciej Kruk <mkruk@cern.ch> - IFJ PAN & AGH, Poland *
18 * *
19 * Copyright (c) 2005-2006: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (see tmva/doc/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::MethodFDA
29\ingroup TMVA
30
31Function discriminant analysis (FDA).
32
33This simple classifier
34fits any user-defined TFormula (via option configuration string) to
35the training data by requiring a formula response of 1 (0) to signal
36(background) events. The parameter fitting is done via the abstract
37class FitterBase, featuring Monte Carlo sampling, Genetic
38Algorithm, Simulated Annealing, MINUIT and combinations of these.
39
40Can compute regression value for one dimensional output
41*/
42
43#include "TMVA/MethodFDA.h"
44
46#include "TMVA/Config.h"
47#include "TMVA/Configurable.h"
48#include "TMVA/DataSetInfo.h"
49#include "TMVA/FitterBase.h"
50#include "TMVA/GeneticFitter.h"
51#include "TMVA/Interval.h"
52#include "TMVA/IFitterTarget.h"
53#include "TMVA/IMethod.h"
54#include "TMVA/MCFitter.h"
55#include "TMVA/MethodBase.h"
56#include "TMVA/MinuitFitter.h"
57#include "TMVA/MsgLogger.h"
58#include "TMVA/Timer.h"
59#include "TMVA/Tools.h"
61#include "TMVA/Types.h"
63
64#include "TList.h"
65#include "TFormula.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TRandom3.h"
69#include "TMath.h"
70
71#include <iostream>
72#include <algorithm>
73#include <iterator>
74#include <stdexcept>
75#include <sstream>
76
77using std::stringstream;
78
80
81
82////////////////////////////////////////////////////////////////////////////////
83/// standard constructor
84
86 const TString& methodTitle,
89 : MethodBase( jobName, Types::kFDA, methodTitle, theData, theOption),
91 fFormula ( 0 ),
92 fNPars ( 0 ),
93 fFitter ( 0 ),
94 fConvergerFitter( 0 ),
95 fSumOfWeightsSig( 0 ),
96 fSumOfWeightsBkg( 0 ),
97 fSumOfWeights ( 0 ),
98 fOutputDimensions( 0 )
99{
100}
101
102////////////////////////////////////////////////////////////////////////////////
103/// constructor from weight file
104
106 const TString& theWeightFile)
108 IFitterTarget (),
109 fFormula ( 0 ),
110 fNPars ( 0 ),
111 fFitter ( 0 ),
112 fConvergerFitter( 0 ),
113 fSumOfWeightsSig( 0 ),
114 fSumOfWeightsBkg( 0 ),
115 fSumOfWeights ( 0 ),
116 fOutputDimensions( 0 )
117{
118}
119
120////////////////////////////////////////////////////////////////////////////////
121/// default initialisation
122
124{
125 fNPars = 0;
126
127 fBestPars.clear();
128
129 fSumOfWeights = 0;
130 fSumOfWeightsSig = 0;
131 fSumOfWeightsBkg = 0;
132
133 fFormulaStringP = "";
134 fParRangeStringP = "";
135 fFormulaStringT = "";
136 fParRangeStringT = "";
137
138 fFitMethod = "";
139 fConverger = "";
140
141 if( DoMulticlass() )
142 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal = new std::vector<Float_t>();
143
144}
145
146////////////////////////////////////////////////////////////////////////////////
147/// define the options (their key words) that can be set in the option string
148///
149/// format of function string:
150///
151/// "x0*(0)+((1)/x1)**(2)..."
152///
153/// where "[i]" are the parameters, and "xi" the input variables
154///
155/// format of parameter string:
156///
157/// "(-1.2,3.4);(-2.3,4.55);..."
158///
159/// where the numbers in "(a,b)" correspond to the a=min, b=max parameter ranges;
160/// each parameter defined in the function string must have a corresponding range
161
163{
164 DeclareOptionRef( fFormulaStringP = "(0)", "Formula", "The discrimination formula" );
165 DeclareOptionRef( fParRangeStringP = "()", "ParRanges", "Parameter ranges" );
166
167 // fitter
168 DeclareOptionRef( fFitMethod = "MINUIT", "FitMethod", "Optimisation Method");
169 AddPreDefVal(TString("MC"));
170 AddPreDefVal(TString("GA"));
171 AddPreDefVal(TString("SA"));
172 AddPreDefVal(TString("MINUIT"));
173
174 DeclareOptionRef( fConverger = "None", "Converger", "FitMethod uses Converger to improve result");
175 AddPreDefVal(TString("None"));
176 AddPreDefVal(TString("MINUIT"));
177}
178
179////////////////////////////////////////////////////////////////////////////////
180/// translate formula string into TFormula, and parameter string into par ranges
181
183{
184 // process transient strings
185 fFormulaStringT = fFormulaStringP;
186
187 // interpret formula string
188
189 // replace the parameters "(i)" by the TFormula style "[i]"
190 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
191 fFormulaStringT.ReplaceAll( TString::Format("(%i)",ipar), TString::Format("[%i]",ipar) );
192 }
193
194 // sanity check, there should be no "(i)", with 'i' a number anymore
195 for (Int_t ipar=fNPars; ipar<1000; ipar++) {
196 if (fFormulaStringT.Contains( TString::Format("(%i)",ipar) ))
197 Log() << kFATAL
198 << "<CreateFormula> Formula contains expression: \"" << TString::Format("(%i)",ipar) << "\", "
199 << "which cannot be attributed to a parameter; "
200 << "it may be that the number of variable ranges given via \"ParRanges\" "
201 << "does not match the number of parameters in the formula expression, please verify!"
202 << Endl;
203 }
204
205 // write the variables "xi" as additional parameters "[npar+i]"
206 for (Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
207 fFormulaStringT.ReplaceAll( TString::Format("x%i",ivar), TString::Format("[%i]",ivar+fNPars) );
208 }
209
210 // sanity check, there should be no "xi", with 'i' a number anymore
211 for (UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
212 if (fFormulaStringT.Contains( TString::Format("x%i",ivar) ))
213 Log() << kFATAL
214 << "<CreateFormula> Formula contains expression: \"" << TString::Format("x%i",ivar) << "\", "
215 << "which cannot be attributed to an input variable" << Endl;
216 }
217
218 Log() << "User-defined formula string : \"" << fFormulaStringP << "\"" << Endl;
219 Log() << "TFormula-compatible formula string: \"" << fFormulaStringT << "\"" << Endl;
220 Log() << kDEBUG << "Creating and compiling formula" << Endl;
221
222 // create TF1
223 if (fFormula) delete fFormula;
224 fFormula = new TFormula( "FDA_Formula", fFormulaStringT );
225
226 // is formula correct ?
227 if (!fFormula->IsValid())
228 Log() << kFATAL << "<ProcessOptions> Formula expression could not be properly compiled" << Endl;
229
230 // other sanity checks
231 if (fFormula->GetNpar() > (Int_t)(fNPars + GetNvar()))
232 Log() << kFATAL << "<ProcessOptions> Dubious number of parameters in formula expression: "
233 << fFormula->GetNpar() << " - compared to maximum allowed: " << fNPars + GetNvar() << Endl;
234}
235
236////////////////////////////////////////////////////////////////////////////////
237/// the option string is decoded, for available options see "DeclareOptions"
238
240{
241 // process transient strings
242 fParRangeStringT = fParRangeStringP;
243
244 // interpret parameter string
245 fParRangeStringT.ReplaceAll( " ", "" );
246 fNPars = fParRangeStringT.CountChar( ')' );
247
248 TList* parList = gTools().ParseFormatLine( fParRangeStringT, ";" );
249 if ((UInt_t)parList->GetSize() != fNPars) {
250 Log() << kFATAL << "<ProcessOptions> Mismatch in parameter string: "
251 << "the number of parameters: " << fNPars << " != ranges defined: "
252 << parList->GetSize() << "; the format of the \"ParRanges\" string "
253 << "must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
254 << "where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
255 << "each parameter defined in the function string must have a corresponding rang."
256 << Endl;
257 }
258
259 fParRange.resize( fNPars );
260 for (UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
261
262 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
263 // parse (a,b)
264 TString str = ((TObjString*)parList->At(ipar))->GetString();
265 Ssiz_t istr = str.First( ',' );
266 TString pminS(str(1,istr-1));
267 TString pmaxS(str(istr+1,str.Length()-2-istr));
268
269 stringstream stmin; Float_t pmin=0; stmin << pminS.Data(); stmin >> pmin;
270 stringstream stmax; Float_t pmax=0; stmax << pmaxS.Data(); stmax >> pmax;
271
272 // sanity check
273 if (TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
274 if (pmin > pmax) Log() << kFATAL << "<ProcessOptions> max > min in interval for parameter: ["
275 << ipar << "] : [" << pmin << ", " << pmax << "] " << Endl;
276
277 Log() << kINFO << "Create parameter interval for parameter " << ipar << " : [" << pmin << "," << pmax << "]" << Endl;
278 fParRange[ipar] = new Interval( pmin, pmax );
279 }
280 delete parList;
281
282 // create formula
283 CreateFormula();
284
285
286 // copy parameter ranges for each output dimension ==================
287 fOutputDimensions = 1;
288 if( DoRegression() )
289 fOutputDimensions = DataInfo().GetNTargets();
290 if( DoMulticlass() )
291 fOutputDimensions = DataInfo().GetNClasses();
292
293 for( Int_t dim = 1; dim < fOutputDimensions; ++dim ){
294 for( UInt_t par = 0; par < fNPars; ++par ){
295 fParRange.push_back( fParRange.at(par) );
296 }
297 }
298 // ====================
299
300 // create minimiser
301 fConvergerFitter = (IFitterTarget*)this;
302 if (fConverger == "MINUIT") {
303 fConvergerFitter = new MinuitFitter( *this, TString::Format("%s_Converger_Minuit", GetName()), fParRange, GetOptions() );
304 SetOptions(dynamic_cast<Configurable*>(fConvergerFitter)->GetOptions());
305 }
306
307 if(fFitMethod == "MC")
308 fFitter = new MCFitter( *fConvergerFitter, TString::Format("%s_Fitter_MC", GetName()), fParRange, GetOptions() );
309 else if (fFitMethod == "GA")
310 fFitter = new GeneticFitter( *fConvergerFitter, TString::Format("%s_Fitter_GA", GetName()), fParRange, GetOptions() );
311 else if (fFitMethod == "SA")
312 fFitter = new SimulatedAnnealingFitter( *fConvergerFitter, TString::Format("%s_Fitter_SA", GetName()), fParRange, GetOptions() );
313 else if (fFitMethod == "MINUIT")
314 fFitter = new MinuitFitter( *fConvergerFitter, TString::Format("%s_Fitter_Minuit", GetName()), fParRange, GetOptions() );
315 else {
316 Log() << kFATAL << "<Train> Do not understand fit method:" << fFitMethod << Endl;
317 }
318
319 fFitter->CheckForUnusedOptions();
320}
321
322////////////////////////////////////////////////////////////////////////////////
323/// destructor
324
326{
327 ClearAll();
328}
329
330////////////////////////////////////////////////////////////////////////////////
331/// FDA can handle classification with 2 classes and regression with one regression-target
332
334{
335 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
336 if (type == Types::kMulticlass ) return kTRUE;
337 if (type == Types::kRegression ) return kTRUE;
338 return kFALSE;
339}
340
341
342////////////////////////////////////////////////////////////////////////////////
343/// delete and clear all class members
344
346{
347 // if there is more than one output dimension, the paramater ranges are the same again (object has been copied).
348 // hence, ... erase the copied pointers to assure, that they are deleted only once.
349 // fParRange.erase( fParRange.begin()+(fNPars), fParRange.end() );
350 for (UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
351 if (fParRange[ipar] != 0) { delete fParRange[ipar]; fParRange[ipar] = 0; }
352 }
353 fParRange.clear();
354
355 if (fFormula != 0) { delete fFormula; fFormula = 0; }
356 fBestPars.clear();
357}
358
359////////////////////////////////////////////////////////////////////////////////
360/// FDA training
361
363{
364 // cache training events
365 fSumOfWeights = 0;
366 fSumOfWeightsSig = 0;
367 fSumOfWeightsBkg = 0;
368
369 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
370
371 // read the training event
372 const Event* ev = GetEvent(ievt);
373
374 // true event copy
375 Float_t w = ev->GetWeight();
376
377 if (!DoRegression()) {
378 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig += w; }
379 else { fSumOfWeightsBkg += w; }
380 }
381 fSumOfWeights += w;
382 }
383
384 // sanity check
385 if (!DoRegression()) {
386 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
387 Log() << kFATAL << "<Train> Troubles in sum of weights: "
388 << fSumOfWeightsSig << " (S) : " << fSumOfWeightsBkg << " (B)" << Endl;
389 }
390 }
391 else if (fSumOfWeights <= 0) {
392 Log() << kFATAL << "<Train> Troubles in sum of weights: "
393 << fSumOfWeights << Endl;
394 }
395
396 // starting values (not used by all fitters)
397 fBestPars.clear();
398 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); ++parIt) {
399 fBestPars.push_back( (*parIt)->GetMean() );
400 }
401
402 // execute the fit
403 Double_t estimator = fFitter->Run( fBestPars );
404
405 // print results
406 PrintResults( fFitMethod, fBestPars, estimator );
407
408 delete fFitter; fFitter = 0;
409 if (fConvergerFitter!=0 && fConvergerFitter!=(IFitterTarget*)this) {
410 delete fConvergerFitter;
411 fConvergerFitter = 0;
412 }
413 ExitFromTraining();
414}
415
416////////////////////////////////////////////////////////////////////////////////
417/// display fit parameters
418/// check maximum length of variable name
419
420void TMVA::MethodFDA::PrintResults( const TString& fitter, std::vector<Double_t>& pars, const Double_t estimator ) const
421{
422 Log() << kINFO;
423 Log() << kHEADER << "Results for parameter fit using \"" << fitter << "\" fitter:" << Endl;
424 std::vector<TString> parNames;
425 for (UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back( TString::Format("Par(%i)",ipar ) );
426 gTools().FormattedOutput( pars, parNames, "Parameter" , "Fit result", Log(), "%g" );
427 Log() << "Discriminator expression: \"" << fFormulaStringP << "\"" << Endl;
428 Log() << "Value of estimator at minimum: " << estimator << Endl;
429}
430
431////////////////////////////////////////////////////////////////////////////////
432/// compute estimator for given parameter set (to be minimised)
433
434Double_t TMVA::MethodFDA::EstimatorFunction( std::vector<Double_t>& pars )
435{
436 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
437 Double_t estimator[] = { 0, 0, 0 };
438
440 Double_t desired = 0.0;
441
442 // calculate the deviation from the desired value
443 if( DoRegression() ){
444 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
445 // read the training event
446 const TMVA::Event* ev = GetEvent(ievt);
447
448 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
449 desired = ev->GetTarget( dim );
450 result = InterpretFormula( ev, pars.begin(), pars.end() );
452 estimator[2] += deviation * ev->GetWeight();
453 }
454 }
455 estimator[2] /= sumOfWeights[2];
456 // return value is sum over normalised signal and background contributions
457 return estimator[2];
458
459 }else if( DoMulticlass() ){
460 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
461 // read the training event
462 const TMVA::Event* ev = GetEvent(ievt);
463
464 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
465
467 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
468 Double_t y = fMulticlassReturnVal->at(dim);
469 Double_t t = (ev->GetClass() == static_cast<UInt_t>(dim) ? 1.0 : 0.0 );
470 crossEntropy += t*log(y);
471 }
472 estimator[2] += ev->GetWeight()*crossEntropy;
473 }
474 estimator[2] /= sumOfWeights[2];
475 // return value is sum over normalised signal and background contributions
476 return estimator[2];
477
478 }else{
479 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
480 // read the training event
481 const TMVA::Event* ev = GetEvent(ievt);
482
483 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
484 result = InterpretFormula( ev, pars.begin(), pars.end() );
486 estimator[Int_t(desired)] += deviation * ev->GetWeight();
487 }
488 estimator[0] /= sumOfWeights[0];
489 estimator[1] /= sumOfWeights[1];
490 // return value is sum over normalised signal and background contributions
491 return estimator[0] + estimator[1];
492 }
493}
494
495////////////////////////////////////////////////////////////////////////////////
496/// formula interpretation
497
498Double_t TMVA::MethodFDA::InterpretFormula( const Event* event, std::vector<Double_t>::iterator parBegin, std::vector<Double_t>::iterator parEnd )
499{
500 Int_t ipar = 0;
501 // std::cout << "pars ";
502 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
503 // std::cout << " i" << ipar << " val" << (*it);
504 fFormula->SetParameter( ipar, (*it) );
505 ++ipar;
506 }
507 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->GetValue(ivar) );
508
509 Double_t result = fFormula->Eval( 0 );
510 // std::cout << " result " << result << std::endl;
511 return result;
512}
513
514////////////////////////////////////////////////////////////////////////////////
515/// returns MVA value for given event
516
518{
519 const Event* ev = GetEvent();
520
521 // cannot determine error
522 NoErrorCalc(err, errUpper);
523
524 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
525}
526
527////////////////////////////////////////////////////////////////////////////////
528
529const std::vector<Float_t>& TMVA::MethodFDA::GetRegressionValues()
530{
531 if (fRegressionReturnVal == NULL) fRegressionReturnVal = new std::vector<Float_t>();
532 fRegressionReturnVal->clear();
533
534 const Event* ev = GetEvent();
535
536 Event* evT = new Event(*ev);
537
538 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
539 Int_t offset = dim*fNPars;
540 evT->SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+offset, fBestPars.begin()+offset+fNPars ) );
541 }
542 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
543 fRegressionReturnVal->push_back(evT2->GetTarget(0));
544
545 delete evT;
546
547 return (*fRegressionReturnVal);
548}
549
550////////////////////////////////////////////////////////////////////////////////
551
552const std::vector<Float_t>& TMVA::MethodFDA::GetMulticlassValues()
553{
554 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal = new std::vector<Float_t>();
555 fMulticlassReturnVal->clear();
556 std::vector<Float_t> temp;
557
558 // returns MVA value for given event
559 const TMVA::Event* evt = GetEvent();
560
561 CalculateMulticlassValues( evt, fBestPars, temp );
562
563 UInt_t nClasses = DataInfo().GetNClasses();
565 Double_t norm = 0.0;
566 for(UInt_t j=0;j<nClasses;j++){
567 if(iClass!=j)
568 norm+=exp(temp[j]-temp[iClass]);
569 }
570 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
571 }
572
573 return (*fMulticlassReturnVal);
574}
575
576
577////////////////////////////////////////////////////////////////////////////////
578/// calculate the values for multiclass
579
580void TMVA::MethodFDA::CalculateMulticlassValues( const TMVA::Event*& evt, std::vector<Double_t>& parameters, std::vector<Float_t>& values)
581{
582 values.clear();
583
584 // std::copy( parameters.begin(), parameters.end(), std::ostream_iterator<double>( std::cout, " " ) );
585 // std::cout << std::endl;
586
587 // char inp;
588 // std::cin >> inp;
589
590 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){ // check for all other dimensions (=classes)
591 Int_t offset = dim*fNPars;
592 Double_t value = InterpretFormula( evt, parameters.begin()+offset, parameters.begin()+offset+fNPars );
593 // std::cout << "dim : " << dim << " value " << value << " offset " << offset << std::endl;
594 values.push_back( value );
595 }
596}
597
598////////////////////////////////////////////////////////////////////////////////
599/// read back the training results from a file (stream)
600
602{
603 // retrieve best function parameters
604 // coverity[tainted_data_argument]
605 istr >> fNPars;
606
607 fBestPars.clear();
608 fBestPars.resize( fNPars );
609 for (UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
610}
611
612////////////////////////////////////////////////////////////////////////////////
613/// create XML description for LD classification and regression
614/// (for arbitrary number of output classes/targets)
615
616void TMVA::MethodFDA::AddWeightsXMLTo( void* parent ) const
617{
618 void* wght = gTools().AddChild(parent, "Weights");
619 gTools().AddAttr( wght, "NPars", fNPars );
620 gTools().AddAttr( wght, "NDim", fOutputDimensions );
621 for (UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
622 void* coeffxml = gTools().AddChild( wght, "Parameter" );
623 gTools().AddAttr( coeffxml, "Index", ipar );
624 gTools().AddAttr( coeffxml, "Value", fBestPars[ipar] );
625 }
626
627 // write formula
628 gTools().AddAttr( wght, "Formula", fFormulaStringP );
629}
630
631////////////////////////////////////////////////////////////////////////////////
632/// read coefficients from xml weight file
633
635{
636 gTools().ReadAttr( wghtnode, "NPars", fNPars );
637
638 if(gTools().HasAttr( wghtnode, "NDim")) {
639 gTools().ReadAttr( wghtnode, "NDim" , fOutputDimensions );
640 } else {
641 // older weight files don't have this attribute
642 fOutputDimensions = 1;
643 }
644
645 fBestPars.clear();
646 fBestPars.resize( fNPars*fOutputDimensions );
647
648 void* ch = gTools().GetChild(wghtnode);
649 Double_t par;
650 UInt_t ipar;
651 while (ch) {
652 gTools().ReadAttr( ch, "Index", ipar );
653 gTools().ReadAttr( ch, "Value", par );
654
655 // sanity check
656 if (ipar >= fNPars*fOutputDimensions) Log() << kFATAL << "<ReadWeightsFromXML> index out of range: "
657 << ipar << " >= " << fNPars << Endl;
658 fBestPars[ipar] = par;
659
660 ch = gTools().GetNextChild(ch);
661 }
662
663 // read formula
664 gTools().ReadAttr( wghtnode, "Formula", fFormulaStringP );
665
666 // create the TFormula
667 CreateFormula();
668}
669
670////////////////////////////////////////////////////////////////////////////////
671/// write FDA-specific classifier response
672
673void TMVA::MethodFDA::MakeClassSpecific( std::ostream& fout, const TString& className ) const
674{
675 fout << " double fParameter[" << fNPars << "];" << std::endl;
676 fout << "};" << std::endl;
677 fout << "" << std::endl;
678 fout << "inline void " << className << "::Initialize() " << std::endl;
679 fout << "{" << std::endl;
680 for(UInt_t ipar=0; ipar<fNPars; ipar++) {
681 fout << " fParameter[" << ipar << "] = " << fBestPars[ipar] << ";" << std::endl;
682 }
683 fout << "}" << std::endl;
684 fout << std::endl;
685 fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
686 fout << "{" << std::endl;
687 fout << " // interpret the formula" << std::endl;
688
689 // replace parameters
690 TString str = fFormulaStringT;
691 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
692 str.ReplaceAll( TString::Format("[%i]", ipar), TString::Format("fParameter[%i]", ipar) );
693 }
694
695 // replace input variables
696 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
697 str.ReplaceAll( TString::Format("[%i]", ivar+fNPars), TString::Format("inputValues[%i]", ivar) );
698 }
699
700 fout << " double retval = " << str << ";" << std::endl;
701 fout << std::endl;
702 fout << " return retval; " << std::endl;
703 fout << "}" << std::endl;
704 fout << std::endl;
705 fout << "// Clean up" << std::endl;
706 fout << "inline void " << className << "::Clear() " << std::endl;
707 fout << "{" << std::endl;
708 fout << " // nothing to clear" << std::endl;
709 fout << "}" << std::endl;
710}
711
712////////////////////////////////////////////////////////////////////////////////
713/// get help message text
714///
715/// typical length of text line:
716/// "|--------------------------------------------------------------|"
717
719{
720 Log() << Endl;
721 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
722 Log() << Endl;
723 Log() << "The function discriminant analysis (FDA) is a classifier suitable " << Endl;
724 Log() << "to solve linear or simple nonlinear discrimination problems." << Endl;
725 Log() << Endl;
726 Log() << "The user provides the desired function with adjustable parameters" << Endl;
727 Log() << "via the configuration option string, and FDA fits the parameters to" << Endl;
728 Log() << "it, requiring the signal (background) function value to be as close" << Endl;
729 Log() << "as possible to 1 (0). Its advantage over the more involved and" << Endl;
730 Log() << "automatic nonlinear discriminators is the simplicity and transparency " << Endl;
731 Log() << "of the discrimination expression. A shortcoming is that FDA will" << Endl;
732 Log() << "underperform for involved problems with complicated, phase space" << Endl;
733 Log() << "dependent nonlinear correlations." << Endl;
734 Log() << Endl;
735 Log() << "Please consult the Users Guide for the format of the formula string" << Endl;
736 Log() << "and the allowed parameter ranges:" << Endl;
737 if (gConfig().WriteOptionsReference()) {
738 Log() << "<a href=\"https://github.com/root-project/root/blob/master/documentation/tmva/UsersGuide/TMVAUsersGuide.pdf\">"
739 << "TMVAUsersGuide.pdf</a>" << Endl;
740 }
741 else Log() << "documentation/tmva/UsersGuide/TMVAUsersGuide.pdf" << Endl;
742 Log() << Endl;
743 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
744 Log() << Endl;
745 Log() << "The FDA performance depends on the complexity and fidelity of the" << Endl;
746 Log() << "user-defined discriminator function. As a general rule, it should" << Endl;
747 Log() << "be able to reproduce the discrimination power of any linear" << Endl;
748 Log() << "discriminant analysis. To reach into the nonlinear domain, it is" << Endl;
749 Log() << "useful to inspect the correlation profiles of the input variables," << Endl;
750 Log() << "and add quadratic and higher polynomial terms between variables as" << Endl;
751 Log() << "necessary. Comparison with more involved nonlinear classifiers can" << Endl;
752 Log() << "be used as a guide." << Endl;
753 Log() << Endl;
754 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
755 Log() << Endl;
756 Log() << "Depending on the function used, the choice of \"FitMethod\" is" << Endl;
757 Log() << "crucial for getting valuable solutions with FDA. As a guideline it" << Endl;
758 Log() << "is recommended to start with \"FitMethod=MINUIT\". When more complex" << Endl;
759 Log() << "functions are used where MINUIT does not converge to reasonable" << Endl;
760 Log() << "results, the user should switch to non-gradient FitMethods such" << Endl;
761 Log() << "as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" << Endl;
762 Log() << "useful to combine GA (or MC) with MINUIT by setting the option" << Endl;
763 Log() << "\"Converger=MINUIT\". GA (MC) will then set the starting parameters" << Endl;
764 Log() << "for MINUIT such that the basic quality of GA (MC) of finding global" << Endl;
765 Log() << "minima is combined with the efficacy of MINUIT of finding local" << Endl;
766 Log() << "minima." << Endl;
767}
#define REGISTER_METHOD(CLASS)
for example
int Int_t
Signed integer 4 bytes (int)
Definition RtypesCore.h:59
float Float_t
Float 4 bytes (float)
Definition RtypesCore.h:71
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
The Formula class.
Definition TFormula.h:89
A doubly linked list.
Definition TList.h:38
Class that contains all the data information.
Definition DataSetInfo.h:62
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
Fitter using a Genetic Algorithm.
Interface for a fitter 'target'.
The TMVA::Interval Class.
Definition Interval.h:61
Fitter using Monte Carlo sampling of parameters.
Definition MCFitter.h:44
Virtual base Class for all MVA method.
Definition MethodBase.h:111
Function discriminant analysis (FDA).
Definition MethodFDA.h:61
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
FDA can handle classification with 2 classes and regression with one regression-target.
const std::vector< Float_t > & GetRegressionValues() override
virtual ~MethodFDA(void)
destructor
Double_t InterpretFormula(const Event *, std::vector< Double_t >::iterator begin, std::vector< Double_t >::iterator end)
formula interpretation
void Train(void) override
FDA training.
void CalculateMulticlassValues(const TMVA::Event *&evt, std::vector< Double_t > &parameters, std::vector< Float_t > &values)
calculate the values for multiclass
void ReadWeightsFromStream(std::istream &i) override
read back the training results from a file (stream)
MethodFDA(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
Definition MethodFDA.cxx:85
void ClearAll()
delete and clear all class members
void PrintResults(const TString &, std::vector< Double_t > &, const Double_t) const
display fit parameters check maximum length of variable name
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
returns MVA value for given event
void MakeClassSpecific(std::ostream &, const TString &) const override
write FDA-specific classifier response
void AddWeightsXMLTo(void *parent) const override
create XML description for LD classification and regression (for arbitrary number of output classes/t...
void ProcessOptions() override
the option string is decoded, for available options see "DeclareOptions"
void ReadWeightsFromXML(void *wghtnode) override
read coefficients from xml weight file
Double_t EstimatorFunction(std::vector< Double_t > &) override
compute estimator for given parameter set (to be minimised)
void DeclareOptions() override
define the options (their key words) that can be set in the option string
const std::vector< Float_t > & GetMulticlassValues() override
void CreateFormula()
translate formula string into TFormula, and parameter string into par ranges
void Init(void) override
default initialisation
void GetHelpMessage() const override
get help message text
/Fitter using MINUIT
Fitter using a Simulated Annealing Algorithm.
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition Tools.cxx:887
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition Tools.cxx:401
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
Collectable string class.
Definition TObjString.h:28
Basic string class.
Definition TString.h:138
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:712
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
Double_t y[n]
Definition legend1.C:17
double crossEntropy(ItProbability itProbabilityBegin, ItProbability itProbabilityEnd, ItTruth itTruthBegin, ItTruth itTruthEnd, ItDelta itDelta, ItDelta itDeltaEnd, ItInvActFnc itInvActFnc, double patternWeight)
cross entropy error function
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.
Definition TMath.h:732
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:124