Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROCCurve.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Simon Pfreundschuh and Kim Albertsson
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : ROCCurve *
8 * *
9 * Description: *
10 * This is class to compute ROC Integral (AUC) *
11 * *
12 * Authors : *
13 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
17 * *
18 * Copyright (c) 2015: *
19 * CERN, Switzerland *
20 * UdeA/ITM, Colombia *
21 * U. of Florida, USA *
22 **********************************************************************************/
23
24/*! \class TMVA::ROCCurve
25\ingroup TMVA
26
27*/
28#include "TMVA/Tools.h"
29#include "TMVA/TSpline1.h"
30#include "TMVA/ROCCurve.h"
31#include "TMVA/Config.h"
32#include "TMVA/Version.h"
33#include "TMVA/MsgLogger.h"
34#include "TGraph.h"
35
36#include <algorithm>
37#include <vector>
38#include <cassert>
39
40auto tupleSort = [](std::tuple<Float_t, Float_t, Bool_t> _a, std::tuple<Float_t, Float_t, Bool_t> _b) {
41 return std::get<0>(_a) < std::get<0>(_b);
42};
43
44//_______________________________________________________________________
45TMVA::ROCCurve::ROCCurve(const std::vector<std::tuple<Float_t, Float_t, Bool_t>> &mvas)
46 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL), fMva(mvas)
47{
48}
49
50////////////////////////////////////////////////////////////////////////////////
51///
52
53TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
54 const std::vector<Float_t> &mvaWeights)
55 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
56{
57 assert(mvaValues.size() == mvaTargets.size());
58 assert(mvaValues.size() == mvaWeights.size());
59
60 for (UInt_t i = 0; i < mvaValues.size(); i++) {
61 fMva.emplace_back(mvaValues[i], mvaWeights[i], mvaTargets[i]);
62 }
63
64 std::sort(fMva.begin(), fMva.end(), tupleSort);
65}
66
67////////////////////////////////////////////////////////////////////////////////
68///
69
70TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets)
71 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
72{
73 assert(mvaValues.size() == mvaTargets.size());
74
75 for (UInt_t i = 0; i < mvaValues.size(); i++) {
76 fMva.emplace_back(mvaValues[i], 1, mvaTargets[i]);
77 }
78
79 std::sort(fMva.begin(), fMva.end(), tupleSort);
80}
81
82////////////////////////////////////////////////////////////////////////////////
83///
84
85TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground)
86 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
87{
88 for (UInt_t i = 0; i < mvaSignal.size(); i++) {
89 fMva.emplace_back(mvaSignal[i], 1, kTRUE);
90 }
91
92 for (UInt_t i = 0; i < mvaBackground.size(); i++) {
93 fMva.emplace_back(mvaBackground[i], 1, kFALSE);
94 }
95
96 std::sort(fMva.begin(), fMva.end(), tupleSort);
97}
98
99////////////////////////////////////////////////////////////////////////////////
100///
101
102TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
103 const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights)
104 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
105{
106 assert(mvaSignal.size() == mvaSignalWeights.size());
107 assert(mvaBackground.size() == mvaBackgroundWeights.size());
108
109 for (UInt_t i = 0; i < mvaSignal.size(); i++) {
110 fMva.emplace_back(mvaSignal[i], mvaSignalWeights[i], kTRUE);
111 }
112
113 for (UInt_t i = 0; i < mvaBackground.size(); i++) {
114 fMva.emplace_back(mvaBackground[i], mvaBackgroundWeights[i], kFALSE);
115 }
116
117 std::sort(fMva.begin(), fMva.end(), tupleSort);
118}
119
120////////////////////////////////////////////////////////////////////////////////
121/// destructor
122
124 delete fLogger;
125 if(fGraph) delete fGraph;
126}
127
129{
130 if (!fLogger)
131 fLogger = new TMVA::MsgLogger("ROCCurve");
132 return *fLogger;
133}
134
135////////////////////////////////////////////////////////////////////////////////
136///
137
139{
140 if (num_points <= 2) {
141 return {0.0, 1.0};
142 }
143
144 std::vector<Double_t> specificity_vector;
145 std::vector<Double_t> true_negatives;
146 specificity_vector.reserve(fMva.size());
147 true_negatives.reserve(fMva.size());
148
150 for (auto &ev : fMva) {
151 // auto value = std::get<0>(ev);
152 auto weight = std::get<1>(ev);
153 auto isSignal = std::get<2>(ev);
154
155 true_negatives_sum += weight * (!isSignal ? 1. : 0.);
157 }
158
159 specificity_vector.push_back(0.0);
161 for (auto &tn : true_negatives) {
165 }
166 specificity_vector.push_back(1.0);
167
168 return specificity_vector;
169}
170
171////////////////////////////////////////////////////////////////////////////////
172///
173
175{
176 if (num_points <= 2) {
177 return {1.0, 0.0};
178 }
179
180 std::vector<Double_t> sensitivity_vector;
181 std::vector<Double_t> true_positives;
182 sensitivity_vector.reserve(fMva.size());
183 true_positives.reserve(fMva.size());
184
186 for (auto it = fMva.rbegin(); it != fMva.rend(); ++it) {
187 // auto value = std::get<0>(*it);
188 auto weight = std::get<1>(*it);
189 auto isSignal = std::get<2>(*it);
190
191 true_positives_sum += weight * (isSignal);
193 }
194 std::reverse(true_positives.begin(), true_positives.end());
195
196 sensitivity_vector.push_back(1.0);
198 for (auto &tp : true_positives) {
201 }
202 sensitivity_vector.push_back(0.0);
203
204 return sensitivity_vector;
205}
206
207////////////////////////////////////////////////////////////////////////////////
208/// Calculate the signal efficiency (sensitivity) for a given background
209/// efficiency (sensitivity).
210///
211/// @param effB Background efficiency for which to calculate signal
212/// efficiency.
213/// @param num_points Number of points used for the underlying histogram.
214/// The number of bins will be num_points - 1.
215///
216
218{
219 assert(0.0 <= effB && effB <= 1.0);
220
221 auto effS_vec = ComputeSensitivity(num_points);
222 auto effB_vec = ComputeSpecificity(num_points);
223
224 // Specificity is actually rejB, so we need to transform it.
225 auto complement = [](Double_t x) { return 1 - x; };
226 std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);
227
228 // Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
229 std::reverse(effS_vec.begin(), effS_vec.end());
230 std::reverse(effB_vec.begin(), effB_vec.end());
231
232 TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);
233
234 // TSpline1 does linear interpolation of ROC curve
235 TSpline1 rocSpline = TSpline1("", graph);
236 return rocSpline.Eval(effB);
237}
238
239////////////////////////////////////////////////////////////////////////////////
240/// Calculates the ROC integral (AUC)
241///
242/// @param num_points Granularity of the resulting curve used for integration.
243/// The curve will be subdivided into num_points - 1 regions
244/// where the performance of the classifier is sampled.
245/// Larger number means more accurate, but more costly,
246/// evaluation.
247
249{
250 auto sensitivity = ComputeSensitivity(num_points);
251 auto specificity = ComputeSpecificity(num_points);
252
253 Double_t integral = 0.0;
254 for (UInt_t i = 0; i < sensitivity.size() - 1; i++) {
255 // FNR, false negatigve rate = 1 - Sensitivity
257 Double_t nextFnr = 1 - sensitivity[i + 1];
258 // Trapezodial integration
259 integral += 0.5 * (nextFnr - currFnr) * (specificity[i] + specificity[i + 1]);
260 }
261
262 return integral;
263}
264
265////////////////////////////////////////////////////////////////////////////////
266/// Returns a new TGraph containing the ROC curve. Sensitivity is on the x-axis,
267/// specificity on the y-axis.
268///
269/// @param num_points Granularity of the resulting curve. The curve will be subdivided
270/// into num_points - 1 regions where the performance of the
271/// classifier is sampled. Larger number means more accurate,
272/// but more costly, evaluation.
273
275{
276 if (fGraph != nullptr) {
277 delete fGraph;
278 }
279
280 auto sensitivity = ComputeSensitivity(num_points);
281 auto specificity = ComputeSpecificity(num_points);
282
283 fGraph = new TGraph(sensitivity.size(), &sensitivity[0], &specificity[0]);
284
285 return fGraph;
286}
auto tupleSort
Definition ROCCurve.cxx:40
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
const_iterator begin() const
const_iterator end() const
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
std::vector< Double_t > ComputeSpecificity(const UInt_t num_points)
Definition ROCCurve.cxx:138
ROCCurve(const std::vector< std::tuple< Float_t, Float_t, Bool_t > > &mvas)
Definition ROCCurve.cxx:45
~ROCCurve()
destructor
Definition ROCCurve.cxx:123
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition ROCCurve.cxx:217
std::vector< Double_t > ComputeSensitivity(const UInt_t num_points)
Definition ROCCurve.cxx:174
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition ROCCurve.cxx:248
MsgLogger & Log() const
Definition ROCCurve.cxx:128
std::vector< std::tuple< Float_t, Float_t, Bool_t > > fMva
Definition ROCCurve.h:75
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition ROCCurve.cxx:274
Linear interpolation of TGraph.
Definition TSpline1.h:43
Double_t x[n]
Definition legend1.C:17
create variable transformations
void mvas(TString dataset, TString fin="TMVA.root", HistType htype=kMVAType, Bool_t useTMVAStyle=kTRUE)