Logo ROOT  
Reference Guide
ResultsMulticlass.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ResultsMulticlass *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * *
20  * Copyright (c) 2006: *
21  * CERN, Switzerland *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 /*! \class TMVA::ResultsMulticlass
31 \ingroup TMVA
32 Class which takes the results of a multiclass classification
33 */
34 
35 #include "TMVA/ResultsMulticlass.h"
36 
37 #include "TMVA/DataSet.h"
38 #include "TMVA/DataSetInfo.h"
39 #include "TMVA/GeneticAlgorithm.h"
40 #include "TMVA/GeneticFitter.h"
41 #include "TMVA/MsgLogger.h"
42 #include "TMVA/Results.h"
43 #include "TMVA/ROCCurve.h"
44 #include "TMVA/Tools.h"
45 #include "TMVA/Types.h"
46 
47 #include "TGraph.h"
48 #include "TH1F.h"
49 #include "TMatrixD.h"
50 
51 #include <limits>
52 #include <vector>
53 
54 
55 ////////////////////////////////////////////////////////////////////////////////
56 /// constructor
57 
59  : Results( dsi, resultsName ),
60  IFitterTarget(),
61  fLogger( new MsgLogger(Form("ResultsMultiClass%s",resultsName.Data()) , kINFO) ),
62  fClassToOptimize(0),
63  fAchievableEff(dsi->GetNClasses()),
64  fAchievablePur(dsi->GetNClasses()),
65  fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
66 {
67 }
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 /// destructor
71 
73 {
74  delete fLogger;
75 }
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 
79 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
80 {
81  if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
82  fMultiClassValues[ievt] = value;
83 }
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 /// Returns a confusion matrix where each class is pitted against each other.
87 /// Results are
88 
90 {
91  const DataSet *ds = GetDataSet();
92  const DataSetInfo *dsi = GetDataSetInfo();
93  ds->SetCurrentType(GetTreeType());
94 
95  UInt_t numClasses = dsi->GetNClasses();
96  TMatrixD mat(numClasses, numClasses);
97 
98  // class == iRow is considered signal class
99  for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
100  for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
101 
102  // Number is meaningless with only one class
103  if (iRow == iCol) {
104  mat(iRow, iCol) = std::numeric_limits<double>::quiet_NaN();
105  }
106 
107  std::vector<Float_t> valueVector;
108  std::vector<Bool_t> classVector;
109  std::vector<Float_t> weightVector;
110 
111  for (UInt_t iEvt = 0; iEvt < ds->GetNEvents(); ++iEvt) {
112  const Event *ev = ds->GetEvent(iEvt);
113  const UInt_t cls = ev->GetClass();
114  const Float_t weight = ev->GetWeight();
115  const Float_t mvaValue = fMultiClassValues[iEvt][iRow];
116 
117  if (cls != iRow && cls != iCol) {
118  continue;
119  }
120 
121  classVector.push_back(cls == iRow);
122  weightVector.push_back(weight);
123  valueVector.push_back(mvaValue);
124  }
125 
126  ROCCurve roc(valueVector, classVector, weightVector);
127  mat(iRow, iCol) = roc.GetEffSForEffB(effB);
128  }
129  }
130 
131  return mat;
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 
136 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
137 
138  DataSet* ds = GetDataSet();
139  ds->SetCurrentType( GetTreeType() );
140 
141  // Cache optimisation, count true and false positives with memory access
142  // instead of code branch.
143  Float_t positives[2] = {0, 0};
144 
145  for (Int_t ievt = 0; ievt < ds->GetNEvents(); ievt++) {
146  UInt_t evClass = fEventClasses[ievt];
147  Float_t w = fEventWeights[ievt];
148 
149  Bool_t break_outer_loop = false;
150  for (UInt_t icls = 0; icls < cutvalues.size(); ++icls) {
151  auto value = fMultiClassValues[ievt][icls];
152  auto cutvalue = cutvalues.at(icls);
153  if (cutvalue < 0. ? (-value < cutvalue) : (+value <= cutvalue)) {
154  break_outer_loop = true;
155  break;
156  }
157  }
158 
159  if (break_outer_loop) {
160  continue;
161  }
162 
163  Bool_t isEvCurrClass = (evClass == fClassToOptimize);
164  positives[isEvCurrClass] += w;
165  }
166 
167  const Float_t truePositive = positives[1];
168  const Float_t falsePositive = positives[0];
169 
170  Float_t eff = truePositive / fClassSumWeights[fClassToOptimize];
171  Float_t pur = truePositive / (truePositive + falsePositive);
172  Float_t effTimesPur = eff*pur;
173 
174  Float_t toMinimize = std::numeric_limits<float>::max();
175  if (effTimesPur > std::numeric_limits<float>::min())
176  toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
177 
178  fAchievableEff.at(fClassToOptimize) = eff;
179  fAchievablePur.at(fClassToOptimize) = pur;
180 
181  return toMinimize;
182 }
183 
184 ////////////////////////////////////////////////////////////////////////////////
185 ///calculate the best working point (optimal cut values)
186 ///for the multiclass classifier
187 
188 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
189 
190  const DataSetInfo* dsi = GetDataSetInfo();
191  Log() << kINFO << "Calculating best set of cuts for class "
192  << dsi->GetClassInfo( targetClass )->GetName() << Endl;
193 
194  fClassToOptimize = targetClass;
195  std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
196 
197  fClassSumWeights.clear();
198  fEventWeights.clear();
199  fEventClasses.clear();
200 
201  for (UInt_t icls = 0; icls < dsi->GetNClasses(); ++icls) {
202  fClassSumWeights.push_back(0);
203  }
204 
205  DataSet *ds = GetDataSet();
206  for (Int_t ievt = 0; ievt < ds->GetNEvents(); ievt++) {
207  const Event *ev = ds->GetEvent(ievt);
208  fClassSumWeights[ev->GetClass()] += ev->GetWeight();
209  fEventWeights.push_back(ev->GetWeight());
210  fEventClasses.push_back(ev->GetClass());
211  }
212 
213  const TString name( "MulticlassGA" );
214  const TString opts( "PopSize=100:Steps=30" );
215  GeneticFitter mg( *this, name, ranges, opts);
216 
217  std::vector<Double_t> result;
218  mg.Run(result);
219 
220  fBestCuts.at(targetClass) = result;
221 
222  UInt_t n = 0;
223  for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); ++it ){
224  Log() << kINFO << " cutValue[" <<dsi->GetClassInfo( n )->GetName() << "] = " << (*it) << ";"<< Endl;
225  n++;
226  }
227 
228  return result;
229 }
230 
231 ////////////////////////////////////////////////////////////////////////////////
232 /// Create performance graphs for this classifier a multiclass setting.
233 /// Requires that the method has already been evaluated (that a resultset
234 /// already exists.)
235 ///
236 /// Currently uses the new way of calculating ROC Curves. If anything looks
237 /// fishy, please contact the ROOT TMVA team.
238 ///
239 
241 {
242 
243  Log() << kINFO << "Creating multiclass performance histograms..." << Endl;
244 
245  DataSet *ds = GetDataSet();
246  ds->SetCurrentType(GetTreeType());
247  const DataSetInfo *dsi = GetDataSetInfo();
248 
249  UInt_t numClasses = dsi->GetNClasses();
250 
251  std::vector<std::vector<Float_t>> *rawMvaRes = GetValueVector();
252 
253  //
254  // 1-vs-rest ROC curves
255  //
256  for (size_t iClass = 0; iClass < numClasses; ++iClass) {
257 
258  TString className = dsi->GetClassInfo(iClass)->GetName();
259  TString name = Form("%s_rejBvsS_%s", prefix.Data(), className.Data());
260  TString title = Form("%s_%s", prefix.Data(), className.Data());
261 
262  // Histograms are already generated, skip.
263  if ( DoesExist(name) ) {
264  return;
265  }
266 
267  // Format data
268  std::vector<Float_t> mvaRes;
269  std::vector<Bool_t> mvaResTypes;
270  std::vector<Float_t> mvaResWeights;
271 
272  // Vector transpose due to values being stored as
273  // [ [0, 1, 2], [0, 1, 2], ... ]
274  // in ResultsMulticlass::GetValueVector.
275  mvaRes.reserve(rawMvaRes->size());
276  for (auto item : *rawMvaRes) {
277  mvaRes.push_back(item[iClass]);
278  }
279 
280  auto eventCollection = ds->GetEventCollection();
281  mvaResTypes.reserve(eventCollection.size());
282  mvaResWeights.reserve(eventCollection.size());
283  for (auto ev : eventCollection) {
284  mvaResTypes.push_back(ev->GetClass() == iClass);
285  mvaResWeights.push_back(ev->GetWeight());
286  }
287 
288  // Get ROC Curve
289  ROCCurve *roc = new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
290  TGraph *rocGraph = new TGraph(*(roc->GetROCCurve()));
291  delete roc;
292 
293  // Style ROC Curve
294  rocGraph->SetName(name);
295  rocGraph->SetTitle(title);
296 
297  // Store ROC Curve
298  Store(rocGraph);
299  }
300 
301  //
302  // 1-vs-1 ROC curves
303  //
304  for (size_t iClass = 0; iClass < numClasses; ++iClass) {
305  for (size_t jClass = 0; jClass < numClasses; ++jClass) {
306  if (iClass == jClass) {
307  continue;
308  }
309 
310  auto eventCollection = ds->GetEventCollection();
311 
312  // Format data
313  std::vector<Float_t> mvaRes;
314  std::vector<Bool_t> mvaResTypes;
315  std::vector<Float_t> mvaResWeights;
316 
317  mvaRes.reserve(rawMvaRes->size());
318  mvaResTypes.reserve(eventCollection.size());
319  mvaResWeights.reserve(eventCollection.size());
320 
321  for (size_t iEvent = 0; iEvent < eventCollection.size(); ++iEvent) {
322  Event *ev = eventCollection[iEvent];
323 
324  if (ev->GetClass() == iClass || ev->GetClass() == jClass) {
325  Float_t output_value = (*rawMvaRes)[iEvent][iClass];
326  mvaRes.push_back(output_value);
327  mvaResTypes.push_back(ev->GetClass() == iClass);
328  mvaResWeights.push_back(ev->GetWeight());
329  }
330  }
331 
332  // Get ROC Curve
333  ROCCurve *roc = new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
334  TGraph *rocGraph = new TGraph(*(roc->GetROCCurve()));
335  delete roc;
336 
337  // Style ROC Curve
338  TString iClassName = dsi->GetClassInfo(iClass)->GetName();
339  TString jClassName = dsi->GetClassInfo(jClass)->GetName();
340  TString name = Form("%s_1v1rejBvsS_%s_vs_%s", prefix.Data(), iClassName.Data(), jClassName.Data());
341  TString title = Form("%s_%s_vs_%s", prefix.Data(), iClassName.Data(), jClassName.Data());
342  rocGraph->SetName(name);
343  rocGraph->SetTitle(title);
344 
345  // Store ROC Curve
346  Store(rocGraph);
347  }
348  }
349 }
350 
351 ////////////////////////////////////////////////////////////////////////////////
352 /// this function fills the mva response histos for multiclass classification
353 
355 {
356  Log() << kINFO << "Creating multiclass response histograms..." << Endl;
357 
358  DataSet* ds = GetDataSet();
359  ds->SetCurrentType( GetTreeType() );
360  const DataSetInfo* dsi = GetDataSetInfo();
361 
362  std::vector<std::vector<TH1F*> > histos;
363  Float_t xmin = 0.-0.0002;
364  Float_t xmax = 1.+0.0002;
365  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
366  histos.push_back(std::vector<TH1F*>(0));
367  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
368  TString name(Form("%s_%s_prob_for_%s",prefix.Data(),
369  dsi->GetClassInfo( jCls )->GetName(),
370  dsi->GetClassInfo( iCls )->GetName()));
371 
372  // Histograms are already generated, skip.
373  if ( DoesExist(name) ) {
374  return;
375  }
376 
377  histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
378  }
379  }
380 
381  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
382  const Event* ev = ds->GetEvent(ievt);
383  Int_t cls = ev->GetClass();
384  Float_t w = ev->GetWeight();
385  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
386  histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
387  }
388  }
389  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
390  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
391  gTools().NormHist( histos.at(iCls).at(jCls) );
392  Store(histos.at(iCls).at(jCls));
393  }
394  }
395 
396  /*
397  //fill fine binned histos for testing
398  if(prefix.Contains("Test")){
399  std::vector<std::vector<TH1F*> > histos_highbin;
400  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
401  histos_highbin.push_back(std::vector<TH1F*>(0));
402  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
403  TString name(Form("%s_%s_prob_for_%s_HIGHBIN",prefix.Data(),
404  dsi->GetClassInfo( jCls )->GetName().Data(),
405  dsi->GetClassInfo( iCls )->GetName().Data()));
406  histos_highbin.at(iCls).push_back(new TH1F(name,name,nbins_high,xmin,xmax));
407  }
408  }
409 
410  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
411  const Event* ev = ds->GetEvent(ievt);
412  Int_t cls = ev->GetClass();
413  Float_t w = ev->GetWeight();
414  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
415  histos_highbin.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
416  }
417  }
418  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
419  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
420  gTools().NormHist( histos_highbin.at(iCls).at(jCls) );
421  Store(histos_highbin.at(iCls).at(jCls));
422  }
423  }
424  }
425  */
426 }
TGeant4Unit::mg
static constexpr double mg
Definition: TGeant4SystemOfUnits.h:210
n
const Int_t n
Definition: legend1.C:16
TMVA::GeneticFitter
Definition: GeneticFitter.h:63
TH1F.h
TMVA::Tools::NormHist
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:395
TGraph::SetTitle
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2324
TMVA::ResultsMulticlass::CreateMulticlassHistos
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Definition: ResultsMulticlass.cxx:354
TMVA::ResultsMulticlass::EstimatorFunction
Double_t EstimatorFunction(std::vector< Double_t > &)
Definition: ResultsMulticlass.cxx:136
TString::Data
const char * Data() const
Definition: TString.h:369
DataSetInfo.h
Form
char * Form(const char *fmt,...)
TMVA::Event::GetClass
UInt_t GetClass() const
Definition: Event.h:86
TMVA::DataSet::SetCurrentType
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:112
TGraph.h
xmax
float xmax
Definition: THbookFile.cxx:95
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMVA::ResultsMulticlass::CreateMulticlassPerformanceHistos
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Definition: ResultsMulticlass.cxx:240
Float_t
float Float_t
Definition: RtypesCore.h:57
TString
Definition: TString.h:136
TMatrixT
Definition: TMatrixDfwd.h:22
TMVA::ResultsMulticlass::ResultsMulticlass
ResultsMulticlass(const DataSetInfo *dsi, TString resultsName)
constructor
Definition: ResultsMulticlass.cxx:58
bool
TMVA::ResultsMulticlass::GetBestMultiClassCuts
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
Definition: ResultsMulticlass.cxx:188
TMVA::DataSetInfo::GetNClasses
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
TMVA::DataSetInfo
Definition: DataSetInfo.h:62
MsgLogger.h
TMVA::DataSet::GetEventCollection
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:216
TMVA::ResultsMulticlass::SetValue
void SetValue(std::vector< Float_t > &value, Int_t ievt)
Definition: ResultsMulticlass.cxx:79
xmin
float xmin
Definition: THbookFile.cxx:95
TGraph::SetName
virtual void SetName(const char *name="")
Set graph name.
Definition: TGraph.cxx:2308
TMVA::DataSet::GetEvent
const Event * GetEvent() const
Definition: DataSet.cxx:202
TMVA::DataSet::GetNEvents
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
TMVA::DataSetInfo::GetClassInfo
ClassInfo * GetClassInfo(Int_t clNum) const
Definition: DataSetInfo.cxx:146
GeneticFitter.h
TMVA::DataSet
Definition: DataSet.h:81
ROCCurve.h
Types.h
TMVA::Results
Definition: Results.h:57
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:182
unsigned int
Double_t
double Double_t
Definition: RtypesCore.h:59
TGraph
Definition: TGraph.h:41
TMVA::MsgLogger
Definition: MsgLogger.h:83
TMVA::ROCCurve::GetROCCurve
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:276
TMVA::Event::GetWeight
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
TMVA::ResultsMulticlass::~ResultsMulticlass
~ResultsMulticlass()
destructor
Definition: ResultsMulticlass.cxx:72
TH1F
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:572
TMVA::Event
Definition: Event.h:51
name
char name[80]
Definition: TGX11.cxx:110
GeneticAlgorithm.h
TMVA::IFitterTarget
Definition: IFitterTarget.h:64
TMVA::ResultsMulticlass::GetConfusionMatrix
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Definition: ResultsMulticlass.cxx:89
ResultsMulticlass.h
Tools.h
TNamed::GetName
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:53
TMVA::ROCCurve
Definition: ROCCurve.h:45
TMatrixD.h
Results.h
TMVA::gTools
Tools & gTools()
DataSet.h
int
TMVA::ROCCurve::GetEffSForEffB
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:219
TMVA::Interval
Definition: Interval.h:61