Skip to content

Commit 09ce760

Browse files
Revert "Add missing data types at torch export serialization (#138561)"
This reverts commit 1ef1b3b. Reverted #138561 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#138561 (comment)))
1 parent 4959784 commit 09ce760

File tree

6 files changed

+41
-57
lines changed

6 files changed

+41
-57
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,44 +1174,40 @@ def validate(self, model: torch.fx.GraphModule) -> None:
11741174
self.assertIsNot(observers[0], observers[2])
11751175
self.assertIsNot(observers[1], observers[2])
11761176

1177-
class DtypeActQuantizer(Quantizer):
1178-
def __init__(self, quant_dtype, op_name):
1179-
self.quant_dtype = quant_dtype
1180-
self.op_name = op_name
1181-
1182-
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1183-
quant_dtype = self.quant_dtype
1184-
info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo
1185-
activate_qspec = QuantizationSpec(
1186-
dtype=quant_dtype,
1187-
quant_min=int(info_fun(quant_dtype).min),
1188-
quant_max=int(info_fun(quant_dtype).max),
1189-
qscheme=torch.per_tensor_affine,
1190-
is_dynamic=False,
1191-
observer_or_fake_quant_ctr=observer.default_observer,
1192-
)
1193-
int8_qspec = QuantizationSpec(
1194-
dtype=torch.int8,
1195-
quant_min=-128,
1196-
quant_max=127,
1197-
qscheme=torch.per_tensor_symmetric,
1198-
is_dynamic=False,
1199-
observer_or_fake_quant_ctr=observer.default_weight_observer,
1200-
)
1201-
quantization_config = QuantizationConfig(
1202-
input_activation=activate_qspec,
1203-
weight=int8_qspec,
1204-
bias=None,
1205-
output_activation=activate_qspec,
1206-
)
1207-
OP_TO_ANNOTATOR[self.op_name](model, quantization_config)
1208-
1209-
def validate(self, model: torch.fx.GraphModule) -> None:
1210-
pass
1211-
1177+
@skipIfHpu
12121178
@parametrize("dtype", (torch.float32, torch.bfloat16))
12131179
@parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
12141180
def test_quantization_dtype(self, dtype, quant_dtype):
1181+
class DtypeActQuantizer(Quantizer):
1182+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1183+
info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo
1184+
activate_qspec = QuantizationSpec(
1185+
dtype=quant_dtype,
1186+
quant_min=int(info_fun(quant_dtype).min),
1187+
quant_max=int(info_fun(quant_dtype).max),
1188+
qscheme=torch.per_tensor_affine,
1189+
is_dynamic=False,
1190+
observer_or_fake_quant_ctr=observer.default_observer,
1191+
)
1192+
int8_qspec = QuantizationSpec(
1193+
dtype=torch.int8,
1194+
quant_min=-128,
1195+
quant_max=127,
1196+
qscheme=torch.per_tensor_symmetric,
1197+
is_dynamic=False,
1198+
observer_or_fake_quant_ctr=observer.default_weight_observer,
1199+
)
1200+
quantization_config = QuantizationConfig(
1201+
input_activation=activate_qspec,
1202+
weight=int8_qspec,
1203+
bias=None,
1204+
output_activation=activate_qspec,
1205+
)
1206+
OP_TO_ANNOTATOR["conv"](model, quantization_config)
1207+
1208+
def validate(self, model: torch.fx.GraphModule) -> None:
1209+
pass
1210+
12151211
class M(torch.nn.Module):
12161212
def __init__(self, dtype):
12171213
super().__init__()
@@ -1220,7 +1216,7 @@ def __init__(self, dtype):
12201216
def forward(self, x):
12211217
return self.conv(x)
12221218

1223-
quantizer = self.DtypeActQuantizer(quant_dtype=quant_dtype, op_name="conv")
1219+
quantizer = DtypeActQuantizer()
12241220
node_occurrence = {
12251221
# one for input of the first conv, one for output for the first conv
12261222
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
@@ -1456,13 +1452,9 @@ def forward(self, x):
14561452
for key in n.meta:
14571453
self.assertEqual(n.meta[key], weight_meta[key])
14581454

1459-
@parametrize("quant_dtype", (torch.float32, torch.float8_e5m2, torch.float8_e4m3fn))
1460-
def test_save_load(self, quant_dtype=None):
1455+
def test_save_load(self):
14611456
"""Test save/load a quantized model"""
1462-
quantizer = None
1463-
if quant_dtype != torch.float32:
1464-
quantizer = self.DtypeActQuantizer(quant_dtype=quant_dtype, op_name="conv")
1465-
m = self._get_pt2e_quantized_linear(quantizer=quantizer)
1457+
m = self._get_pt2e_quantized_linear()
14661458
example_inputs = (torch.randn(2, 2),)
14671459
ref_res = m(*example_inputs)
14681460

torch/_export/serde/schema.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ class ScalarType(IntEnum):
2727
COMPLEXDOUBLE = 11
2828
BOOL = 12
2929
BFLOAT16 = 13
30-
FLOAT8_E5M2 = 23
31-
FLOAT8_E4M3FN = 24
3230
UINT16 = 28
3331

3432

torch/_export/serde/schema.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# @generated by update_schema.py
2-
# checksum<<976f1a95674e0e9ca72f7bb7df2e648172aaae0ca43a52f7c6f814c79a96ddf4>>
2+
# checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
33
Argument:
44
kind: union
55
fields:
@@ -338,8 +338,6 @@ ScalarType:
338338
COMPLEXDOUBLE: 11
339339
BOOL: 12
340340
BFLOAT16: 13
341-
FLOAT8_E5M2: 23
342-
FLOAT8_E4M3FN: 24
343341
UINT16: 28
344342
SchemaVersion:
345343
kind: struct

torch/_export/serde/serialize.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ def _reverse_map(d: Dict[Any, Enum]):
140140
torch.complex128: ScalarType.COMPLEXDOUBLE,
141141
torch.bool: ScalarType.BOOL,
142142
torch.bfloat16: ScalarType.BFLOAT16,
143-
torch.float8_e4m3fn: ScalarType.FLOAT8_E4M3FN,
144-
torch.float8_e5m2: ScalarType.FLOAT8_E5M2,
145143
}
146144

147145

torch/csrc/utils/generated_serialization_types.h

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

torch/testing/_internal/common_quantization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,18 +1333,18 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
13331333
m = convert_pt2e(m)
13341334
return m
13351335

1336-
def _get_pt2e_quantized_linear(self, is_per_channel=False, quantizer=None) -> torch.fx.GraphModule:
1336+
def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
13371337
class M(torch.nn.Module):
13381338
def __init__(self) -> None:
13391339
super().__init__()
13401340
self.linear = torch.nn.Linear(2, 2)
13411341

13421342
def forward(self, x):
13431343
return self.linear(x)
1344-
if quantizer is None:
1345-
quantizer = XNNPACKQuantizer()
1346-
operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
1347-
quantizer.set_global(operator_config)
1344+
1345+
quantizer = XNNPACKQuantizer()
1346+
operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
1347+
quantizer.set_global(operator_config)
13481348
example_inputs = (torch.randn(2, 2),)
13491349
m = M().eval()
13501350
return self._quantize(m, quantizer, example_inputs)

0 commit comments

Comments
 (0)