-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Fix weak module cuda() _flat_weights bug #21107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind explaining what MRO was and how exactly htis fixed the cuda() bug (and maybe add test for that bug) ?
|
|
||
| self.checkScript(int1, ()) | ||
|
|
||
| def test_number_all(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this meant to be in here?
| weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100) | ||
| self.assertFalse(strong_mod(inp).allclose(weak_mod(inp))) | ||
|
|
||
| @unittest.skipIf(hasattr(torch.jit, 'WeakScriptModuleProxy'), "# TODO: re-enable" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re-enable this when what lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The full message is split over 2 strings
|
For one it was making an inheritance diamond on |
|
Could you add a test ? |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with follow up to test. Once this lands we should re-open the isinstance bug
| self.lstm = torch.nn.LSTM(5, 5) | ||
| self.lstm.cuda() | ||
|
|
||
| m = M() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you assert that the output is a cuda ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? This isn't testing the forward pass, just that a weak module can be made into a cuda module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes could you verify that the ouptut of the module is a cuda tensor. This is just testing that it doesn't error, it's not testing the runtime.
Dynamically creating a type at runtime was messing up the MRO and has been causing many other problems. I think it's best to delete it, this causes a regression since
will now be
Falseagain, but this will be fixed once recursive script mode is the default (#20939)Differential Revision: D15560549