Logo ROOT  
Reference Guide
Initialization.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 21/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////
13 // Implementation of the DNN initialization methods for the //
14 // multi-threaded CPU backend. //
15 //////////////////////////////////////////////////////////////
16
17#include "TRandom3.h"
19
20namespace TMVA
21{
22namespace DNN
23{
24
25template <typename AFloat_t>
27//______________________________________________________________________________
28template<typename AFloat>
30{
31 if (!fgRandomGen) fgRandomGen = new TRandom3();
32 fgRandomGen->SetSeed(seed);
33}
34template<typename AFloat>
36{
37 if (!fgRandomGen) fgRandomGen = new TRandom3(0);
38 return *fgRandomGen;
39}
40
41//______________________________________________________________________________
42template<typename AFloat>
44{
45 size_t n = A.GetNcols();
46
47 TRandom & rand = GetRandomGenerator();
48
49 AFloat sigma = sqrt(2.0 / ((AFloat) n));
50
51 for (size_t i = 0; i < A.GetSize(); ++i) {
52 A.GetRawDataPointer()[i] = rand.Gaus(0.0, sigma);
53 }
54}
55
56//______________________________________________________________________________
57template<typename AFloat>
59{
60 //size_t m = A.GetNrows();
61 size_t n = A.GetNcols();
62
63 TRandom & rand = GetRandomGenerator();
64
65 AFloat range = sqrt(2.0 / ((AFloat) n));
66
67 // for debugging
68 //range = 1;
69 //rand.SetSeed(111);
70
71 for (size_t i = 0; i < A.GetSize(); ++i) {
72 A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
73 }
74}
75
76 //______________________________________________________________________________
77/// Truncated normal initialization (Glorot, called also Xavier normal)
78/// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
79/// values larger than 2 * stddev are discarded
80/// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
81template<typename AFloat>
83{
84 size_t m,n;
85 // for conv layer weights output m is only output depth. It shouild ne multiplied also by filter sizes
86 // e.g. 9 for a 3x3 filter. But this information is lost if we use Tensors of dims 2
87 m = A.GetNrows();
88 n = A.GetNcols();
89
90 TRandom & rand = GetRandomGenerator();
91
92 AFloat sigma = sqrt(2.0 /( ((AFloat) n) + ((AFloat) m)) );
93 // AFloat sigma = sqrt(2.0 /( ((AFloat) m)) );
94
95 size_t nsize = A.GetSize();
96 for (size_t i = 0; i < nsize; i++) {
97 AFloat value = 0;
98 do {
99 value = rand.Gaus(0.0, sigma);
100 } while (std::abs(value) > 2 * sigma);
101 R__ASSERT(std::abs(value) < 2 * sigma);
102 A.GetRawDataPointer()[i] = value;
103 }
104}
105
106//______________________________________________________________________________
107/// Sample from a uniform distribution in range [ -lim,+lim] where
108/// lim = sqrt(6/N_in+N_out).
109/// This initialization is also called Xavier uniform
110/// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
111template<typename AFloat>
113{
114 size_t m,n;
115 m = A.GetNrows();
116 n = A.GetNcols();
117
118 TRandom & rand = GetRandomGenerator();
119
120 AFloat range = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
121
122 size_t nsize = A.GetSize();
123 for (size_t i = 0; i < nsize; i++) {
124 A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
125 }
126}
127
128//______________________________________________________________________________
129template<typename AFloat>
131{
132 size_t m,n;
133 m = A.GetNrows();
134 n = A.GetNcols();
135
136 for (size_t i = 0; i < m; i++) {
137 for (size_t j = 0; j < n ; j++) {
138 //A(i,j) = 0.0;
139 A(i,j) = 1.0;
140 }
141
142 if (i < n) {
143 A(i,i) = 1.0;
144 }
145 }
146}
147
148//______________________________________________________________________________
149template<typename AFloat>
151{
152 size_t m,n;
153 m = A.GetNrows();
154 n = A.GetNcols();
155
156 for (size_t i = 0; i < m; i++) {
157 for (size_t j = 0; j < n ; j++) {
158 A(i,j) = 0.0;
159 }
160 }
161}
162
163} // namespace DNN
164} // namespace TMVA
#define R__ASSERT(e)
Definition: TError.h:96
double sqrt(double)
The TCpuMatrix class.
Definition: CpuMatrix.h:87
static TRandom * fgRandomGen
Definition: Cpu.h:65
static void InitializeIdentity(Matrix_t &A)
static TRandom & GetRandomGenerator()
static void InitializeUniform(Matrix_t &A)
static void SetRandomSeed(size_t seed)
static void InitializeGauss(Matrix_t &A)
static void InitializeGlorotUniform(Matrix_t &A)
Sample from a uniform distribution in range [ -lim,+lim] where lim = sqrt(6/N_in+N_out).
static void InitializeGlorotNormal(Matrix_t &A)
Truncated normal initialization (Glorot, called also Xavier normal) The values are sample with a norm...
static void InitializeZero(Matrix_t &A)
Random number generator class based on M.
Definition: TRandom3.h:27
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:27
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition: TRandom.cxx:263
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:635
const Double_t sigma
const Int_t n
Definition: legend1.C:16
static double A[]
create variable transformations
auto * m
Definition: textangle.C:8