Skip to content

Commit f0c231c

Browse files
author
Thiago Crepaldi
committed
Fix scalar type resolution for optional tensor
1 parent fe0e28a commit f0c231c

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

torch/onnx/symbolic_opset9.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,12 +1323,15 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw
13231323

13241324
if require_cast:
13251325
for input in inputs:
1326-
input_scalar_type = _type_utils.JitScalarType.from_value(input)
1327-
if input.isCompleteTensor() and input_scalar_type != dtype_0:
1328-
raise errors.SymbolicValueError(
1329-
f"Inputs of {op_name} must have same dtype. Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
1330-
input,
1331-
)
1326+
1327+
if input.isCompleteTensor():
1328+
input_scalar_type = _type_utils.JitScalarType.from_value(input)
1329+
if input_scalar_type != dtype_0:
1330+
raise errors.SymbolicValueError(
1331+
f"Inputs of {op_name} must have same dtype."
1332+
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
1333+
input,
1334+
)
13321335
for i, input in enumerate(inputs):
13331336
if input.isCompleteTensor() and not symbolic_helper._is_fp(input):
13341337
inputs[i] = g.op(
@@ -3617,7 +3620,7 @@ def tensor(
36173620
for t in symbolic_helper._unpack_list(data):
36183621
shape_reference = g.op("Constant", value_t=torch.LongTensor([1]))
36193622
t = symbolic_helper._reshape_helper(g, t, shape_reference)
3620-
t = g.op("Cast", t, to_i=dtype.onnx_type())
3623+
t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type())
36213624
input_list.append(t)
36223625
return g.op("Concat", *input_list, axis_i=0)
36233626
else:

0 commit comments

Comments
 (0)