Logo ROOT  
Reference Guide
MethodFDA.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodFDA *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
17 * Maciej Kruk <mkruk@cern.ch> - IFJ PAN & AGH, Poland *
18 * *
19 * Copyright (c) 2005-2006: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::MethodFDA
29\ingroup TMVA
30
31Function discriminant analysis (FDA).
32
33This simple classifier
34fits any user-defined TFormula (via option configuration string) to
35the training data by requiring a formula response of 1 (0) to signal
36(background) events. The parameter fitting is done via the abstract
37class FitterBase, featuring Monte Carlo sampling, Genetic
38Algorithm, Simulated Annealing, MINUIT and combinations of these.
39
40Can compute regression value for one dimensional output
41*/
42
43#include "TMVA/MethodFDA.h"
44
46#include "TMVA/Config.h"
47#include "TMVA/Configurable.h"
48#include "TMVA/DataSetInfo.h"
49#include "TMVA/FitterBase.h"
50#include "TMVA/GeneticFitter.h"
51#include "TMVA/Interval.h"
52#include "TMVA/IFitterTarget.h"
53#include "TMVA/IMethod.h"
54#include "TMVA/MCFitter.h"
55#include "TMVA/MethodBase.h"
56#include "TMVA/MinuitFitter.h"
57#include "TMVA/MsgLogger.h"
58#include "TMVA/Timer.h"
59#include "TMVA/Tools.h"
61#include "TMVA/Types.h"
63
64#include "Riostream.h"
65#include "TList.h"
66#include "TFormula.h"
67#include "TString.h"
68#include "TObjString.h"
69#include "TRandom3.h"
70#include "TMath.h"
71
72#include <algorithm>
73#include <iterator>
74#include <stdexcept>
75#include <sstream>
76
77using std::stringstream;
78
80
82
83////////////////////////////////////////////////////////////////////////////////
84/// standard constructor
85
87 const TString& methodTitle,
88 DataSetInfo& theData,
89 const TString& theOption)
90 : MethodBase( jobName, Types::kFDA, methodTitle, theData, theOption),
92 fFormula ( 0 ),
93 fNPars ( 0 ),
94 fFitter ( 0 ),
95 fConvergerFitter( 0 ),
96 fSumOfWeightsSig( 0 ),
97 fSumOfWeightsBkg( 0 ),
98 fSumOfWeights ( 0 ),
99 fOutputDimensions( 0 )
100{
101}
102
103////////////////////////////////////////////////////////////////////////////////
104/// constructor from weight file
105
107 const TString& theWeightFile)
108 : MethodBase( Types::kFDA, theData, theWeightFile),
109 IFitterTarget (),
110 fFormula ( 0 ),
111 fNPars ( 0 ),
112 fFitter ( 0 ),
113 fConvergerFitter( 0 ),
114 fSumOfWeightsSig( 0 ),
115 fSumOfWeightsBkg( 0 ),
116 fSumOfWeights ( 0 ),
117 fOutputDimensions( 0 )
118{
119}
120
121////////////////////////////////////////////////////////////////////////////////
122/// default initialisation
123
125{
126 fNPars = 0;
127
128 fBestPars.clear();
129
130 fSumOfWeights = 0;
131 fSumOfWeightsSig = 0;
132 fSumOfWeightsBkg = 0;
133
134 fFormulaStringP = "";
135 fParRangeStringP = "";
136 fFormulaStringT = "";
137 fParRangeStringT = "";
138
139 fFitMethod = "";
140 fConverger = "";
141
142 if( DoMulticlass() )
143 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal = new std::vector<Float_t>();
144
145}
146
147////////////////////////////////////////////////////////////////////////////////
148/// define the options (their key words) that can be set in the option string
149///
150/// format of function string:
151///
152/// "x0*(0)+((1)/x1)**(2)..."
153///
154/// where "[i]" are the parameters, and "xi" the input variables
155///
156/// format of parameter string:
157///
158/// "(-1.2,3.4);(-2.3,4.55);..."
159///
160/// where the numbers in "(a,b)" correspond to the a=min, b=max parameter ranges;
161/// each parameter defined in the function string must have a corresponding range
162
164{
165 DeclareOptionRef( fFormulaStringP = "(0)", "Formula", "The discrimination formula" );
166 DeclareOptionRef( fParRangeStringP = "()", "ParRanges", "Parameter ranges" );
167
168 // fitter
169 DeclareOptionRef( fFitMethod = "MINUIT", "FitMethod", "Optimisation Method");
170 AddPreDefVal(TString("MC"));
171 AddPreDefVal(TString("GA"));
172 AddPreDefVal(TString("SA"));
173 AddPreDefVal(TString("MINUIT"));
174
175 DeclareOptionRef( fConverger = "None", "Converger", "FitMethod uses Converger to improve result");
176 AddPreDefVal(TString("None"));
177 AddPreDefVal(TString("MINUIT"));
178}
179
180////////////////////////////////////////////////////////////////////////////////
181/// translate formula string into TFormula, and parameter string into par ranges
182
184{
185 // process transient strings
186 fFormulaStringT = fFormulaStringP;
187
188 // interpret formula string
189
190 // replace the parameters "(i)" by the TFormula style "[i]"
191 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
192 fFormulaStringT.ReplaceAll( Form("(%i)",ipar), Form("[%i]",ipar) );
193 }
194
195 // sanity check, there should be no "(i)", with 'i' a number anymore
196 for (Int_t ipar=fNPars; ipar<1000; ipar++) {
197 if (fFormulaStringT.Contains( Form("(%i)",ipar) ))
198 Log() << kFATAL
199 << "<CreateFormula> Formula contains expression: \"" << Form("(%i)",ipar) << "\", "
200 << "which cannot be attributed to a parameter; "
201 << "it may be that the number of variable ranges given via \"ParRanges\" "
202 << "does not match the number of parameters in the formula expression, please verify!"
203 << Endl;
204 }
205
206 // write the variables "xi" as additional parameters "[npar+i]"
207 for (Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
208 fFormulaStringT.ReplaceAll( Form("x%i",ivar), Form("[%i]",ivar+fNPars) );
209 }
210
211 // sanity check, there should be no "xi", with 'i' a number anymore
212 for (UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
213 if (fFormulaStringT.Contains( Form("x%i",ivar) ))
214 Log() << kFATAL
215 << "<CreateFormula> Formula contains expression: \"" << Form("x%i",ivar) << "\", "
216 << "which cannot be attributed to an input variable" << Endl;
217 }
218
219 Log() << "User-defined formula string : \"" << fFormulaStringP << "\"" << Endl;
220 Log() << "TFormula-compatible formula string: \"" << fFormulaStringT << "\"" << Endl;
221 Log() << kDEBUG << "Creating and compiling formula" << Endl;
222
223 // create TF1
224 if (fFormula) delete fFormula;
225 fFormula = new TFormula( "FDA_Formula", fFormulaStringT );
226
227 // is formula correct ?
228 if (!fFormula->IsValid())
229 Log() << kFATAL << "<ProcessOptions> Formula expression could not be properly compiled" << Endl;
230
231 // other sanity checks
232 if (fFormula->GetNpar() > (Int_t)(fNPars + GetNvar()))
233 Log() << kFATAL << "<ProcessOptions> Dubious number of parameters in formula expression: "
234 << fFormula->GetNpar() << " - compared to maximum allowed: " << fNPars + GetNvar() << Endl;
235}
236
237////////////////////////////////////////////////////////////////////////////////
238/// the option string is decoded, for available options see "DeclareOptions"
239
241{
242 // process transient strings
243 fParRangeStringT = fParRangeStringP;
244
245 // interpret parameter string
246 fParRangeStringT.ReplaceAll( " ", "" );
247 fNPars = fParRangeStringT.CountChar( ')' );
248
249 TList* parList = gTools().ParseFormatLine( fParRangeStringT, ";" );
250 if ((UInt_t)parList->GetSize() != fNPars) {
251 Log() << kFATAL << "<ProcessOptions> Mismatch in parameter string: "
252 << "the number of parameters: " << fNPars << " != ranges defined: "
253 << parList->GetSize() << "; the format of the \"ParRanges\" string "
254 << "must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
255 << "where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
256 << "each parameter defined in the function string must have a corresponding rang."
257 << Endl;
258 }
259
260 fParRange.resize( fNPars );
261 for (UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
262
263 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
264 // parse (a,b)
265 TString str = ((TObjString*)parList->At(ipar))->GetString();
266 Ssiz_t istr = str.First( ',' );
267 TString pminS(str(1,istr-1));
268 TString pmaxS(str(istr+1,str.Length()-2-istr));
269
270 stringstream stmin; Float_t pmin=0; stmin << pminS.Data(); stmin >> pmin;
271 stringstream stmax; Float_t pmax=0; stmax << pmaxS.Data(); stmax >> pmax;
272
273 // sanity check
274 if (TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
275 if (pmin > pmax) Log() << kFATAL << "<ProcessOptions> max > min in interval for parameter: ["
276 << ipar << "] : [" << pmin << ", " << pmax << "] " << Endl;
277
278 Log() << kINFO << "Create parameter interval for parameter " << ipar << " : [" << pmin << "," << pmax << "]" << Endl;
279 fParRange[ipar] = new Interval( pmin, pmax );
280 }
281 delete parList;
282
283 // create formula
284 CreateFormula();
285
286
287 // copy parameter ranges for each output dimension ==================
288 fOutputDimensions = 1;
289 if( DoRegression() )
290 fOutputDimensions = DataInfo().GetNTargets();
291 if( DoMulticlass() )
292 fOutputDimensions = DataInfo().GetNClasses();
293
294 for( Int_t dim = 1; dim < fOutputDimensions; ++dim ){
295 for( UInt_t par = 0; par < fNPars; ++par ){
296 fParRange.push_back( fParRange.at(par) );
297 }
298 }
299 // ====================
300
301 // create minimiser
302 fConvergerFitter = (IFitterTarget*)this;
303 if (fConverger == "MINUIT") {
304 fConvergerFitter = new MinuitFitter( *this, Form("%s_Converger_Minuit", GetName()), fParRange, GetOptions() );
305 SetOptions(dynamic_cast<Configurable*>(fConvergerFitter)->GetOptions());
306 }
307
308 if(fFitMethod == "MC")
309 fFitter = new MCFitter( *fConvergerFitter, Form("%s_Fitter_MC", GetName()), fParRange, GetOptions() );
310 else if (fFitMethod == "GA")
311 fFitter = new GeneticFitter( *fConvergerFitter, Form("%s_Fitter_GA", GetName()), fParRange, GetOptions() );
312 else if (fFitMethod == "SA")
313 fFitter = new SimulatedAnnealingFitter( *fConvergerFitter, Form("%s_Fitter_SA", GetName()), fParRange, GetOptions() );
314 else if (fFitMethod == "MINUIT")
315 fFitter = new MinuitFitter( *fConvergerFitter, Form("%s_Fitter_Minuit", GetName()), fParRange, GetOptions() );
316 else {
317 Log() << kFATAL << "<Train> Do not understand fit method:" << fFitMethod << Endl;
318 }
319
320 fFitter->CheckForUnusedOptions();
321}
322
323////////////////////////////////////////////////////////////////////////////////
324/// destructor
325
327{
328 ClearAll();
329}
330
331////////////////////////////////////////////////////////////////////////////////
332/// FDA can handle classification with 2 classes and regression with one regression-target
333
335{
336 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
337 if (type == Types::kMulticlass ) return kTRUE;
338 if (type == Types::kRegression ) return kTRUE;
339 return kFALSE;
340}
341
342
343////////////////////////////////////////////////////////////////////////////////
344/// delete and clear all class members
345
347{
348 // if there is more than one output dimension, the paramater ranges are the same again (object has been copied).
349 // hence, ... erase the copied pointers to assure, that they are deleted only once.
350 // fParRange.erase( fParRange.begin()+(fNPars), fParRange.end() );
351 for (UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
352 if (fParRange[ipar] != 0) { delete fParRange[ipar]; fParRange[ipar] = 0; }
353 }
354 fParRange.clear();
355
356 if (fFormula != 0) { delete fFormula; fFormula = 0; }
357 fBestPars.clear();
358}
359
360////////////////////////////////////////////////////////////////////////////////
361/// FDA training
362
364{
365 // cache training events
366 fSumOfWeights = 0;
367 fSumOfWeightsSig = 0;
368 fSumOfWeightsBkg = 0;
369
370 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
371
372 // read the training event
373 const Event* ev = GetEvent(ievt);
374
375 // true event copy
376 Float_t w = ev->GetWeight();
377
378 if (!DoRegression()) {
379 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig += w; }
380 else { fSumOfWeightsBkg += w; }
381 }
382 fSumOfWeights += w;
383 }
384
385 // sanity check
386 if (!DoRegression()) {
387 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
388 Log() << kFATAL << "<Train> Troubles in sum of weights: "
389 << fSumOfWeightsSig << " (S) : " << fSumOfWeightsBkg << " (B)" << Endl;
390 }
391 }
392 else if (fSumOfWeights <= 0) {
393 Log() << kFATAL << "<Train> Troubles in sum of weights: "
394 << fSumOfWeights << Endl;
395 }
396
397 // starting values (not used by all fitters)
398 fBestPars.clear();
399 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); ++parIt) {
400 fBestPars.push_back( (*parIt)->GetMean() );
401 }
402
403 // execute the fit
404 Double_t estimator = fFitter->Run( fBestPars );
405
406 // print results
407 PrintResults( fFitMethod, fBestPars, estimator );
408
409 delete fFitter; fFitter = 0;
410 if (fConvergerFitter!=0 && fConvergerFitter!=(IFitterTarget*)this) {
411 delete fConvergerFitter;
412 fConvergerFitter = 0;
413 }
414 ExitFromTraining();
415}
416
417////////////////////////////////////////////////////////////////////////////////
418/// display fit parameters
419/// check maximum length of variable name
420
421void TMVA::MethodFDA::PrintResults( const TString& fitter, std::vector<Double_t>& pars, const Double_t estimator ) const
422{
423 Log() << kINFO;
424 Log() << kHEADER << "Results for parameter fit using \"" << fitter << "\" fitter:" << Endl;
425 std::vector<TString> parNames;
426 for (UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back( Form("Par(%i)",ipar ) );
427 gTools().FormattedOutput( pars, parNames, "Parameter" , "Fit result", Log(), "%g" );
428 Log() << "Discriminator expression: \"" << fFormulaStringP << "\"" << Endl;
429 Log() << "Value of estimator at minimum: " << estimator << Endl;
430}
431
432////////////////////////////////////////////////////////////////////////////////
433/// compute estimator for given parameter set (to be minimised)
434
435Double_t TMVA::MethodFDA::EstimatorFunction( std::vector<Double_t>& pars )
436{
437 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
438 Double_t estimator[] = { 0, 0, 0 };
439
440 Double_t result, deviation;
441 Double_t desired = 0.0;
442
443 // calculate the deviation from the desired value
444 if( DoRegression() ){
445 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
446 // read the training event
447 const TMVA::Event* ev = GetEvent(ievt);
448
449 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
450 desired = ev->GetTarget( dim );
451 result = InterpretFormula( ev, pars.begin(), pars.end() );
452 deviation = TMath::Power(result - desired, 2);
453 estimator[2] += deviation * ev->GetWeight();
454 }
455 }
456 estimator[2] /= sumOfWeights[2];
457 // return value is sum over normalised signal and background contributions
458 return estimator[2];
459
460 }else if( DoMulticlass() ){
461 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
462 // read the training event
463 const TMVA::Event* ev = GetEvent(ievt);
464
465 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
466
468 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
469 Double_t y = fMulticlassReturnVal->at(dim);
470 Double_t t = (ev->GetClass() == static_cast<UInt_t>(dim) ? 1.0 : 0.0 );
471 crossEntropy += t*log(y);
472 }
473 estimator[2] += ev->GetWeight()*crossEntropy;
474 }
475 estimator[2] /= sumOfWeights[2];
476 // return value is sum over normalised signal and background contributions
477 return estimator[2];
478
479 }else{
480 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
481 // read the training event
482 const TMVA::Event* ev = GetEvent(ievt);
483
484 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
485 result = InterpretFormula( ev, pars.begin(), pars.end() );
486 deviation = TMath::Power(result - desired, 2);
487 estimator[Int_t(desired)] += deviation * ev->GetWeight();
488 }
489 estimator[0] /= sumOfWeights[0];
490 estimator[1] /= sumOfWeights[1];
491 // return value is sum over normalised signal and background contributions
492 return estimator[0] + estimator[1];
493 }
494}
495
496////////////////////////////////////////////////////////////////////////////////
497/// formula interpretation
498
499Double_t TMVA::MethodFDA::InterpretFormula( const Event* event, std::vector<Double_t>::iterator parBegin, std::vector<Double_t>::iterator parEnd )
500{
501 Int_t ipar = 0;
502 // std::cout << "pars ";
503 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
504 // std::cout << " i" << ipar << " val" << (*it);
505 fFormula->SetParameter( ipar, (*it) );
506 ++ipar;
507 }
508 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->GetValue(ivar) );
509
510 Double_t result = fFormula->Eval( 0 );
511 // std::cout << " result " << result << std::endl;
512 return result;
513}
514
515////////////////////////////////////////////////////////////////////////////////
516/// returns MVA value for given event
517
519{
520 const Event* ev = GetEvent();
521
522 // cannot determine error
523 NoErrorCalc(err, errUpper);
524
525 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
526}
527
528////////////////////////////////////////////////////////////////////////////////
529
530const std::vector<Float_t>& TMVA::MethodFDA::GetRegressionValues()
531{
532 if (fRegressionReturnVal == NULL) fRegressionReturnVal = new std::vector<Float_t>();
533 fRegressionReturnVal->clear();
534
535 const Event* ev = GetEvent();
536
537 Event* evT = new Event(*ev);
538
539 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
540 Int_t offset = dim*fNPars;
541 evT->SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+offset, fBestPars.begin()+offset+fNPars ) );
542 }
543 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
544 fRegressionReturnVal->push_back(evT2->GetTarget(0));
545
546 delete evT;
547
548 return (*fRegressionReturnVal);
549}
550
551////////////////////////////////////////////////////////////////////////////////
552
553const std::vector<Float_t>& TMVA::MethodFDA::GetMulticlassValues()
554{
555 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal = new std::vector<Float_t>();
556 fMulticlassReturnVal->clear();
557 std::vector<Float_t> temp;
558
559 // returns MVA value for given event
560 const TMVA::Event* evt = GetEvent();
561
562 CalculateMulticlassValues( evt, fBestPars, temp );
563
564 UInt_t nClasses = DataInfo().GetNClasses();
565 for(UInt_t iClass=0; iClass<nClasses; iClass++){
566 Double_t norm = 0.0;
567 for(UInt_t j=0;j<nClasses;j++){
568 if(iClass!=j)
569 norm+=exp(temp[j]-temp[iClass]);
570 }
571 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
572 }
573
574 return (*fMulticlassReturnVal);
575}
576
577
578////////////////////////////////////////////////////////////////////////////////
579/// calculate the values for multiclass
580
581void TMVA::MethodFDA::CalculateMulticlassValues( const TMVA::Event*& evt, std::vector<Double_t>& parameters, std::vector<Float_t>& values)
582{
583 values.clear();
584
585 // std::copy( parameters.begin(), parameters.end(), std::ostream_iterator<double>( std::cout, " " ) );
586 // std::cout << std::endl;
587
588 // char inp;
589 // std::cin >> inp;
590
591 Double_t sum=0;
592 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){ // check for all other dimensions (=classes)
593 Int_t offset = dim*fNPars;
594 Double_t value = InterpretFormula( evt, parameters.begin()+offset, parameters.begin()+offset+fNPars );
595 // std::cout << "dim : " << dim << " value " << value << " offset " << offset << std::endl;
596 values.push_back( value );
597 sum += value;
598 }
599}
600
601////////////////////////////////////////////////////////////////////////////////
602/// read back the training results from a file (stream)
603
605{
606 // retrieve best function parameters
607 // coverity[tainted_data_argument]
608 istr >> fNPars;
609
610 fBestPars.clear();
611 fBestPars.resize( fNPars );
612 for (UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
613}
614
615////////////////////////////////////////////////////////////////////////////////
616/// create XML description for LD classification and regression
617/// (for arbitrary number of output classes/targets)
618
619void TMVA::MethodFDA::AddWeightsXMLTo( void* parent ) const
620{
621 void* wght = gTools().AddChild(parent, "Weights");
622 gTools().AddAttr( wght, "NPars", fNPars );
623 gTools().AddAttr( wght, "NDim", fOutputDimensions );
624 for (UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
625 void* coeffxml = gTools().AddChild( wght, "Parameter" );
626 gTools().AddAttr( coeffxml, "Index", ipar );
627 gTools().AddAttr( coeffxml, "Value", fBestPars[ipar] );
628 }
629
630 // write formula
631 gTools().AddAttr( wght, "Formula", fFormulaStringP );
632}
633
634////////////////////////////////////////////////////////////////////////////////
635/// read coefficients from xml weight file
636
638{
639 gTools().ReadAttr( wghtnode, "NPars", fNPars );
640
641 if(gTools().HasAttr( wghtnode, "NDim")) {
642 gTools().ReadAttr( wghtnode, "NDim" , fOutputDimensions );
643 } else {
644 // older weight files don't have this attribute
645 fOutputDimensions = 1;
646 }
647
648 fBestPars.clear();
649 fBestPars.resize( fNPars*fOutputDimensions );
650
651 void* ch = gTools().GetChild(wghtnode);
652 Double_t par;
653 UInt_t ipar;
654 while (ch) {
655 gTools().ReadAttr( ch, "Index", ipar );
656 gTools().ReadAttr( ch, "Value", par );
657
658 // sanity check
659 if (ipar >= fNPars*fOutputDimensions) Log() << kFATAL << "<ReadWeightsFromXML> index out of range: "
660 << ipar << " >= " << fNPars << Endl;
661 fBestPars[ipar] = par;
662
663 ch = gTools().GetNextChild(ch);
664 }
665
666 // read formula
667 gTools().ReadAttr( wghtnode, "Formula", fFormulaStringP );
668
669 // create the TFormula
670 CreateFormula();
671}
672
673////////////////////////////////////////////////////////////////////////////////
674/// write FDA-specific classifier response
675
676void TMVA::MethodFDA::MakeClassSpecific( std::ostream& fout, const TString& className ) const
677{
678 fout << " double fParameter[" << fNPars << "];" << std::endl;
679 fout << "};" << std::endl;
680 fout << "" << std::endl;
681 fout << "inline void " << className << "::Initialize() " << std::endl;
682 fout << "{" << std::endl;
683 for(UInt_t ipar=0; ipar<fNPars; ipar++) {
684 fout << " fParameter[" << ipar << "] = " << fBestPars[ipar] << ";" << std::endl;
685 }
686 fout << "}" << std::endl;
687 fout << std::endl;
688 fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
689 fout << "{" << std::endl;
690 fout << " // interpret the formula" << std::endl;
691
692 // replace parameters
693 TString str = fFormulaStringT;
694 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
695 str.ReplaceAll( Form("[%i]", ipar), Form("fParameter[%i]", ipar) );
696 }
697
698 // replace input variables
699 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
700 str.ReplaceAll( Form("[%i]", ivar+fNPars), Form("inputValues[%i]", ivar) );
701 }
702
703 fout << " double retval = " << str << ";" << std::endl;
704 fout << std::endl;
705 fout << " return retval; " << std::endl;
706 fout << "}" << std::endl;
707 fout << std::endl;
708 fout << "// Clean up" << std::endl;
709 fout << "inline void " << className << "::Clear() " << std::endl;
710 fout << "{" << std::endl;
711 fout << " // nothing to clear" << std::endl;
712 fout << "}" << std::endl;
713}
714
715////////////////////////////////////////////////////////////////////////////////
716/// get help message text
717///
718/// typical length of text line:
719/// "|--------------------------------------------------------------|"
720
722{
723 Log() << Endl;
724 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
725 Log() << Endl;
726 Log() << "The function discriminant analysis (FDA) is a classifier suitable " << Endl;
727 Log() << "to solve linear or simple nonlinear discrimination problems." << Endl;
728 Log() << Endl;
729 Log() << "The user provides the desired function with adjustable parameters" << Endl;
730 Log() << "via the configuration option string, and FDA fits the parameters to" << Endl;
731 Log() << "it, requiring the signal (background) function value to be as close" << Endl;
732 Log() << "as possible to 1 (0). Its advantage over the more involved and" << Endl;
733 Log() << "automatic nonlinear discriminators is the simplicity and transparency " << Endl;
734 Log() << "of the discrimination expression. A shortcoming is that FDA will" << Endl;
735 Log() << "underperform for involved problems with complicated, phase space" << Endl;
736 Log() << "dependent nonlinear correlations." << Endl;
737 Log() << Endl;
738 Log() << "Please consult the Users Guide for the format of the formula string" << Endl;
739 Log() << "and the allowed parameter ranges:" << Endl;
740 if (gConfig().WriteOptionsReference()) {
741 Log() << "<a href=\"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf\">"
742 << "http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf</a>" << Endl;
743 }
744 else Log() << "http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf" << Endl;
745 Log() << Endl;
746 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
747 Log() << Endl;
748 Log() << "The FDA performance depends on the complexity and fidelity of the" << Endl;
749 Log() << "user-defined discriminator function. As a general rule, it should" << Endl;
750 Log() << "be able to reproduce the discrimination power of any linear" << Endl;
751 Log() << "discriminant analysis. To reach into the nonlinear domain, it is" << Endl;
752 Log() << "useful to inspect the correlation profiles of the input variables," << Endl;
753 Log() << "and add quadratic and higher polynomial terms between variables as" << Endl;
754 Log() << "necessary. Comparison with more involved nonlinear classifiers can" << Endl;
755 Log() << "be used as a guide." << Endl;
756 Log() << Endl;
757 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
758 Log() << Endl;
759 Log() << "Depending on the function used, the choice of \"FitMethod\" is" << Endl;
760 Log() << "crucial for getting valuable solutions with FDA. As a guideline it" << Endl;
761 Log() << "is recommended to start with \"FitMethod=MINUIT\". When more complex" << Endl;
762 Log() << "functions are used where MINUIT does not converge to reasonable" << Endl;
763 Log() << "results, the user should switch to non-gradient FitMethods such" << Endl;
764 Log() << "as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" << Endl;
765 Log() << "useful to combine GA (or MC) with MINUIT by setting the option" << Endl;
766 Log() << "\"Converger=MINUIT\". GA (MC) will then set the starting parameters" << Endl;
767 Log() << "for MINUIT such that the basic quality of GA (MC) of finding global" << Endl;
768 Log() << "minima is combined with the efficacy of MINUIT of finding local" << Endl;
769 Log() << "minima." << Endl;
770}
#define REGISTER_METHOD(CLASS)
for example
int Int_t
Definition: RtypesCore.h:43
const Bool_t kFALSE
Definition: RtypesCore.h:90
double Double_t
Definition: RtypesCore.h:57
float Float_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
int type
Definition: TGX11.cxx:120
double exp(double)
double log(double)
char * Form(const char *fmt,...)
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:182
The Formula class.
Definition: TFormula.h:84
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:356
Class that contains all the data information.
Definition: DataSetInfo.h:60
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
Definition: Event.cxx:359
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
UInt_t GetClass() const
Definition: Event.h:86
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
Fitter using a Genetic Algorithm.
Definition: GeneticFitter.h:43
Interface for a fitter 'target'.
Definition: IFitterTarget.h:44
The TMVA::Interval Class.
Definition: Interval.h:61
Fitter using Monte Carlo sampling of parameters.
Definition: MCFitter.h:43
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
Function discriminant analysis (FDA).
Definition: MethodFDA.h:60
void Train(void)
FDA training.
Definition: MethodFDA.cxx:363
void AddWeightsXMLTo(void *parent) const
create XML description for LD classification and regression (for arbitrary number of output classes/t...
Definition: MethodFDA.cxx:619
Double_t EstimatorFunction(std::vector< Double_t > &)
compute estimator for given parameter set (to be minimised)
Definition: MethodFDA.cxx:435
virtual ~MethodFDA(void)
destructor
Definition: MethodFDA.cxx:326
Double_t InterpretFormula(const Event *, std::vector< Double_t >::iterator begin, std::vector< Double_t >::iterator end)
formula interpretation
Definition: MethodFDA.cxx:499
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Definition: MethodFDA.cxx:334
void ReadWeightsFromXML(void *wghtnode)
read coefficients from xml weight file
Definition: MethodFDA.cxx:637
void CalculateMulticlassValues(const TMVA::Event *&evt, std::vector< Double_t > &parameters, std::vector< Float_t > &values)
calculate the values for multiclass
Definition: MethodFDA.cxx:581
void ReadWeightsFromStream(std::istream &i)
read back the training results from a file (stream)
Definition: MethodFDA.cxx:604
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition: MethodFDA.cxx:553
MethodFDA(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
Definition: MethodFDA.cxx:86
void Init(void)
default initialisation
Definition: MethodFDA.cxx:124
void ClearAll()
delete and clear all class members
Definition: MethodFDA.cxx:346
void PrintResults(const TString &, std::vector< Double_t > &, const Double_t) const
display fit parameters check maximum length of variable name
Definition: MethodFDA.cxx:421
void MakeClassSpecific(std::ostream &, const TString &) const
write FDA-specific classifier response
Definition: MethodFDA.cxx:676
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
Definition: MethodFDA.cxx:518
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodFDA.cxx:530
void ProcessOptions()
the option string is decoded, for available options see "DeclareOptions"
Definition: MethodFDA.cxx:240
void CreateFormula()
translate formula string into TFormula, and parameter string into par ranges
Definition: MethodFDA.cxx:183
void DeclareOptions()
define the options (their key words) that can be set in the option string
Definition: MethodFDA.cxx:163
void GetHelpMessage() const
get help message text
Definition: MethodFDA.cxx:721
/Fitter using MINUIT
Definition: MinuitFitter.h:47
Fitter using a Simulated Annealing Algorithm.
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:898
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:412
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
Singleton class for Global types used by TMVA.
Definition: Types.h:73
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kClassification
Definition: Types.h:128
@ kRegression
Definition: Types.h:129
Collectable string class.
Definition: TObjString.h:28
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
Ssiz_t First(char c) const
Find first occurrence of a character c.
Definition: TString.cxx:499
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
Double_t y[n]
Definition: legend1.C:17
double crossEntropy(ItProbability itProbabilityBegin, ItProbability itProbabilityEnd, ItTruth itTruthBegin, ItTruth itTruthEnd, ItDelta itDelta, ItDelta itDeltaEnd, ItInvActFnc itInvActFnc, double patternWeight)
cross entropy error function
Definition: NeuralNet.icc:412
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition: TMath.h:725
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
static long int sum(long int i)
Definition: Factory.cxx:2275