27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
40template <
typename Architecture_t>
43 using Tensor_t =
typename Architecture_t::Tensor_t;
44 using Matrix_t =
typename Architecture_t::Matrix_t;
45 using Scalar_t =
typename Architecture_t::Scalar_t;
93template <
typename Architecture_t>
97 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
depth,
height,
width, 0, 0, 0, 0, 0,
103 std::cout <<
"Reshape Dimensions not compatible \n"
111template <
typename Architecture_t>
118template <
typename Architecture_t>
126template <
typename Architecture_t>
133template <
typename Architecture_t>
138 Architecture_t::Flatten(this->GetOutput(),
input);
143 Architecture_t::Deflatten(this->GetOutput(),
input);
148template <
typename Architecture_t>
156 if (
size == 0)
return;
168template <
typename Architecture_t>
171 std::cout <<
" RESHAPE Layer \t ";
172 std::cout <<
"Input = ( " << this->GetInputDepth() <<
" , " << this->GetInputHeight() <<
" , " << this->GetInputWidth() <<
" ) ";
173 if (this->GetOutput().GetSize() > 0) {
174 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" ) ";
176 std::cout << std::endl;
179template <
typename Architecture_t>
194template <
typename Architecture_t>
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t height
typename Architecture_t::Scalar_t Scalar_t
typename Architecture_t::Tensor_t Tensor_t
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
typename Architecture_t::Matrix_t Matrix_t
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
bool fFlattening
Whether the layer is doing flattening.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
void Forward(Tensor_t &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
void Print() const
Prints the info about the layer.
~TReshapeLayer()
Destructor.
Generic General Layer class.
size_t GetInputDepth() const
size_t GetInputHeight() const
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
create variable transformations