27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
40template <
typename Architecture_t>
43 using Tensor_t =
typename Architecture_t::Tensor_t;
44 using Matrix_t =
typename Architecture_t::Matrix_t;
45 using Scalar_t =
typename Architecture_t::Scalar_t;
74 void Print()
const override;
93template <
typename Architecture_t>
97 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
depth,
height,
width, 0, 0, 0, 0, 0,
103 std::cout <<
"Reshape Dimensions not compatible \n"
111template <
typename Architecture_t>
118template <
typename Architecture_t>
126template <
typename Architecture_t>
133template <
typename Architecture_t>
148template <
typename Architecture_t>
156 if (
size == 0)
return;
168template <
typename Architecture_t>
171 std::cout <<
" RESHAPE Layer \t ";
172 std::cout <<
"Input = ( " << this->GetInputDepth() <<
" , " << this->GetInputHeight() <<
" , " << this->GetInputWidth() <<
" ) ";
174 std::cout <<
"\tOutput = ( " << this->
GetOutput().GetFirstSize() <<
" , " << this->
GetOutput().GetHSize() <<
" , " << this->
GetOutput().GetWSize() <<
" ) ";
176 std::cout << std::endl;
179template <
typename Architecture_t>
194template <
typename Architecture_t>
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t height
TObject * GetOutput(const char *name)
typename Architecture_t::Scalar_t Scalar_t
typename Architecture_t::Tensor_t Tensor_t
void ReadWeightsFromXML(void *parent) override
Read the information and the weights about the layer from XML node.
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
typename Architecture_t::Matrix_t Matrix_t
void Forward(Tensor_t &input, bool applyDropout=false) override
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
bool fFlattening
Whether the layer is doing flattening.
void AddWeightsXMLTo(void *parent) override
Writes the information and the weights about the layer in an XML node.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward) override
Backpropagates the error.
void Print() const override
Prints the info about the layer.
~TReshapeLayer()
Destructor.
Generic General Layer class.
size_t GetInputDepth() const
size_t GetInputHeight() const
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
create variable transformations