Skip to content

torch.quantization.convert() doesn't remove observer from empty Sequential() module, causing error when scripting #28375

@hx89

Description

@hx89

🐛 Bug

torch.quantization.convert() doesn't remove observer from empty Sequential() module, causing errors when scripting the model since observers are not scriptable.

Steps to reproduce the behavior:

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.branch1 = nn.Sequential()
    
    def forward(self, x):
        return x
    
model = Test()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig()
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
print(model)
torch.jit.script(model)

The printed model has observer in it:

Test(
  (branch1): Sequential(
    (activation_post_process): HistogramObserver()
  )
)

Got error:

---------------------------------------------------------------------------
UnsupportedNodeError                      Traceback (most recent call last)
<ipython-input-2-199c23a9d1bf> in <module>
     16 torch.quantization.convert(model, inplace=True)
     17 print(model)
---> 18 torch.jit.script(model)

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
   1237 
   1238     if isinstance(obj, torch.nn.Module):
-> 1239         return torch.jit.torch.jit._recursive.recursive_script(obj)
   1240 
   1241     qualified_name = _qualified_name(obj)

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in recursive_script(nn_module)
    506         return create_constant_iterable_module(nn_module)
    507 
--> 508     return create_script_module(nn_module, infer_methods_to_compile(nn_module))
    509 
    510 def try_compile_fn(fn, loc):

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in create_script_module(nn_module, stubs)
    303     """
    304     check_module_initialized(nn_module)
--> 305     concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
    306     cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
    307 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
    241             return scripted._concrete_type
    242 
--> 243         raw_concrete_type = infer_raw_concrete_type(nn_module)
    244 
    245         nn_module_type = type(nn_module)

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in infer_raw_concrete_type(nn_module)
     89 
     90     for name, item in nn_module._modules.items():
---> 91         sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
     92         concrete_type.add_module(name, sub_concrete_type)
     93         added_names.add(name)

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
    238             # compilation path. But for now, just mimic what compilation does when
    239             # generating a ConcreteType
--> 240             scripted = create_constant_iterable_module(nn_module)
    241             return scripted._concrete_type
    242 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in create_constant_iterable_module(module)
    537             modules[key] = create_constant_iterable_module(submodule)
    538         else:
--> 539             modules[key] = recursive_script(submodule)
    540 
    541     if isinstance(module, Sequential):

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in recursive_script(nn_module)
    506         return create_constant_iterable_module(nn_module)
    507 
--> 508     return create_script_module(nn_module, infer_methods_to_compile(nn_module))
    509 
    510 def try_compile_fn(fn, loc):

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in infer_methods_to_compile(nn_module)
    489     stubs = []
    490     for method in uniqued_methods:
--> 491         stubs.append(make_stub_from_method(nn_module, method))
    492     return overload_stubs + stubs
    493 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in make_stub_from_method(nn_module, method)
     39     if isinstance(func, ScriptMethodStub):
     40         return func
---> 41     return make_stub(func)
     42 
     43 # base types that can be constants

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in make_stub(func)
     32 def make_stub(func):
     33     rcb = _jit_internal.createResolutionCallbackFromClosure(func)
---> 34     ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
     35     return ScriptMethodStub(rcb, ast, func)
     36 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in get_jit_def(fn, self_name)
    167     type_line = torch.jit.annotations.get_type_line(source)
    168     ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, _uses_true_division(fn))
--> 169     return build_def(ctx, py_ast.body[0], type_line, self_name)
    170 
    171 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in build_def(ctx, py_def, type_line, self_name)
    208     return Def(Ident(r, py_def.name),
    209                decl,
--> 210                build_stmts(ctx, body))
    211 
    212 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in build_stmts(ctx, stmts)
    125 
    126 def build_stmts(ctx, stmts):
--> 127     stmts = [build_stmt(ctx, s) for s in stmts]
    128     return list(filter(None, stmts))
    129 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in <listcomp>(.0)
    125 
    126 def build_stmts(ctx, stmts):
--> 127     stmts = [build_stmt(ctx, s) for s in stmts]
    128     return list(filter(None, stmts))
    129 

/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in __call__(self, ctx, node)
    182         method = getattr(self, 'build_' + node.__class__.__name__, None)
    183         if method is None:
--> 184             raise UnsupportedNodeError(ctx, node)
    185         return method(ctx, node)
    186 

UnsupportedNodeError: with statements aren't supported:
at /mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/quantization/observer.py:550:8
    def forward(self, x):
        with torch.no_grad():
        ~~~~ <--- HERE
            min_val = self.min_val
            max_val = self.max_val

cc @suo @jerryzh168 @jianyuh @dzhulgakov @raghuramank100

Metadata

Metadata

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queueoncall: quantizationQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions