Skip to content

Conversation

@dzhulgakov
Copy link
Collaborator

Motivation

We allow to override JIT module serialization with __getstate__/__setstate__ in order to cover cases where parameters are not serializable. Use cases include: MKLDNN integration:

@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense())
@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor]) -> None
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()

and also fbgemm prepacked format integration for quantized tensors.

However many Eager scripts use torch.save(module.state_dict()) form of serialization. There are several ways to make it work:

  • make packed_weight itself pickleable (e.g. by binding __getstate__/__setstate__ on C++ UDT level)
    • change: we’d need to allow module buffers to be of arbitrary, non-Tensor types
    • pro: no change to state_dict behavior
    • cons: might not be directly inspectable by user calling .state_dict(), especially if packed weights represent several tensors fused together
  • make packed_weight being proper Tensor layout
    • pro: no change to state_dict or buffers behavior
    • cons: adding new tensor layouts is pretty costly today
    • cons: doesn’t work if multiple tensors are packed in one interleaved representation
  • [this approach] allow Modules to override state_dict and return regular tensors
    • pro: most flexible and hackable
    • pro: maintains semantic meaning of statedict as all data necessary to represent module’s state
    • cons: complicates state_dict logic
    • cons: potential code duplication between __getstate__/__setstate__

Based on discussions with @zdevito and @gchanan we decided to pick latter approach. Rationale: this behavior is fully opt-in and will impact only modules that need it. For those modules the requirement listed above won't be true. But we do preserve requirement that all elements of state_dict are tensors. (https://fburl.com/qgybrug4 for internal discussion)

In the future we might also implement one of the approaches above but those are more involved.

@pytorchbot pytorchbot added the module: nn Related to torch.nn label Jun 19, 2019
Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I can tell, this doesn't do what you want it to do.

If you nest m inside a Sequential for example, it wont call m's custom state dict function

@ezyang
Copy link
Contributor

ezyang commented Jun 19, 2019

from what I can tell, this doesn't do what you want it to do.

If you nest m inside a Sequential for example, it wont call m's custom state dict function

Err, I don't think I agree? We make recursive calls to the nested modules state_dict which will the handle the dispatch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to give specific example of when using this method is appropriate.

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I suppose you can appease soumith by testing the nested case explicitly here ;)

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to be overriding this method directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be symmetric to the _save_to_state_dict (mostly, except hooks). Better suggestions?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable? I do wonder about _load_from_state_dict though

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 20, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzhulgakov is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@dzhulgakov merged this pull request in 82dd693.

iotamudelta pushed a commit to ROCm/pytorch that referenced this pull request Jun 21, 2019
…1933)

Summary:
# Motivation

We allow to override JIT module serialization with `__getstate__/__setstate__` in order to cover cases where parameters are not serializable. Use cases include: MKLDNN integration: https://github.com/pytorch/pytorch/blob/a388c783505987363717bd4da0b166e8d1d7ccb9/torch/utils/mkldnn.py#L18-L26
and also fbgemm prepacked format integration for quantized tensors.

However many Eager scripts use `torch.save(module.state_dict())` form of serialization. There are several ways to make it work:

* make packed_weight itself pickleable (e.g. by binding `__getstate__/__setstate__` on C++ UDT level)
    * change: we’d need to allow module buffers to be of arbitrary, non-Tensor types
    * pro: no change to state_dict behavior
    * cons: might not be directly inspectable by user calling .state_dict(), especially if packed weights represent several tensors fused together
* make packed_weight being proper Tensor layout
    * pro: no change to state_dict or buffers behavior
    * cons: adding new tensor layouts is pretty costly today
    * cons: doesn’t work if multiple tensors are packed in one interleaved representation
* *[this approach]* allow Modules to override state_dict and return regular tensors
    * pro: most flexible and hackable
    * pro: maintains semantic meaning of statedict as all data necessary to represent module’s state
    * cons: complicates state_dict logic
    * cons: potential code duplication between `__getstate__/__setstate__`

Based on discussions with zdevito and gchanan we decided to pick latter approach. Rationale: this behavior is fully opt-in and will impact only modules that need it. For those modules the requirement listed above won't be true. But we do preserve requirement that all elements of state_dict are tensors. (https://fburl.com/qgybrug4 for internal discussion)

In the future we might also implement one of the approaches above but those are more involved.
Pull Request resolved: pytorch#21933

Differential Revision: D15937678

Pulled By: dzhulgakov

fbshipit-source-id: 3cb5d1a8304d04def7aabc0969d0a2e7be182367
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants