-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queueoncall: quantizationQuantization support in PyTorchQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
torch.quantization.convert() doesn't remove observer from empty Sequential() module, causing errors when scripting the model since observers are not scriptable.
Steps to reproduce the behavior:
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.branch1 = nn.Sequential()
def forward(self, x):
return x
model = Test()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig()
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
print(model)
torch.jit.script(model)
The printed model has observer in it:
Test(
(branch1): Sequential(
(activation_post_process): HistogramObserver()
)
)
Got error:
---------------------------------------------------------------------------
UnsupportedNodeError Traceback (most recent call last)
<ipython-input-2-199c23a9d1bf> in <module>
16 torch.quantization.convert(model, inplace=True)
17 print(model)
---> 18 torch.jit.script(model)
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
1237
1238 if isinstance(obj, torch.nn.Module):
-> 1239 return torch.jit.torch.jit._recursive.recursive_script(obj)
1240
1241 qualified_name = _qualified_name(obj)
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in recursive_script(nn_module)
506 return create_constant_iterable_module(nn_module)
507
--> 508 return create_script_module(nn_module, infer_methods_to_compile(nn_module))
509
510 def try_compile_fn(fn, loc):
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in create_script_module(nn_module, stubs)
303 """
304 check_module_initialized(nn_module)
--> 305 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
306 cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
307
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
241 return scripted._concrete_type
242
--> 243 raw_concrete_type = infer_raw_concrete_type(nn_module)
244
245 nn_module_type = type(nn_module)
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in infer_raw_concrete_type(nn_module)
89
90 for name, item in nn_module._modules.items():
---> 91 sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
92 concrete_type.add_module(name, sub_concrete_type)
93 added_names.add(name)
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
238 # compilation path. But for now, just mimic what compilation does when
239 # generating a ConcreteType
--> 240 scripted = create_constant_iterable_module(nn_module)
241 return scripted._concrete_type
242
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in create_constant_iterable_module(module)
537 modules[key] = create_constant_iterable_module(submodule)
538 else:
--> 539 modules[key] = recursive_script(submodule)
540
541 if isinstance(module, Sequential):
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in recursive_script(nn_module)
506 return create_constant_iterable_module(nn_module)
507
--> 508 return create_script_module(nn_module, infer_methods_to_compile(nn_module))
509
510 def try_compile_fn(fn, loc):
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in infer_methods_to_compile(nn_module)
489 stubs = []
490 for method in uniqued_methods:
--> 491 stubs.append(make_stub_from_method(nn_module, method))
492 return overload_stubs + stubs
493
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in make_stub_from_method(nn_module, method)
39 if isinstance(func, ScriptMethodStub):
40 return func
---> 41 return make_stub(func)
42
43 # base types that can be constants
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/_recursive.py in make_stub(func)
32 def make_stub(func):
33 rcb = _jit_internal.createResolutionCallbackFromClosure(func)
---> 34 ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
35 return ScriptMethodStub(rcb, ast, func)
36
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in get_jit_def(fn, self_name)
167 type_line = torch.jit.annotations.get_type_line(source)
168 ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, _uses_true_division(fn))
--> 169 return build_def(ctx, py_ast.body[0], type_line, self_name)
170
171
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in build_def(ctx, py_def, type_line, self_name)
208 return Def(Ident(r, py_def.name),
209 decl,
--> 210 build_stmts(ctx, body))
211
212
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in build_stmts(ctx, stmts)
125
126 def build_stmts(ctx, stmts):
--> 127 stmts = [build_stmt(ctx, s) for s in stmts]
128 return list(filter(None, stmts))
129
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in <listcomp>(.0)
125
126 def build_stmts(ctx, stmts):
--> 127 stmts = [build_stmt(ctx, s) for s in stmts]
128 return list(filter(None, stmts))
129
/mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/jit/frontend.py in __call__(self, ctx, node)
182 method = getattr(self, 'build_' + node.__class__.__name__, None)
183 if method is None:
--> 184 raise UnsupportedNodeError(ctx, node)
185 return method(ctx, node)
186
UnsupportedNodeError: with statements aren't supported:
at /mnt/xarfuse/uid-136047/9da60e26-ns-4026531840/torch/quantization/observer.py:550:8
def forward(self, x):
with torch.no_grad():
~~~~ <--- HERE
min_val = self.min_val
max_val = self.max_val
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queueoncall: quantizationQuantization support in PyTorchQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module