Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodFDA.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodFDA *
8 * *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
17 * Maciej Kruk <mkruk@cern.ch> - IFJ PAN & AGH, Poland *
18 * *
19 * Copyright (c) 2005-2006: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (see tmva/doc/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::MethodFDA
29\ingroup TMVA
30
31Function discriminant analysis (FDA).
32
33This simple classifier
34fits any user-defined TFormula (via option configuration string) to
35the training data by requiring a formula response of 1 (0) to signal
36(background) events. The parameter fitting is done via the abstract
37class FitterBase, featuring Monte Carlo sampling, Genetic
38Algorithm, Simulated Annealing, MINUIT and combinations of these.
39
40Can compute regression value for one dimensional output
41*/
42
43#include "TMVA/MethodFDA.h"
44
46#include "TMVA/Config.h"
47#include "TMVA/Configurable.h"
48#include "TMVA/DataSetInfo.h"
49#include "TMVA/FitterBase.h"
50#include "TMVA/GeneticFitter.h"
51#include "TMVA/Interval.h"
52#include "TMVA/IFitterTarget.h"
53#include "TMVA/IMethod.h"
54#include "TMVA/MCFitter.h"
55#include "TMVA/MethodBase.h"
56#include "TMVA/MinuitFitter.h"
57#include "TMVA/MsgLogger.h"
58#include "TMVA/Timer.h"
59#include "TMVA/Tools.h"
61#include "TMVA/Types.h"
63
64#include "TList.h"
65#include "TFormula.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TRandom3.h"
69#include "TMath.h"
70
71#include <iostream>
72#include <algorithm>
73#include <iterator>
74#include <stdexcept>
75#include <sstream>
76
77using std::stringstream;
78
80
82
83////////////////////////////////////////////////////////////////////////////////
84/// standard constructor
85
87 const TString& methodTitle,
88 DataSetInfo& theData,
89 const TString& theOption)
90 : MethodBase( jobName, Types::kFDA, methodTitle, theData, theOption),
92 fFormula ( 0 ),
93 fNPars ( 0 ),
94 fFitter ( 0 ),
95 fConvergerFitter( 0 ),
96 fSumOfWeightsSig( 0 ),
97 fSumOfWeightsBkg( 0 ),
98 fSumOfWeights ( 0 ),
99 fOutputDimensions( 0 )
100{
101}
102
103////////////////////////////////////////////////////////////////////////////////
104/// constructor from weight file
105
107 const TString& theWeightFile)
108 : MethodBase( Types::kFDA, theData, theWeightFile),
109 IFitterTarget (),
110 fFormula ( 0 ),
111 fNPars ( 0 ),
112 fFitter ( 0 ),
113 fConvergerFitter( 0 ),
114 fSumOfWeightsSig( 0 ),
115 fSumOfWeightsBkg( 0 ),
116 fSumOfWeights ( 0 ),
117 fOutputDimensions( 0 )
118{
119}
120
121////////////////////////////////////////////////////////////////////////////////
122/// default initialisation
123
125{
126 fNPars = 0;
127
128 fBestPars.clear();
129
130 fSumOfWeights = 0;
131 fSumOfWeightsSig = 0;
132 fSumOfWeightsBkg = 0;
133
134 fFormulaStringP = "";
135 fParRangeStringP = "";
136 fFormulaStringT = "";
137 fParRangeStringT = "";
138
139 fFitMethod = "";
140 fConverger = "";
141
142 if( DoMulticlass() )
143 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal = new std::vector<Float_t>();
144
145}
146
147////////////////////////////////////////////////////////////////////////////////
148/// define the options (their key words) that can be set in the option string
149///
150/// format of function string:
151///
152/// "x0*(0)+((1)/x1)**(2)..."
153///
154/// where "[i]" are the parameters, and "xi" the input variables
155///
156/// format of parameter string:
157///
158/// "(-1.2,3.4);(-2.3,4.55);..."
159///
160/// where the numbers in "(a,b)" correspond to the a=min, b=max parameter ranges;
161/// each parameter defined in the function string must have a corresponding range
162
164{
165 DeclareOptionRef( fFormulaStringP = "(0)", "Formula", "The discrimination formula" );
166 DeclareOptionRef( fParRangeStringP = "()", "ParRanges", "Parameter ranges" );
167
168 // fitter
169 DeclareOptionRef( fFitMethod = "MINUIT", "FitMethod", "Optimisation Method");
170 AddPreDefVal(TString("MC"));
171 AddPreDefVal(TString("GA"));
172 AddPreDefVal(TString("SA"));
173 AddPreDefVal(TString("MINUIT"));
174
175 DeclareOptionRef( fConverger = "None", "Converger", "FitMethod uses Converger to improve result");
176 AddPreDefVal(TString("None"));
177 AddPreDefVal(TString("MINUIT"));
178}
179
180////////////////////////////////////////////////////////////////////////////////
181/// translate formula string into TFormula, and parameter string into par ranges
182
184{
185 // process transient strings
186 fFormulaStringT = fFormulaStringP;
187
188 // interpret formula string
189
190 // replace the parameters "(i)" by the TFormula style "[i]"
191 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
192 fFormulaStringT.ReplaceAll( TString::Format("(%i)",ipar), TString::Format("[%i]",ipar) );
193 }
194
195 // sanity check, there should be no "(i)", with 'i' a number anymore
196 for (Int_t ipar=fNPars; ipar<1000; ipar++) {
197 if (fFormulaStringT.Contains( TString::Format("(%i)",ipar) ))
198 Log() << kFATAL
199 << "<CreateFormula> Formula contains expression: \"" << TString::Format("(%i)",ipar) << "\", "
200 << "which cannot be attributed to a parameter; "
201 << "it may be that the number of variable ranges given via \"ParRanges\" "
202 << "does not match the number of parameters in the formula expression, please verify!"
203 << Endl;
204 }
205
206 // write the variables "xi" as additional parameters "[npar+i]"
207 for (Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
208 fFormulaStringT.ReplaceAll( TString::Format("x%i",ivar), TString::Format("[%i]",ivar+fNPars) );
209 }
210
211 // sanity check, there should be no "xi", with 'i' a number anymore
212 for (UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
213 if (fFormulaStringT.Contains( TString::Format("x%i",ivar) ))
214 Log() << kFATAL
215 << "<CreateFormula> Formula contains expression: \"" << TString::Format("x%i",ivar) << "\", "
216 << "which cannot be attributed to an input variable" << Endl;
217 }
218
219 Log() << "User-defined formula string : \"" << fFormulaStringP << "\"" << Endl;
220 Log() << "TFormula-compatible formula string: \"" << fFormulaStringT << "\"" << Endl;
221 Log() << kDEBUG << "Creating and compiling formula" << Endl;
222
223 // create TF1
224 if (fFormula) delete fFormula;
225 fFormula = new TFormula( "FDA_Formula", fFormulaStringT );
226
227 // is formula correct ?
228 if (!fFormula->IsValid())
229 Log() << kFATAL << "<ProcessOptions> Formula expression could not be properly compiled" << Endl;
230
231 // other sanity checks
232 if (fFormula->GetNpar() > (Int_t)(fNPars + GetNvar()))
233 Log() << kFATAL << "<ProcessOptions> Dubious number of parameters in formula expression: "
234 << fFormula->GetNpar() << " - compared to maximum allowed: " << fNPars + GetNvar() << Endl;
235}
236
237////////////////////////////////////////////////////////////////////////////////
238/// the option string is decoded, for available options see "DeclareOptions"
239
241{
242 // process transient strings
243 fParRangeStringT = fParRangeStringP;
244
245 // interpret parameter string
246 fParRangeStringT.ReplaceAll( " ", "" );
247 fNPars = fParRangeStringT.CountChar( ')' );
248
249 TList* parList = gTools().ParseFormatLine( fParRangeStringT, ";" );
250 if ((UInt_t)parList->GetSize() != fNPars) {
251 Log() << kFATAL << "<ProcessOptions> Mismatch in parameter string: "
252 << "the number of parameters: " << fNPars << " != ranges defined: "
253 << parList->GetSize() << "; the format of the \"ParRanges\" string "
254 << "must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
255 << "where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
256 << "each parameter defined in the function string must have a corresponding rang."
257 << Endl;
258 }
259
260 fParRange.resize( fNPars );
261 for (UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
262
263 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
264 // parse (a,b)
265 TString str = ((TObjString*)parList->At(ipar))->GetString();
266 Ssiz_t istr = str.First( ',' );
267 TString pminS(str(1,istr-1));
268 TString pmaxS(str(istr+1,str.Length()-2-istr));
269
270 stringstream stmin; Float_t pmin=0; stmin << pminS.Data(); stmin >> pmin;
271 stringstream stmax; Float_t pmax=0; stmax << pmaxS.Data(); stmax >> pmax;
272
273 // sanity check
274 if (TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
275 if (pmin > pmax) Log() << kFATAL << "<ProcessOptions> max > min in interval for parameter: ["
276 << ipar << "] : [" << pmin << ", " << pmax << "] " << Endl;
277
278 Log() << kINFO << "Create parameter interval for parameter " << ipar << " : [" << pmin << "," << pmax << "]" << Endl;
279 fParRange[ipar] = new Interval( pmin, pmax );
280 }
281 delete parList;
282
283 // create formula
284 CreateFormula();
285
286
287 // copy parameter ranges for each output dimension ==================
288 fOutputDimensions = 1;
289 if( DoRegression() )
290 fOutputDimensions = DataInfo().GetNTargets();
291 if( DoMulticlass() )
292 fOutputDimensions = DataInfo().GetNClasses();
293
294 for( Int_t dim = 1; dim < fOutputDimensions; ++dim ){
295 for( UInt_t par = 0; par < fNPars; ++par ){
296 fParRange.push_back( fParRange.at(par) );
297 }
298 }
299 // ====================
300
301 // create minimiser
302 fConvergerFitter = (IFitterTarget*)this;
303 if (fConverger == "MINUIT") {
304 fConvergerFitter = new MinuitFitter( *this, TString::Format("%s_Converger_Minuit", GetName()), fParRange, GetOptions() );
305 SetOptions(dynamic_cast<Configurable*>(fConvergerFitter)->GetOptions());
306 }
307
308 if(fFitMethod == "MC")
309 fFitter = new MCFitter( *fConvergerFitter, TString::Format("%s_Fitter_MC", GetName()), fParRange, GetOptions() );
310 else if (fFitMethod == "GA")
311 fFitter = new GeneticFitter( *fConvergerFitter, TString::Format("%s_Fitter_GA", GetName()), fParRange, GetOptions() );
312 else if (fFitMethod == "SA")
313 fFitter = new SimulatedAnnealingFitter( *fConvergerFitter, TString::Format("%s_Fitter_SA", GetName()), fParRange, GetOptions() );
314 else if (fFitMethod == "MINUIT")
315 fFitter = new MinuitFitter( *fConvergerFitter, TString::Format("%s_Fitter_Minuit", GetName()), fParRange, GetOptions() );
316 else {
317 Log() << kFATAL << "<Train> Do not understand fit method:" << fFitMethod << Endl;
318 }
319
320 fFitter->CheckForUnusedOptions();
321}
322
323////////////////////////////////////////////////////////////////////////////////
324/// destructor
325
327{
328 ClearAll();
329}
330
331////////////////////////////////////////////////////////////////////////////////
332/// FDA can handle classification with 2 classes and regression with one regression-target
333
335{
336 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
337 if (type == Types::kMulticlass ) return kTRUE;
338 if (type == Types::kRegression ) return kTRUE;
339 return kFALSE;
340}
341
342
343////////////////////////////////////////////////////////////////////////////////
344/// delete and clear all class members
345
347{
348 // if there is more than one output dimension, the paramater ranges are the same again (object has been copied).
349 // hence, ... erase the copied pointers to assure, that they are deleted only once.
350 // fParRange.erase( fParRange.begin()+(fNPars), fParRange.end() );
351 for (UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
352 if (fParRange[ipar] != 0) { delete fParRange[ipar]; fParRange[ipar] = 0; }
353 }
354 fParRange.clear();
355
356 if (fFormula != 0) { delete fFormula; fFormula = 0; }
357 fBestPars.clear();
358}
359
360////////////////////////////////////////////////////////////////////////////////
361/// FDA training
362
364{
365 // cache training events
366 fSumOfWeights = 0;
367 fSumOfWeightsSig = 0;
368 fSumOfWeightsBkg = 0;
369
370 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
371
372 // read the training event
373 const Event* ev = GetEvent(ievt);
374
375 // true event copy
376 Float_t w = ev->GetWeight();
377
378 if (!DoRegression()) {
379 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig += w; }
380 else { fSumOfWeightsBkg += w; }
381 }
382 fSumOfWeights += w;
383 }
384
385 // sanity check
386 if (!DoRegression()) {
387 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
388 Log() << kFATAL << "<Train> Troubles in sum of weights: "
389 << fSumOfWeightsSig << " (S) : " << fSumOfWeightsBkg << " (B)" << Endl;
390 }
391 }
392 else if (fSumOfWeights <= 0) {
393 Log() << kFATAL << "<Train> Troubles in sum of weights: "
394 << fSumOfWeights << Endl;
395 }
396
397 // starting values (not used by all fitters)
398 fBestPars.clear();
399 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); ++parIt) {
400 fBestPars.push_back( (*parIt)->GetMean() );
401 }
402
403 // execute the fit
404 Double_t estimator = fFitter->Run( fBestPars );
405
406 // print results
407 PrintResults( fFitMethod, fBestPars, estimator );
408
409 delete fFitter; fFitter = 0;
410 if (fConvergerFitter!=0 && fConvergerFitter!=(IFitterTarget*)this) {
411 delete fConvergerFitter;
412 fConvergerFitter = 0;
413 }
414 ExitFromTraining();
415}
416
417////////////////////////////////////////////////////////////////////////////////
418/// display fit parameters
419/// check maximum length of variable name
420
421void TMVA::MethodFDA::PrintResults( const TString& fitter, std::vector<Double_t>& pars, const Double_t estimator ) const
422{
423 Log() << kINFO;
424 Log() << kHEADER << "Results for parameter fit using \"" << fitter << "\" fitter:" << Endl;
425 std::vector<TString> parNames;
426 for (UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back( TString::Format("Par(%i)",ipar ) );
427 gTools().FormattedOutput( pars, parNames, "Parameter" , "Fit result", Log(), "%g" );
428 Log() << "Discriminator expression: \"" << fFormulaStringP << "\"" << Endl;
429 Log() << "Value of estimator at minimum: " << estimator << Endl;
430}
431
432////////////////////////////////////////////////////////////////////////////////
433/// compute estimator for given parameter set (to be minimised)
434
435Double_t TMVA::MethodFDA::EstimatorFunction( std::vector<Double_t>& pars )
436{
437 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
438 Double_t estimator[] = { 0, 0, 0 };
439
440 Double_t result, deviation;
441 Double_t desired = 0.0;
442
443 // calculate the deviation from the desired value
444 if( DoRegression() ){
445 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
446 // read the training event
447 const TMVA::Event* ev = GetEvent(ievt);
448
449 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
450 desired = ev->GetTarget( dim );
451 result = InterpretFormula( ev, pars.begin(), pars.end() );
452 deviation = TMath::Power(result - desired, 2);
453 estimator[2] += deviation * ev->GetWeight();
454 }
455 }
456 estimator[2] /= sumOfWeights[2];
457 // return value is sum over normalised signal and background contributions
458 return estimator[2];
459
460 }else if( DoMulticlass() ){
461 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
462 // read the training event
463 const TMVA::Event* ev = GetEvent(ievt);
464
465 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
466
468 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
469 Double_t y = fMulticlassReturnVal->at(dim);
470 Double_t t = (ev->GetClass() == static_cast<UInt_t>(dim) ? 1.0 : 0.0 );
471 crossEntropy += t*log(y);
472 }
473 estimator[2] += ev->GetWeight()*crossEntropy;
474 }
475 estimator[2] /= sumOfWeights[2];
476 // return value is sum over normalised signal and background contributions
477 return estimator[2];
478
479 }else{
480 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
481 // read the training event
482 const TMVA::Event* ev = GetEvent(ievt);
483
484 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
485 result = InterpretFormula( ev, pars.begin(), pars.end() );
486 deviation = TMath::Power(result - desired, 2);
487 estimator[Int_t(desired)] += deviation * ev->GetWeight();
488 }
489 estimator[0] /= sumOfWeights[0];
490 estimator[1] /= sumOfWeights[1];
491 // return value is sum over normalised signal and background contributions
492 return estimator[0] + estimator[1];
493 }
494}
495
496////////////////////////////////////////////////////////////////////////////////
497/// formula interpretation
498
499Double_t TMVA::MethodFDA::InterpretFormula( const Event* event, std::vector<Double_t>::iterator parBegin, std::vector<Double_t>::iterator parEnd )
500{
501 Int_t ipar = 0;
502 // std::cout << "pars ";
503 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
504 // std::cout << " i" << ipar << " val" << (*it);
505 fFormula->SetParameter( ipar, (*it) );
506 ++ipar;
507 }
508 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->GetValue(ivar) );
509
510 Double_t result = fFormula->Eval( 0 );
511 // std::cout << " result " << result << std::endl;
512 return result;
513}
514
515////////////////////////////////////////////////////////////////////////////////
516/// returns MVA value for given event
517
519{
520 const Event* ev = GetEvent();
521
522 // cannot determine error
523 NoErrorCalc(err, errUpper);
524
525 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
526}
527
528////////////////////////////////////////////////////////////////////////////////
529
530const std::vector<Float_t>& TMVA::MethodFDA::GetRegressionValues()
531{
532 if (fRegressionReturnVal == NULL) fRegressionReturnVal = new std::vector<Float_t>();
533 fRegressionReturnVal->clear();
534
535 const Event* ev = GetEvent();
536
537 Event* evT = new Event(*ev);
538
539 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
540 Int_t offset = dim*fNPars;
541 evT->SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+offset, fBestPars.begin()+offset+fNPars ) );
542 }
543 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
544 fRegressionReturnVal->push_back(evT2->GetTarget(0));
545
546 delete evT;
547
548 return (*fRegressionReturnVal);
549}
550
551////////////////////////////////////////////////////////////////////////////////
552
553const std::vector<Float_t>& TMVA::MethodFDA::GetMulticlassValues()
554{
555 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal = new std::vector<Float_t>();
556 fMulticlassReturnVal->clear();
557 std::vector<Float_t> temp;
558
559 // returns MVA value for given event
560 const TMVA::Event* evt = GetEvent();
561
562 CalculateMulticlassValues( evt, fBestPars, temp );
563
564 UInt_t nClasses = DataInfo().GetNClasses();
565 for(UInt_t iClass=0; iClass<nClasses; iClass++){
566 Double_t norm = 0.0;
567 for(UInt_t j=0;j<nClasses;j++){
568 if(iClass!=j)
569 norm+=exp(temp[j]-temp[iClass]);
570 }
571 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
572 }
573
574 return (*fMulticlassReturnVal);
575}
576
577
578////////////////////////////////////////////////////////////////////////////////
579/// calculate the values for multiclass
580
581void TMVA::MethodFDA::CalculateMulticlassValues( const TMVA::Event*& evt, std::vector<Double_t>& parameters, std::vector<Float_t>& values)
582{
583 values.clear();
584
585 // std::copy( parameters.begin(), parameters.end(), std::ostream_iterator<double>( std::cout, " " ) );
586 // std::cout << std::endl;
587
588 // char inp;
589 // std::cin >> inp;
590
591 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){ // check for all other dimensions (=classes)
592 Int_t offset = dim*fNPars;
593 Double_t value = InterpretFormula( evt, parameters.begin()+offset, parameters.begin()+offset+fNPars );
594 // std::cout << "dim : " << dim << " value " << value << " offset " << offset << std::endl;
595 values.push_back( value );
596 }
597}
598
599////////////////////////////////////////////////////////////////////////////////
600/// read back the training results from a file (stream)
601
603{
604 // retrieve best function parameters
605 // coverity[tainted_data_argument]
606 istr >> fNPars;
607
608 fBestPars.clear();
609 fBestPars.resize( fNPars );
610 for (UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
611}
612
613////////////////////////////////////////////////////////////////////////////////
614/// create XML description for LD classification and regression
615/// (for arbitrary number of output classes/targets)
616
617void TMVA::MethodFDA::AddWeightsXMLTo( void* parent ) const
618{
619 void* wght = gTools().AddChild(parent, "Weights");
620 gTools().AddAttr( wght, "NPars", fNPars );
621 gTools().AddAttr( wght, "NDim", fOutputDimensions );
622 for (UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
623 void* coeffxml = gTools().AddChild( wght, "Parameter" );
624 gTools().AddAttr( coeffxml, "Index", ipar );
625 gTools().AddAttr( coeffxml, "Value", fBestPars[ipar] );
626 }
627
628 // write formula
629 gTools().AddAttr( wght, "Formula", fFormulaStringP );
630}
631
632////////////////////////////////////////////////////////////////////////////////
633/// read coefficients from xml weight file
634
636{
637 gTools().ReadAttr( wghtnode, "NPars", fNPars );
638
639 if(gTools().HasAttr( wghtnode, "NDim")) {
640 gTools().ReadAttr( wghtnode, "NDim" , fOutputDimensions );
641 } else {
642 // older weight files don't have this attribute
643 fOutputDimensions = 1;
644 }
645
646 fBestPars.clear();
647 fBestPars.resize( fNPars*fOutputDimensions );
648
649 void* ch = gTools().GetChild(wghtnode);
650 Double_t par;
651 UInt_t ipar;
652 while (ch) {
653 gTools().ReadAttr( ch, "Index", ipar );
654 gTools().ReadAttr( ch, "Value", par );
655
656 // sanity check
657 if (ipar >= fNPars*fOutputDimensions) Log() << kFATAL << "<ReadWeightsFromXML> index out of range: "
658 << ipar << " >= " << fNPars << Endl;
659 fBestPars[ipar] = par;
660
661 ch = gTools().GetNextChild(ch);
662 }
663
664 // read formula
665 gTools().ReadAttr( wghtnode, "Formula", fFormulaStringP );
666
667 // create the TFormula
668 CreateFormula();
669}
670
671////////////////////////////////////////////////////////////////////////////////
672/// write FDA-specific classifier response
673
674void TMVA::MethodFDA::MakeClassSpecific( std::ostream& fout, const TString& className ) const
675{
676 fout << " double fParameter[" << fNPars << "];" << std::endl;
677 fout << "};" << std::endl;
678 fout << "" << std::endl;
679 fout << "inline void " << className << "::Initialize() " << std::endl;
680 fout << "{" << std::endl;
681 for(UInt_t ipar=0; ipar<fNPars; ipar++) {
682 fout << " fParameter[" << ipar << "] = " << fBestPars[ipar] << ";" << std::endl;
683 }
684 fout << "}" << std::endl;
685 fout << std::endl;
686 fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
687 fout << "{" << std::endl;
688 fout << " // interpret the formula" << std::endl;
689
690 // replace parameters
691 TString str = fFormulaStringT;
692 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
693 str.ReplaceAll( TString::Format("[%i]", ipar), TString::Format("fParameter[%i]", ipar) );
694 }
695
696 // replace input variables
697 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
698 str.ReplaceAll( TString::Format("[%i]", ivar+fNPars), TString::Format("inputValues[%i]", ivar) );
699 }
700
701 fout << " double retval = " << str << ";" << std::endl;
702 fout << std::endl;
703 fout << " return retval; " << std::endl;
704 fout << "}" << std::endl;
705 fout << std::endl;
706 fout << "// Clean up" << std::endl;
707 fout << "inline void " << className << "::Clear() " << std::endl;
708 fout << "{" << std::endl;
709 fout << " // nothing to clear" << std::endl;
710 fout << "}" << std::endl;
711}
712
713////////////////////////////////////////////////////////////////////////////////
714/// get help message text
715///
716/// typical length of text line:
717/// "|--------------------------------------------------------------|"
718
720{
721 Log() << Endl;
722 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
723 Log() << Endl;
724 Log() << "The function discriminant analysis (FDA) is a classifier suitable " << Endl;
725 Log() << "to solve linear or simple nonlinear discrimination problems." << Endl;
726 Log() << Endl;
727 Log() << "The user provides the desired function with adjustable parameters" << Endl;
728 Log() << "via the configuration option string, and FDA fits the parameters to" << Endl;
729 Log() << "it, requiring the signal (background) function value to be as close" << Endl;
730 Log() << "as possible to 1 (0). Its advantage over the more involved and" << Endl;
731 Log() << "automatic nonlinear discriminators is the simplicity and transparency " << Endl;
732 Log() << "of the discrimination expression. A shortcoming is that FDA will" << Endl;
733 Log() << "underperform for involved problems with complicated, phase space" << Endl;
734 Log() << "dependent nonlinear correlations." << Endl;
735 Log() << Endl;
736 Log() << "Please consult the Users Guide for the format of the formula string" << Endl;
737 Log() << "and the allowed parameter ranges:" << Endl;
738 if (gConfig().WriteOptionsReference()) {
739 Log() << "<a href=\"https://github.com/root-project/root/blob/master/documentation/tmva/UsersGuide/TMVAUsersGuide.pdf\">"
740 << "TMVAUsersGuide.pdf</a>" << Endl;
741 }
742 else Log() << "documentation/tmva/UsersGuide/TMVAUsersGuide.pdf" << Endl;
743 Log() << Endl;
744 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
745 Log() << Endl;
746 Log() << "The FDA performance depends on the complexity and fidelity of the" << Endl;
747 Log() << "user-defined discriminator function. As a general rule, it should" << Endl;
748 Log() << "be able to reproduce the discrimination power of any linear" << Endl;
749 Log() << "discriminant analysis. To reach into the nonlinear domain, it is" << Endl;
750 Log() << "useful to inspect the correlation profiles of the input variables," << Endl;
751 Log() << "and add quadratic and higher polynomial terms between variables as" << Endl;
752 Log() << "necessary. Comparison with more involved nonlinear classifiers can" << Endl;
753 Log() << "be used as a guide." << Endl;
754 Log() << Endl;
755 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
756 Log() << Endl;
757 Log() << "Depending on the function used, the choice of \"FitMethod\" is" << Endl;
758 Log() << "crucial for getting valuable solutions with FDA. As a guideline it" << Endl;
759 Log() << "is recommended to start with \"FitMethod=MINUIT\". When more complex" << Endl;
760 Log() << "functions are used where MINUIT does not converge to reasonable" << Endl;
761 Log() << "results, the user should switch to non-gradient FitMethods such" << Endl;
762 Log() << "as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" << Endl;
763 Log() << "useful to combine GA (or MC) with MINUIT by setting the option" << Endl;
764 Log() << "\"Converger=MINUIT\". GA (MC) will then set the starting parameters" << Endl;
765 Log() << "for MINUIT such that the basic quality of GA (MC) of finding global" << Endl;
766 Log() << "minima is combined with the efficacy of MINUIT of finding local" << Endl;
767 Log() << "minima." << Endl;
768}
#define REGISTER_METHOD(CLASS)
for example
int Int_t
Definition RtypesCore.h:45
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:377
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
The Formula class.
Definition TFormula.h:89
A doubly linked list.
Definition TList.h:38
TObject * At(Int_t idx) const override
Returns the object at position idx. Returns 0 if idx is out of range.
Definition TList.cxx:355
Class that contains all the data information.
Definition DataSetInfo.h:62
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
Definition Event.cxx:367
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition Event.cxx:389
UInt_t GetClass() const
Definition Event.h:86
Float_t GetTarget(UInt_t itgt) const
Definition Event.h:102
Fitter using a Genetic Algorithm.
Interface for a fitter 'target'.
The TMVA::Interval Class.
Definition Interval.h:61
Fitter using Monte Carlo sampling of parameters.
Definition MCFitter.h:44
Virtual base Class for all MVA method.
Definition MethodBase.h:111
Function discriminant analysis (FDA).
Definition MethodFDA.h:61
void Train(void)
FDA training.
void AddWeightsXMLTo(void *parent) const
create XML description for LD classification and regression (for arbitrary number of output classes/t...
Double_t EstimatorFunction(std::vector< Double_t > &)
compute estimator for given parameter set (to be minimised)
virtual ~MethodFDA(void)
destructor
Double_t InterpretFormula(const Event *, std::vector< Double_t >::iterator begin, std::vector< Double_t >::iterator end)
formula interpretation
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
void ReadWeightsFromXML(void *wghtnode)
read coefficients from xml weight file
void CalculateMulticlassValues(const TMVA::Event *&evt, std::vector< Double_t > &parameters, std::vector< Float_t > &values)
calculate the values for multiclass
void ReadWeightsFromStream(std::istream &i)
read back the training results from a file (stream)
virtual const std::vector< Float_t > & GetMulticlassValues()
MethodFDA(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
Definition MethodFDA.cxx:86
void Init(void)
default initialisation
void ClearAll()
delete and clear all class members
void PrintResults(const TString &, std::vector< Double_t > &, const Double_t) const
display fit parameters check maximum length of variable name
void MakeClassSpecific(std::ostream &, const TString &) const
write FDA-specific classifier response
virtual const std::vector< Float_t > & GetRegressionValues()
void ProcessOptions()
the option string is decoded, for available options see "DeclareOptions"
void CreateFormula()
translate formula string into TFormula, and parameter string into par ranges
void DeclareOptions()
define the options (their key words) that can be set in the option string
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
returns MVA value for given event
void GetHelpMessage() const
get help message text
/Fitter using MINUIT
Fitter using a Simulated Annealing Algorithm.
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition Tools.cxx:887
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition Tools.cxx:401
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
Collectable string class.
Definition TObjString.h:28
Basic string class.
Definition TString.h:139
Ssiz_t Length() const
Definition TString.h:417
Ssiz_t First(char c) const
Find first occurrence of a character c.
Definition TString.cxx:538
const char * Data() const
Definition TString.h:376
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
Double_t y[n]
Definition legend1.C:17
double crossEntropy(ItProbability itProbabilityBegin, ItProbability itProbabilityEnd, ItTruth itTruthBegin, ItTruth itTruthEnd, ItDelta itDelta, ItDelta itDeltaEnd, ItInvActFnc itInvActFnc, double patternWeight)
cross entropy error function
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.
Definition TMath.h:721
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123