Logo ROOT   6.14/05
Reference Guide
Initialization.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 10/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  //////////////////////////////////////////////////////////////////////
13  // Implementation of the initialization functions for the reference //
14  // implementation. //
15  //////////////////////////////////////////////////////////////////////
16 
17 #include "TRandom3.h"
19 
20 namespace TMVA
21 {
22 namespace DNN
23 {
24 
25 template <typename Real_t>
27 //______________________________________________________________________________
28 template<typename Real_t>
30 {
31  if (!fgRandomGen) fgRandomGen = new TRandom3();
32  fgRandomGen->SetSeed(seed);
33 }
34 template<typename Real_t>
36 {
37  if (!fgRandomGen) fgRandomGen = new TRandom3(0);
38  return *fgRandomGen;
39 }
40 
41 //______________________________________________________________________________
42 template<typename Real_t>
44 {
45  size_t m,n;
46  m = A.GetNrows();
47  n = A.GetNcols();
48 
49  TRandom & rand = GetRandomGenerator();
50 
51  Real_t sigma = sqrt(2.0 / ((Real_t) n));
52 
53  for (size_t i = 0; i < m; i++) {
54  for (size_t j = 0; j < n; j++) {
55  A(i,j) = rand.Gaus(0.0, sigma);
56  }
57  }
58 }
59 
60 //______________________________________________________________________________
61 template<typename Real_t>
63 {
64  size_t m,n;
65  m = A.GetNrows();
66  n = A.GetNcols();
67 
68  TRandom & rand = GetRandomGenerator();
69 
70  Real_t range = sqrt(2.0 / ((Real_t) n));
71 
72  for (size_t i = 0; i < m; i++) {
73  for (size_t j = 0; j < n; j++) {
74  A(i,j) = rand.Uniform(-range, range);
75  }
76  }
77 }
78 
79  //______________________________________________________________________________
80 /// Truncated normal initialization (Glorot, called also Xavier normal)
81 /// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
82 /// values larger than 2 * stddev are discarded
83 /// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
84 template<typename Real_t>
86 {
87  size_t m,n;
88  m = A.GetNrows();
89  n = A.GetNcols();
90 
91  TRandom & rand = GetRandomGenerator();
92 
93  Real_t sigma = sqrt(2.0 /( ((Real_t) n) + ((Real_t) m)) );
94 
95  for (size_t i = 0; i < m; i++) {
96  for (size_t j = 0; j < n; j++) {
97  Real_t value = rand.Gaus(0.0, sigma);
98  if ( std::abs(value) > 2*sigma) continue;
99  A(i,j) = rand.Gaus(0.0, sigma);
100  }
101  }
102 }
103 
104 //______________________________________________________________________________
105 /// Sample from a uniform distribution in range [ -lim,+lim] where
106 /// lim = sqrt(6/N_in+N_out).
107 /// This initialization is also called Xavier uniform
108 /// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
109 template<typename Real_t>
111 {
112  size_t m,n;
113  m = A.GetNrows();
114  n = A.GetNcols();
115 
116  TRandom & rand = GetRandomGenerator();
117 
118  Real_t range = sqrt(6.0 /( ((Real_t) n) + ((Real_t) m)) );
119 
120  for (size_t i = 0; i < m; i++) {
121  for (size_t j = 0; j < n; j++) {
122  A(i,j) = rand.Uniform(-range, range);
123  }
124  }
125 }
126 
127 //______________________________________________________________________________
128 template<typename Real_t>
130 {
131  size_t m,n;
132  m = A.GetNrows();
133  n = A.GetNcols();
134 
135  for (size_t i = 0; i < m; i++) {
136  for (size_t j = 0; j < n ; j++) {
137  A(i,j) = 0.0;
138  }
139 
140  if (i < n) {
141  A(i,i) = 1.0;
142  }
143  }
144 }
145 
146 //______________________________________________________________________________
147 template<typename Real_t>
149 {
150  size_t m,n;
151  m = A.GetNrows();
152  n = A.GetNcols();
153 
154  for (size_t i = 0; i < m; i++) {
155  for (size_t j = 0; j < n ; j++) {
156  A(i,j) = 0.0;
157  }
158  }
159 }
160 
161 
162 } // namespace DNN
163 } // namespace TMVA
static void SetRandomSeed(size_t seed)
Random number generator class based on M.
Definition: TRandom3.h:27
static void InitializeGlorotUniform(TMatrixT< AReal > &A)
Sample from a uniform distribution in range [ -lim,+lim] where lim = sqrt(6/N_in+N_out).
auto * m
Definition: textangle.C:8
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition: TRandom.cxx:256
Int_t GetNcols() const
Definition: TMatrixTBase.h:125
static double A[]
TMatrixT.
Definition: TMatrixDfwd.h:22
double sqrt(double)
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:27
static TRandom & GetRandomGenerator()
const Double_t sigma
static void InitializeUniform(TMatrixT< AReal > &A)
static void InitializeGlorotNormal(TMatrixT< AReal > &A)
Truncated normal initialization (Glorot, called also Xavier normal) The values are sample with a norm...
static void InitializeZero(TMatrixT< AReal > &A)
Int_t GetNrows() const
Definition: TMatrixTBase.h:122
static TRandom * fgRandomGen
Definition: Reference.h:46
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:627
static void InitializeIdentity(TMatrixT< AReal > &A)
float Real_t
Definition: RtypesCore.h:64
Abstract ClassifierFactory template that handles arbitrary types.
static void InitializeGauss(TMatrixT< AReal > &A)
const Int_t n
Definition: legend1.C:16