Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Initialization.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 10/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////////////
13 // Implementation of the initialization functions for the reference //
14 // implementation. //
15 //////////////////////////////////////////////////////////////////////
16
17#include "TRandom3.h"
19
20namespace TMVA
21{
22namespace DNN
23{
24
25template <typename Real_t>
27//______________________________________________________________________________
28template<typename Real_t>
30{
31 if (!fgRandomGen) fgRandomGen = new TRandom3();
32 fgRandomGen->SetSeed(seed);
33}
34template<typename Real_t>
36{
37 if (!fgRandomGen) fgRandomGen = new TRandom3(0);
38 return *fgRandomGen;
39}
40
41//______________________________________________________________________________
42template<typename Real_t>
44{
45 size_t m,n;
46 m = A.GetNrows();
47 n = A.GetNcols();
48
49 TRandom & rand = GetRandomGenerator();
50
51 Real_t sigma = sqrt(2.0 / ((Real_t) n));
52
53 for (size_t i = 0; i < m; i++) {
54 for (size_t j = 0; j < n; j++) {
55 A(i,j) = rand.Gaus(0.0, sigma);
56 }
57 }
58}
59
60//______________________________________________________________________________
61template<typename Real_t>
63{
64 size_t m,n;
65 m = A.GetNrows();
66 n = A.GetNcols();
67
68 TRandom & rand = GetRandomGenerator();
69
70 Real_t range = sqrt(2.0 / ((Real_t) n));
71
72 for (size_t i = 0; i < m; i++) {
73 for (size_t j = 0; j < n; j++) {
74 A(i,j) = rand.Uniform(-range, range);
75 }
76 }
77}
78
79 //______________________________________________________________________________
80/// Truncated normal initialization (Glorot, called also Xavier normal)
81/// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
82/// values larger than 2 * stddev are discarded
83/// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
84template<typename Real_t>
86{
87 size_t m,n;
88 m = A.GetNrows();
89 n = A.GetNcols();
90
91 TRandom & rand = GetRandomGenerator();
92
93 Real_t sigma = sqrt(2.0 /( ((Real_t) n) + ((Real_t) m)) );
94
95 for (size_t i = 0; i < m; i++) {
96 for (size_t j = 0; j < n; j++) {
97 Real_t value = rand.Gaus(0.0, sigma);
98 if ( std::abs(value) > 2*sigma) continue;
99 A(i,j) = rand.Gaus(0.0, sigma);
100 }
101 }
102}
103
104//______________________________________________________________________________
105/// Sample from a uniform distribution in range [ -lim,+lim] where
106/// lim = sqrt(6/N_in+N_out).
107/// This initialization is also called Xavier uniform
108/// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
109template<typename Real_t>
111{
112 size_t m,n;
113 m = A.GetNrows();
114 n = A.GetNcols();
115
116 TRandom & rand = GetRandomGenerator();
117
118 Real_t range = sqrt(6.0 /( ((Real_t) n) + ((Real_t) m)) );
119
120 for (size_t i = 0; i < m; i++) {
121 for (size_t j = 0; j < n; j++) {
122 A(i,j) = rand.Uniform(-range, range);
123 }
124 }
125}
126
127//______________________________________________________________________________
128template<typename Real_t>
130{
131 size_t m,n;
132 m = A.GetNrows();
133 n = A.GetNcols();
134
135 for (size_t i = 0; i < m; i++) {
136 for (size_t j = 0; j < n ; j++) {
137 A(i,j) = 0.0;
138 }
139
140 if (i < n) {
141 A(i,i) = 1.0;
142 }
143 }
144}
145
146//______________________________________________________________________________
147template<typename Real_t>
149{
150 size_t m,n;
151 m = A.GetNrows();
152 n = A.GetNcols();
153
154 for (size_t i = 0; i < m; i++) {
155 for (size_t j = 0; j < n ; j++) {
156 A(i,j) = 0.0;
157 }
158 }
159}
160
161
162} // namespace DNN
163} // namespace TMVA
float Real_t
Definition RtypesCore.h:68
double sqrt(double)
static void InitializeIdentity(TMatrixT< AReal > &A)
static void InitializeGlorotNormal(TMatrixT< AReal > &A)
Truncated normal initialization (Glorot, called also Xavier normal) The values are sample with a norm...
static void SetRandomSeed(size_t seed)
static void InitializeZero(TMatrixT< AReal > &A)
static TRandom * fgRandomGen
Definition Reference.h:55
static void InitializeGauss(TMatrixT< AReal > &A)
static void InitializeGlorotUniform(TMatrixT< AReal > &A)
Sample from a uniform distribution in range [ -lim,+lim] where lim = sqrt(6/N_in+N_out).
static TRandom & GetRandomGenerator()
static void InitializeUniform(TMatrixT< AReal > &A)
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
Random number generator class based on M.
Definition TRandom3.h:27
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition TRandom.cxx:274
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition TRandom.cxx:672
const Double_t sigma
const Int_t n
Definition legend1.C:16
create variable transformations
auto * m
Definition textangle.C:8