File tree Expand file tree Collapse file tree 2 files changed +14
-2
lines changed
Expand file tree Collapse file tree 2 files changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -1369,6 +1369,17 @@ def fn(y):
13691369 self .assertEqual (r , x * 3 )
13701370 del x
13711371
1372+ def test_numpy_gt (self ):
1373+ x = np .arange (10 )
1374+
1375+ @torch .compile
1376+ def fn (y ):
1377+ return y >= 3
1378+
1379+ r = fn (x )
1380+ self .assertEqual (type (r ), np .ndarray )
1381+ self .assertEqual (r , x >= 3 )
1382+
13721383 def test_graph_break_correctly_when_passing_numpy_ndarray_to_torch_function (self ):
13731384 # from transformers/models/big_bird/modeling_big_bird.py
13741385 def fn (x : int , y : torch .Tensor ):
Original file line number Diff line number Diff line change @@ -1339,11 +1339,12 @@ def _unimplemented():
13391339 return ConstantVariable (op (left ._underlying_items , right ._underlying_items ))
13401340
13411341 if isinstance (left , TensorVariable ):
1342- from .builder import wrap_fx_proxy
1342+ from .builder import wrap_fx_proxy_cls
13431343
13441344 if op not in supported_tensor_comparison_ops .values ():
13451345 _unimplemented ()
1346- return wrap_fx_proxy (
1346+ return wrap_fx_proxy_cls (
1347+ type (left ), # handle Ndarrays and Tensors
13471348 tx ,
13481349 op (left .as_proxy (), right .as_proxy ()),
13491350 )
You can’t perform that action at this time.
0 commit comments