Skip to content

Conversation

@zdevito
Copy link
Contributor

@zdevito zdevito commented Jul 17, 2019

Stack from ghstack:

Differential Revision: D16340214

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: pybind Related to our Python bindings / interactions with other Python libraries labels Jul 17, 2019
zdevito added 3 commits July 16, 2019 18:27
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
@zdevito zdevito requested a review from driazati July 17, 2019 17:56
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
push<OpCode>(OpCode::NEWOBJ);
push<OpCode>(OpCode::EMPTY_DICT);
push<OpCode>(OpCode::MARK);
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check if __getstate__ is present and use that if so instead of the state dict, but that's fine to leave for a later PR

# "the attr {} on module {} is not the the class".format(name, name, module_name))

# __main__ is a builtin module, so rewrite it to "__torch__".
if module_name == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this? How can the names collide if this is all happening in the compiler? This would make pickle files with classes in them un-readable from Python without a custom unpickler to translate __torch__ to __main__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, this is just a move of _qualname so it is accessible in attributes.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From @suo, it was originally for some namespace uniqueness issues but can probably be deleted, we should do that before the 1.2 release

raise RuntimeError("Could not get qualified name for class '{}': "
"__module__ can't be None.".format(name))

# if getattr(sys.modules[module_name], name) is not obj:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted if it's not used

self.checkModule(M(), (torch.randn(5, 5),))

def test_attributes(self):
@torch.jit.script
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also test a class that uses another class as an attribute

auto module_name = readString();
auto class_name = readString();
// TODO [unpickler refactor] __main__ isn't used by the pickler anymore
if (module_name == "__main__") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete the whole if (module_name == "__main__") { block and the comment above as well as long as it passes in fbcode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ill land this and then do that as a PR whose only purpose is to see if we can drop the old pickler stuff.

}

template <typename T>
static IValue toSpecializedList(const IValue& generic) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the extra copy this is now doing going to be a perf bottleneck later on that we'll end up reverting back to specializing when the list is created?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, before the patch I landed today, we were already accidentally doing this copy anyway. I have a plan to try to remove specialization entirely that I am going to try. If it end ups being a perf bottleneck, I can restore the copy-free version, but it is a lot of confusing logic so I'd like to avoid it if it is not needed.

zdevito added 4 commits July 17, 2019 13:40
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
Add initial support for serializing classes

gh-metadata: pytorch pytorch 22953 gh/zdevito/73/head
@zou3519 zou3519 deleted the gh/zdevito/73/head branch July 19, 2019 21:55
@facebook-github-bot
Copy link
Contributor

@zdevito merged this pull request in c09e922.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants