-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add initial support for serializing classes #22953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
| push<OpCode>(OpCode::NEWOBJ); | ||
| push<OpCode>(OpCode::EMPTY_DICT); | ||
| push<OpCode>(OpCode::MARK); | ||
| for (size_t i = 0, n = type->numAttributes(); i < n; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if __getstate__ is present and use that if so instead of the state dict, but that's fine to leave for a later PR
| # "the attr {} on module {} is not the the class".format(name, name, module_name)) | ||
|
|
||
| # __main__ is a builtin module, so rewrite it to "__torch__". | ||
| if module_name == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we do this? How can the names collide if this is all happening in the compiler? This would make pickle files with classes in them un-readable from Python without a custom unpickler to translate __torch__ to __main__.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know, this is just a move of _qualname so it is accessible in attributes.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From @suo, it was originally for some namespace uniqueness issues but can probably be deleted, we should do that before the 1.2 release
| raise RuntimeError("Could not get qualified name for class '{}': " | ||
| "__module__ can't be None.".format(name)) | ||
|
|
||
| # if getattr(sys.modules[module_name], name) is not obj: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be deleted if it's not used
| self.checkModule(M(), (torch.randn(5, 5),)) | ||
|
|
||
| def test_attributes(self): | ||
| @torch.jit.script |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also test a class that uses another class as an attribute
| auto module_name = readString(); | ||
| auto class_name = readString(); | ||
| // TODO [unpickler refactor] __main__ isn't used by the pickler anymore | ||
| if (module_name == "__main__") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can delete the whole if (module_name == "__main__") { block and the comment above as well as long as it passes in fbcode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ill land this and then do that as a PR whose only purpose is to see if we can drop the old pickler stuff.
| } | ||
|
|
||
| template <typename T> | ||
| static IValue toSpecializedList(const IValue& generic) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the extra copy this is now doing going to be a perf bottleneck later on that we'll end up reverting back to specializing when the list is created?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, before the patch I landed today, we were already accidentally doing this copy anyway. I have a plan to try to remove specialization entirely that I am going to try. If it end ups being a perf bottleneck, I can restore the copy-free version, but it is a lot of confusing logic so I'd like to avoid it if it is not needed.
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Stack from ghstack:
Differential Revision: D16340214