Skip to content

Commit 7e4c956

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] Support opset13 Squeeze and Unsqueeze (#50150) (#50906)
Summary: Pull Request resolved: #50906 In opset 13, squeeze/unsqueeze is updated to take axes as input, instead of attribute. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26050883 Pulled By: SplitInfinity fbshipit-source-id: 7b5faf0e016d476bc75cbf2bfee6918d77e8aecd
1 parent 1c9347c commit 7e4c956

File tree

11 files changed

+205
-266
lines changed

11 files changed

+205
-266
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 8 additions & 111 deletions
Large diffs are not rendered by default.

test/onnx/test_utility_funs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.onnx import utils, OperatorExportTypes, TrainingMode
66
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
77
import torch.utils.cpp_extension
8-
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
8+
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
99
import caffe2.python.onnx.backend as backend
1010
from verify import verify
1111

@@ -618,6 +618,8 @@ def forward(self, x):
618618
assert next(iter).kind() == "aten::quantize_per_tensor"
619619
assert next(iter).kind() == "aten::dequantize"
620620

621+
# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
622+
@skipIfUnsupportedOpsetVersion([11, 12])
621623
def test_prim_fallthrough(self):
622624
# Test prim op
623625
class PrimModule(torch.jit.ScriptModule):

torch/csrc/jit/passes/onnx/helper.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,28 @@ Value* addInputToBlock(Block* block) {
9797
return block->addInput();
9898
}
9999

100+
Node* createONNXUnsqueeze(
101+
Graph* graph,
102+
Node* n_to_insert_before,
103+
Value* input,
104+
int axis,
105+
int opset_version) {
106+
Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1);
107+
unsqueeze_node->addInput(input);
108+
unsqueeze_node->insertBefore(n_to_insert_before);
109+
if (opset_version >= OPSET_VERSION_13) {
110+
// ONNX spec sets `axes` as input for opset >= 13.
111+
Node* unsqueeze_axes = graph->create(onnx::Constant, 1);
112+
unsqueeze_axes->insertBefore(unsqueeze_node);
113+
unsqueeze_axes->t_(
114+
attr::value, at::unsqueeze(at::scalar_to_tensor(at::Scalar(axis)), 0));
115+
unsqueeze_node->addInput(unsqueeze_axes->output());
116+
} else {
117+
// ONNX spec sets `axes` as attribute for opset < 13.
118+
unsqueeze_node->is_(attr::axes, {0});
119+
}
120+
return unsqueeze_node;
121+
}
122+
100123
} // namespace jit
101124
} // namespace torch

torch/csrc/jit/passes/onnx/helper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ static const int OPSET_VERSION_9 = 9;
1313
static const int OPSET_VERSION_10 = 10;
1414
static const int OPSET_VERSION_11 = 11;
1515
static const int OPSET_VERSION_12 = 12;
16+
static const int OPSET_VERSION_13 = 13;
1617

1718
using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
1819

@@ -33,5 +34,13 @@ Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef<Value*> inputs);
3334
Value* addInputToBlock(Block* block);
3435

3536
TORCH_API c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);
37+
38+
Node* createONNXUnsqueeze(
39+
Graph* graph,
40+
Node* n_to_insert_before,
41+
Value* input,
42+
int axis,
43+
int opset_version);
44+
3645
} // namespace jit
3746
} // namespace torch

torch/csrc/jit/passes/onnx/peephole.cpp

Lines changed: 65 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,8 @@ void fixDefaultRNNState(
416416
batch_size->addInput(shape_of_input->outputs()[0]);
417417
batch_size->addInput(gather_indices->outputs()[0]);
418418

419-
Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
420-
unsqueezed_batch_size->insertBefore(n);
421-
unsqueezed_batch_size->addInput(batch_size->outputs()[0]);
422-
unsqueezed_batch_size->is_(attr::axes, {0});
419+
Node* unsqueezed_batch_size =
420+
createONNXUnsqueeze(graph, n, batch_size->outputs()[0], 0, opset_version);
423421

424422
Node* hidden_size = graph->create(onnx::Constant, 1);
425423
hidden_size->insertBefore(n);
@@ -440,10 +438,8 @@ void fixDefaultRNNState(
440438
? 2
441439
: 1)));
442440

443-
Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
444-
unsqueezed_num_directions->insertBefore(n);
445-
unsqueezed_num_directions->addInput(num_directions->outputs()[0]);
446-
unsqueezed_num_directions->is_(attr::axes, {0});
441+
Node* unsqueezed_num_directions = createONNXUnsqueeze(
442+
graph, n, num_directions->outputs()[0], 0, opset_version);
447443

448444
Node* concated_dims = graph->create(onnx::Concat, 1);
449445
concated_dims->insertBefore(n);
@@ -555,6 +551,65 @@ static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
555551
}
556552
}
557553

554+
static void eraseListConstruct(Block* block, int opset_version);
555+
556+
static void eraseListConstruct(Node* n, int opset_version) {
557+
for (auto b : n->blocks()) {
558+
eraseListConstruct(b, opset_version);
559+
}
560+
std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;
561+
562+
auto block = n->owningBlock();
563+
size_t i = 0;
564+
for (auto* input : n->inputs()) {
565+
if (input->node()->kind() == prim::ListConstruct) {
566+
auto* lc_node = input->node();
567+
TypePtr elem =
568+
lc_node->output()->type()->cast<ListType>()->getElementType();
569+
if (elem->cast<IntType>()) {
570+
// ListConstruct Int[] output case, we need to transform to ONNX
571+
// Concat to ensure the output is a single tensor(dynamic) type in
572+
// order to be consumed as inputs
573+
std::vector<Value*> unsqueezed;
574+
Graph* g = block->owningGraph();
575+
for (auto* input : lc_node->inputs()) {
576+
Node* unsqueezed_node =
577+
createONNXUnsqueeze(g, lc_node, input, 0, opset_version);
578+
unsqueezed.emplace_back(unsqueezed_node->output());
579+
}
580+
Node* concat_node = g->create(onnx::Concat, 1);
581+
concat_node->i_(attr::axis, 0);
582+
for (auto v : unsqueezed) {
583+
concat_node->addInput(v);
584+
}
585+
concat_node->insertBefore(lc_node);
586+
587+
// make concat node output as new input, then ListConstruct should
588+
// become dead
589+
replacements.emplace_back(
590+
i, std::vector<Value*>({concat_node->output()}));
591+
592+
} else {
593+
if (opset_version >= OPSET_VERSION_11) {
594+
c10::Symbol seq_node_kind = lc_node->inputs().size() > 0
595+
? onnx::SequenceConstruct
596+
: onnx::SequenceEmpty;
597+
Node* seq_node = block->owningGraph()->create(
598+
seq_node_kind, {lc_node->inputs()}, 1);
599+
seq_node->insertBefore(lc_node);
600+
seq_node->output()->copyMetadata(lc_node->output());
601+
lc_node->replaceAllUsesWith(seq_node);
602+
}
603+
}
604+
}
605+
i++;
606+
}
607+
608+
for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); ++ritr) {
609+
replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
610+
}
611+
}
612+
558613
static void eraseListConstruct(Block* block, int opset_version) {
559614
// TODO: Fix this pass/maybe get rid of this part.
560615
// Tensor lists might be used for meshgrid and such ops as well.
@@ -563,71 +618,9 @@ static void eraseListConstruct(Block* block, int opset_version) {
563618
Node* n = *it;
564619
++it;
565620

566-
for (auto b : n->blocks()) {
567-
eraseListConstruct(b, opset_version);
568-
}
569-
std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;
570-
571-
size_t i = 0;
572-
for (auto* input : n->inputs()) {
573-
if (input->node()->kind() == prim::ListConstruct) {
574-
auto* lc_node = input->node();
575-
TypePtr elem =
576-
lc_node->output()->type()->cast<ListType>()->getElementType();
577-
if (elem->cast<IntType>()) {
578-
// ListConstruct Int[] output case, we need to transform to ONNX
579-
// Concat to ensure the output is a single tensor(dynamic) type in
580-
// order to be consumed as inputs
581-
std::vector<Value*> unsqueezed;
582-
Graph* g = block->owningGraph();
583-
for (auto* input : lc_node->inputs()) {
584-
Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1);
585-
unsqueezed_node->insertBefore(lc_node);
586-
unsqueezed_node->addInput(input);
587-
unsqueezed_node->is_(attr::axes, {0});
588-
unsqueezed.emplace_back(unsqueezed_node->output());
589-
}
590-
Node* concat_node = g->create(onnx::Concat, 1);
591-
concat_node->i_(attr::axis, 0);
592-
for (auto v : unsqueezed) {
593-
concat_node->addInput(v);
594-
}
595-
concat_node->insertBefore(lc_node);
596-
597-
// make concat node output as new input, then ListConstruct should
598-
// become dead
599-
replacements.emplace_back(
600-
i, std::vector<Value*>({concat_node->output()}));
601-
602-
} else {
603-
if (opset_version < OPSET_VERSION_11) {
604-
// Tensor lists are used mostly for inputs to cat/stack. They are
605-
// already handled in those symbolics, and should become dead
606-
// afterwards.
607-
replacements.emplace_back(
608-
i,
609-
std::vector<Value*>(
610-
lc_node->inputs().begin(), lc_node->inputs().end()));
611-
} else {
612-
c10::Symbol seq_node_kind = lc_node->inputs().size() > 0
613-
? onnx::SequenceConstruct
614-
: onnx::SequenceEmpty;
615-
Node* seq_node = block->owningGraph()->create(
616-
seq_node_kind, {lc_node->inputs()}, 1);
617-
seq_node->insertBefore(lc_node);
618-
seq_node->output()->copyMetadata(lc_node->output());
619-
lc_node->replaceAllUsesWith(seq_node);
620-
}
621-
}
622-
}
623-
i++;
624-
}
625-
626-
for (auto ritr = replacements.rbegin(); ritr != replacements.rend();
627-
++ritr) {
628-
replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
629-
}
621+
eraseListConstruct(n, opset_version);
630622
}
623+
eraseListConstruct(block->return_node(), opset_version);
631624
}
632625

633626
// For ops such as meshgrid where output is a list of Tensors

torch/csrc/jit/passes/onnx/shape_type_inference.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,10 @@ bool IsSupportedNode(const Node* n) {
201201
return true;
202202
}
203203

204-
Value* CloneValueFromListConstruct(Value* v, std::shared_ptr<Graph> n_graph) {
204+
Value* CloneValueFromListConstruct(
205+
Value* v,
206+
std::shared_ptr<Graph> n_graph,
207+
int opset_version) {
205208
auto lc_node = v->node();
206209
TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct);
207210
// In jit/passes/onnx/peephole.cpp::eraseListConstruct,
@@ -221,12 +224,10 @@ Value* CloneValueFromListConstruct(Value* v, std::shared_ptr<Graph> n_graph) {
221224
// order to be consumed as inputs
222225
std::vector<Value*> unsqueezed;
223226
for (auto* input : lc_node->inputs()) {
224-
Node* unsqueezed_node =
225-
n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1));
226227
auto new_input = n_graph->addInput();
227228
new_input->copyMetadata(input);
228-
unsqueezed_node->addInput(new_input);
229-
unsqueezed_node->is_(attr::axes, {0});
229+
Node* unsqueezed_node = createONNXUnsqueeze(
230+
n_graph.get(), n_graph->return_node(), new_input, 0, opset_version);
230231
unsqueezed.emplace_back(unsqueezed_node->output());
231232
}
232233
Node* concat_node =
@@ -261,11 +262,12 @@ Value* CloneValueFromListConstruct(Value* v, std::shared_ptr<Graph> n_graph) {
261262
Node* CloneNodeToGraph(
262263
Node* n,
263264
std::shared_ptr<Graph> n_graph,
264-
const ParamMap& params_dict) {
265+
const ParamMap& params_dict,
266+
int opset_version) {
265267
auto vals_to_params_map =
266268
buildValueToParamsMap(n->owningGraph()->block(), params_dict);
267-
auto clone_node =
268-
n_graph->createClone(n, [&n_graph, &vals_to_params_map](Value* v) {
269+
auto clone_node = n_graph->createClone(
270+
n, [&n_graph, &vals_to_params_map, opset_version](Value* v) {
269271
auto v_n = v->node();
270272
switch (v_n->kind()) {
271273
case ::c10::onnx::Constant: {
@@ -275,7 +277,7 @@ Node* CloneNodeToGraph(
275277
return constant_n->output();
276278
}
277279
case ::c10::prim::ListConstruct: {
278-
return CloneValueFromListConstruct(v, n_graph);
280+
return CloneValueFromListConstruct(v, n_graph, opset_version);
279281
}
280282
case ::c10::prim::PackPadded: {
281283
auto input = n_graph->addInput();
@@ -476,7 +478,7 @@ void ONNXShapeTypeInference(
476478
// Create a Graph containing only the single node n.
477479
// This graph is later converted to ONNX to run shape inference.
478480
auto n_graph = std::make_shared<Graph>();
479-
auto clone_node = CloneNodeToGraph(n, n_graph, params_dict);
481+
auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
480482
n_graph->insertNode(clone_node);
481483

482484
// Register all node outputs as graph outputs.
@@ -507,12 +509,16 @@ void ONNXShapeTypeInference(
507509
} catch (std::runtime_error& ex) {
508510
// TODO: include this as warning once we have a more consolidated warning
509511
// system.
512+
GRAPH_DEBUG(
513+
"ONNX shape inference fails with: ",
514+
ex.what(),
515+
" on graph: ",
516+
n_graph->toString());
510517
const char shape_err[] = "ShapeInferenceError";
511518
const char type_err[] = "TypeInferenceError";
512519
if ((strstr(ex.what(), shape_err) == NULL) &&
513520
(strstr(ex.what(), type_err) == NULL))
514521
throw;
515-
GRAPH_DEBUG("ONNX shape inference fails with: ", ex.what());
516522
}
517523
GRAPH_DEBUG(
518524
"ONNX graph after shape inference: ", prettyPrint(*model_proto));

torch/onnx/symbolic_helper.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,19 @@ def _interpolate_warning(interpolate_mode):
321321
"to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n"
322322
"We recommend using opset 11 and above for models using this operator. ")
323323

324-
def _unsqueeze_helper(g, input, dim):
325-
from torch.onnx.symbolic_opset9 import unsqueeze
326-
return unsqueeze(g, input, dim)
324+
def _unsqueeze_helper(g, input, axes_i):
325+
if _export_onnx_opset_version >= 13:
326+
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
327+
return g.op("Unsqueeze", input, axes)
328+
else:
329+
return g.op("Unsqueeze", input, axes_i=axes_i)
330+
331+
def _squeeze_helper(g, input, axes_i):
332+
if _export_onnx_opset_version >= 13:
333+
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
334+
return g.op("Squeeze", input, axes)
335+
else:
336+
return g.op("Squeeze", input, axes_i=axes_i)
327337

328338
def _interpolate_size_to_scales(g, input, output_size, dim):
329339
output_size = _maybe_get_const(output_size, 'is')
@@ -371,7 +381,7 @@ def _interpolate_get_scales(g, scale_factor, dim):
371381
if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor_rank is not None and scale_factor_rank > 0):
372382
return g.op("Concat", offsets, scale_factor, axis_i=0)
373383
else:
374-
scale_factor = _unsqueeze_helper(g, scale_factor, 0)
384+
scale_factor = _unsqueeze_helper(g, scale_factor, [0])
375385
scale_factor = g.op("Cast", scale_factor, to_i=cast_pytorch_to_onnx["Float"])
376386
scales = [scale_factor for i in range(dim - 2)]
377387
scale_factor = g.op("Concat", offsets, *scales, axis_i=0)
@@ -400,7 +410,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_
400410
if not _is_packed_list(size):
401411
is_scalar = ((_maybe_get_const(size, 't').dim() == 0))
402412
if is_scalar:
403-
size = _unsqueeze_helper(g, size, 0)
413+
size = _unsqueeze_helper(g, size, [0])
404414
size = [size for i in range(dim - 2)]
405415
size = g.op("Concat", *size, axis_i=0)
406416
scale_factor = _interpolate_size_to_scales(g, input, size, dim)
@@ -477,9 +487,9 @@ def _index_fill_reshape_helper(g, self, dim, index):
477487
return _unimplemented("index_fill", "input rank not accesible")
478488
self_dim = self.type().dim()
479489
dim_value = _parse_arg(dim, 'i')
480-
unsqueezed_index = g.op("Unsqueeze", index, axes_i=[i for i in range(self_dim) if i != dim_value])
490+
unsqueezed_index = _unsqueeze_helper(g, index, [i for i in range(self_dim) if i != dim_value])
481491
expanded_index_shape = scatter(g, g.op("Shape", self), 0,
482-
g.op("Unsqueeze", dim, axes_i=[0]), g.op("Shape", index))
492+
_unsqueeze_helper(g, dim, [0]), g.op("Shape", index))
483493
expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None)
484494
return expanded_index_shape, expanded_index
485495

0 commit comments

Comments
 (0)