Skip to content

Commit 3a2c3c8

Browse files
christinaburgepytorchmergebot
authored andcommitted
unskipped mobilenet_v3 quantization and mobilenet_v2 quantization plus tests from #125438 (#157786)
These tests now pass on AArch64 in our downstream CI. `test_quantization.py::TestNumericSuiteEager::test_mobilenet_v2 <- test/quantization/eager/test_numeric_suite_eager.py PASSED [2.4434s] [ 35%]` Pull Request resolved: #157786 Approved by: https://github.com/jerryzh168, https://github.com/malfet
1 parent 9fd5b5f commit 3a2c3c8

File tree

2 files changed

+1
-6
lines changed

2 files changed

+1
-6
lines changed

test/quantization/eager/test_numeric_suite_eager.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Owner(s): ["oncall: quantization"]
22
# ruff: noqa: F841
33

4-
import unittest
54

65
import torch
76
import torch.ao.nn.quantized as nnq
@@ -38,7 +37,7 @@
3837
test_only_eval_fn,
3938
)
4039
from torch.testing._internal.common_quantized import override_qengines
41-
from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly
40+
from torch.testing._internal.common_utils import raise_on_run_directly
4241

4342

4443
class SubModule(torch.nn.Module):
@@ -600,14 +599,12 @@ def compute_error(x, y):
600599
act_compare_dict = get_matching_activations(float_model, qmodel)
601600

602601
@skip_if_no_torchvision
603-
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
604602
def test_mobilenet_v2(self):
605603
from torchvision.models.quantization import mobilenet_v2
606604

607605
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
608606

609607
@skip_if_no_torchvision
610-
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
611608
def test_mobilenet_v3(self):
612609
from torchvision.models.quantization import mobilenet_v3_large
613610

test/test_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,8 +1401,6 @@ def run_test_case(input_size, ord, keepdim):
14011401

14021402
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
14031403
def test_vector_norm(self, device, dtype):
1404-
if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]:
1405-
raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
14061404
# have to use torch.randn(...).to(bfloat16) instead of
14071405
# This test compares torch.linalg.vector_norm's output with
14081406
# torch.linalg.norm given a flattened tensor

0 commit comments

Comments
 (0)