* *
**********************************************************************************/
// Begin_Html
/*
Fisher and Mahalanobis Discriminants (Linear Discriminant Analysis)
<p>
In the method of Fisher discriminants event selection is performed
in a transformed variable space with zero linear correlations, by
distinguishing the mean values of the signal and background
distributions.<br></p>
<p>
The linear discriminant analysis determines an axis in the (correlated)
hyperspace of the input variables
such that, when projecting the output classes (signal and background)
upon this axis, they are pushed as far as possible away from each other,
while events of a same class are confined in a close vicinity.
The linearity property of this method is reflected in the metric with
which "far apart" and "close vicinity" are determined: the covariance
matrix of the discriminant variable space.
</p>
<p>
The classification of the events in signal and background classes
relies on the following characteristics (only): overall sample means,
<i><my:o>x</my:o><sub>i</sub></i>, for each input variable, <i>i</i>,
class-specific sample means, <i><my:o>x</my:o><sub>S(B),i</sub></i>,
and total covariance matrix <i>T<sub>ij</sub></i>. The covariance matrix
can be decomposed into the sum of a <i>within-</i> (<i>W<sub>ij</sub></i>)
and a <i>between-class</i> (<i>B<sub>ij</sub></i>) class matrix. They describe
the dispersion of events relative to the means of their own class (within-class
matrix), and relative to the overall sample means (between-class matrix).
The Fisher coefficients, <i>F<sub>i</sub></i>, are then given by <br>
<center>
<img vspace=6 src="gif/tmva_fisherC.gif" align="bottom" >
</center>
where in TMVA is set <i>N<sub>S</sub>=N<sub>B</sub></i>, so that the factor
in front of the sum simplifies to ½.
The Fisher discriminant then reads<br>
<center>
<img vspace=6 src="gif/tmva_fisherD.gif" align="bottom" >
</center>
The offset <i>F</i><sub>0</sub> centers the sample mean of <i>x</i><sub>Fi</sub>
at zero. Instead of using the within-class matrix, the Mahalanobis variant
determines the Fisher coefficients as follows:<br>
<center>
<img vspace=6 src="gif/tmva_mahaC.gif" align="bottom" >
</center>
with resulting <i>x</i><sub>Ma</sub> that are very similar to the
<i>x</i><sub>Fi</sub>. <br></p>
TMVA provides two outputs for the ranking of the input variables:<br><p></p>
<ul>
<li> <u>Fisher test:</u> the Fisher analysis aims at simultaneously maximising
the between-class separation, while minimising the within-class dispersion.
A useful measure of the discrimination power of a variable is hence given
by the diagonal quantity: <i>B<sub>ii</sub>/W<sub>ii</sub></i>.
</li>
<li> <u>Discrimination power:</u> the value of the Fisher coefficient is a
measure of the discriminating power of a variable. The discrimination power
of set of input variables can therefore be measured by the scalar
<center>
<img vspace=6 src="gif/tmva_discpower.gif" align="bottom" >
</center>
</li>
</ul>
The corresponding numbers are printed on standard output.
*/
// End_Html
#include "TMVA/MethodFisher.h"
#include "TMVA/Tools.h"
#include "TMatrix.h"
#include "Riostream.h"
#include <algorithm>
ClassImp(TMVA::MethodFisher)
TMVA::MethodFisher::MethodFisher( TString jobName, vector<TString>* theVariables,
TTree* theTree, TString theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, theVariables, theTree, theOption, theTargetDir )
{
InitFisher();
if (fOptions.Sizeof()<2) {
fOptions = "Fisher";
fOptions = "Fisher ";
cout << "--- " << GetName() << ": using default options= "<< fOptions <<endl;
}
fOptions.ToLower();
if (fOptions.Contains( "fi" )) fFisherMethod = kFisher;
else if (fOptions.Contains( "ma" )) fFisherMethod = kMahalanobis;
else {
cout << "--- " << GetName() << ": Error: unrecognized option string: "
<< GetOptions() << " | " << fOptions
<< " --> exit(1)" << endl;
exit(1);
}
if (0 != fTrainingTree) {
if (fTrainingTree->GetListOfBranches()->GetEntries() - 1 != fNvar) {
cout << "--- " << GetName() << ": Error: mismatch in number of variables: "
<< fTrainingTree->GetListOfBranches()->GetEntries() << " " << fNvar
<< " --> exit(1)" << endl;
exit(1);
}
fNevt = fTrainingTree->GetEntries();
fNsig = 0;
fNbgd = 0;
for (Int_t ievt = 0; ievt < fNevt; ievt++) {
if ((Int_t)TMVA::Tools::GetValue( fTrainingTree, ievt, "type" ) == 1)
++fNsig;
else
++fNbgd;
}
if (fNsig + fNbgd != fNevt) {
cout << "--- " << GetName() << ": Error: mismatch in number of events"
<< " --> exit(1)" << endl;
exit(1);
}
if (Verbose())
cout << "--- " << GetName() << " <verbose>: num of events for training (signal, background): "
<< " (" << fNsig << ", " << fNbgd << ")" << endl;
if (fNsig != fNbgd) {
cout << "--- " << GetName() << ":\t--------------------------------------------------"
<< endl;
cout << "--- " << GetName() << ":\tWarning: different number of signal and background\n"
<< "--- " << GetName() << " \tevents: Fisher training will not be optimal :-("
<< endl;
cout << "--- " << GetName() << ":\t--------------------------------------------------"
<< endl;
}
Init();
}
else {
fNevt = 0;
fNsig = 0;
fNbgd = 0;
}
}
TMVA::MethodFisher::MethodFisher( vector<TString> *theVariables,
TString theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theVariables, theWeightFile, theTargetDir )
{
InitFisher();
}
void TMVA::MethodFisher::InitFisher( void )
{
fMethodName = "Fisher";
fMethod = TMVA::Types::Fisher;
fTestvar = fTestvarPrefix+GetMethodName();
fMeanMatx = 0;
fBetw = 0;
fWith = 0;
fCov = 0;
fNevt = 0;
fNsig = 0;
fNbgd = 0;
fF0 = 0;
fFisherCoeff = new vector<Double_t>( fNvar );
}
TMVA::MethodFisher::~MethodFisher( void )
{
delete fSig;
delete fBgd;
delete fBetw;
delete fWith;
delete fCov;
delete fDiscrimPow;
delete fFisherCoeff;
}
void TMVA::MethodFisher::Train( void )
{
if (!CheckSanity()) {
cout << "--- " << GetName() << ": Error: sanity check failed" << endl;
exit(1);
}
GetMean();
GetCov_WithinClass();
GetCov_BetweenClass();
GetCov_Full();
GetFisherCoeff();
GetDiscrimPower();
PrintCoefficients();
WriteWeightsToFile();
}
Double_t TMVA::MethodFisher::GetMvaValue( TMVA::Event *e )
{
Double_t result = fF0;
for (Int_t ivar=0; ivar<fNvar; ivar++) {
result += (*fFisherCoeff)[ivar]*__N__( e->GetData(ivar) , GetXminNorm( ivar ), GetXmaxNorm( ivar ) );
}
return result;
}
void TMVA::MethodFisher::Init( void )
{
if (0 == fTrainingTree) {
cout << "--- " << GetName() << ": Error in ::Init(): fTrainingTree is zero pointer"
<< " --> exit(1)" << endl;
exit(1);
}
fSig = new TMatrix( fNvar, fNsig );
fBgd = new TMatrix( fNvar, fNbgd );
fMeanMatx = new TMatrixD( fNvar, 3 );
fBetw = new TMatrixD( fNvar, fNvar );
fWith = new TMatrixD( fNvar, fNvar );
fCov = new TMatrixD( fNvar, fNvar );
fDiscrimPow = new vector<Double_t>( fNvar );
Int_t isig = 0, ibgd = 0, ivar;
for (Int_t ievt=0; ievt<fNevt; ievt++) {
if ((Int_t)TMVA::Tools::GetValue( fTrainingTree, ievt, "type" ) == 1) {
for (ivar=0; ivar<fNvar; ivar++) {
Double_t x = TMVA::Tools::GetValue( fTrainingTree, ievt, (*fInputVars)[ivar] );
(*fSig)(ivar, isig) = __N__( x, GetXminNorm( ivar ), GetXmaxNorm( ivar ) );
}
++isig;
}
else {
for (ivar=0; ivar<fNvar; ivar++) {
Double_t x = TMVA::Tools::GetValue( fTrainingTree, ievt, (*fInputVars)[ivar] );
(*fBgd)(ivar, ibgd) = __N__( x, GetXminNorm( ivar ), GetXmaxNorm( ivar ) );
}
++ibgd;
}
}
}
void TMVA::MethodFisher::GetMean( void )
{
for(Int_t ivar=0; ivar<fNvar; ivar++) {
Double_t sum = 0;
for (Int_t ievt=0; ievt<fNsig; ievt++) sum += (*fSig)(ivar, ievt);
(*fMeanMatx)( ivar, 2 ) = sum;
(*fMeanMatx)( ivar, 0 ) = sum/fNsig;
sum = 0;
for (Int_t ievt=0; ievt<fNbgd; ievt++) sum += (*fBgd)(ivar, ievt);
(*fMeanMatx)( ivar, 2 ) += sum;
(*fMeanMatx)( ivar, 1 ) = sum/fNbgd;
(*fMeanMatx)( ivar, 2 ) /= (fNsig + fNbgd);
}
}
void TMVA::MethodFisher::GetCov_WithinClass( void )
{
Double_t prodSig, prodBgd;
Int_t ievt;
for (Int_t x=0; x<fNvar; x++) {
for (Int_t y=0; y<fNvar; y++) {
Double_t sumSig = 0;
Double_t sumBgd = 0;
for (ievt=0; ievt<fNsig; ievt++) {
prodSig = ( ((*fSig)(x, ievt) - (*fMeanMatx)(x, 0))*
((*fSig)(y, ievt) - (*fMeanMatx)(y, 0)) );
sumSig += prodSig;
}
for (ievt=0; ievt<fNbgd; ievt++) {
prodBgd = ( ((*fBgd)(x, ievt) - (*fMeanMatx)(x, 1))*
((*fBgd)(y, ievt) - (*fMeanMatx)(y, 1)) );
sumBgd += prodBgd;
}
(*fWith)(x, y) = (sumSig + sumBgd)/fNevt;
}
}
}
void TMVA::MethodFisher::GetCov_BetweenClass( void )
{
Double_t prodSig, prodBgd;
for (Int_t x=0; x<fNvar; x++) {
for (Int_t y=0; y<fNvar; y++) {
prodSig = ( ((*fMeanMatx)(x, 0) - (*fMeanMatx)(x, 2))*
((*fMeanMatx)(y, 0) - (*fMeanMatx)(y, 2)) );
prodBgd = ( ((*fMeanMatx)(x, 1) - (*fMeanMatx)(x, 2))*
((*fMeanMatx)(y, 1) - (*fMeanMatx)(y, 2)) );
(*fBetw)(x, y) = (fNsig*prodSig + fNbgd*prodBgd)/Double_t(fNevt);
}
}
}
void TMVA::MethodFisher::GetCov_Full( void )
{
for (Int_t x=0; x<fNvar; x++)
for (Int_t y=0; y<fNvar; y++)
(*fCov)(x, y) = (*fWith)(x, y) + (*fBetw)(x, y);
}
void TMVA::MethodFisher::GetFisherCoeff( void )
{
TMatrixD* theMat = 0;
switch (fFisherMethod) {
case kFisher:
theMat = fWith;
break;
case kMahalanobis:
theMat = fCov;
break;
default:
cout << "--- " << GetName() << ": ERROR: undefined method ==> exit(1)" << endl;
exit(1);
}
TMatrixD invCov( *theMat );
invCov.Invert();
Double_t xfact = sqrt(Double_t(fNsig*fNbgd))/Double_t(fNsig + fNbgd);
vector<Double_t> diffMeans( fNvar );
Int_t ivar, jvar;
for (ivar=0; ivar<fNvar; ivar++) {
(*fFisherCoeff)[ivar] = 0;
for(jvar=0; jvar<fNvar; jvar++) {
Double_t d = (*fMeanMatx)(jvar, 0) - (*fMeanMatx)(jvar, 1);
(*fFisherCoeff)[ivar] += invCov(ivar, jvar)*d;
}
(*fFisherCoeff)[ivar] *= xfact;
}
fF0 = 0.0;
for(ivar=0; ivar<fNvar; ivar++)
fF0 += (*fFisherCoeff)[ivar]*((*fMeanMatx)(ivar, 0) + (*fMeanMatx)(ivar, 1));
fF0 /= -2.0;
}
void TMVA::MethodFisher::GetDiscrimPower( void )
{
for (Int_t ivar=0; ivar<fNvar; ivar++)
if ((*fCov)(ivar, ivar) != 0)
(*fDiscrimPow)[ivar] = (*fBetw)(ivar, ivar)/(*fCov)(ivar, ivar);
else
(*fDiscrimPow)[ivar] = 0;
}
void TMVA::MethodFisher::PrintCoefficients( void )
{
Int_t maxL = 0;
vector<Double_t> dp( fNvar );
for (Int_t ivar=0; ivar<fNvar; ivar++) {
if ((*fInputVars)[ivar].Length() > maxL) maxL = (*fInputVars)[ivar].Length();
dp[ivar] = (*fDiscrimPow)[ivar];
}
sort ( dp.begin(), dp.end() );
reverse( dp.begin(), dp.end() );
cout << "--- " << endl;
cout << "--- " << GetName() << ": ranked output (top variable is best ranked)" << endl;
cout << "----------------------------------------------------------------" << endl;
cout << "--- " << setiosflags(ios::left) << setw(maxL+5) << "Variable :"
<< resetiosflags(ios::right)
<< setw(12) << " Coefficient:"
<< " Discr. power:" << endl;
cout << "----------------------------------------------------------------" << endl;
for (Int_t ivar=0; ivar<fNvar; ivar++) {
for (Int_t jvar=0; jvar<fNvar; jvar++) {
if (dp[ivar] == (*fDiscrimPow)[jvar]) {
printf( "--- %-11s: %+.3f %.4f\n",
(const char*)(*fInputVars)[jvar], (*fFisherCoeff)[jvar], (*fDiscrimPow)[jvar]);
}
}
}
printf( "--- %-11s: %+.3f %i\n", "(offset)", fF0, 0 );
cout << "----------------------------------------------------------------" << endl;
cout << "--- " << endl;
}
void TMVA::MethodFisher::WriteWeightsToFile( void )
{
TString fname = GetWeightFileName();
cout << "--- " << GetName() << ": creating weight file: " << fname << endl;
ofstream fout( fname );
if (!fout.good( )) {
cout << "--- " << GetName() << ": Error in ::WriteWeightsToFile: "
<< "unable to open output weight file: " << fname << endl;
exit(1);
}
fout << this->GetMethodName() <<endl;
fout << "NVars= " << fNvar <<endl;
for (Int_t ivar=0; ivar<fNvar; ivar++) {
TString var = (*fInputVars)[ivar];
fout << var << " " << GetXminNorm( var ) << " " << GetXmaxNorm( var ) << endl;
}
fout << fF0 << endl;
for (Int_t ivar=0; ivar<fNvar; ivar++) fout << (*fFisherCoeff)[ivar] << endl;
fout.close();
}
void TMVA::MethodFisher::ReadWeightsFromFile( void )
{
TString fname = GetWeightFileName();
cout << "--- " << GetName() << ": reading weight file: " << fname << endl;
ifstream fin( fname );
if (!fin.good( )) {
cout << "--- " << GetName() << ": Error in ::ReadWeightsFromFile: "
<< "unable to open input file: " << fname << endl;
exit(1);
}
TString var, dummy;
Double_t xmin, xmax;
fin >> dummy;
this->SetMethodName(dummy);
fin >> dummy >> fNvar;
for (Int_t ivar=0; ivar<fNvar; ivar++) {
fin >> var >> xmin >> xmax;
if (var != (*fInputVars)[ivar]) {
cout << "--- " << GetName() << ": Error while reading weight file; "
<< "unknown variable: " << var << " at position: " << ivar << ". "
<< "Expected variable: " << (*fInputVars)[ivar] << " ==> abort" << endl;
exit(1);
}
this->SetXminNorm( ivar, xmin );
this->SetXmaxNorm( ivar, xmax );
}
fin >> fF0;
for (Int_t ivar=0; ivar<fNvar; ivar++) fin >> (*fFisherCoeff)[ivar];
fin.close();
}
void TMVA::MethodFisher::WriteHistosToFile( void )
{
cout << "--- " << GetName() << ": write " << GetName()
<<" special histos to file: " << fBaseDir->GetPath() << endl;
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.