Logo ROOT  
Reference Guide
ROCCurve.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Simon Pfreundschuh and Kim Albertsson
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ROCCurve *
8  * *
9  * Description: *
10  * This is class to compute ROC Integral (AUC) *
11  * *
12  * Authors : *
13  * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15  * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16  * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
17  * *
18  * Copyright (c) 2015: *
19  * CERN, Switzerland *
20  * UdeA/ITM, Colombia *
21  * U. of Florida, USA *
22  **********************************************************************************/
23 
24 /*! \class TMVA::ROCCurve
25 \ingroup TMVA
26 
27 */
28 #include "TMVA/Tools.h"
29 #include "TMVA/TSpline1.h"
30 #include "TMVA/ROCCurve.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Version.h"
33 #include "TMVA/MsgLogger.h"
34 #include "TGraph.h"
35 
36 #include <algorithm>
37 #include <vector>
38 #include <cassert>
39 
40 using namespace std;
41 
42 auto tupleSort = [](std::tuple<Float_t, Float_t, Bool_t> _a, std::tuple<Float_t, Float_t, Bool_t> _b) {
43  return std::get<0>(_a) < std::get<0>(_b);
44 };
45 
46 //_______________________________________________________________________
47 TMVA::ROCCurve::ROCCurve(const std::vector<std::tuple<Float_t, Float_t, Bool_t>> &mvas)
48  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL), fMva(mvas)
49 {
50 }
51 
52 ////////////////////////////////////////////////////////////////////////////////
53 ///
54 
55 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
56  const std::vector<Float_t> &mvaWeights)
57  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
58 {
59  assert(mvaValues.size() == mvaTargets.size());
60  assert(mvaValues.size() == mvaWeights.size());
61 
62  for (UInt_t i = 0; i < mvaValues.size(); i++) {
63  fMva.emplace_back(mvaValues[i], mvaWeights[i], mvaTargets[i]);
64  }
65 
66  std::sort(fMva.begin(), fMva.end(), tupleSort);
67 }
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 ///
71 
72 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets)
73  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
74 {
75  assert(mvaValues.size() == mvaTargets.size());
76 
77  for (UInt_t i = 0; i < mvaValues.size(); i++) {
78  fMva.emplace_back(mvaValues[i], 1, mvaTargets[i]);
79  }
80 
81  std::sort(fMva.begin(), fMva.end(), tupleSort);
82 }
83 
84 ////////////////////////////////////////////////////////////////////////////////
85 ///
86 
87 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground)
88  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
89 {
90  for (UInt_t i = 0; i < mvaSignal.size(); i++) {
91  fMva.emplace_back(mvaSignal[i], 1, kTRUE);
92  }
93 
94  for (UInt_t i = 0; i < mvaBackground.size(); i++) {
95  fMva.emplace_back(mvaBackground[i], 1, kFALSE);
96  }
97 
98  std::sort(fMva.begin(), fMva.end(), tupleSort);
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 ///
103 
104 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
105  const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights)
106  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
107 {
108  assert(mvaSignal.size() == mvaSignalWeights.size());
109  assert(mvaBackground.size() == mvaBackgroundWeights.size());
110 
111  for (UInt_t i = 0; i < mvaSignal.size(); i++) {
112  fMva.emplace_back(mvaSignal[i], mvaSignalWeights[i], kTRUE);
113  }
114 
115  for (UInt_t i = 0; i < mvaBackground.size(); i++) {
116  fMva.emplace_back(mvaBackground[i], mvaBackgroundWeights[i], kFALSE);
117  }
118 
119  std::sort(fMva.begin(), fMva.end(), tupleSort);
120 }
121 
122 ////////////////////////////////////////////////////////////////////////////////
123 /// destructor
124 
126  delete fLogger;
127  if(fGraph) delete fGraph;
128 }
129 
131 {
132  if (!fLogger)
133  fLogger = new TMVA::MsgLogger("ROCCurve");
134  return *fLogger;
135 }
136 
137 ////////////////////////////////////////////////////////////////////////////////
138 ///
139 
140 std::vector<Double_t> TMVA::ROCCurve::ComputeSpecificity(const UInt_t num_points)
141 {
142  if (num_points <= 2) {
143  return {0.0, 1.0};
144  }
145 
146  std::vector<Double_t> specificity_vector;
147  std::vector<Double_t> true_negatives;
148  specificity_vector.reserve(fMva.size());
149  true_negatives.reserve(fMva.size());
150 
151  Double_t true_negatives_sum = 0.0;
152  for (auto &ev : fMva) {
153  // auto value = std::get<0>(ev);
154  auto weight = std::get<1>(ev);
155  auto isSignal = std::get<2>(ev);
156 
157  true_negatives_sum += weight * (!isSignal);
158  true_negatives.push_back(true_negatives_sum);
159  }
160 
161  specificity_vector.push_back(0.0);
162  Double_t total_background = true_negatives_sum;
163  for (auto &tn : true_negatives) {
164  Double_t specificity =
165  (total_background <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tn / total_background);
166  specificity_vector.push_back(specificity);
167  }
168  specificity_vector.push_back(1.0);
169 
170  return specificity_vector;
171 }
172 
173 ////////////////////////////////////////////////////////////////////////////////
174 ///
175 
176 std::vector<Double_t> TMVA::ROCCurve::ComputeSensitivity(const UInt_t num_points)
177 {
178  if (num_points <= 2) {
179  return {1.0, 0.0};
180  }
181 
182  std::vector<Double_t> sensitivity_vector;
183  std::vector<Double_t> true_positives;
184  sensitivity_vector.reserve(fMva.size());
185  true_positives.reserve(fMva.size());
186 
187  Double_t true_positives_sum = 0.0;
188  for (auto it = fMva.rbegin(); it != fMva.rend(); ++it) {
189  // auto value = std::get<0>(*it);
190  auto weight = std::get<1>(*it);
191  auto isSignal = std::get<2>(*it);
192 
193  true_positives_sum += weight * (isSignal);
194  true_positives.push_back(true_positives_sum);
195  }
196  std::reverse(true_positives.begin(), true_positives.end());
197 
198  sensitivity_vector.push_back(1.0);
199  Double_t total_signal = true_positives_sum;
200  for (auto &tp : true_positives) {
201  Double_t sensitivity = (total_signal <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tp / total_signal);
202  sensitivity_vector.push_back(sensitivity);
203  }
204  sensitivity_vector.push_back(0.0);
205 
206  return sensitivity_vector;
207 }
208 
209 ////////////////////////////////////////////////////////////////////////////////
210 /// Calculate the signal efficiency (sensitivity) for a given background
211 /// efficiency (sensitivity).
212 ///
213 /// @param effB Background efficiency for which to calculate signal
214 /// efficiency.
215 /// @param num_points Number of points used for the underlying histogram.
216 /// The number of bins will be num_points - 1.
217 ///
218 
220 {
221  assert(0.0 <= effB && effB <= 1.0);
222 
223  auto effS_vec = ComputeSensitivity(num_points);
224  auto effB_vec = ComputeSpecificity(num_points);
225 
226  // Specificity is actually rejB, so we need to transform it.
227  auto complement = [](Double_t x) { return 1 - x; };
228  std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);
229 
230  // Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
231  std::reverse(effS_vec.begin(), effS_vec.end());
232  std::reverse(effB_vec.begin(), effB_vec.end());
233 
234  TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);
235 
236  // TSpline1 does linear interpolation of ROC curve
237  TSpline1 rocSpline = TSpline1("", graph);
238  return rocSpline.Eval(effB);
239 }
240 
241 ////////////////////////////////////////////////////////////////////////////////
242 /// Calculates the ROC integral (AUC)
243 ///
244 /// @param num_points Granularity of the resulting curve used for integration.
245 /// The curve will be subdivided into num_points - 1 regions
246 /// where the performance of the classifier is sampled.
247 /// Larger number means more accurate, but more costly,
248 /// evaluation.
249 
251 {
252  auto sensitivity = ComputeSensitivity(num_points);
253  auto specificity = ComputeSpecificity(num_points);
254 
255  Double_t integral = 0.0;
256  for (UInt_t i = 0; i < sensitivity.size() - 1; i++) {
257  // FNR, false negatigve rate = 1 - Sensitivity
258  Double_t currFnr = 1 - sensitivity[i];
259  Double_t nextFnr = 1 - sensitivity[i + 1];
260  // Trapezodial integration
261  integral += 0.5 * (nextFnr - currFnr) * (specificity[i] + specificity[i + 1]);
262  }
263 
264  return integral;
265 }
266 
267 ////////////////////////////////////////////////////////////////////////////////
268 /// Returns a new TGraph containing the ROC curve. Sensitivity is on the x-axis,
269 /// specificity on the y-axis.
270 ///
271 /// @param num_points Granularity of the resulting curve. The curve will be subdivided
272 /// into num_points - 1 regions where the performance of the
273 /// classifier is sampled. Larger number means more accurate,
274 /// but more costly, evaluation.
275 
277 {
278  if (fGraph != nullptr) {
279  delete fGraph;
280  }
281 
282  auto sensitivity = ComputeSensitivity(num_points);
283  auto specificity = ComputeSpecificity(num_points);
284 
285  fGraph = new TGraph(sensitivity.size(), &sensitivity[0], &specificity[0]);
286 
287  return fGraph;
288 }
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
TMVA::TSpline1::Eval
virtual Double_t Eval(Double_t x) const
returns linearly interpolated TGraph entry around x
Definition: TSpline1.cxx:61
TGraph.h
TMVA::ROCCurve::GetROCIntegral
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:250
x
Double_t x[n]
Definition: legend1.C:17
TMVA::ROCCurve::~ROCCurve
~ROCCurve()
destructor
Definition: ROCCurve.cxx:125
TMVA::ROCCurve::ComputeSpecificity
std::vector< Double_t > ComputeSpecificity(const UInt_t num_points)
Definition: ROCCurve.cxx:140
Version.h
MsgLogger.h
TMVA::ROCCurve::ROCCurve
ROCCurve(const std::vector< std::tuple< Float_t, Float_t, Bool_t >> &mvas)
Definition: ROCCurve.cxx:47
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TMVA::TSpline1
Linear interpolation of TGraph.
Definition: TSpline1.h:43
TMVA::ROCCurve::ComputeSensitivity
std::vector< Double_t > ComputeSensitivity(const UInt_t num_points)
Definition: ROCCurve.cxx:176
ROCCurve.h
TSpline1.h
Config.h
unsigned int
tupleSort
auto tupleSort
Definition: ROCCurve.cxx:42
TMVA::ROCCurve::Log
MsgLogger & Log() const
message logger
Definition: ROCCurve.cxx:130
Double_t
double Double_t
Definition: RtypesCore.h:59
TGraph
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TMVA::MsgLogger
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
TMVA::ROCCurve::GetROCCurve
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:276
graph
Definition: graph.py:1
TMVA::mvas
void mvas(TString dataset, TString fin="TMVA.root", HistType htype=kMVAType, Bool_t useTMVAStyle=kTRUE)
Tools.h
TMVA::ROCCurve::fMva
std::vector< std::tuple< Float_t, Float_t, Bool_t > > fMva
Definition: ROCCurve.h:74
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
TMVA::ROCCurve::GetEffSForEffB
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:219