70 TMVA::MethodHMatrix::MethodHMatrix( const
TString& jobName,
75 :
TMVA::MethodBase( jobName, Types::kHMatrix, methodTitle, theData, theOption, theTargetDir )
104 fInvHMatrixS =
new TMatrixD( GetNvar(), GetNvar() );
105 fInvHMatrixB =
new TMatrixD( GetNvar(), GetNvar() );
106 fVecMeanS =
new TVectorD( GetNvar() );
107 fVecMeanB =
new TVectorD( GetNvar() );
110 SetSignalReferenceCut( 0.0 );
118 if (
NULL != fInvHMatrixS)
delete fInvHMatrixS;
119 if (
NULL != fInvHMatrixB)
delete fInvHMatrixB;
120 if (
NULL != fVecMeanS )
delete fVecMeanS;
121 if (
NULL != fVecMeanB )
delete fVecMeanB;
154 ComputeCovariance(
kTRUE, fInvHMatrixS );
155 ComputeCovariance(
kFALSE, fInvHMatrixB );
158 if (
TMath::Abs(fInvHMatrixS->Determinant()) < 10
E-24) {
159 Log() <<
kWARNING <<
"<Train> H-matrix S is almost singular with deterinant= "
161 <<
" did you use the variables that are linear combinations or highly correlated ???"
164 if (
TMath::Abs(fInvHMatrixB->Determinant()) < 10
E-24) {
165 Log() <<
kWARNING <<
"<Train> H-matrix B is almost singular with deterinant= "
167 <<
" did you use the variables that are linear combinations or highly correlated ???"
171 if (
TMath::Abs(fInvHMatrixS->Determinant()) < 10
E-120) {
172 Log() <<
kFATAL <<
"<Train> H-matrix S is singular with deterinant= "
174 <<
" did you use the variables that are linear combinations ???"
177 if (
TMath::Abs(fInvHMatrixB->Determinant()) < 10
E-120) {
178 Log() <<
kFATAL <<
"<Train> H-matrix B is singular with deterinant= "
180 <<
" did you use the variables that are linear combinations ???"
185 fInvHMatrixS->Invert();
186 fInvHMatrixB->Invert();
196 const UInt_t nvar = DataInfo().GetNVariables();
201 TMatrixD mat2(nvar, nvar); mat2 *= 0;
208 for (
Int_t i=0, iEnd=
Data()->GetNEvents(); i<iEnd; ++i) {
211 const Event* origEvt =
Data()->GetEvent(i);
215 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
217 if (DataInfo().IsSignal(origEvt) != isSignal)
continue;
220 GetTransformationHandler().SetTransformationReferenceClass( origEvt->
GetClass() );
221 const Event* ev = GetTransformationHandler().Transform( origEvt );
224 sumOfWeights += weight;
227 for (ivar=0; ivar<nvar; ivar++) xval[ivar] = ev->
GetValue(ivar);
230 for (ivar=0; ivar<nvar; ivar++) {
232 vec(ivar) += xval[ivar]*weight;
233 mat2(ivar, ivar) += (xval[ivar]*xval[ivar])*weight;
235 for (jvar=ivar+1; jvar<nvar; jvar++) {
236 mat2(ivar, jvar) += (xval[ivar]*xval[jvar])*weight;
237 mat2(jvar, ivar) = mat2(ivar, jvar);
243 for (ivar=0; ivar<nvar; ivar++) {
245 if (isSignal) (*fVecMeanS)(ivar) = vec(ivar)/sumOfWeights;
246 else (*fVecMeanB)(ivar) = vec(ivar)/sumOfWeights;
248 for (jvar=0; jvar<nvar; jvar++) {
249 (*mat)(ivar, jvar) = mat2(ivar, jvar)/sumOfWeights - vec(ivar)*vec(jvar)/(sumOfWeights*sumOfWeights);
264 if (s+b < 0)
Log() <<
kFATAL <<
"big trouble: s+b: " << s+b <<
Endl;
267 NoErrorCalc(err, errUpper);
269 return (b - s)/(s + b);
279 const Event* origEvt = fTmpEvent ? fTmpEvent:
Data()->GetEvent();
282 UInt_t ivar(0), jvar(0), nvar(GetNvar());
283 std::vector<Double_t> val( nvar );
287 GetTransformationHandler().SetTransformationReferenceClass( fSignalClass );
289 GetTransformationHandler().SetTransformationReferenceClass( fBackgroundClass );
291 const Event* ev = GetTransformationHandler().Transform( origEvt );
293 for (ivar=0; ivar<nvar; ivar++) val[ivar] = ev->
GetValue( ivar );
296 for (ivar=0; ivar<nvar; ivar++) {
297 for (jvar=0; jvar<nvar; jvar++) {
299 chi2 += ( (val[ivar] - (*fVecMeanS)(ivar))*(val[jvar] - (*fVecMeanS)(jvar))
300 * (*fInvHMatrixS)(ivar,jvar) );
302 chi2 += ( (val[ivar] - (*fVecMeanB)(ivar))*(val[jvar] - (*fVecMeanB)(jvar))
303 * (*fInvHMatrixB)(ivar,jvar) );
308 if (chi2 < 0)
Log() <<
kFATAL <<
"<GetChi2> negative chi2: " << chi2 <<
Endl;
353 for (ivar=0; ivar<GetNvar(); ivar++)
354 istr >> (*fVecMeanS)(ivar) >> (*fVecMeanB)(ivar);
357 for (ivar=0; ivar<GetNvar(); ivar++)
358 for (jvar=0; jvar<GetNvar(); jvar++)
359 istr >> (*fInvHMatrixS)(ivar,jvar);
362 for (ivar=0; ivar<GetNvar(); ivar++)
363 for (jvar=0; jvar<GetNvar(); jvar++)
364 istr >> (*fInvHMatrixB)(ivar,jvar);
372 fout <<
" // arrays of input evt vs. variable " << std::endl;
373 fout <<
" double fInvHMatrixS[" << GetNvar() <<
"][" << GetNvar() <<
"]; // inverse H-matrix (signal)" << std::endl;
374 fout <<
" double fInvHMatrixB[" << GetNvar() <<
"][" << GetNvar() <<
"]; // inverse H-matrix (background)" << std::endl;
375 fout <<
" double fVecMeanS[" << GetNvar() <<
"]; // vector of mean values (signal)" << std::endl;
376 fout <<
" double fVecMeanB[" << GetNvar() <<
"]; // vector of mean values (background)" << std::endl;
377 fout <<
" " << std::endl;
378 fout <<
" double GetChi2( const std::vector<double>& inputValues, int type ) const;" << std::endl;
379 fout <<
"};" << std::endl;
380 fout <<
" " << std::endl;
381 fout <<
"void " << className <<
"::Initialize() " << std::endl;
382 fout <<
"{" << std::endl;
383 fout <<
" // init vectors with mean values" << std::endl;
384 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
385 fout <<
" fVecMeanS[" << ivar <<
"] = " << (*fVecMeanS)(ivar) <<
";" << std::endl;
386 fout <<
" fVecMeanB[" << ivar <<
"] = " << (*fVecMeanB)(ivar) <<
";" << std::endl;
388 fout <<
" " << std::endl;
389 fout <<
" // init H-matrices" << std::endl;
390 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
391 for (
UInt_t jvar=0; jvar<GetNvar(); jvar++) {
392 fout <<
" fInvHMatrixS[" << ivar <<
"][" << jvar <<
"] = "
393 << (*fInvHMatrixS)(ivar,jvar) <<
";" << std::endl;
394 fout <<
" fInvHMatrixB[" << ivar <<
"][" << jvar <<
"] = "
395 << (*fInvHMatrixB)(ivar,jvar) <<
";" << std::endl;
398 fout <<
"}" << std::endl;
399 fout <<
" " << std::endl;
400 fout <<
"inline double " << className <<
"::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
401 fout <<
"{" << std::endl;
402 fout <<
" // returns the H-matrix signal estimator" << std::endl;
403 fout <<
" std::vector<double> inputValuesSig = inputValues;" << std::endl;
404 fout <<
" std::vector<double> inputValuesBgd = inputValues;" << std::endl;
405 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
407 UInt_t signalClass =DataInfo().GetClassInfo(
"Signal")->GetNumber();
408 UInt_t backgroundClass=DataInfo().GetClassInfo(
"Background")->GetNumber();
410 fout <<
" Transform(inputValuesSig," << signalClass <<
");" << std::endl;
411 fout <<
" Transform(inputValuesBgd," << backgroundClass <<
");" << std::endl;
416 fout <<
" double s = GetChi2( inputValuesSig, " <<
Types::kSignal <<
" );" << std::endl;
417 fout <<
" double b = GetChi2( inputValuesBgd, " <<
Types::kBackground <<
" );" << std::endl;
421 fout <<
" " << std::endl;
422 fout <<
" if (s+b <= 0) std::cout << \"Problem in class " << className <<
"::GetMvaValue__: s+b = \"" << std::endl;
423 fout <<
" << s+b << \" <= 0 \" << std::endl;" << std::endl;
424 fout <<
" " << std::endl;
425 fout <<
" return (b - s)/(s + b);" << std::endl;
426 fout <<
"}" << std::endl;
427 fout <<
" " << std::endl;
428 fout <<
"inline double " << className <<
"::GetChi2( const std::vector<double>& inputValues, int type ) const" << std::endl;
429 fout <<
"{" << std::endl;
430 fout <<
" // compute chi2-estimator for event according to type (signal/background)" << std::endl;
431 fout <<
" " << std::endl;
432 fout <<
" size_t ivar,jvar;" << std::endl;
433 fout <<
" double chi2 = 0;" << std::endl;
434 fout <<
" for (ivar=0; ivar<GetNvar(); ivar++) {" << std::endl;
435 fout <<
" for (jvar=0; jvar<GetNvar(); jvar++) {" << std::endl;
437 fout <<
" chi2 += ( (inputValues[ivar] - fVecMeanS[ivar])*(inputValues[jvar] - fVecMeanS[jvar])" << std::endl;
438 fout <<
" * fInvHMatrixS[ivar][jvar] );" << std::endl;
439 fout <<
" else" << std::endl;
440 fout <<
" chi2 += ( (inputValues[ivar] - fVecMeanB[ivar])*(inputValues[jvar] - fVecMeanB[jvar])" << std::endl;
441 fout <<
" * fInvHMatrixB[ivar][jvar] );" << std::endl;
442 fout <<
" }" << std::endl;
443 fout <<
" } // loop over variables " << std::endl;
444 fout <<
" " << std::endl;
445 fout <<
" // sanity check" << std::endl;
446 fout <<
" if (chi2 < 0) std::cout << \"Problem in class " << className <<
"::GetChi2: chi2 = \"" << std::endl;
447 fout <<
" << chi2 << \" < 0 \" << std::endl;" << std::endl;
448 fout <<
" " << std::endl;
449 fout <<
" return chi2;" << std::endl;
450 fout <<
"}" << std::endl;
451 fout <<
" " << std::endl;
452 fout <<
"// Clean up" << std::endl;
453 fout <<
"inline void " << className <<
"::Clear() " << std::endl;
454 fout <<
"{" << std::endl;
455 fout <<
" // nothing to clear" << std::endl;
456 fout <<
"}" << std::endl;
470 Log() <<
"The H-Matrix classifier discriminates one class (signal) of a feature" <<
Endl;
471 Log() <<
"vector from another (background). The correlated elements of the" <<
Endl;
472 Log() <<
"vector are assumed to be Gaussian distributed, and the inverse of" <<
Endl;
473 Log() <<
"the covariance matrix is the H-Matrix. A multivariate chi-squared" <<
Endl;
474 Log() <<
"estimator is built that exploits differences in the mean values of" <<
Endl;
475 Log() <<
"the vector elements between the two classes for the purpose of" <<
Endl;
476 Log() <<
"discrimination." <<
Endl;
480 Log() <<
"The TMVA implementation of the H-Matrix classifier has been shown" <<
Endl;
481 Log() <<
"to underperform in comparison with the corresponding Fisher discriminant," <<
Endl;
482 Log() <<
"when using similar assumptions and complexity. Its use is therefore" <<
Endl;
483 Log() <<
"depreciated. Only in cases where the background model is strongly" <<
Endl;
484 Log() <<
"non-Gaussian, H-Matrix may perform better than Fisher. In such" <<
Endl;
485 Log() <<
"occurrences the user is advised to employ non-linear classifiers. " <<
Endl;
void ReadWeightsFromXML(void *wghtnode)
read weights from XML file
void Init()
default initialization called by all constructors
MsgLogger & Endl(MsgLogger &ml)
void AddWeightsXMLTo(void *parent) const
create XML description for HMatrix classification
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the H-matrix signal estimator
void ProcessOptions()
process user options
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
void ReadWeightsFromStream(std::istream &istr)
read variable names and min/max NOTE: the latter values are mandatory for the normalisation in the re...
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
virtual ~MethodHMatrix()
destructor
void GetHelpMessage() const
get help message text
ClassImp(TMVA::MethodHMatrix) TMVA
standard constructor for the H-Matrix method
std::vector< std::vector< double > > Data
MethodHMatrix(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=0)
TVectorT< Double_t > TVectorD
TMatrixT< Double_t > TMatrixD
void Train()
computes H-matrices for signal and background samples
Describe directory structure in memory.
void ComputeCovariance(Bool_t, TMatrixD *)
compute covariance matrix
void MakeClassSpecific(std::ostream &, const TString &) const
write Fisher-specific classifier response
static RooMathCoreReg dummy
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
void DeclareOptions()
MethodHMatrix options: none (apart from those implemented in MethodBase)
Double_t GetChi2(Types::ESBType)
compute chi2-estimator for event according to type (signal/background)