11# Owner(s): ["module: meta tensors"]
22
33from torch .testing ._internal .common_utils import (
4- TestCase , run_tests , skipIfCrossRef , skipIfRocm , skipIfTorchDynamo , parametrize ,
4+ TestCase , TEST_WITH_TORCHDYNAMO , run_tests , skipIfCrossRef , skipIfRocm , skipIfTorchDynamo , parametrize ,
55 instantiate_parametrized_tests )
66import torch
77import torch ._dynamo
@@ -111,6 +111,7 @@ def test_non_parameter_grad(self):
111111 fake_t = mode .from_tensor (t )
112112 self .assertEqual (fake_t .requires_grad , t .requires_grad )
113113
114+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
114115 @unittest .skipIf (not RUN_CUDA , "requires cuda" )
115116 def test_index_cuda_with_cpu (self ):
116117 with FakeTensorMode ():
@@ -253,6 +254,7 @@ def test_fake_mode_error(self):
253254 with FakeTensorMode ():
254255 y = x [0 ]
255256
257+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
256258 def test_fake_grad_copy (self ):
257259 x = torch .rand ([4 , 4 ], requires_grad = True )
258260 x .grad = torch .rand ([4 , 4 ])
@@ -282,6 +284,7 @@ def test_binary_op_type_promotion(self):
282284 self .assertEqual (out .dtype , torch .float )
283285 self .assertEqual (out .device .type , "cpu" )
284286
287+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
285288 def test_from_numpy (self ):
286289 with FakeTensorMode ():
287290 x = torch .tensor (np .zeros ([4 , 4 ]))
@@ -353,6 +356,7 @@ def test_out_multi_device(self):
353356 x .add_ (y )
354357
355358
359+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
356360 @unittest .skipIf (not RUN_CUDA , "requires cuda" )
357361 def test_normalize_device (self ):
358362 with FakeTensorMode ():
@@ -369,6 +373,7 @@ def test_recursive_invocation(self):
369373 y = x + x
370374 self .assertTrue (mode .in_kernel_invocation )
371375
376+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
372377 @skipIfRocm
373378 @parametrize ("allow_fallback_kernels" , [False , True ],
374379 lambda a : 'with_fallback' if a else 'without_fallback' )
@@ -540,6 +545,7 @@ def __init__(self):
540545 self .assertIs (mod_copied .a , mod_copied .b )
541546 self .assertEqual (mod_copied .b .storage ()._cdata , mod_copied .a .storage ()._cdata )
542547
548+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
543549 @unittest .skipIf (not RUN_CUDA , "requires cuda" )
544550 def test_new (self ):
545551 with FakeTensorMode ():
@@ -550,13 +556,15 @@ def test_new(self):
550556 self .checkType (b .new (device = 'cuda' ), "cuda" , [0 ])
551557 self .checkType (a .new (torch .rand ([1 ])), "cpu" , [1 ])
552558
559+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
553560 def test_scalar_inputs (self ):
554561 with FakeTensorMode ():
555562 self .checkType (torch .div (3 , 2 ), "cpu" , [])
556563 ten = torch .zeros (2 , dtype = torch .int32 ) * 2.0
557564 self .assertEqual (ten .dtype , torch .float )
558565 self .checkType (ten , "cpu" , [2 ])
559566
567+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
560568 def test_allow_meta (self ):
561569 def run_meta ():
562570 with FakeTensorMode ():
@@ -585,6 +593,7 @@ def f():
585593 self .assertEqual (r .size (), f .size ())
586594 self .assertEqual (r .device , f .device )
587595
596+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
588597 def test_mixed_real_and_fake_inputs (self ):
589598 class _TestPattern (torch .nn .Module ):
590599 def __init__ (self ):
@@ -613,6 +622,7 @@ def forward(self, input):
613622 out = mod (torch .randn (1 , 1 , 3 , 3 ))
614623 self .checkType (out , "cpu" , (1 , 1 , 3 , 3 ))
615624
625+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
616626 @unittest .skipIf (not RUN_CUDA , "requires cuda" )
617627 def test_aten_copy_multi_device (self ):
618628 with FakeTensorMode ():
@@ -626,6 +636,7 @@ def test_aten_copy_multi_device(self):
626636 self .checkType (copy2 , "cuda" , (4 ,))
627637 self .checkType (out , "cpu" , (4 ,))
628638
639+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
629640 @unittest .skipIf (not RUN_CUDA , "requires cuda" )
630641 def test_aten_index_multi_device (self ):
631642 with FakeTensorMode ():
@@ -647,6 +658,7 @@ def test_aten_index_multi_device(self):
647658 self .checkType (r3 , "cpu" , (4 , 4 ))
648659 self .checkType (r4 , "cuda" , (4 , 4 ))
649660
661+ @unittest .skipIf (TEST_WITH_TORCHDYNAMO , "isinstance check for FakeTensor won't work with compile" )
650662 @unittest .skipIf (not RUN_CUDA , "requires cuda" )
651663 def test_aten_slice_scatter_multi_device (self ):
652664 with FakeTensorMode ():
0 commit comments