Logo ROOT   6.14/05
Reference Guide
ReshapeLayer.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Vladimir Ilievski
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : TReshapeLayer *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Reshape Deep Neural Network Layer *
12  * *
13  * Authors (alphabetical): *
14  * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15  * *
16  * Copyright (c) 2005-2015: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * U. of Bonn, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #ifndef TMVA_DNN_RESHAPELAYER
28 #define TMVA_DNN_RESHAPELAYER
29 
30 #include "TMatrix.h"
31 
32 #include "TMVA/DNN/GeneralLayer.h"
33 #include "TMVA/DNN/Functions.h"
34 
35 #include <iostream>
36 
37 namespace TMVA {
38 namespace DNN {
39 
40 template <typename Architecture_t>
41 class TReshapeLayer : public VGeneralLayer<Architecture_t> {
42 public:
43  using Matrix_t = typename Architecture_t::Matrix_t;
44  using Scalar_t = typename Architecture_t::Scalar_t;
45 
46 private:
47  bool fFlattening; ///< Whather the layer is doing flattening
48 
49 public:
50  /*! Constructor */
51  TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
52  size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
53  bool Flattening);
54 
55  /*! Copy the reshape layer provided as a pointer */
57 
58  /*! Copy Constructor */
60 
61  /*! Destructor. */
63 
64  /*! The input must be in 3D tensor form with the different matrices
65  * corresponding to different events in the batch. It transforms the
66  * input matrices. */
67  void Forward(std::vector<Matrix_t> &input, bool applyDropout = false);
68 
69  void Backward(std::vector<Matrix_t> &gradients_backward, const std::vector<Matrix_t> &activations_backward,
70  std::vector<Matrix_t> &inp1, std::vector<Matrix_t> &inp2);
71 
72  /*! Prints the info about the layer. */
73  void Print() const;
74 
75  /*! Writes the information and the weights about the layer in an XML node. */
76  virtual void AddWeightsXMLTo(void *parent);
77 
78  /*! Read the information and the weights about the layer from XML node. */
79  virtual void ReadWeightsFromXML(void *parent);
80 
81 
82  /*! TODO Add documentation
83  * Does this layer flatten? (necessary for DenseLayer)
84  * B x D1 x D2 --> 1 x B x (D1 * D2) */
85  bool isFlattening() const { return fFlattening; }
86 };
87 
88 //
89 //
90 // The Reshape Layer Class - Implementation
91 //_________________________________________________________________________________________________
92 template <typename Architecture_t>
93 TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
94  size_t depth, size_t height, size_t width, size_t outputNSlices,
95  size_t outputNRows, size_t outputNCols, bool flattening)
96  : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
97  0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
98  fFlattening(flattening)
99 {
100  if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
101  this->GetDepth() * this->GetHeight() * this->GetWidth()) {
102  std::cout << "Reshape Dimensions not compatible \n"
103  << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
104  << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
105  return;
106  }
107 }
108 
109 //_________________________________________________________________________________________________
110 template <typename Architecture_t>
112  : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
113 {
114 }
115 
116 //_________________________________________________________________________________________________
117 template <typename Architecture_t>
119  : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
120 {
121  // Nothing to do here.
122 }
123 
124 //_________________________________________________________________________________________________
125 template <typename Architecture_t>
127 {
128  // Nothing to do here.
129 }
130 
131 //_________________________________________________________________________________________________
132 template <typename Architecture_t>
133 auto TReshapeLayer<Architecture_t>::Forward(std::vector<Matrix_t> &input, bool /*applyDropout*/) -> void
134 {
135  if (fFlattening) {
136  size_t size = input.size();
137  size_t nRows = input[0].GetNrows();
138  size_t nCols = input[0].GetNcols();
139  Architecture_t::Flatten(this->GetOutputAt(0), input, size, nRows, nCols);
140  } else {
141  for (size_t i = 0; i < this->GetBatchSize(); i++) {
142  Architecture_t::Reshape(this->GetOutputAt(i), input[i]);
143  }
144  }
145 }
146 
147 //_________________________________________________________________________________________________
148 template <typename Architecture_t>
149 auto TReshapeLayer<Architecture_t>::Backward(std::vector<Matrix_t> &gradients_backward,
150  const std::vector<Matrix_t> & /*activations_backward*/,
151  std::vector<Matrix_t> & /*inp1*/, std::vector<Matrix_t> &
152  /*inp2*/) -> void
153 {
154  // in case of first layer size is zero - do nothing
155  if (gradients_backward.size() == 0) return;
156  if (fFlattening) {
157  size_t size = gradients_backward.size();
158  size_t nRows = gradients_backward[0].GetNrows();
159  size_t nCols = gradients_backward[0].GetNcols();
160  Architecture_t::Deflatten(gradients_backward, this->GetActivationGradientsAt(0), size, nRows, nCols);
161  } else {
162  for (size_t i = 0; i < this->GetBatchSize(); i++) {
163  Architecture_t::Reshape(gradients_backward[i], this->GetActivationGradientsAt(i));
164  }
165  }
166 }
167 
168 //_________________________________________________________________________________________________
169 template <typename Architecture_t>
171 {
172  std::cout << " RESHAPE Layer \t ";
173  std::cout << "Input = ( " << this->GetInputDepth() << " , " << this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
174  if (this->GetOutput().size() > 0) {
175  std::cout << "\tOutput = ( " << this->GetOutput().size() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " ) ";
176  }
177  std::cout << std::endl;
178 }
179 
180 template <typename Architecture_t>
182 {
183  auto layerxml = gTools().xmlengine().NewChild(parent, 0, "ReshapeLayer");
184 
185  // write info for reshapelayer
186  gTools().xmlengine().NewAttr(layerxml, 0, "Depth", gTools().StringFromInt(this->GetDepth()));
187  gTools().xmlengine().NewAttr(layerxml, 0, "Height", gTools().StringFromInt(this->GetHeight()));
188  gTools().xmlengine().NewAttr(layerxml, 0, "Width", gTools().StringFromInt(this->GetWidth()));
189  gTools().xmlengine().NewAttr(layerxml, 0, "Flattening", gTools().StringFromInt(this->isFlattening()));
190 
191 
192 }
193 
194 //______________________________________________________________________________
195 template <typename Architecture_t>
197 {
198  // no info to read
199 }
200 
201 
202 
203 } // namespace DNN
204 } // namespace TMVA
205 
206 #endif
size_t GetDepth() const
Definition: GeneralLayer.h:144
size_t GetInputWidth() const
Definition: GeneralLayer.h:143
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 –> 1 x B x (D...
Definition: ReshapeLayer.h:85
Generic General Layer class.
Definition: GeneralLayer.h:45
TXMLEngine & xmlengine()
Definition: Tools.h:270
bool fFlattening
Whather the layer is doing flattening.
Definition: ReshapeLayer.h:47
void Backward(std::vector< Matrix_t > &gradients_backward, const std::vector< Matrix_t > &activations_backward, std::vector< Matrix_t > &inp1, std::vector< Matrix_t > &inp2)
Backpropagates the error.
Definition: ReshapeLayer.h:149
image html pict1_TGaxis_012 png width
Define new text attributes for the label number "labNum".
Definition: TGaxis.cxx:2551
void Forward(std::vector< Matrix_t > &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
Definition: ReshapeLayer.h:133
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:213
EInitialization
Definition: Functions.h:70
Matrix_t & GetOutputAt(size_t i)
Definition: GeneralLayer.h:179
size_t GetBatchSize() const
Getters.
Definition: GeneralLayer.h:140
size_t GetHeight() const
Definition: GeneralLayer.h:145
typename Architecture_t::Matrix_t Matrix_t
Definition: GeneralLayer.h:46
size_t GetInputHeight() const
Definition: GeneralLayer.h:142
size_t GetInputDepth() const
Definition: GeneralLayer.h:141
Tools & gTools()
Matrix_t & GetActivationGradientsAt(size_t i)
Definition: GeneralLayer.h:182
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
Definition: TXMLEngine.cxx:578
void Print() const
Prints the info about the layer.
Definition: ReshapeLayer.h:170
~TReshapeLayer()
Destructor.
Definition: ReshapeLayer.h:126
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
Definition: ReshapeLayer.h:196
Abstract ClassifierFactory template that handles arbitrary types.
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=0)
create new child element for parent node
Definition: TXMLEngine.cxx:707
typename Architecture_t::Scalar_t Scalar_t
Definition: GeneralLayer.h:47
const std::vector< Matrix_t > & GetOutput() const
Definition: GeneralLayer.h:173
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Definition: ReshapeLayer.h:181
size_t GetWidth() const
Definition: GeneralLayer.h:146
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
Definition: ReshapeLayer.h:93