#include <iomanip>
#include <cassert>
#include "TMath.h"
#include "Riostream.h"
#include "TMVA/VariableTransformBase.h"
#include "TMVA/MethodFisher.h"
#include "TMVA/Tools.h"
#include "TMatrix.h"
#include "TMVA/Ranking.h"
#include "TMVA/Types.h"
#include "TMVA/ClassifierFactory.h"
REGISTER_METHOD(Fisher)
ClassImp(TMVA::MethodFisher);
TMVA::MethodFisher::MethodFisher( const TString& jobName,
const TString& methodTitle,
DataSetInfo& dsi,
const TString& theOption,
TDirectory* theTargetDir ) :
MethodBase( jobName, Types::kFisher, methodTitle, dsi, theOption, theTargetDir ),
fMeanMatx ( 0 ),
fTheMethod ( "Fisher" ),
fFisherMethod ( kFisher ),
fBetw ( 0 ),
fWith ( 0 ),
fCov ( 0 ),
fSumOfWeightsS( 0 ),
fSumOfWeightsB( 0 ),
fDiscrimPow ( 0 ),
fFisherCoeff ( 0 ),
fF0 ( 0 )
{
}
TMVA::MethodFisher::MethodFisher( DataSetInfo& dsi,
const TString& theWeightFile,
TDirectory* theTargetDir ) :
MethodBase( Types::kFisher, dsi, theWeightFile, theTargetDir ),
fMeanMatx ( 0 ),
fTheMethod ( "Fisher" ),
fFisherMethod ( kFisher ),
fBetw ( 0 ),
fWith ( 0 ),
fCov ( 0 ),
fSumOfWeightsS( 0 ),
fSumOfWeightsB( 0 ),
fDiscrimPow ( 0 ),
fFisherCoeff ( 0 ),
fF0 ( 0 )
{
}
void TMVA::MethodFisher::Init( void )
{
fFisherCoeff = new std::vector<Double_t>( GetNvar() );
SetSignalReferenceCut( 0.0 );
InitMatrices();
}
void TMVA::MethodFisher::DeclareOptions()
{
DeclareOptionRef( fTheMethod = "Fisher", "Method", "Discrimination method" );
AddPreDefVal(TString("Fisher"));
AddPreDefVal(TString("Mahalanobis"));
}
void TMVA::MethodFisher::ProcessOptions()
{
if (fTheMethod == "Fisher" ) fFisherMethod = kFisher;
else fFisherMethod = kMahalanobis;
InitMatrices();
}
TMVA::MethodFisher::~MethodFisher( void )
{
if (fBetw ) { delete fBetw; fBetw = 0; }
if (fWith ) { delete fWith; fWith = 0; }
if (fCov ) { delete fCov; fCov = 0; }
if (fDiscrimPow ) { delete fDiscrimPow; fDiscrimPow = 0; }
if (fFisherCoeff) { delete fFisherCoeff; fFisherCoeff = 0; }
}
Bool_t TMVA::MethodFisher::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
{
if (type == Types::kClassification && numberClasses == 2) return kTRUE;
return kFALSE;
}
void TMVA::MethodFisher::Train( void )
{
GetMean();
GetCov_WithinClass();
GetCov_BetweenClass();
GetCov_Full();
GetFisherCoeff();
GetDiscrimPower();
PrintCoefficients();
}
Double_t TMVA::MethodFisher::GetMvaValue( Double_t* err, Double_t* errUpper )
{
const Event * ev = GetEvent();
Double_t result = fF0;
for (UInt_t ivar=0; ivar<GetNvar(); ivar++)
result += (*fFisherCoeff)[ivar]*ev->GetValue(ivar);
NoErrorCalc(err, errUpper);
return result;
}
void TMVA::MethodFisher::InitMatrices( void )
{
fMeanMatx = new TMatrixD( GetNvar(), 3 );
fBetw = new TMatrixD( GetNvar(), GetNvar() );
fWith = new TMatrixD( GetNvar(), GetNvar() );
fCov = new TMatrixD( GetNvar(), GetNvar() );
fDiscrimPow = new std::vector<Double_t>( GetNvar() );
}
void TMVA::MethodFisher::GetMean( void )
{
fSumOfWeightsS = 0;
fSumOfWeightsB = 0;
const UInt_t nvar = DataInfo().GetNVariables();
Double_t* sumS = new Double_t[nvar];
Double_t* sumB = new Double_t[nvar];
for (UInt_t ivar=0; ivar<nvar; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
const Event * ev = GetEvent(ievt);
Double_t weight = GetTWeight(ev);
if (DataInfo().IsSignal(ev)) fSumOfWeightsS += weight;
else fSumOfWeightsB += weight;
Double_t* sum = DataInfo().IsSignal(ev) ? sumS : sumB;
for (UInt_t ivar=0; ivar<nvar; ivar++) sum[ivar] += ev->GetValue( ivar )*weight;
}
for (UInt_t ivar=0; ivar<nvar; ivar++) {
(*fMeanMatx)( ivar, 2 ) = sumS[ivar];
(*fMeanMatx)( ivar, 0 ) = sumS[ivar]/fSumOfWeightsS;
(*fMeanMatx)( ivar, 2 ) += sumB[ivar];
(*fMeanMatx)( ivar, 1 ) = sumB[ivar]/fSumOfWeightsB;
(*fMeanMatx)( ivar, 2 ) /= (fSumOfWeightsS + fSumOfWeightsB);
}
delete [] sumS;
delete [] sumB;
}
void TMVA::MethodFisher::GetCov_WithinClass( void )
{
assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0 );
const Int_t nvar = GetNvar();
const Int_t nvar2 = nvar*nvar;
Double_t *sumSig = new Double_t[nvar2];
Double_t *sumBgd = new Double_t[nvar2];
Double_t *xval = new Double_t[nvar];
memset(sumSig,0,nvar2*sizeof(Double_t));
memset(sumBgd,0,nvar2*sizeof(Double_t));
for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
const Event* ev = GetEvent(ievt);
Double_t weight = GetTWeight(ev);
for (Int_t x=0; x<nvar; x++) xval[x] = ev->GetValue( x );
Int_t k=0;
for (Int_t x=0; x<nvar; x++) {
for (Int_t y=0; y<nvar; y++) {
Double_t v = ( (xval[x] - (*fMeanMatx)(x, 0))*(xval[y] - (*fMeanMatx)(y, 0)) )*weight;
if (DataInfo().IsSignal(ev)) sumSig[k] += v;
else sumBgd[k] += v;
k++;
}
}
}
Int_t k=0;
for (Int_t x=0; x<nvar; x++) {
for (Int_t y=0; y<nvar; y++) {
(*fWith)(x, y) = (sumSig[k] + sumBgd[k])/(fSumOfWeightsS + fSumOfWeightsB);
k++;
}
}
delete [] sumSig;
delete [] sumBgd;
delete [] xval;
}
void TMVA::MethodFisher::GetCov_BetweenClass( void )
{
assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0);
Double_t prodSig, prodBgd;
for (UInt_t x=0; x<GetNvar(); x++) {
for (UInt_t y=0; y<GetNvar(); y++) {
prodSig = ( ((*fMeanMatx)(x, 0) - (*fMeanMatx)(x, 2))*
((*fMeanMatx)(y, 0) - (*fMeanMatx)(y, 2)) );
prodBgd = ( ((*fMeanMatx)(x, 1) - (*fMeanMatx)(x, 2))*
((*fMeanMatx)(y, 1) - (*fMeanMatx)(y, 2)) );
(*fBetw)(x, y) = (fSumOfWeightsS*prodSig + fSumOfWeightsB*prodBgd) / (fSumOfWeightsS + fSumOfWeightsB);
}
}
}
void TMVA::MethodFisher::GetCov_Full( void )
{
for (UInt_t x=0; x<GetNvar(); x++)
for (UInt_t y=0; y<GetNvar(); y++)
(*fCov)(x, y) = (*fWith)(x, y) + (*fBetw)(x, y);
}
void TMVA::MethodFisher::GetFisherCoeff( void )
{
assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0);
TMatrixD* theMat = 0;
switch (GetFisherMethod()) {
case kFisher:
theMat = fWith;
break;
case kMahalanobis:
theMat = fCov;
break;
default:
Log() << kFATAL << "<GetFisherCoeff> undefined method" << GetFisherMethod() << Endl;
}
TMatrixD invCov( *theMat );
if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
Log() << kWARNING << "<GetFisherCoeff> matrix is almost singular with deterninant="
<< TMath::Abs(invCov.Determinant())
<< " did you use the variables that are linear combinations or highly correlated?"
<< Endl;
}
if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
Log() << kFATAL << "<GetFisherCoeff> matrix is singular with determinant="
<< TMath::Abs(invCov.Determinant())
<< " did you use the variables that are linear combinations?"
<< Endl;
}
invCov.Invert();
Double_t xfact = TMath::Sqrt( fSumOfWeightsS*fSumOfWeightsB ) / (fSumOfWeightsS + fSumOfWeightsB);
std::vector<Double_t> diffMeans( GetNvar() );
UInt_t ivar, jvar;
for (ivar=0; ivar<GetNvar(); ivar++) {
(*fFisherCoeff)[ivar] = 0;
for (jvar=0; jvar<GetNvar(); jvar++) {
Double_t d = (*fMeanMatx)(jvar, 0) - (*fMeanMatx)(jvar, 1);
(*fFisherCoeff)[ivar] += invCov(ivar, jvar)*d;
}
(*fFisherCoeff)[ivar] *= xfact;
}
fF0 = 0.0;
for (ivar=0; ivar<GetNvar(); ivar++){
fF0 += (*fFisherCoeff)[ivar]*((*fMeanMatx)(ivar, 0) + (*fMeanMatx)(ivar, 1));
}
fF0 /= -2.0;
}
void TMVA::MethodFisher::GetDiscrimPower( void )
{
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
if ((*fCov)(ivar, ivar) != 0)
(*fDiscrimPow)[ivar] = (*fBetw)(ivar, ivar)/(*fCov)(ivar, ivar);
else
(*fDiscrimPow)[ivar] = 0;
}
}
const TMVA::Ranking* TMVA::MethodFisher::CreateRanking()
{
fRanking = new Ranking( GetName(), "Discr. power" );
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( Rank( GetInputLabel(ivar), (*fDiscrimPow)[ivar] ) );
}
return fRanking;
}
void TMVA::MethodFisher::PrintCoefficients( void )
{
Log() << kINFO << "Results for Fisher coefficients:" << Endl;
if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
Log() << kINFO << "NOTE: The coefficients must be applied to TRANFORMED variables" << Endl;
Log() << kINFO << " List of the transformation: " << Endl;
TListIter trIt(&GetTransformationHandler().GetTransformationList());
while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) {
Log() << kINFO << " -- " << trf->GetName() << Endl;
}
}
std::vector<TString> vars;
std::vector<Double_t> coeffs;
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
vars .push_back( GetInputLabel(ivar) );
coeffs.push_back( (*fFisherCoeff)[ivar] );
}
vars .push_back( "(offset)" );
coeffs.push_back( fF0 );
TMVA::gTools().FormattedOutput( coeffs, vars, "Variable" , "Coefficient", Log() );
if (IsNormalised()) {
Log() << kINFO << "NOTE: You have chosen to use the \"Normalise\" booking option. Hence, the" << Endl;
Log() << kINFO << " coefficients must be applied to NORMALISED (') variables as follows:" << Endl;
Int_t maxL = 0;
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) if (GetInputLabel(ivar).Length() > maxL) maxL = GetInputLabel(ivar).Length();
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
Log() << kINFO
<< setw(maxL+9) << TString("[") + GetInputLabel(ivar) + "]' = 2*("
<< setw(maxL+2) << TString("[") + GetInputLabel(ivar) + "]"
<< setw(3) << (GetXmin(ivar) > 0 ? " - " : " + ")
<< setw(6) << TMath::Abs(GetXmin(ivar)) << setw(3) << ")/"
<< setw(6) << (GetXmax(ivar) - GetXmin(ivar) )
<< setw(3) << " - 1"
<< Endl;
}
Log() << kINFO << "The TMVA Reader will properly account for this normalisation, but if the" << Endl;
Log() << kINFO << "Fisher classifier is applied outside the Reader, the transformation must be" << Endl;
Log() << kINFO << "implemented -- or the \"Normalise\" option is removed and Fisher retrained." << Endl;
Log() << kINFO << Endl;
}
}
void TMVA::MethodFisher::ReadWeightsFromStream( istream& istr )
{
istr >> fF0;
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) istr >> (*fFisherCoeff)[ivar];
}
void TMVA::MethodFisher::AddWeightsXMLTo( void* parent ) const
{
void* wght = gTools().AddChild(parent, "Weights");
gTools().AddAttr( wght, "NCoeff", GetNvar()+1 );
void* coeffxml = gTools().AddChild(wght, "Coefficient");
gTools().AddAttr( coeffxml, "Index", 0 );
gTools().AddAttr( coeffxml, "Value", fF0 );
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
coeffxml = gTools().AddChild( wght, "Coefficient" );
gTools().AddAttr( coeffxml, "Index", ivar+1 );
gTools().AddAttr( coeffxml, "Value", (*fFisherCoeff)[ivar] );
}
}
void TMVA::MethodFisher::ReadWeightsFromXML( void* wghtnode )
{
UInt_t ncoeff, coeffidx;
gTools().ReadAttr( wghtnode, "NCoeff", ncoeff );
fFisherCoeff->resize(ncoeff-1);
void* ch = gTools().GetChild(wghtnode);
Double_t coeff;
while (ch) {
gTools().ReadAttr( ch, "Index", coeffidx );
gTools().ReadAttr( ch, "Value", coeff );
if (coeffidx==0) fF0 = coeff;
else (*fFisherCoeff)[coeffidx-1] = coeff;
ch = gTools().GetNextChild(ch);
}
}
void TMVA::MethodFisher::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
Int_t dp = fout.precision();
fout << " double fFisher0;" << endl;
fout << " std::vector<double> fFisherCoefficients;" << endl;
fout << "};" << endl;
fout << "" << endl;
fout << "inline void " << className << "::Initialize() " << endl;
fout << "{" << endl;
fout << " fFisher0 = " << std::setprecision(12) << fF0 << ";" << endl;
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
fout << " fFisherCoefficients.push_back( " << std::setprecision(12) << (*fFisherCoeff)[ivar] << " );" << endl;
}
fout << endl;
fout << " // sanity check" << endl;
fout << " if (fFisherCoefficients.size() != fNvars) {" << endl;
fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\"::Initialize: mismatch in number of input values\"" << endl;
fout << " << fFisherCoefficients.size() << \" != \" << fNvars << std::endl;" << endl;
fout << " fStatusIsClean = false;" << endl;
fout << " } " << endl;
fout << "}" << endl;
fout << endl;
fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " double retval = fFisher0;" << endl;
fout << " for (size_t ivar = 0; ivar < fNvars; ivar++) {" << endl;
fout << " retval += fFisherCoefficients[ivar]*inputValues[ivar];" << endl;
fout << " }" << endl;
fout << endl;
fout << " return retval;" << endl;
fout << "}" << endl;
fout << endl;
fout << "// Clean up" << endl;
fout << "inline void " << className << "::Clear() " << endl;
fout << "{" << endl;
fout << " // clear coefficients" << endl;
fout << " fFisherCoefficients.clear(); " << endl;
fout << "}" << endl;
fout << std::setprecision(dp);
}
void TMVA::MethodFisher::GetHelpMessage() const
{
Log() << Endl;
Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "Fisher discriminants select events by distinguishing the mean " << Endl;
Log() << "values of the signal and background distributions in a trans- " << Endl;
Log() << "formed variable space where linear correlations are removed." << Endl;
Log() << Endl;
Log() << " (More precisely: the \"linear discriminator\" determines" << Endl;
Log() << " an axis in the (correlated) hyperspace of the input " << Endl;
Log() << " variables such that, when projecting the output classes " << Endl;
Log() << " (signal and background) upon this axis, they are pushed " << Endl;
Log() << " as far as possible away from each other, while events" << Endl;
Log() << " of a same class are confined in a close vicinity. The " << Endl;
Log() << " linearity property of this classifier is reflected in the " << Endl;
Log() << " metric with which \"far apart\" and \"close vicinity\" are " << Endl;
Log() << " determined: the covariance matrix of the discriminating" << Endl;
Log() << " variable space.)" << Endl;
Log() << Endl;
Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "Optimal performance for Fisher discriminants is obtained for " << Endl;
Log() << "linearly correlated Gaussian-distributed variables. Any deviation" << Endl;
Log() << "from this ideal reduces the achievable separation power. In " << Endl;
Log() << "particular, no discrimination at all is achieved for a variable" << Endl;
Log() << "that has the same sample mean for signal and background, even if " << Endl;
Log() << "the shapes of the distributions are very different. Thus, Fisher " << Endl;
Log() << "discriminants often benefit from suitable transformations of the " << Endl;
Log() << "input variables. For example, if a variable x in [-1,1] has a " << Endl;
Log() << "a parabolic signal distributions, and a uniform background" << Endl;
Log() << "distributions, their mean value is zero in both cases, leading " << Endl;
Log() << "to no separation. The simple transformation x -> |x| renders this " << Endl;
Log() << "variable powerful for the use in a Fisher discriminant." << Endl;
Log() << Endl;
Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "<None>" << Endl;
}