Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Initialization.cu
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 14/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 /////////////////////////////////////////////////////////////
13 // Implementation of the initialization functions for CUDA //
14 // Architectures //
15 /////////////////////////////////////////////////////////////
16
17#include "TRandom3.h"
18#include "TMatrix.h"
20#include "Kernels.cuh"
21
22namespace TMVA
23{
24namespace DNN
25{
26
27template <typename AFloat>
29//______________________________________________________________________________
30template<typename AFloat>
32{
33 if (!fgRandomGen) fgRandomGen = new TRandom3();
34 fgRandomGen->SetSeed(seed);
35}
36template<typename AFloat>
38{
39 if (!fgRandomGen) fgRandomGen = new TRandom3(0);
40 return *fgRandomGen;
41}
42//______________________________________________________________________________
43template<typename AFloat>
45{
46 size_t m,n;
47 m = A.GetNrows();
48 n = A.GetNcols();
49
50 TRandom & rand = GetRandomGenerator();
52
53 Double_t sigma = sqrt(2.0 / ((Double_t) n));
54
55 for (size_t i = 0; i < m; i++) {
56 for (size_t j = 0; j < n; j++) {
57 B(i,j) = rand.Gaus(0.0, sigma);
58 }
59 }
60 A = B;
61}
62
63//______________________________________________________________________________
64template<typename AFloat>
66{
67 size_t m,n;
68 m = A.GetNrows();
69 n = A.GetNcols();
70
71 TRandom & rand = GetRandomGenerator();
73
74 Double_t range = sqrt(2.0 / ((Double_t) n));
75
76 for (size_t i = 0; i < m; i++) {
77 for (size_t j = 0; j < n; j++) {
78 B(i,j) = rand.Uniform(-range, range);
79 }
80 }
81 A = B;
82}
83
84//______________________________________________________________________________
85/// Truncated normal initialization (Glorot, called also Xavier normal)
86/// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
87/// values larger than 2 * stddev are discarded
88/// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
89template<typename AFloat>
91{
92 size_t m,n;
93 m = A.GetNrows();
94 n = A.GetNcols();
95
96 TRandom & rand = GetRandomGenerator();
98
99 AFloat sigma = sqrt(2.0 /( ((AFloat) n) + ((AFloat) m)) );
100
101 for (size_t i = 0; i < m; i++) {
102 for (size_t j = 0; j < n; j++) {
103 AFloat value = 0;
104 do {
105 value = rand.Gaus(0.0, sigma);
106 } while ( std::abs(value) > 2*sigma);
107 B(i,j) = value;
108 }
109 }
110 A = B;
111}
112
113//______________________________________________________________________________
114/// Sample from a uniform distribution in range [ -lim,+lim] where
115/// lim = sqrt(6/N_in+N_out).
116/// This initialization is also called Xavier uniform
117/// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
118template<typename AFloat>
120{
121 size_t m,n;
122 m = A.GetNrows();
123 n = A.GetNcols();
124
125 TRandom & rand = GetRandomGenerator();
126 TMatrixT<AFloat> B(m, n);
127
128 AFloat range = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
129
130 for (size_t i = 0; i < m; i++) {
131 for (size_t j = 0; j < n; j++) {
132 B(i,j) = rand.Uniform(-range, range);
133 }
134 }
135 printf("initialize glorotuniform \n");
136 B.Print();
137 A = B;
138}
139
140//______________________________________________________________________________
141template<typename AFloat>
143{
144 size_t m,n;
145 m = A.GetNrows();
146 n = A.GetNcols();
147 TMatrixT<AFloat> B(m, n);
148
149 for (size_t i = 0; i < m; i++) {
150 for (size_t j = 0; j < n ; j++) {
151 B(i,j) = 0.0;
152 }
153
154 if (i < n) {
155 B(i,i) = 1.0;
156 }
157 }
158 A = B;
159}
160
161//______________________________________________________________________________
162template<typename AFloat>
164{
165 // use fast zero initialization on the device
166 A.Zero();
167}
168//______________________________________________________________________________
169template <typename AFloat>
171{
172 // use fast zero initialization on the device
173 T.Zero();
174}
175
176} // namespace DNN
177} // namespace TMVA
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
TCudaMatrix Class.
Definition CudaMatrix.h:103
size_t GetNcols() const
Definition CudaMatrix.h:160
size_t GetNrows() const
Definition CudaMatrix.h:159
TCudaTensor Class.
Definition CudaTensor.h:84
static void SetRandomSeed(size_t seed)
static void InitializeGlorotUniform(Matrix_t &A)
Sample from a uniform distribution in range [ -lim,+lim] where lim = sqrt(6/N_in+N_out).
static void InitializeGlorotNormal(Matrix_t &A)
Truncated normal initialization (Glorot, called also Xavier normal) The values are sample with a norm...
static void InitializeUniform(Matrix_t &A)
static void InitializeGauss(Matrix_t &A)
static TRandom * fgRandomGen
Definition Cuda.h:67
static void InitializeIdentity(Matrix_t &A)
static TRandom & GetRandomGenerator()
static void InitializeZero(Matrix_t &A)
void Print(Option_t *name="") const override
Print the matrix as a table of elements.
TMatrixT.
Definition TMatrixT.h:40
Random number generator class based on M.
Definition TRandom3.h:27
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition TRandom.cxx:275
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition TRandom.cxx:682
const Double_t sigma
const Int_t n
Definition legend1.C:16
create variable transformations
TMarker m
Definition textangle.C:8