Logo ROOT  
Reference Guide
OutputFunctions.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 21/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12///////////////////////////////////////////////////////////////
13// Implementation of output functions for multi-threaded CPU //
14// architectures. //
15///////////////////////////////////////////////////////////////
16
18
19namespace TMVA
20{
21namespace DNN
22{
23
24template<typename AFloat>
25void TCpu<AFloat>::Sigmoid(TCpuMatrix<AFloat> & B,
26 const TCpuMatrix<AFloat> & A)
27{
28 auto f = [](AFloat x) {return 1.0 / (1.0 + exp(-x));};
29 B.MapFrom(f, A);
30}
31
32template<typename AFloat>
34 const TCpuMatrix<AFloat> & A)
35{
36 const AFloat *dataA = A.GetRawDataPointer();
37 AFloat *dataB = B.GetRawDataPointer();
38 size_t n = A.GetNcols();
39 size_t m = A.GetNrows();
40
41 auto f = [&dataA, &dataB, n, m](UInt_t workerID)
42 {
43 AFloat sum = 0.0;
44 for (size_t i = 0; i < n; i++) {
45 sum += exp(dataA[workerID + i * m]);
46 }
47 for (size_t i = 0; i < n; i++) {
48 dataB[workerID + i * m] = exp(dataA[workerID + i * m]) / sum;
49 }
50 return 0;
51 };
52
53 B.GetThreadExecutor().Map(f, ROOT::TSeqI(A.GetNrows()));
54}
55
56} // namespace DNN
57} // namespace TMVA
#define f(i)
Definition: RSha256.hxx:104
double exp(double)
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
The TCpuMatrix class.
Definition: CpuMatrix.h:86
static void Sigmoid(Tensor_t &B)
static void Softmax(Matrix_t &YHat, const Matrix_t &)
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
static double B[]
static double A[]
create variable transformations
auto * m
Definition: textangle.C:8
static long int sum(long int i)
Definition: Factory.cxx:2275