Logo ROOT   6.16/01
Reference Guide
ReshapeLayer.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Vladimir Ilievski
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TReshapeLayer *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Reshape Deep Neural Network Layer *
12 * *
13 * Authors (alphabetical): *
14 * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
24 * (http://tmva.sourceforge.net/LICENSE) *
25 **********************************************************************************/
26
27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
29
30#include "TMatrix.h"
31
33#include "TMVA/DNN/Functions.h"
34
35#include <iostream>
36
37namespace TMVA {
38namespace DNN {
39
40template <typename Architecture_t>
41class TReshapeLayer : public VGeneralLayer<Architecture_t> {
42public:
43 using Matrix_t = typename Architecture_t::Matrix_t;
44 using Scalar_t = typename Architecture_t::Scalar_t;
45
46private:
47 bool fFlattening; ///< Whather the layer is doing flattening
48
49public:
50 /*! Constructor */
51 TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
52 size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
53 bool Flattening);
54
55 /*! Copy the reshape layer provided as a pointer */
57
58 /*! Copy Constructor */
60
61 /*! Destructor. */
63
64 /*! The input must be in 3D tensor form with the different matrices
65 * corresponding to different events in the batch. It transforms the
66 * input matrices. */
67 void Forward(std::vector<Matrix_t> &input, bool applyDropout = false);
68
69 void Backward(std::vector<Matrix_t> &gradients_backward, const std::vector<Matrix_t> &activations_backward,
70 std::vector<Matrix_t> &inp1, std::vector<Matrix_t> &inp2);
71
72 /*! Prints the info about the layer. */
73 void Print() const;
74
75 /*! Writes the information and the weights about the layer in an XML node. */
76 virtual void AddWeightsXMLTo(void *parent);
77
78 /*! Read the information and the weights about the layer from XML node. */
79 virtual void ReadWeightsFromXML(void *parent);
80
81
82 /*! TODO Add documentation
83 * Does this layer flatten? (necessary for DenseLayer)
84 * B x D1 x D2 --> 1 x B x (D1 * D2) */
85 bool isFlattening() const { return fFlattening; }
86};
87
88//
89//
90// The Reshape Layer Class - Implementation
91//_________________________________________________________________________________________________
92template <typename Architecture_t>
93TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
94 size_t depth, size_t height, size_t width, size_t outputNSlices,
95 size_t outputNRows, size_t outputNCols, bool flattening)
96 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
97 0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
98 fFlattening(flattening)
99{
100 if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
101 this->GetDepth() * this->GetHeight() * this->GetWidth()) {
102 std::cout << "Reshape Dimensions not compatible \n"
103 << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
104 << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
105 return;
106 }
107}
108
109//_________________________________________________________________________________________________
110template <typename Architecture_t>
112 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
113{
114}
115
116//_________________________________________________________________________________________________
117template <typename Architecture_t>
119 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
120{
121 // Nothing to do here.
122}
123
124//_________________________________________________________________________________________________
125template <typename Architecture_t>
127{
128 // Nothing to do here.
129}
130
131//_________________________________________________________________________________________________
132template <typename Architecture_t>
133auto TReshapeLayer<Architecture_t>::Forward(std::vector<Matrix_t> &input, bool /*applyDropout*/) -> void
134{
135 if (fFlattening) {
136 size_t size = input.size();
137 size_t nRows = input[0].GetNrows();
138 size_t nCols = input[0].GetNcols();
139 Architecture_t::Flatten(this->GetOutputAt(0), input, size, nRows, nCols);
140 } else {
141 for (size_t i = 0; i < this->GetBatchSize(); i++) {
142 Architecture_t::Reshape(this->GetOutputAt(i), input[i]);
143 }
144 }
145}
146
147//_________________________________________________________________________________________________
148template <typename Architecture_t>
149auto TReshapeLayer<Architecture_t>::Backward(std::vector<Matrix_t> &gradients_backward,
150 const std::vector<Matrix_t> & /*activations_backward*/,
151 std::vector<Matrix_t> & /*inp1*/, std::vector<Matrix_t> &
152 /*inp2*/) -> void
153{
154 // in case of first layer size is zero - do nothing
155 if (gradients_backward.size() == 0) return;
156 if (fFlattening) {
157 size_t size = gradients_backward.size();
158 size_t nRows = gradients_backward[0].GetNrows();
159 size_t nCols = gradients_backward[0].GetNcols();
160 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradientsAt(0), size, nRows, nCols);
161 } else {
162 for (size_t i = 0; i < this->GetBatchSize(); i++) {
163 Architecture_t::Reshape(gradients_backward[i], this->GetActivationGradientsAt(i));
164 }
165 }
166}
167
168//_________________________________________________________________________________________________
169template <typename Architecture_t>
171{
172 std::cout << " RESHAPE Layer \t ";
173 std::cout << "Input = ( " << this->GetInputDepth() << " , " << this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
174 if (this->GetOutput().size() > 0) {
175 std::cout << "\tOutput = ( " << this->GetOutput().size() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " ) ";
176 }
177 std::cout << std::endl;
178}
179
180template <typename Architecture_t>
182{
183 auto layerxml = gTools().xmlengine().NewChild(parent, 0, "ReshapeLayer");
184
185 // write info for reshapelayer
186 gTools().xmlengine().NewAttr(layerxml, 0, "Depth", gTools().StringFromInt(this->GetDepth()));
187 gTools().xmlengine().NewAttr(layerxml, 0, "Height", gTools().StringFromInt(this->GetHeight()));
188 gTools().xmlengine().NewAttr(layerxml, 0, "Width", gTools().StringFromInt(this->GetWidth()));
189 gTools().xmlengine().NewAttr(layerxml, 0, "Flattening", gTools().StringFromInt(this->isFlattening()));
190
191
192}
193
194//______________________________________________________________________________
195template <typename Architecture_t>
197{
198 // no info to read
199}
200
201
202
203} // namespace DNN
204} // namespace TMVA
205
206#endif
include TDocParser_001 C image html pict1_TDocParser_001 png width
Definition: TDocParser.cxx:121
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
Definition: ReshapeLayer.h:93
void Backward(std::vector< Matrix_t > &gradients_backward, const std::vector< Matrix_t > &activations_backward, std::vector< Matrix_t > &inp1, std::vector< Matrix_t > &inp2)
Backpropagates the error.
Definition: ReshapeLayer.h:149
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Definition: ReshapeLayer.h:181
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
Definition: ReshapeLayer.h:85
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
Definition: ReshapeLayer.h:196
bool fFlattening
Whather the layer is doing flattening.
Definition: ReshapeLayer.h:47
void Forward(std::vector< Matrix_t > &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
Definition: ReshapeLayer.h:133
void Print() const
Prints the info about the layer.
Definition: ReshapeLayer.h:170
~TReshapeLayer()
Destructor.
Definition: ReshapeLayer.h:126
Generic General Layer class.
Definition: GeneralLayer.h:46
typename Architecture_t::Matrix_t Matrix_t
Definition: GeneralLayer.h:47
size_t GetDepth() const
Definition: GeneralLayer.h:145
typename Architecture_t::Scalar_t Scalar_t
Definition: GeneralLayer.h:48
size_t GetInputDepth() const
Definition: GeneralLayer.h:142
size_t GetInputHeight() const
Definition: GeneralLayer.h:143
size_t GetWidth() const
Definition: GeneralLayer.h:147
size_t GetHeight() const
Definition: GeneralLayer.h:146
size_t GetInputWidth() const
Definition: GeneralLayer.h:144
TXMLEngine & xmlengine()
Definition: Tools.h:270
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
Definition: TXMLEngine.cxx:578
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=0)
create new child element for parent node
Definition: TXMLEngine.cxx:707
EInitialization
Definition: Functions.h:70
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:213
Abstract ClassifierFactory template that handles arbitrary types.
Tools & gTools()