File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed
Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -3210,6 +3210,11 @@ def fn(x, m):
32103210 res = opt_fn (x , m )
32113211 self .assertEqual (ref , res )
32123212
3213+ # Test now the other path
3214+ ref = fn (x , x )
3215+ res = opt_fn (x , x )
3216+ self .assertEqual (ref , res )
3217+
32133218 def test_tensor_dot_grad_no_graph_break (self ):
32143219 def fn (a , b ):
32153220 y = 3 * a ** 3 - b ** 2
Original file line number Diff line number Diff line change 33import types
44from typing import Dict , List
55
6+ import numpy as np
7+
68import sympy
79
810import torch ._numpy as tnp
@@ -955,6 +957,9 @@ def call_method(
955957 )
956958 return NumpyNdarrayVariable .create (tx , proxy , ** options )
957959
960+ def python_type (self ):
961+ return np .ndarray
962+
958963
959964class UnspecializedPythonVariable (TensorVariable ):
960965 """
You can’t perform that action at this time.
0 commit comments