-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: pt2oncall: quantizationQuantization support in PyTorchQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Graph with concatenation of the same node will raise RecursionError when prepare_pt2e.
import torch
import torch.export._trace
from torch import nn
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv(x) # not neccesary to have a conv to reproduce the bug
x = torch.cat([x, x])
return x
model = Net()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
model = torch.export._trace._export(
model, (torch.rand(1, 3, 32, 32),), pre_dispatch=True
).module()
model = prepare_pt2e(model, quantizer)Output
Traceback (most recent call last):
File "reproduce_cat_recursion_error_bug.py", line 28, in <module>
model = prepare_pt2e(model, quantizer)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/quantize_pt2e.py", line 109, in prepare_pt2e
model = prepare(model, node_name_to_scope, is_qat=False)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 470, in prepare
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 196, in _get_edge_or_node_to_group_id
input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 83, in _unwrap_shared_qspec
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 83, in _unwrap_shared_qspec
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 83, in _unwrap_shared_qspec
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
[Previous line repeated 991 more times]
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 81, in _unwrap_shared_qspec
root = _find_root_edge_or_node(sharing_with, shared_with_map)
File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 42, in _find_root_edge_or_node
parent = shared_with_map[edge_or_node]
RecursionError: maximum recursion depth exceeded in comparison
Note that edge_or_node is (conv2d, cat) when fails.
Versions
torch==2.3.1
cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim @ezyang @anijain2305 @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
dgcnz
Metadata
Metadata
Assignees
Labels
oncall: pt2oncall: quantizationQuantization support in PyTorchQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module