// @(#)root/tmva $Id: TMVA_MethodFisher.cxx,v 1.2 2006/05/09 08:37:06 brun Exp $
// Author: Andreas Hoecker, Xavier Prudent, Helge Voss, Kai Voss 

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate Data analysis       *
 * Package: TMVA                                                                  *
 * Class  : TMVA_MethodFisher                                                     *
 *                                                                                *
 * Description:                                                                   *
 *      Implementation (see header for description)                               *
 *                                                                                *
 * Original author of this Fisher-Discriminant implementation:                    *
 *      Andre Gaidot, CEA-France;                                                 *
 *      (Translation from FORTRAN)                                                *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Xavier Prudent  <prudent@lapp.in2p3.fr>  - LAPP, France                   *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-KP Heidelberg, Germany     *
 *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland,                                                        * 
 *      U. of Victoria, Canada,                                                   * 
 *      MPI-KP Heidelberg, Germany,                                               * 
 *      LAPP, Annecy, France                                                      *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://mva.sourceforge.net/license.txt)                                       *
 *                                                                                *
 **********************************************************************************/
 
//_______________________________________________________________________
//                                                                      
// Analysis of Fisher discriminant (Fisher or Mahalanobis approach)     
//                                                                      
//_______________________________________________________________________

#include "TMVA_MethodFisher.h"
#include "TMVA_Tools.h"
#include "TMatrix.h"
#include "Riostream.h"
#include <algorithm>

#define DEBUG_TMVA_MethodFisher kFALSE

ClassImp(TMVA_MethodFisher)

//_______________________________________________________________________
 TMVA_MethodFisher::TMVA_MethodFisher( TString jobName, vector<TString>* theVariables,  
				      TTree* theTree, TString theOption, TDirectory* theTargetDir )
  : TMVA_MethodBase( jobName, theVariables, theTree, theOption, theTargetDir )
{
  InitFisher();

  if (fOptions.Sizeof()<2) {
    fOptions = "Fisher";
    fOptions = "Fisher ";
    cout << "--- " << GetName() << ": using default options= "<< fOptions <<endl;
  }
  // option string defines "Method" (Fisher, Mahalanobis)
  // add to instance name 
  fOptions.ToLower();
  if      (fOptions.Contains( "fi" )) fFisherMethod = kFisher;
  else if (fOptions.Contains( "ma" )) fFisherMethod = kMahalanobis;
  else {
    cout << "--- " << GetName() << ": Error: unrecognized option string: " 
	 << GetOptions() << " | " << fOptions
	 << " --> exit(1)" << endl;
    exit(1);
  }

  // note that one variable is type
  if (0 != fTrainingTree) {

    // trainingTree should only contain those variables that are used in the MVA
    if (fTrainingTree->GetListOfBranches()->GetEntries() - 1 != fNvar) {
      cout << "--- " << GetName() << ": Error: mismatch in number of variables: " 
	   << fTrainingTree->GetListOfBranches()->GetEntries() << " " << fNvar
	   << " --> exit(1)" << endl;
      exit(1);
    }

    // count number of signal and background events
    fNevt = fTrainingTree->GetEntries();
    fNsig = 0;
    fNbgd = 0;
    for (Int_t ievt = 0; ievt < fNevt; ievt++) {
      if ((Int_t)TMVA_Tools::GetValue( fTrainingTree, ievt, "type" ) == 1) 
	++fNsig;
      else       
	++fNbgd;
    }        

    // numbers of events should match
    if (fNsig + fNbgd != fNevt) {
      cout << "--- " << GetName() << ": Error: mismatch in number of events" 
	   << " --> exit(1)" << endl;
      exit(1);
    }

    if (Verbose())
      cout << "--- " << GetName() << " <verbose>: num of events for training (signal, background): "
	   << " (" << fNsig << ", " << fNbgd << ")" << endl;

    // Fisher wants same number of events in each species
    if (fNsig != fNbgd) {
      cout << "--- " << GetName() << ":\t--------------------------------------------------"
	   << endl;
      cout << "--- " << GetName() << ":\tWarning: different number of signal and background\n"
	   << "--- " << GetName() << " \tevents: Fisher training will not be optimal :-("
	   << endl;
      cout << "--- " << GetName() << ":\t--------------------------------------------------"
	   << endl;
    }      

    // allocate arrays 
    Init();
  }
  else {    
    fNevt = 0;
    fNsig = 0;
    fNbgd = 0;
  }
}

//_______________________________________________________________________
 TMVA_MethodFisher::TMVA_MethodFisher( vector<TString> *theVariables, 
				      TString theWeightFile,  
				      TDirectory* theTargetDir )
  : TMVA_MethodBase( theVariables, theWeightFile, theTargetDir ) 
{
  InitFisher();
}

//_______________________________________________________________________
 void TMVA_MethodFisher::InitFisher( void )
{
  fMethodName  = "Fisher";
  fMethod      = TMVA_Types::Fisher;  
  fTestvar     = fTestvarPrefix+GetMethodName();
  fMeanMatx    = 0; 
  fBetw        = 0;
  fWith        = 0;
  fCov         = 0;

  fNevt        = 0;
  fNsig        = 0;
  fNbgd        = 0;

  // allocate Fisher coefficients
  fF0          = 0;
  fFisherCoeff = new vector<Double_t>( fNvar );
}

//_______________________________________________________________________
 TMVA_MethodFisher::~TMVA_MethodFisher( void )
{
  delete fSig;
  delete fBgd;
  delete fBetw;
  delete fWith;
  delete fCov;
  delete fDiscrimPow;
  delete fFisherCoeff;
}

//_______________________________________________________________________
 void TMVA_MethodFisher::Train( void )
{
  //--------------------------------------------------------------

  // default sanity checks
  if (!CheckSanity()) { 
    cout << "--- " << GetName() << ": Error: sanity check failed" << endl;
    exit(1);
  }
  // get mean value of each variables for signal, backgd and signal+backgd
  GetMean();

  // get the matrix of covariance 'within class'
  GetCov_WithinClass();

  // get the matrix of covariance 'between class'
  GetCov_BetweenClass();

  // get the matrix of covariance 'between class'
  GetCov_Full();
 
  //--------------------------------------------------------------

  // get the Fisher coefficients
  GetFisherCoeff();

  // get the discriminating power of each variables
  GetDiscrimPower();

  // nice output
  PrintCoefficients();

  // write weights to file
  WriteWeightsToFile();
}

//_______________________________________________________________________
 Double_t TMVA_MethodFisher::GetMvaValue( TMVA_Event *e )
{
  Double_t result = fF0;
  for (Int_t ivar=0; ivar<fNvar; ivar++) {
    result += (*fFisherCoeff)[ivar]*__N__( e->GetData(ivar) , GetXminNorm( ivar ), GetXmaxNorm( ivar ) );
  }
  
  return result;
}


//_______________________________________________________________________
 void TMVA_MethodFisher::Init( void )
{
  // should never be called without existing trainingTree
  if (0 == fTrainingTree) {
    cout << "--- " << GetName() << ": Error in ::Init(): fTrainingTree is zero pointer"
	 << " --> exit(1)" << endl;
    exit(1);
  }

  // signal and background LUTs
  fSig = new TMatrixT<float>( fNvar, fNsig );
  fBgd = new TMatrixT<float>( fNvar, fNbgd );
   
  // average value of each variables for S, B, S+B
  fMeanMatx = new TMatrixT<float>( fNvar, 3 );

  // the covariance 'within class' and 'between class' matrices
  fBetw = new TMatrixT<float>( fNvar, fNvar );
  fWith = new TMatrixT<float>( fNvar, fNvar );
  fCov  = new TMatrixT<float>( fNvar, fNvar );

  // discriminating power
  fDiscrimPow = new vector<Double_t>( fNvar );

  // ---- fill LUTs

  Int_t isig = 0, ibgd = 0, ivar;
  for (Int_t ievt=0; ievt<fNevt; ievt++) {

    // separate signal and background events  
    if ((Int_t)TMVA_Tools::GetValue( fTrainingTree, ievt, "type" ) == 1) {
      for (ivar=0; ivar<fNvar; ivar++) {
	Double_t x = TMVA_Tools::GetValue( fTrainingTree, ievt, (*fInputVars)[ivar] );
	(*fSig)(ivar, isig) = __N__( x, GetXminNorm( ivar ), GetXmaxNorm( ivar ) );
      }
      ++isig;
    }
    else {
      for (ivar=0; ivar<fNvar; ivar++) {
	Double_t x = TMVA_Tools::GetValue( fTrainingTree, ievt, (*fInputVars)[ivar] );
	(*fBgd)(ivar, ibgd) = __N__( x, GetXminNorm( ivar ), GetXmaxNorm( ivar ) );
      }
      ++ibgd;
    }
  }  
}

//_______________________________________________________________________
 void TMVA_MethodFisher::GetMean( void )
{
  for(Int_t ivar=0; ivar<fNvar; ivar++) {   

    // signal
    Double_t sum = 0;
    for (Int_t ievt=0; ievt<fNsig; ievt++) sum += (*fSig)(ivar, ievt);
    (*fMeanMatx)( ivar, 2 ) = sum;
    (*fMeanMatx)( ivar, 0 ) = sum/fNsig;
    
    // background
    sum = 0;
    for (Int_t ievt=0; ievt<fNbgd; ievt++) sum += (*fBgd)(ivar, ievt);
    (*fMeanMatx)( ivar, 2 ) += sum;
    (*fMeanMatx)( ivar, 1 ) = sum/fNbgd;       

    // signal + background
    (*fMeanMatx)( ivar, 2 ) /= (fNsig + fNbgd);
  }  
}

//_______________________________________________________________________
 void TMVA_MethodFisher::GetCov_WithinClass( void )
{
  // the matrix of covariance 'within class' reflects the dispersion of the
  // events relative to the center of gravity of their own class  

  // products matrix's (x-<x>)(y-<y>) where x;y are variables
  Double_t prodSig, prodBgd;
  Int_t    ievt;

  // 'within class' covariance
  for (Int_t x=0; x<fNvar; x++) {
    for (Int_t y=0; y<fNvar; y++) {

      Double_t sumSig = 0;
      Double_t sumBgd = 0;

      for (ievt=0; ievt<fNsig; ievt++) {
	prodSig = ( ((*fSig)(x, ievt) - (*fMeanMatx)(x, 0))* 
		    ((*fSig)(y, ievt) - (*fMeanMatx)(y, 0)) );
	sumSig += prodSig;
      }

      for (ievt=0; ievt<fNbgd; ievt++) {
	prodBgd = ( ((*fBgd)(x, ievt) - (*fMeanMatx)(x, 1))* 
		    ((*fBgd)(y, ievt) - (*fMeanMatx)(y, 1)) );
	sumBgd += prodBgd;
      }

      (*fWith)(x, y) = (sumSig + sumBgd)/fNevt;
    }
  }
}

//_______________________________________________________________________
 void TMVA_MethodFisher::GetCov_BetweenClass( void )
{
  // the matrix of covariance 'between class' reflects the dispersion of the
  // events of a class relative to the global center of gravity of all the class
  // hence the separation between classes

  Double_t prodSig, prodBgd;

  for (Int_t x=0; x<fNvar; x++) {
    for (Int_t y=0; y<fNvar; y++) {

      prodSig = ( ((*fMeanMatx)(x, 0) - (*fMeanMatx)(x, 2))*
		  ((*fMeanMatx)(y, 0) - (*fMeanMatx)(y, 2)) );
      prodBgd = ( ((*fMeanMatx)(x, 1) - (*fMeanMatx)(x, 2))*
		  ((*fMeanMatx)(y, 1) - (*fMeanMatx)(y, 2)) );

      (*fBetw)(x, y) = (fNsig*prodSig + fNbgd*prodBgd)/Double_t(fNevt);
    }
  }
}

//_______________________________________________________________________
 void TMVA_MethodFisher::GetCov_Full( void )
{
  for (Int_t x=0; x<fNvar; x++) 
    for (Int_t y=0; y<fNvar; y++) 
      (*fCov)(x, y) = (*fWith)(x, y) + (*fBetw)(x, y);
}

//_______________________________________________________________________
 void TMVA_MethodFisher::GetFisherCoeff( void )
{
  // Fisher = Sum { [coeff]*[variables] }
  //
  // let Xs be the array of the mean values of variables for signal evts
  // let Xb be the array of the mean values of variables for backgd evts
  // let InvWith be the inverse matrix of the 'within class' correlation matrix
  //
  // then the array of Fisher coefficients is 
  // [coeff] =sqrt(fNsig*fNbgd)/fNevt*transpose{Xs-Xb}*InvWith

  // invert covariance matrix
  TMatrixT<float>* theMat = 0;
  switch (fFisherMethod) {
  case kFisher:
    theMat = fWith;
    break;
  case kMahalanobis:
    theMat = fCov;
    break;
  default:
    cout << "--- " << GetName() << ": ERROR: undefined method ==> exit(1)" << endl;
    exit(1);
  }

  TMatrixT<float> invCov( *theMat );
  invCov.Invert();

  // apply rescaling factor
  Double_t xfact = sqrt(Double_t(fNsig*fNbgd))/Double_t(fNsig + fNbgd);

  // compute difference of mean values
  vector<Double_t> diffMeans( fNvar );
  Int_t ivar, jvar;
  for (ivar=0; ivar<fNvar; ivar++) {
    (*fFisherCoeff)[ivar] = 0;

    for(jvar=0; jvar<fNvar; jvar++) {
      Double_t d = (*fMeanMatx)(jvar, 0) - (*fMeanMatx)(jvar, 1);
      (*fFisherCoeff)[ivar] += invCov(ivar, jvar)*d;
    }    
    
    // rescale
    (*fFisherCoeff)[ivar] *= xfact;
  }

  // offset correction
  fF0 = 0.0;
  for(ivar=0; ivar<fNvar; ivar++) 
    fF0 += (*fFisherCoeff)[ivar]*((*fMeanMatx)(ivar, 0) + (*fMeanMatx)(ivar, 1));
  fF0 /= -2.0;  
}

//_______________________________________________________________________
 void TMVA_MethodFisher::GetDiscrimPower( void )
{
  //small values of "fWith" indicates little compactness of sig & of backgd
  //big values of "fBetw" indicates large separation between sig & backgd
  //
  //we want signal & backgd classes as compact and separated as possible
  //the discriminating power is then defined as the ration "fBetw/fWith"
  for (Int_t ivar=0; ivar<fNvar; ivar++)
    if ((*fCov)(ivar, ivar) != 0) 
      (*fDiscrimPow)[ivar] = (*fBetw)(ivar, ivar)/(*fCov)(ivar, ivar);
    else
      (*fDiscrimPow)[ivar] = 0;
}

//_______________________________________________________________________
 void TMVA_MethodFisher::PrintCoefficients( void ) 
{
  // display Fisher coefficients and discriminating power for each variable
  // check maximum length of variable name
  Int_t maxL = 0; 
  vector<Double_t> dp( fNvar );
  for (Int_t ivar=0; ivar<fNvar; ivar++) {
    if ((*fInputVars)[ivar].Length() > maxL) maxL = (*fInputVars)[ivar].Length();
    dp[ivar] = (*fDiscrimPow)[ivar];
  }

  // sort according to rank (descending)
  sort   ( dp.begin(), dp.end() );
  reverse( dp.begin(), dp.end() );

  cout << "--- " << endl;
  cout << "--- " << GetName() << ": ranked output (top variable is best ranked)" << endl;
  cout << "----------------------------------------------------------------" << endl;
  cout << "--- " << setiosflags(ios::left
       << resetiosflags(ios::right
       << setw(12) << "  Coefficient:"
       << "   Discr. power:" << endl;
  cout << "----------------------------------------------------------------" << endl;
  for (Int_t ivar=0; ivar<fNvar; ivar++) 
    for (Int_t jvar=0; jvar<fNvar; jvar++) 
      if (dp[ivar] == (*fDiscrimPow)[jvar])
	printf( "--- %-11s:   %+.3f         %.4f\n", 
		(const char*)(*fInputVars)[jvar], (*fFisherCoeff)[jvar], (*fDiscrimPow)[jvar]);
  printf( "--- %-11s:   %+.3f         %i\n", "(offset)", fF0, 0 );
  cout << "----------------------------------------------------------------" << endl;
  cout << "--- " << endl;
}

//_______________________________________________________________________
 void  TMVA_MethodFisher::WriteWeightsToFile( void )
{  
  // write coefficients to file
  TString fname = GetWeightFileName();
  cout << "--- " << GetName() << ": creating weight file: " << fname << endl;
  
  ofstream fout( fname );
  if (!fout.good( )) { // file not found --> Error
    cout << "--- " << GetName() << ": Error in ::WriteWeightsToFile: "
         << "unable to open output  weight file: " << fname << endl;
    exit(1);
  }

  // write variable names and min/max 
  // NOTE: the latter values are mandatory for the normalisation 
  // in the reader application !!!
  fout << this->GetMethodName() <<endl;
  fout << "NVars= " << fNvar <<endl; 
  for (Int_t ivar=0; ivar<fNvar; ivar++) {
    TString var = (*fInputVars)[ivar];
    fout << var << "  " << GetXminNorm( var ) << "  " << GetXmaxNorm( var ) << endl;
  }

  // and save the weights
  fout << fF0 << endl;
  for (Int_t ivar=0; ivar<fNvar; ivar++) fout << (*fFisherCoeff)[ivar] << endl;
  fout.close();    
}
  
//_______________________________________________________________________
 void  TMVA_MethodFisher::ReadWeightsFromFile( void )
{
  // read coefficients from file
  TString fname = GetWeightFileName();
  cout << "--- " << GetName() << ": reading weight file: " << fname << endl;
  ifstream fin( fname );

  if (!fin.good( )) { // file not found --> Error
    cout << "--- " << GetName() << ": Error in ::ReadWeightsFromFile: "
         << "unable to open input file: " << fname << endl;
    exit(1);
  }

  // read variable names and min/max
  // NOTE: the latter values are mandatory for the normalisation 
  // in the reader application !!!
  TString var, dummy;
  Double_t xmin, xmax;
  fin >> dummy;
  this->SetMethodName(dummy);
  fin >> dummy >> fNvar;
  for (Int_t ivar=0; ivar<fNvar; ivar++) {
    fin >> var >> xmin >> xmax;

    // sanity check
    if (var != (*fInputVars)[ivar]) {
      cout << "--- " << GetName() << ": Error while reading weight file; "
	   << "unknown variable: " << var << " at position: " << ivar << ". "
	   << "Expected variable: " << (*fInputVars)[ivar] << " ==> abort" << endl;
      exit(1);
    }
    
    // set min/max
    this->SetXminNorm( ivar, xmin );
    this->SetXmaxNorm( ivar, xmax );
  }  

  // and read the weights (Fisher coefficients)
  fin >> fF0;
  for (Int_t ivar=0; ivar<fNvar; ivar++) fin >> (*fFisherCoeff)[ivar];
  fin.close();    
}

//_______________________________________________________________________
 void  TMVA_MethodFisher::WriteHistosToFile( void )
{
  cout << "--- " << GetName() << ": write " << GetName() 
       <<" special histos to file: " << fBaseDir->GetPath() << endl;
}


ROOT page - Class index - Class Hierarchy - Top of the page

This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.