Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_ConvTranspose.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
2#define TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
3
5#include <TMVA/ROperator.hxx>
6#include <TMVA/RModel.hxx>
7
8#include <memory>
9#include <sstream>
10#include <algorithm>
11#include <stdexcept>
12#include <vector>
13#include <cassert>
14
15namespace TMVA {
16namespace Experimental {
17namespace SOFIE {
18
19/*! \brief Transposed Convolution operator
20 *
21 * Inference code generation for a transposed convolution layer.
22 * See the <a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#convtranspose">ONNX documentation</a> for
23 * details about the transposed conv layer.
24 */
25template <typename T>
27private:
28 std::string fAttrAutopad;
29 std::vector<size_t> fAttrDilations;
30 size_t fAttrGroup;
31 std::vector<size_t> fAttrKernelShape;
32 std::vector<size_t> fAttrOutputPadding;
33 std::vector<size_t> fAttrOutputShape;
34 std::vector<size_t> fAttrPads;
35 std::vector<size_t> fAttrStrides;
36
37 std::string fNX;
38 std::string fNW;
39 std::string fNB;
40 std::string fNBroadcastedB;
41 std::string fNY;
42
43 std::vector<size_t> fShapeX;
44 std::vector<size_t> fShapeW;
45 std::vector<size_t> fShapeB;
46 std::vector<size_t> fShapeY;
47
48 std::string fType;
49
50 size_t fDim; // dimension of the convolution
51
52public:
53 /*! Default constructor of ROperator_ConvTranspose */
55
56 /*! \brief Constructor of ROperator_ConvTranspose from the attributes
57 *
58 * \param autopad padding
59 * \param dilations dilations of the kernel
60 * \param group number of groups
61 * \param kernelShape shape of the kernel
62 * \param outputPadding padding of the output
63 * \param outputShape shape of the output
64 * \param pads padding of the input
65 * \param strides strides
66 * \param nameX name of the input
67 * \param nameW name of the weight
68 * \param nameB name of the bias
69 * \param nameY name of the output
70 */
71 ROperator_ConvTranspose(std::string autopad, std::vector<size_t> dilations, size_t group,
72 std::vector<size_t> kernelShape, std::vector<size_t> outputPadding,
73 std::vector<size_t> outputShape, std::vector<size_t> pads, std::vector<size_t> strides,
74 std::string nameX, std::string nameW, std::string nameB, std::string nameY)
77 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)),
78 fNY(UTILITY::Clean_name(nameY))
79 {
80 if (std::is_same<T, float>::value) {
81 fType = "float";
82 } else {
83 throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
84 }
85 }
86
87 /*! \brief Infers the type of the output tensor
88 * \param input type of the input tensors
89 */
90 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
91 {
92 ETensorType out = input[0];
93 return {out};
94 }
95
96 /*! \brief Infers the shape of the input tensors
97 * \param input shape of the input tensors
98 */
99 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/) override;
100
101 /*! \brief Initialize the model
102 * \param model Model
103 */
104 void Initialize(RModel & /*model*/) override;
105
106 /*! \brief Generate code for initializing the op
107 */
108 std::string GenerateInitCode() override;
109
110 /*! \brief Generate code for Session data members (e.g. internal vectors)
111 * \param opName name of the operator
112 */
113 std::string GenerateSessionMembersCode(std::string /*opName*/) override;
114
115 /*! \brief Generate the inference code
116 * \param opName name of the operator
117 */
118 std::string Generate(std::string opName) override;
119
120 /*! \brief Returns the blas routines needed to compile the generated code
121 */
122 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
123};
124
125} // namespace SOFIE
126} // namespace Experimental
127} // namespace TMVA
128
129// Implementation of the ROperator_ConvTranspose class
131
132#endif
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
ROperator_ConvTranspose(std::string autopad, std::vector< size_t > dilations, size_t group, std::vector< size_t > kernelShape, std::vector< size_t > outputPadding, std::vector< size_t > outputShape, std::vector< size_t > pads, std::vector< size_t > strides, std::string nameX, std::string nameW, std::string nameB, std::string nameY)
Constructor of ROperator_ConvTranspose from the attributes.
void Initialize(RModel &) override
Initialize the model.
ROperator_ConvTranspose()
Default constructor of ROperator_ConvTranspose.
std::string GenerateSessionMembersCode(std::string) override
Generate code for Session data members (e.g.
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
Infers the type of the output tensor.
std::string GenerateInitCode() override
Generate code for initializing the op.
std::string Generate(std::string opName) override
Generate the inference code.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >) override
Infers the shape of the input tensors.
create variable transformations