Skip to content

Commit 0455d36

Browse files
author
Thiago Crepaldi
committed
Address comments
1 parent 629c075 commit 0455d36

File tree

2 files changed

+22
-40
lines changed

2 files changed

+22
-40
lines changed

test/onnx/test_onnx_export.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import onnx
99
import torch
10+
from typing import Callable
1011

1112
from test_pytorch_common import TestCase
1213
from torch.onnx.symbolic_helper import _onnx_unsupported
@@ -62,38 +63,23 @@ def break_is_registered_op_api(opname, domain, version):
6263
self.assertAtenOp(onnx_model, "clamp", "Tensor")
6364

6465

65-
def _helper_test_to_(self, cast_fn):
66-
"""Run a generic test with the specified cast_fn
66+
def _helper_test_to_(self, cast_fn: Callable[[torch.Tensor], torch.Tensor]):
67+
"""Helper to test aten::to(device) variants
6768
68-
`cast_fn` will be converted into a `torch.jit.script` function
69-
and it is meant to wrap a `aten::to` during export to prevent
70-
hard-coded devices
69+
`cast_fn` is converted into a `torch.jit.script`. It wraps `aten::to`
70+
during export to preventing the devices to be hard-coded.
7171
72-
`cast_fn` signature is:
73-
def cast_fn(input: torch.Tensor) -> torch.Tensor
72+
Needed by detectron2 after https://github.com/facebookresearch/detectron2/pull/4132/
7473
"""
75-
cast_fn = torch.jit.script_if_tracing(cast_fn)
76-
77-
class MyModel(torch.nn.Module):
78-
def __init__(self):
79-
super().__init__()
80-
self.conv1 = torch.nn.Conv2d(3, 20, 5)
81-
self.conv2 = torch.nn.Conv2d(20, 20, 5)
82-
83-
def forward(self, x):
84-
x = cast_fn(x)
85-
x = torch.nn.functional.relu(self.conv1(x))
86-
return torch.nn.functional.relu(self.conv2(x))
74+
cast_fn = torch.jit.script(cast_fn)
8775

8876
f = io.BytesIO()
89-
model = MyModel()
9077
x = torch.zeros([1, 3, 32, 32])
91-
torch.onnx.export(model, (x,), f,
92-
operator_export_type=OperatorExportTypes.ONNX)
78+
torch.onnx.export(cast_fn, (x,), f)
9379
onnx_model = onnx.load_from_string(f.getvalue())
9480
for n in onnx_model.graph.node:
95-
assert n.op_type != "To"
96-
assert n.op_type != "Cast"
81+
self.assertNotEqual(n.op_type, "To")
82+
self.assertNotEqual(n.op_type, "Cast")
9783

9884
def test_to__cpu_string(self):
9985
def cast_cpu_string(src: torch.Tensor) -> torch.Tensor:

torch/onnx/symbolic_opset9.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,33 +2249,29 @@ def topk(g, self, k, dim, largest, sorted, out=None):
22492249

22502250
def to(g, self, *args):
22512251

2252-
def is_aten_to_device_only(*args):
2252+
def is_aten_to_device_only(args):
22532253
if len(args) == 4:
22542254
# aten::to(Tensor, Device, bool, bool, memory_format)
2255-
if args[0].node().kind() == "prim::device" or \
2256-
args[0].type().isSubtypeOf(ListType.ofInts()):
2257-
return True
2258-
if sym_help._is_value(args[0]) and \
2259-
args[0].node().kind() == "onnx::Constant" and \
2260-
isinstance(args[0].node()["value"], str):
2261-
return True
2255+
return args[0].node().kind() == "prim::device" or \
2256+
args[0].type().isSubtypeOf(ListType.ofInts()) or \
2257+
(sym_help._is_value(args[0]) and
2258+
args[0].node().kind() == "onnx::Constant" and
2259+
isinstance(args[0].node()["value"], str))
22622260
elif len(args) == 5:
22632261
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
2262+
# When dtype is None, this is a aten::to(device) call
22642263
dtype = sym_help._get_const(args[1], "i", "dtype")
2265-
if not dtype:
2266-
# When dtype is None, this is a aten::to(device) call
2267-
return True
2268-
elif len(args) >= 6 and len(args) <= 7:
2264+
return dtype is None
2265+
elif len(args) in (6, 7):
22692266
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
22702267
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
2268+
# When dtype is None, this is a aten::to(device) call
22712269
dtype = sym_help._get_const(args[0], "i", "dtype")
2272-
if not dtype:
2273-
# When dtype is None, this is a aten::to(device) call
2274-
return True
2270+
return dtype is None
22752271
return False
22762272

22772273
# ONNX doesn't have a concept of a device, so we ignore device-only casts
2278-
if is_aten_to_device_only(*args):
2274+
if is_aten_to_device_only(args):
22792275
return self
22802276

22812277
if len(args) == 4:

0 commit comments

Comments
 (0)