80 std::vector<Dim>
DoShapeInference(
const std::vector<Dim> & input_shape,
const std::vector<Dim> & target_shape) {
84 std::vector<Dim> output_shape = target_shape;
85 bool hasMinusOne =
false;
87 for (
size_t i = 0; i < output_shape.size(); i++) {
89 if (!output_shape[i].isParam) {
90 if (output_shape[i].dim == 0) {
93 output_shape[i] =
Dim{0};
95 if (i > 0 && output_shape.size() != input_shape.size())
96 std::cout <<
"WARNING: TMVA Reshape Op : output shape has zero value at index " << i <<
97 " but input shape has a different rank than output shape" << std::endl;
98 if (i >= input_shape.size())
99 throw std::runtime_error(
"TMVA Reshape Op : output shape has zero value at index " + std::to_string(i) +
100 " but input shape does not have corresponding index");
102 output_shape[i] = input_shape[i];
103 }
else if (output_shape[i].dim ==
static_cast<size_t>(-1)) {
108 if (hasZero && hasMinusOne) {
109 throw std::runtime_error(
"TMVA Reshape Op : zero value in shape is not allowed when there is also a -1 in shape");
112 for (
size_t i = 0; i < output_shape.size(); i++) {
113 if (output_shape[i] ==
static_cast<size_t>(-1) && !output_shape[i].isParam) {
114 auto tmp = output_shape;
115 tmp.erase(tmp.begin() + i);
120 << input_length <<
" to " << tmp_length << std::endl;
123 output_shape[i] =
Dim{
static_cast<size_t>(std::stoi(input_length) / std::stoi(tmp_length))};
124 else if (
IsInteger(tmp_length) && std::stoi(tmp_length) == 1) {
125 output_shape[i] =
Dim{input_length,
static_cast<size_t>(-1)};
130 bool canSimplify =
false;
131 std::vector <Dim> reduced_input;
136 std::stringstream ss(input_length);
141 while(getline(ss, token,
'*'))
144 token.erase(std::remove_if(token.begin(), token.end(),
145 [](
unsigned char x) { return std::isspace(x); }), token.end());
146 if (token != tmp_length) {
148 size_t il =
static_cast<size_t>(std::stoi(input_length));
149 size_t tl =
static_cast<size_t>(std::stoi(tmp_length));
150 if ((il % tl) == 0) {
152 reduced_input.push_back(
Dim{il / tl});
155 reduced_input.push_back(
Dim{token});
166 if (res_shape.find(
'*') != std::string::npos)
167 output_shape[i] =
Dim{std::string(
"(") + res_shape +
")",
static_cast<size_t>(-1)};
169 output_shape[i] =
Dim{res_shape};
172 output_shape[i] =
Dim{std::string(
"(") + input_length +
" / (" + tmp_length +
"))",
static_cast<size_t>(-1)};
193 fAxis += input_shape.size();
194 auto s1 = std::vector<Dim>(input_shape.begin(), input_shape.begin() +
fAxis);
195 auto s2 = std::vector<Dim>(input_shape.begin() +
fAxis, input_shape.end());
198 std::vector<Dim> newShape = {
Dim{l1},
Dim{l2}};
203 auto output_shape = input_shape;
206 while (i < output_shape.size()) {
207 if (output_shape[i] ==
Dim{1}) {
208 output_shape.erase(output_shape.begin() + i);
215 for (
size_t i = 0; i < axes.size(); i++) {
217 axes[i] += input_shape.size();
218 if (!(output_shape[axes[i]] ==
Dim{1}))
219 throw std::runtime_error(
"TMVA Squeeze Op : Invalid axis value " + std::to_string(axes[i]) +
223 std::sort(axes.begin(), axes.end(), std::greater<int>());
224 for (
auto & axis : axes) {
225 output_shape.erase(output_shape.begin() + axis);
233 auto output_shape = input_shape;
236 int64_t
r = input_shape.size() + axes.size();
237 for (
auto &
a : axes) {
238 int64_t i =
static_cast<int64_t
>(
a);
239 if (i < -r || i >
r - 1)
240 throw std::runtime_error(
"TMVA Unsqueeze Op - axes input is not in correct range");
242 output_shape.insert(output_shape.begin() + i,
Dim{1});
245 output_shape.insert(output_shape.end() + i + 1,
Dim{1});
249 throw std::runtime_error(
"TMVA Reshape Op : Invalid ReshapeOpMode");