Loading [MathJax]/extensions/tex2jax.js
Logo ROOT  
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
ROCCurve.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Simon Pfreundschuh and Kim Albertsson
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : ROCCurve *
8 * *
9 * Description: *
10 * This is class to compute ROC Integral (AUC) *
11 * *
12 * Authors : *
13 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
17 * *
18 * Copyright (c) 2015: *
19 * CERN, Switzerland *
20 * UdeA/ITM, Colombia *
21 * U. of Florida, USA *
22 **********************************************************************************/
23
24/*! \class TMVA::ROCCurve
25\ingroup TMVA
26
27*/
28#include "TMVA/Tools.h"
29#include "TMVA/TSpline1.h"
30#include "TMVA/ROCCurve.h"
31#include "TMVA/Config.h"
32#include "TMVA/Version.h"
33#include "TMVA/MsgLogger.h"
34#include "TGraph.h"
35#include "TMath.h"
36
37#include <algorithm>
38#include <vector>
39#include <cassert>
40
41using namespace std;
42
43auto tupleSort = [](std::tuple<Float_t, Float_t, Bool_t> _a, std::tuple<Float_t, Float_t, Bool_t> _b) {
44 return std::get<0>(_a) < std::get<0>(_b);
45};
46
47//_______________________________________________________________________
48TMVA::ROCCurve::ROCCurve(const std::vector<std::tuple<Float_t, Float_t, Bool_t>> &mvas)
49 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL), fMva(mvas)
50{
51}
52
53////////////////////////////////////////////////////////////////////////////////
54///
55
56TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
57 const std::vector<Float_t> &mvaWeights)
58 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
59{
60 assert(mvaValues.size() == mvaTargets.size());
61 assert(mvaValues.size() == mvaWeights.size());
62
63 for (UInt_t i = 0; i < mvaValues.size(); i++) {
64 fMva.emplace_back(mvaValues[i], mvaWeights[i], mvaTargets[i]);
65 }
66
67 std::sort(fMva.begin(), fMva.end(), tupleSort);
68}
69
70////////////////////////////////////////////////////////////////////////////////
71///
72
73TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets)
74 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
75{
76 assert(mvaValues.size() == mvaTargets.size());
77
78 for (UInt_t i = 0; i < mvaValues.size(); i++) {
79 fMva.emplace_back(mvaValues[i], 1, mvaTargets[i]);
80 }
81
82 std::sort(fMva.begin(), fMva.end(), tupleSort);
83}
84
85////////////////////////////////////////////////////////////////////////////////
86///
87
88TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground)
89 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
90{
91 for (UInt_t i = 0; i < mvaSignal.size(); i++) {
92 fMva.emplace_back(mvaSignal[i], 1, kTRUE);
93 }
94
95 for (UInt_t i = 0; i < mvaBackground.size(); i++) {
96 fMva.emplace_back(mvaBackground[i], 1, kFALSE);
97 }
98
99 std::sort(fMva.begin(), fMva.end(), tupleSort);
100}
101
102////////////////////////////////////////////////////////////////////////////////
103///
104
105TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
106 const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights)
107 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
108{
109 assert(mvaSignal.size() == mvaSignalWeights.size());
110 assert(mvaBackground.size() == mvaBackgroundWeights.size());
111
112 for (UInt_t i = 0; i < mvaSignal.size(); i++) {
113 fMva.emplace_back(mvaSignal[i], mvaSignalWeights[i], kTRUE);
114 }
115
116 for (UInt_t i = 0; i < mvaBackground.size(); i++) {
117 fMva.emplace_back(mvaBackground[i], mvaBackgroundWeights[i], kFALSE);
118 }
119
120 std::sort(fMva.begin(), fMva.end(), tupleSort);
121}
122
123////////////////////////////////////////////////////////////////////////////////
124/// destructor
125
127 delete fLogger;
128 if(fGraph) delete fGraph;
129}
130
132{
133 if (!fLogger)
134 fLogger = new TMVA::MsgLogger("ROCCurve");
135 return *fLogger;
136}
137
138////////////////////////////////////////////////////////////////////////////////
139///
140
141std::vector<Double_t> TMVA::ROCCurve::ComputeSpecificity(const UInt_t num_points)
142{
143 if (num_points <= 2) {
144 return {0.0, 1.0};
145 }
146
147 std::vector<Double_t> specificity_vector;
148 std::vector<Double_t> true_negatives;
149 specificity_vector.reserve(fMva.size());
150 true_negatives.reserve(fMva.size());
151
152 Double_t true_negatives_sum = 0.0;
153 for (auto &ev : fMva) {
154 // auto value = std::get<0>(ev);
155 auto weight = std::get<1>(ev);
156 auto isSignal = std::get<2>(ev);
157
158 true_negatives_sum += weight * (!isSignal);
159 true_negatives.push_back(true_negatives_sum);
160 }
161
162 specificity_vector.push_back(0.0);
163 Double_t total_background = true_negatives_sum;
164 for (auto &tn : true_negatives) {
165 Double_t specificity =
166 (total_background <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tn / total_background);
167 specificity_vector.push_back(specificity);
168 }
169 specificity_vector.push_back(1.0);
170
171 return specificity_vector;
172}
173
174////////////////////////////////////////////////////////////////////////////////
175///
176
177std::vector<Double_t> TMVA::ROCCurve::ComputeSensitivity(const UInt_t num_points)
178{
179 if (num_points <= 2) {
180 return {1.0, 0.0};
181 }
182
183 std::vector<Double_t> sensitivity_vector;
184 std::vector<Double_t> true_positives;
185 sensitivity_vector.reserve(fMva.size());
186 true_positives.reserve(fMva.size());
187
188 Double_t true_positives_sum = 0.0;
189 for (auto it = fMva.rbegin(); it != fMva.rend(); ++it) {
190 // auto value = std::get<0>(*it);
191 auto weight = std::get<1>(*it);
192 auto isSignal = std::get<2>(*it);
193
194 true_positives_sum += weight * (isSignal);
195 true_positives.push_back(true_positives_sum);
196 }
197 std::reverse(true_positives.begin(), true_positives.end());
198
199 sensitivity_vector.push_back(1.0);
200 Double_t total_signal = true_positives_sum;
201 for (auto &tp : true_positives) {
202 Double_t sensitivity = (total_signal <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tp / total_signal);
203 sensitivity_vector.push_back(sensitivity);
204 }
205 sensitivity_vector.push_back(0.0);
206
207 return sensitivity_vector;
208}
209
210////////////////////////////////////////////////////////////////////////////////
211/// Calculate the signal efficiency (sensitivity) for a given background
212/// efficiency (sensitivity).
213///
214/// @param effB Background efficiency for which to calculate signal
215/// efficiency.
216/// @param num_points Number of points used for the underlying histogram.
217/// The number of bins will be num_points - 1.
218///
219
221{
222 assert(0.0 <= effB && effB <= 1.0);
223
224 auto effS_vec = ComputeSensitivity(num_points);
225 auto effB_vec = ComputeSpecificity(num_points);
226
227 // Specificity is actually rejB, so we need to transform it.
228 auto complement = [](Double_t x) { return 1 - x; };
229 std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);
230
231 // Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
232 std::reverse(effS_vec.begin(), effS_vec.end());
233 std::reverse(effB_vec.begin(), effB_vec.end());
234
235 TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);
236
237 // TSpline1 does linear interpolation of ROC curve
238 TSpline1 rocSpline = TSpline1("", graph);
239 return rocSpline.Eval(effB);
240}
241
242////////////////////////////////////////////////////////////////////////////////
243/// Calculates the ROC integral (AUC)
244///
245/// @param num_points Granularity of the resulting curve used for integration.
246/// The curve will be subdivided into num_points - 1 regions
247/// where the performance of the classifier is sampled.
248/// Larger number means more accurate, but more costly,
249/// evaluation.
250
252{
253 auto sensitivity = ComputeSensitivity(num_points);
254 auto specificity = ComputeSpecificity(num_points);
255
256 Double_t integral = 0.0;
257 for (UInt_t i = 0; i < sensitivity.size() - 1; i++) {
258 // FNR, false negatigve rate = 1 - Sensitivity
259 Double_t currFnr = 1 - sensitivity[i];
260 Double_t nextFnr = 1 - sensitivity[i + 1];
261 // Trapezodial integration
262 integral += 0.5 * (nextFnr - currFnr) * (specificity[i] + specificity[i + 1]);
263 }
264
265 return integral;
266}
267
268////////////////////////////////////////////////////////////////////////////////
269/// Returns a new TGraph containing the ROC curve. Specificity is on the x-axis,
270/// sensitivity on the y-axis.
271///
272/// @param num_points Granularity of the resulting curve. The curve will be subdivided
273/// into num_points - 1 regions where the performance of the
274/// classifier is sampled. Larger number means more accurate,
275/// but more costly, evaluation.
276
278{
279 if (fGraph != nullptr) {
280 delete fGraph;
281 }
282
283 auto sensitivity = ComputeSensitivity(num_points);
284 auto specificity = ComputeSpecificity(num_points);
285
286 fGraph = new TGraph(sensitivity.size(), &sensitivity[0], &specificity[0]);
287
288 return fGraph;
289}
auto tupleSort
Definition: ROCCurve.cxx:43
const Bool_t kFALSE
Definition: RtypesCore.h:90
const Bool_t kTRUE
Definition: RtypesCore.h:89
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
std::vector< Double_t > ComputeSpecificity(const UInt_t num_points)
Definition: ROCCurve.cxx:141
ROCCurve(const std::vector< std::tuple< Float_t, Float_t, Bool_t > > &mvas)
Definition: ROCCurve.cxx:48
~ROCCurve()
destructor
Definition: ROCCurve.cxx:126
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:220
std::vector< Double_t > ComputeSensitivity(const UInt_t num_points)
Definition: ROCCurve.cxx:177
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:251
MsgLogger & Log() const
message logger
Definition: ROCCurve.cxx:131
std::vector< std::tuple< Float_t, Float_t, Bool_t > > fMva
Definition: ROCCurve.h:76
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:277
Linear interpolation of TGraph.
Definition: TSpline1.h:43
virtual Double_t Eval(Double_t x) const
returns linearly interpolated TGraph entry around x
Definition: TSpline1.cxx:61
Double_t x[n]
Definition: legend1.C:17
create variable transformations
void mvas(TString dataset, TString fin="TMVA.root", HistType htype=kMVAType, Bool_t useTMVAStyle=kTRUE)
Definition: graph.py:1