27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
40template <
typename Architecture_t>
43 using Tensor_t =
typename Architecture_t::Tensor_t;
44 using Matrix_t =
typename Architecture_t::Matrix_t;
45 using Scalar_t =
typename Architecture_t::Scalar_t;
52 TReshapeLayer(
size_t BatchSize,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t Depth,
53 size_t Height,
size_t Width,
size_t OutputNSlices,
size_t OutputNRows,
size_t OutputNCols,
93template <
typename Architecture_t>
95 size_t depth,
size_t height,
size_t width,
size_t outputNSlices,
96 size_t outputNRows,
size_t outputNCols,
bool flattening)
97 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height,
width, 0, 0, 0, 0, 0,
99 fFlattening(flattening)
103 std::cout <<
"Reshape Dimensions not compatible \n"
111template <
typename Architecture_t>
113 :
VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
118template <
typename Architecture_t>
120 :
VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
126template <
typename Architecture_t>
133template <
typename Architecture_t>
138 Architecture_t::Flatten(this->GetOutput(), input);
143 Architecture_t::Deflatten(this->GetOutput(), input);
148template <
typename Architecture_t>
154 size_t size = gradients_backward.GetSize();
156 if (size == 0)
return;
159 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
162 Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
168template <
typename Architecture_t>
171 std::cout <<
" RESHAPE Layer \t ";
172 std::cout <<
"Input = ( " << this->GetInputDepth() <<
" , " << this->GetInputHeight() <<
" , " << this->GetInputWidth() <<
" ) ";
173 if (this->GetOutput().GetSize() > 0) {
174 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" ) ";
176 std::cout << std::endl;
179template <
typename Architecture_t>
194template <
typename Architecture_t>
include TDocParser_001 C image html pict1_TDocParser_001 png width
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
bool fFlattening
Whather the layer is doing flattening.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
void Forward(Tensor_t &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
void Print() const
Prints the info about the layer.
~TReshapeLayer()
Destructor.
Generic General Layer class.
typename Architecture_t::Matrix_t Matrix_t
typename Architecture_t::Scalar_t Scalar_t
size_t GetInputDepth() const
size_t GetInputHeight() const
typename Architecture_t::Tensor_t Tensor_t
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
UInt_t Depth(const Node< T > *node)
create variable transformations