Skip to content

[Spectral Normalization] KeyError on load_state_dict #21251

@frgfm

Description

@frgfm

Hi there,

As I couldn't find a SpectralNorm class to add it in a nn.Sequential juste like nn.BatchNorm2d, I used the nn.utils.spectral_norm function and applied it to my conv layers.

When I load a state_dict in a modified version of a nn.Module, I always pass strict=False. However, here, I received a KeyError when loading the weights of the layer that got spectral_norm applied to it.

This does not happen if I remove the spectral_norm or if I had common normalization modules such as nn.BatchNorm2d or nn.InstanceNorm2d.

I may have missed something in the documentation but I guess that the strict parameter is not passed down (on purpose) to functions applied on Modules.
Apparently, the code section involved here is this one:
https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/spectral_norm.py#L165-L171

To Reproduce

Steps to reproduce the behavior:

import torch.nn as nn
mod = nn.Module()
mod.add_module('a', nn.utils.spectral_norm(nn.Conv2d(2, 4, 3)))
state_dict = mod.state_dict()
mod.add_module('b', nn.utils.spectral_norm(nn.Conv2d(2, 4, 3)))
mod.load_state_dict(state_dict, strict=False)

which produces a KeyError. Here is the end of the Traceback:

~/miniconda3/lib/python3.7/site-packages/torch/nn/utils/spectral_norm.py in __call__(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
    163         if version is None or version < 1:
    164             with torch.no_grad():
--> 165                 weight_orig = state_dict[prefix + fn.name + '_orig']
    166                 weight = state_dict.pop(prefix + fn.name)
    167                 sigma = (weight_orig / weight).mean()

KeyError: 'b.weight_orig'

Expected behavior

With the same code (including strict=False), I was expecting the same behavior as any other missing parameters:

IncompatibleKeys(missing_keys=['b.bias', 'b.weight_orig', 'b.weight_u', 'b.weight_v'], unexpected_keys=[])

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04) 7.4.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.168
GPU models and configuration: GPU 0: GeForce GTX 1050
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0

Versions of relevant libraries:
[pip3] numpy==1.13.3
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.3.0 py37_cu10.0.130_1 pytorch

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions