Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
LossFunctions.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 10/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 ////////////////////////////////////////////////////////////
13 // Implementation of the loss functions for the reference //
14 // implementation. //
15 ////////////////////////////////////////////////////////////
16
18
19namespace TMVA
20{
21namespace DNN
22{
23//______________________________________________________________________________
24template <typename AReal>
26 const TMatrixT<AReal> &weights)
27{
28 size_t m,n;
29 m = Y.GetNrows();
30 n = Y.GetNcols();
31 AReal result = 0.0;
32
33 for (size_t i = 0; i < m; i++) {
34 for (size_t j = 0; j < n; j++) {
35 AReal dY = (Y(i,j) - output(i,j));
36 result += weights(i, 0) * dY * dY;
37 }
38 }
39 result /= static_cast<AReal>(m * n);
40 return result;
41}
42
43//______________________________________________________________________________
44template <typename AReal>
46 const TMatrixT<AReal> &output, const TMatrixT<AReal> &weights)
47{
48 size_t m,n;
49 m = Y.GetNrows();
50 n = Y.GetNcols();
51
52 dY.Minus(Y, output);
53 dY *= -2.0 / static_cast<AReal>(m * n);
54
55 for (size_t i = 0; i < m; i++) {
56 for (size_t j = 0; j < n; j++) {
57 dY(i, j) *= weights(i, 0);
58 }
59 }
60}
61
62//______________________________________________________________________________
63template <typename AReal>
65 const TMatrixT<AReal> &weights)
66{
67 size_t m,n;
68 m = Y.GetNrows();
69 n = Y.GetNcols();
70 AReal result = 0.0;
71
72 for (size_t i = 0; i < m; i++) {
73 AReal w = weights(i, 0);
74 for (size_t j = 0; j < n; j++) {
75 AReal sig = 1.0 / (1.0 + std::exp(-output(i,j)));
76 result += w * (Y(i, j) * std::log(sig) + (1.0 - Y(i, j)) * std::log(1.0 - sig));
77 }
78 }
79 result /= -static_cast<AReal>(m * n);
80 return result;
81}
82
83//______________________________________________________________________________
84template <typename AReal>
86 const TMatrixT<AReal> &output, const TMatrixT<AReal> &weights)
87{
88 size_t m,n;
89 m = Y.GetNrows();
90 n = Y.GetNcols();
91
92 AReal norm = 1.0 / static_cast<AReal>(m * n);
93 for (size_t i = 0; i < m; i++)
94 {
95 AReal w = weights(i, 0);
96 for (size_t j = 0; j < n; j++)
97 {
98 AReal y = Y(i,j);
99 AReal sig = 1.0 / (1.0 + std::exp(-output(i,j)));
100 dY(i, j) = norm * w * (sig - y);
101 }
102 }
103}
104
105//______________________________________________________________________________
106template <typename AReal>
108 const TMatrixT<AReal> &weights)
109{
110 size_t m,n;
111 m = Y.GetNrows();
112 n = Y.GetNcols();
113 AReal result = 0.0;
114
115 for (size_t i = 0; i < m; i++) {
116 AReal sum = 0.0;
117 AReal w = weights(i, 0);
118 for (size_t j = 0; j < n; j++) {
119 sum += exp(output(i,j));
120 }
121 for (size_t j = 0; j < n; j++) {
122 result += w * Y(i, j) * log(exp(output(i, j)) / sum);
123 }
124 }
125 result /= -static_cast<AReal>(m);
126 return result;
127}
128
129//______________________________________________________________________________
130template <typename AReal>
132 const TMatrixT<AReal> &output, const TMatrixT<AReal> &weights)
133{
134 size_t m,n;
135 m = Y.GetNrows();
136 n = Y.GetNcols();
137 AReal norm = 1.0 / m ;
138
139 for (size_t i = 0; i < m; i++)
140 {
141 AReal sum = 0.0;
142 AReal sumY = 0.0;
143 AReal w = weights(i, 0);
144 for (size_t j = 0; j < n; j++) {
145 sum += exp(output(i,j));
146 sumY += Y(i,j);
147 }
148 for (size_t j = 0; j < n; j++) {
149 dY(i, j) = w * norm * (exp(output(i, j)) / sum * sumY - Y(i, j));
150 }
151 }
152}
153
154} // namespace DNN
155} // namespace TMVA
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
static void CrossEntropyGradients(TMatrixT< AReal > &dY, const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
static AReal SoftmaxCrossEntropy(const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Softmax transformation is implicitly applied, thus output should hold the linear activations of the l...
static AReal CrossEntropy(const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Sigmoid transformation is implicitly applied, thus output should hold the linear activations of the l...
static void MeanSquaredErrorGradients(TMatrixT< AReal > &dY, const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
static void SoftmaxCrossEntropyGradients(TMatrixT< AReal > &dY, const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
static AReal MeanSquaredError(const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
void Minus(const TMatrixT< Element > &a, const TMatrixT< Element > &b)
General matrix subtraction. Replace this matrix with C such that C = A - B.
Definition TMatrixT.cxx:576
Double_t y[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
create variable transformations
TMarker m
Definition textangle.C:8
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345
static void output()