-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Split nn.Module._save_to_state_dict to make it overridable #21933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from what I can tell, this doesn't do what you want it to do.
If you nest m inside a Sequential for example, it wont call m's custom state dict function
Err, I don't think I agree? We make recursive calls to the nested modules |
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably better to give specific example of when using this method is appropriate.
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I suppose you can appease soumith by testing the nested case explicitly here ;)
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really want to be overriding this method directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be symmetric to the _save_to_state_dict (mostly, except hooks). Better suggestions?
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable? I do wonder about _load_from_state_dict though
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzhulgakov is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@dzhulgakov merged this pull request in 82dd693. |
…1933) Summary: # Motivation We allow to override JIT module serialization with `__getstate__/__setstate__` in order to cover cases where parameters are not serializable. Use cases include: MKLDNN integration: https://github.com/pytorch/pytorch/blob/a388c783505987363717bd4da0b166e8d1d7ccb9/torch/utils/mkldnn.py#L18-L26 and also fbgemm prepacked format integration for quantized tensors. However many Eager scripts use `torch.save(module.state_dict())` form of serialization. There are several ways to make it work: * make packed_weight itself pickleable (e.g. by binding `__getstate__/__setstate__` on C++ UDT level) * change: we’d need to allow module buffers to be of arbitrary, non-Tensor types * pro: no change to state_dict behavior * cons: might not be directly inspectable by user calling .state_dict(), especially if packed weights represent several tensors fused together * make packed_weight being proper Tensor layout * pro: no change to state_dict or buffers behavior * cons: adding new tensor layouts is pretty costly today * cons: doesn’t work if multiple tensors are packed in one interleaved representation * *[this approach]* allow Modules to override state_dict and return regular tensors * pro: most flexible and hackable * pro: maintains semantic meaning of statedict as all data necessary to represent module’s state * cons: complicates state_dict logic * cons: potential code duplication between `__getstate__/__setstate__` Based on discussions with zdevito and gchanan we decided to pick latter approach. Rationale: this behavior is fully opt-in and will impact only modules that need it. For those modules the requirement listed above won't be true. But we do preserve requirement that all elements of state_dict are tensors. (https://fburl.com/qgybrug4 for internal discussion) In the future we might also implement one of the approaches above but those are more involved. Pull Request resolved: pytorch#21933 Differential Revision: D15937678 Pulled By: dzhulgakov fbshipit-source-id: 3cb5d1a8304d04def7aabc0969d0a2e7be182367
Motivation
We allow to override JIT module serialization with
__getstate__/__setstate__in order to cover cases where parameters are not serializable. Use cases include: MKLDNN integration:pytorch/torch/utils/mkldnn.py
Lines 18 to 26 in a388c78
and also fbgemm prepacked format integration for quantized tensors.
However many Eager scripts use
torch.save(module.state_dict())form of serialization. There are several ways to make it work:__getstate__/__setstate__on C++ UDT level)__getstate__/__setstate__Based on discussions with @zdevito and @gchanan we decided to pick latter approach. Rationale: this behavior is fully opt-in and will impact only modules that need it. For those modules the requirement listed above won't be true. But we do preserve requirement that all elements of state_dict are tensors. (https://fburl.com/qgybrug4 for internal discussion)
In the future we might also implement one of the approaches above but those are more involved.