-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] support torch.as_tensor in script #23247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't really looked into it but i think all of the implementation is more or less a copy-pasta of the torch.tensor code. Can you refactor the code ?
[jit] support torch.as_tensor in script gh-metadata: pytorch pytorch 23247 gh/wanchaol/32/head
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! i think there are a couple errors in shape analysis that need fixing.
Also in the interest of supporting GRU in script it's worth mentioning that you could rewrite:
torch.as_tensor(lengths, dtype=torch.int64) to lengths.to(dtype=torch.int64) which is currently supported
| IValue dtype; | ||
| IValue device; | ||
| if (if_set_requires_grad) { | ||
| pop(stack, data, dtype, device, requires_grad); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little confusing to check if this function is correct you have to look back at all of its callsites and see if if the schema matches to verify correctness, which kind of breaks abstraction of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by this? this abstraction only differed in a flag where one schema contains requires_grad and the other not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add comment that torch::tensor has a fourth requires_grad arg that as_tensor does not have. or not
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one function that is copy-pasta'd to be refactored and should be ready to go
| IValue dtype; | ||
| IValue device; | ||
| if (if_set_requires_grad) { | ||
| pop(stack, data, dtype, device, requires_grad); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add comment that torch::tensor has a fourth requires_grad arg that as_tensor does not have. or not
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. One small comment about refactoring but should be good to go
Summary: Pull Request resolved: pytorch/pytorch#23247 Test Plan: Imported from OSS Differential Revision: D16466590 Pulled By: wanchaol fbshipit-source-id: cf52721eacd177d9040564790382db13a9fcc2fe
Stack from ghstack:
Differential Revision: D16466590