Skip to content

Commit 0411324

Browse files
committed
Fix nparray >= smth
1 parent ec26e23 commit 0411324

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,17 @@ def fn(y):
13691369
self.assertEqual(r, x * 3)
13701370
del x
13711371

1372+
def test_numpy_gt(self):
1373+
x = np.arange(10)
1374+
1375+
@torch.compile
1376+
def fn(y):
1377+
return y >= 3
1378+
1379+
r = fn(x)
1380+
self.assertEqual(type(r), np.ndarray)
1381+
self.assertEqual(r, x >= 3)
1382+
13721383
def test_graph_break_correctly_when_passing_numpy_ndarray_to_torch_function(self):
13731384
# from transformers/models/big_bird/modeling_big_bird.py
13741385
def fn(x: int, y: torch.Tensor):

torch/_dynamo/variables/builtin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,11 +1339,12 @@ def _unimplemented():
13391339
return ConstantVariable(op(left._underlying_items, right._underlying_items))
13401340

13411341
if isinstance(left, TensorVariable):
1342-
from .builder import wrap_fx_proxy
1342+
from .builder import wrap_fx_proxy_cls
13431343

13441344
if op not in supported_tensor_comparison_ops.values():
13451345
_unimplemented()
1346-
return wrap_fx_proxy(
1346+
return wrap_fx_proxy_cls(
1347+
type(left), # handle Ndarrays and Tensors
13471348
tx,
13481349
op(left.as_proxy(), right.as_proxy()),
13491350
)

0 commit comments

Comments
 (0)