-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ API parity: at::Tensor::requires_grad_ #26332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks awesome! I left a minor comment.
|
This will break things in |
|
@eellison Do you mind elaborating more on the use cases that this will break? The |
All of the functions in |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @pbelevich !
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @pbelevich!
Could you remove the register_prim_ops implementation ? Those are for registering ops that are not bound to the torch c++ library. There is no need to have it in C++ and in register_prim_ops, since the c++ ops are exposed to the JIT already. |
| return 0; | ||
| }, | ||
| aliasAnalysisConservative()), | ||
| Operator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good and should fix the issue in JIT
bwasti
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assuming all tests pass
|
@pbelevich edited my comment |
|
@eellison , this one wouldn't be exposed correctly. note the "conservative" annotation in the prim_ops registration. |
Do you know that the register_prim_ops schema is being matched to before the native_functions one? |
|
@eellison that seems to be how it is used in the file, I don't know if bugs have crept in -- the c10 dispatch code is extremely hard to understand. I believe, if there is an issue, there will be a runtime error |
From looking at the code I would suspect this should raise, since it's the same schema with a different options. If doesn't than i would think that's a bug. cc @smessmer |
smessmer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
smessmer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait I'm a bit confused here. The entry in native_functions.yaml should already create a jit op for this in register_aten_ops.cpp. Why is the one in register_prim_ops.cpp needed?
@eellison This doesn't crash because register_prim_ops.cpp isn't the c10 operator library, that's a shortcut directly to jit which should only be used if absolutely needed.
smessmer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
withdrawing my concerns with recent changes
Differential Revision: [D17427575](https://our.internmc.facebook.com/intern/diff/D17427575) [ghstack-poisoned]
Differential Revision: [D17427575](https://our.internmc.facebook.com/intern/diff/D17427575) [ghstack-poisoned]
|
@pbelevich merged this pull request in 46f96d1. |
Summary: Pull Request resolved: pytorch/pytorch#26332 Test Plan: Imported from OSS Differential Revision: D17427575 Pulled By: pbelevich fbshipit-source-id: 5500169a4fa0ef9cc2a7272e13b6e2d89df09260
Summary: Pull Request resolved: pytorch#26332 Test Plan: Imported from OSS Differential Revision: D17427575 Pulled By: pbelevich fbshipit-source-id: 5500169a4fa0ef9cc2a7272e13b6e2d89df09260
| }, | ||
| aliasAnalysisConservative()), | ||
| Operator( | ||
| "aten::requires_grad_(Tensor(a!) self, bool _requires_grad=True) -> Tensor(a!)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbelevich hey any reason we have this schema with _requires_grad instead of requires_grad? This is creating discrepancy between the jit api and the python api as it take requires_grad..
Stack from ghstack:
Differential Revision: D17427575