@@ -1323,12 +1323,15 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw
13231323
13241324 if require_cast :
13251325 for input in inputs :
1326- input_scalar_type = _type_utils .JitScalarType .from_value (input )
1327- if input .isCompleteTensor () and input_scalar_type != dtype_0 :
1328- raise errors .SymbolicValueError (
1329- f"Inputs of { op_name } must have same dtype. Got { dtype_0 .scalar_name ()} and { input_scalar_type .scalar_name ()} " ,
1330- input ,
1331- )
1326+
1327+ if input .isCompleteTensor ():
1328+ input_scalar_type = _type_utils .JitScalarType .from_value (input )
1329+ if input_scalar_type != dtype_0 :
1330+ raise errors .SymbolicValueError (
1331+ f"Inputs of { op_name } must have same dtype."
1332+ f"Got { dtype_0 .scalar_name ()} and { input_scalar_type .scalar_name ()} " ,
1333+ input ,
1334+ )
13321335 for i , input in enumerate (inputs ):
13331336 if input .isCompleteTensor () and not symbolic_helper ._is_fp (input ):
13341337 inputs [i ] = g .op (
@@ -3617,7 +3620,7 @@ def tensor(
36173620 for t in symbolic_helper ._unpack_list (data ):
36183621 shape_reference = g .op ("Constant" , value_t = torch .LongTensor ([1 ]))
36193622 t = symbolic_helper ._reshape_helper (g , t , shape_reference )
3620- t = g .op ("Cast" , t , to_i = dtype .onnx_type ())
3623+ t = g .op ("Cast" , t , to_i = _type_utils . JitScalarType ( dtype ) .onnx_type ())
36213624 input_list .append (t )
36223625 return g .op ("Concat" , * input_list , axis_i = 0 )
36233626 else :
0 commit comments