Skip to content

[PT2E Quantization] Graph with concatenation of the same node will raise RecursionError when prepare_pt2e #129038

@siahuat0727

Description

@siahuat0727

🐛 Describe the bug

Graph with concatenation of the same node will raise RecursionError when prepare_pt2e.

import torch
import torch.export._trace
from torch import nn
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 3)

    def forward(self, x):
        x = self.conv(x)  # not neccesary to have a conv to reproduce the bug
        x = torch.cat([x, x])
        return x

model = Net()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
model = torch.export._trace._export(
    model, (torch.rand(1, 3, 32, 32),), pre_dispatch=True
).module()
model = prepare_pt2e(model, quantizer)

Output

Traceback (most recent call last):
  File "reproduce_cat_recursion_error_bug.py", line 28, in <module>
    model = prepare_pt2e(model, quantizer)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/quantize_pt2e.py", line 109, in prepare_pt2e
    model = prepare(model, node_name_to_scope, is_qat=False)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 470, in prepare
    edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 196, in _get_edge_or_node_to_group_id
    input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 83, in _unwrap_shared_qspec
    return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 83, in _unwrap_shared_qspec
    return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 83, in _unwrap_shared_qspec
    return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
  [Previous line repeated 991 more times]
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 81, in _unwrap_shared_qspec
    root = _find_root_edge_or_node(sharing_with, shared_with_map)
  File "/home/chensf/git/pytorch-quantization/env/lib/python3.8/site-packages/torch/ao/quantization/pt2e/prepare.py", line 42, in _find_root_edge_or_node
    parent = shared_with_map[edge_or_node]
RecursionError: maximum recursion depth exceeded in comparison

Note that edge_or_node is (conv2d, cat) when fails.

Versions

torch==2.3.1

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim @ezyang @anijain2305 @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: pt2oncall: quantizationQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions