Logo ROOT   6.18/05
Reference Guide
Initialization.cxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 21/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////
13 // Implementation of the DNN initialization methods for the //
14 // multi-threaded CPU backend. //
15 //////////////////////////////////////////////////////////////
16
17#include "TRandom3.h"
19
20namespace TMVA
21{
22namespace DNN
23{
24
25template <typename AFloat_t>
27//______________________________________________________________________________
28template<typename AFloat>
30{
31 if (!fgRandomGen) fgRandomGen = new TRandom3();
32 fgRandomGen->SetSeed(seed);
33}
34template<typename AFloat>
36{
37 if (!fgRandomGen) fgRandomGen = new TRandom3(0);
38 return *fgRandomGen;
39}
40
41//______________________________________________________________________________
42template<typename AFloat>
44{
45 size_t m,n;
46 m = A.GetNrows();
47 n = A.GetNcols();
48
49 TRandom & rand = GetRandomGenerator();
50
51 AFloat sigma = sqrt(2.0 / ((AFloat) n));
52
53 for (size_t i = 0; i < m; i++) {
54 for (size_t j = 0; j < n; j++) {
55 A(i,j) = rand.Gaus(0.0, sigma);
56 }
57 }
58}
59
60//______________________________________________________________________________
61template<typename AFloat>
63{
64 size_t m,n;
65 m = A.GetNrows();
66 n = A.GetNcols();
67
68 TRandom & rand = GetRandomGenerator();
69
70 AFloat range = sqrt(2.0 / ((AFloat) n));
71
72 for (size_t i = 0; i < m; i++) {
73 for (size_t j = 0; j < n; j++) {
74 A(i,j) = rand.Uniform(-range, range);
75 }
76 }
77}
78
79 //______________________________________________________________________________
80/// Truncated normal initialization (Glorot, called also Xavier normal)
81/// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
82/// values larger than 2 * stddev are discarded
83/// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
84template<typename AFloat>
86{
87 size_t m,n;
88 m = A.GetNrows();
89 n = A.GetNcols();
90
91 TRandom & rand = GetRandomGenerator();
92
93 AFloat sigma = sqrt(2.0 /( ((AFloat) n) + ((AFloat) m)) );
94
95 for (size_t i = 0; i < m; i++) {
96 for (size_t j = 0; j < n; j++) {
97 AFloat value = rand.Gaus(0.0, sigma);
98 if ( std::abs(value) > 2*sigma) continue;
99 A(i,j) = rand.Gaus(0.0, sigma);
100 }
101 }
102}
103
104//______________________________________________________________________________
105/// Sample from a uniform distribution in range [ -lim,+lim] where
106/// lim = sqrt(6/N_in+N_out).
107/// This initialization is also called Xavier uniform
108/// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
109template<typename AFloat>
111{
112 size_t m,n;
113 m = A.GetNrows();
114 n = A.GetNcols();
115
116 TRandom & rand = GetRandomGenerator();
117
118 AFloat range = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
119
120 for (size_t i = 0; i < m; i++) {
121 for (size_t j = 0; j < n; j++) {
122 A(i,j) = rand.Uniform(-range, range);
123 }
124 }
125}
126
127//______________________________________________________________________________
128template<typename AFloat>
130{
131 size_t m,n;
132 m = A.GetNrows();
133 n = A.GetNcols();
134
135 for (size_t i = 0; i < m; i++) {
136 for (size_t j = 0; j < n ; j++) {
137 A(i,j) = 0.0;
138 }
139
140 if (i < n) {
141 A(i,i) = 1.0;
142 }
143 }
144}
145
146//______________________________________________________________________________
147template<typename AFloat>
149{
150 size_t m,n;
151 m = A.GetNrows();
152 n = A.GetNcols();
153
154 for (size_t i = 0; i < m; i++) {
155 for (size_t j = 0; j < n ; j++) {
156 A(i,j) = 0.0;
157 }
158 }
159}
160
161} // namespace DNN
162} // namespace TMVA
double sqrt(double)
The TCpuMatrix class.
Definition: CpuMatrix.h:89
static TRandom * fgRandomGen
Definition: Cpu.h:47
static void InitializeUniform(TCpuMatrix< Scalar_t > &A)
static void InitializeGlorotNormal(TCpuMatrix< Scalar_t > &A)
Truncated normal initialization (Glorot, called also Xavier normal) The values are sample with a norm...
static void InitializeIdentity(TCpuMatrix< Scalar_t > &A)
static TRandom & GetRandomGenerator()
static void SetRandomSeed(size_t seed)
static void InitializeZero(TCpuMatrix< Scalar_t > &A)
static void InitializeGauss(TCpuMatrix< Scalar_t > &A)
static void InitializeGlorotUniform(TCpuMatrix< Scalar_t > &A)
Sample from a uniform distribution in range [ -lim,+lim] where lim = sqrt(6/N_in+N_out).
Random number generator class based on M.
Definition: TRandom3.h:27
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:27
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition: TRandom.cxx:263
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:635
const Double_t sigma
const Int_t n
Definition: legend1.C:16
static double A[]
create variable transformations
auto * m
Definition: textangle.C:8