-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Hi there,
As I couldn't find a SpectralNorm class to add it in a nn.Sequential juste like nn.BatchNorm2d, I used the nn.utils.spectral_norm function and applied it to my conv layers.
When I load a state_dict in a modified version of a nn.Module, I always pass strict=False. However, here, I received a KeyError when loading the weights of the layer that got spectral_norm applied to it.
This does not happen if I remove the spectral_norm or if I had common normalization modules such as nn.BatchNorm2d or nn.InstanceNorm2d.
I may have missed something in the documentation but I guess that the strict parameter is not passed down (on purpose) to functions applied on Modules.
Apparently, the code section involved here is this one:
https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/spectral_norm.py#L165-L171
To Reproduce
Steps to reproduce the behavior:
import torch.nn as nn
mod = nn.Module()
mod.add_module('a', nn.utils.spectral_norm(nn.Conv2d(2, 4, 3)))
state_dict = mod.state_dict()
mod.add_module('b', nn.utils.spectral_norm(nn.Conv2d(2, 4, 3)))
mod.load_state_dict(state_dict, strict=False)
which produces a KeyError. Here is the end of the Traceback:
~/miniconda3/lib/python3.7/site-packages/torch/nn/utils/spectral_norm.py in __call__(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
163 if version is None or version < 1:
164 with torch.no_grad():
--> 165 weight_orig = state_dict[prefix + fn.name + '_orig']
166 weight = state_dict.pop(prefix + fn.name)
167 sigma = (weight_orig / weight).mean()
KeyError: 'b.weight_orig'
Expected behavior
With the same code (including strict=False), I was expecting the same behavior as any other missing parameters:
IncompatibleKeys(missing_keys=['b.bias', 'b.weight_orig', 'b.weight_u', 'b.weight_v'], unexpected_keys=[])
Environment
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04) 7.4.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.168
GPU models and configuration: GPU 0: GeForce GTX 1050
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0
Versions of relevant libraries:
[pip3] numpy==1.13.3
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.3.0 py37_cu10.0.130_1 pytorch