Skip to content

Cannot print LayerNorm in ScriptModule #20978

@draplater

Description

@draplater

🐛 Bug

Version: master (1.2.0.dev20190526)
Got KeyError: 'elementwise_affine' when printing a ScriptModule with a LayerNorm in it.

To Reproduce

import torch
from torch.nn import LayerNorm
from torch.jit import ScriptModule


class Test(ScriptModule):
    def __init__(self, dim):
        super().__init__()
        self.layer_norm = LayerNorm(dim)


if __name__ == '__main__':
    m = Test(100)
    print(m)

The result is:

raceback (most recent call last):
  File "/home/chenyufei/Development/nn-parser/local_scripts/pytorch_1_2_master_bug.py", line 14, in <module>
    print(m)
  File "/home/chenyufei/.local/anaconda3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1044, in __repr__
    mod_str = repr(module)
  File "/home/chenyufei/.local/anaconda3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1038, in __repr__
    extra_repr = self.extra_repr()
  File "/home/chenyufei/.local/anaconda3.7/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 161, in extra_repr
    'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
KeyError: 'elementwise_affine'

Expected behavior

In stable version, it is:

Test(
  (layer_norm): WeakScriptModuleProxy()
)

Environment

Collecting environment information...
PyTorch version: 1.2.0.dev20190526
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 6.5.0-2ubuntu1~18.04) 6.5.0 20181026
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 980 Ti
GPU 1: GeForce GTX TITAN X

Nvidia driver version: 390.116
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.3.1
[pip3] numpy==1.16.2
[pip3] numpydoc==0.8.0
[pip3] pytorch-pretrained-bert==0.6.1
[pip3] torch==1.0.0
[pip3] torchfile==0.1.0
[pip3] torchtext==0.4.0
[pip3] torchvision==0.2.1
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions