Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Initialization.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 21/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////
13 // Implementation of the DNN initialization methods for the //
14 // multi-threaded CPU backend. //
15 //////////////////////////////////////////////////////////////
16
17#include "TRandom3.h"
19
20namespace TMVA
21{
22namespace DNN
23{
24
25template <typename AFloat_t>
27//______________________________________________________________________________
28template<typename AFloat>
30{
31 if (!fgRandomGen) fgRandomGen = new TRandom3();
32 fgRandomGen->SetSeed(seed);
33}
34template<typename AFloat>
36{
37 if (!fgRandomGen) fgRandomGen = new TRandom3(0);
38 return *fgRandomGen;
39}
40
41//______________________________________________________________________________
42template<typename AFloat>
44{
45 size_t n = A.GetNcols();
46
47 TRandom & rand = GetRandomGenerator();
48
49 AFloat sigma = sqrt(2.0 / ((AFloat) n));
50
51 for (size_t i = 0; i < A.GetSize(); ++i) {
52 A.GetRawDataPointer()[i] = rand.Gaus(0.0, sigma);
53 }
54 // for (size_t i = 0; i < A.GetSize(); ++i) {
55 // A.GetRawDataPointer()[i] = rand.Gaus(0.0, sigma);
56 // }
57}
58
59//______________________________________________________________________________
60template<typename AFloat>
62{
63 //size_t m = A.GetNrows();
64 size_t n = A.GetNcols();
65
66 TRandom & rand = GetRandomGenerator();
67
68 AFloat range = sqrt(2.0 / ((AFloat) n));
69
70 // for debugging
71 //range = 1;
72 //rand.SetSeed(111);
73
74 for (size_t i = 0; i < A.GetSize(); ++i) {
75 A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
76 }
77}
78
79 //______________________________________________________________________________
80/// Truncated normal initialization (Glorot, called also Xavier normal)
81/// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
82/// values larger than 2 * stddev are discarded
83/// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
84template<typename AFloat>
86{
87 size_t m,n;
88 // for conv layer weights output m is only output depth. It shouild ne multiplied also by filter sizes
89 // e.g. 9 for a 3x3 filter. But this information is lost if we use Tensors of dims 2
90 m = A.GetNrows();
91 n = A.GetNcols();
92
93 TRandom & rand = GetRandomGenerator();
94
95 AFloat sigma = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
96 // AFloat sigma = sqrt(2.0 /( ((AFloat) m)) );
97
98 size_t nsize = A.GetSize();
99 for (size_t i = 0; i < nsize; i++) {
100 AFloat value = 0;
101 do {
102 value = rand.Gaus(0.0, sigma);
103 } while (std::abs(value) > 2 * sigma);
104 A.GetRawDataPointer()[i] = value;
105 }
106}
107
108//______________________________________________________________________________
109/// Sample from a uniform distribution in range [ -lim,+lim] where
110/// lim = sqrt(6/N_in+N_out).
111/// This initialization is also called Xavier uniform
112/// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
113template<typename AFloat>
115{
116 size_t m,n;
117 m = A.GetNrows(); // output size
118 n = A.GetNcols(); // input size
119 // Note that m and n are inverted with respect to cudnn because tensor is here column-wise
120
121 TRandom & rand = GetRandomGenerator();
122
123 AFloat range = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
124
125 size_t nsize = A.GetSize();
126 for (size_t i = 0; i < nsize; i++) {
127 A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
128 }
129}
130
131//______________________________________________________________________________
132template<typename AFloat>
134{
135 size_t m,n;
136 m = A.GetNrows();
137 n = A.GetNcols();
138
139 for (size_t i = 0; i < m; i++) {
140 for (size_t j = 0; j < n; j++) {
141 A(i,j) = 0.0;
142 //A(i,j) = 1.0;
143 }
144
145 if (i < n) {
146 A(i,i) = 1.0;
147 }
148 }
149}
150
151//______________________________________________________________________________
152template<typename AFloat>
154{
155 size_t m,n;
156 m = A.GetNrows();
157 n = A.GetNcols();
158
159 for (size_t i = 0; i < m; i++) {
160 for (size_t j = 0; j < n ; j++) {
161 A(i,j) = 0.0;
162 }
163 }
164}
165//______________________________________________________________________________
166template <typename AFloat>
168{
169 size_t n = A.GetSize();
170
171 for (size_t i = 0; i < n; i++) {
172 A.GetRawDataPointer()[i] = 0.0;
173 }
174}
175
176} // namespace DNN
177} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
The TCpuMatrix class.
Definition CpuMatrix.h:86
size_t GetNcols() const
Definition CpuMatrix.h:156
size_t GetNrows() const
Definition CpuMatrix.h:155
static TRandom * fgRandomGen
Definition Cpu.h:67
TCpuTensor< AReal > Tensor_t
Definition Cpu.h:70
static void InitializeIdentity(Matrix_t &A)
static TRandom & GetRandomGenerator()
static void InitializeUniform(Matrix_t &A)
static void SetRandomSeed(size_t seed)
static void InitializeGauss(Matrix_t &A)
static void InitializeZero(Matrix_t &A)
static void InitializeGlorotUniform(Matrix_t &A)
Sample from a uniform distribution in range [ -lim,+lim] where lim = sqrt(6/N_in+N_out).
static void InitializeGlorotNormal(Matrix_t &A)
Truncated normal initialization (Glorot, called also Xavier normal) The values are sample with a norm...
Random number generator class based on M.
Definition TRandom3.h:27
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
const Double_t sigma
const Int_t n
Definition legend1.C:16
create variable transformations
TMarker m
Definition textangle.C:8