Skip to content

Commit 9cc44aa

Browse files
Zafar Takhirovfacebook-github-bot
authored andcommitted
[quant] AO migration of the quantize.py (resubmission) (#64445)
Summary: Pull Request resolved: #64445 AO Team is migrating the existing torch.quantization into torch.ao.quantization. We are doing it one file at a time to make sure that the internal callsites are updated properly. This migrates the quantize.py from torch.quantization to torch.ao.quantization. At this point both locations will be supported. Eventually the torch.quantization will be deprecated. Test Plan: `buck test mode/dev //caffe2/test:quantization` Reviewed By: HDCharles Differential Revision: D30734870 fbshipit-source-id: dc204f3cc46bff2cc81c95159eab9d333b43bb4b
1 parent 72274e2 commit 9cc44aa

File tree

10 files changed

+677
-581
lines changed

10 files changed

+677
-581
lines changed

test/quantization/ao_migration/__init__.py

Whitespace-only changes.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from torch.testing._internal.common_utils import TestCase
2+
3+
import importlib
4+
from typing import List
5+
6+
7+
class AOMigrationTestCase(TestCase):
8+
def _test_package_import(self, package_name: str):
9+
r"""Tests the module import by making sure that all the internals match
10+
(except the dunder methods)."""
11+
old_module = importlib.import_module(f'torch.quantization.{package_name}')
12+
new_module = importlib.import_module(f'torch.ao.quantization.{package_name}')
13+
old_module_dir = set(dir(old_module))
14+
new_module_dir = set(dir(new_module))
15+
# Remove magic modules from checking in subsets
16+
for el in list(old_module_dir):
17+
if el[:2] == '__' and el[-2:] == '__':
18+
old_module_dir.remove(el)
19+
assert (old_module_dir <= new_module_dir), \
20+
f"Importing {old_module} vs. {new_module} does not match: " \
21+
f"{old_module_dir - new_module_dir}"
22+
23+
def _test_function_import(self, package_name: str, function_list: List[str]):
24+
r"""Tests individual function list import by comparing the functions
25+
and their hashes."""
26+
old_location = importlib.import_module(f'torch.quantization.{package_name}')
27+
new_location = importlib.import_module(f'torch.ao.quantization.{package_name}')
28+
for fn_name in function_list:
29+
old_function = getattr(old_location, fn_name)
30+
new_function = getattr(new_location, fn_name)
31+
assert old_function == new_function, f"Functions don't match: {fn_name}"
32+
assert hash(old_function) == hash(new_function), \
33+
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
34+
f"{new_function}({hash(new_function)})"
35+
36+
37+
class TestAOMigrationQuantizePy(AOMigrationTestCase):
38+
def test_package_import(self):
39+
self._test_package_import('quantize')
40+
41+
def test_function_import(self):
42+
function_list = [
43+
'_convert',
44+
'_observer_forward_hook',
45+
'_propagate_qconfig_helper',
46+
'_remove_activation_post_process',
47+
'_remove_qconfig',
48+
'add_observer_',
49+
'add_quant_dequant',
50+
'convert',
51+
'get_observer_dict',
52+
'get_unique_devices_',
53+
'is_activation_post_process',
54+
'prepare',
55+
'prepare_qat',
56+
'propagate_qconfig_',
57+
'quantize',
58+
'quantize_dynamic',
59+
'quantize_qat',
60+
'register_activation_post_process_hook',
61+
'swap_module',
62+
]
63+
self._test_function_import('quantize', function_list)

test/test_quantization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@
100100
from quantization.jit.test_fusion_passes import TestFusionPasses # noqa: F401
101101
from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized # noqa: F401
102102

103+
# AO Migration tests
104+
from quantization.ao_migration.test_quantize import TestAOMigrationQuantizePy # noqa: F401
105+
103106

104107
if __name__ == '__main__':
105108
run_tests()

torch/ao/quantization/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)