-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[export] Add support for symbool to make it usable for torch.cond #138765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138765
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dbd5906 with merge base 30a83ca ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
Summary: todo Test Plan: ci Differential Revision: D64867504
ec14551 to
94742f0
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
Summary: expect failure right now need pytorch#138765 pytorch#138760 Differential Revision: D64936442
torch/export/graph_signature.py
Outdated
| elif isinstance(val, SymInt): | ||
| return SymIntArgument(name=node.name) | ||
| elif isinstance(val, SymBool): | ||
| return SymIntArgument(name=node.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maye need to add a SymBoolArgument and some test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @angelayi for help on the implementation
94742f0 to
4e53e47
Compare
Summary: todo Test Plan: ci Differential Revision: D64867504
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
|
@ydwu4 Added Symbool argument, added tests in export and aot inductor. |
Summary: todo Test Plan: ci Differential Revision: D64867504
4e53e47 to
76b2ecb
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
|
@ydwu4 I am getting the following rst error. Any idea? |
|
@pytorchbot rebase -b viable/strict |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
76b2ecb to
b67372e
Compare
Summary: todo Test Plan: ci Differential Revision: D64867504
| ), | ||
| "test_size_from_multi_output": fail_stack_allocation(is_skip=True), | ||
| "test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(), | ||
| # TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you paste the full error stack trace? cc @desertfire for review on this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py", line 134, in <genexpr>
f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
/torch/_inductor/ir.py", line 393, in get_dtype
return self.dtype
AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
|
|
||
| for input_spec, node in zip(gs.input_specs, input_node_names): | ||
| if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)): | ||
| if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymBoolArgument)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we have SymBoolArgument inputs? Can add a test for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydwu4 I am not sure, I was just following the pattern of SymIntArgument.
Any SymIntArgument test that I can take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may check test_lifted_constants::Foo
Appears to be some doc issue. Not familiar with doc issues. Probably can follow https://github.com/pytorch/pytorch/blob/main/docs/source/export.rst?plain=1#L896 . But not sure why this is not triggered for SymIntArgument. Worth digging deeper. |
b67372e to
ea17f08
Compare
Summary: todo Test Plan: ci Differential Revision: D64867504
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
ea17f08 to
fdc9f01
Compare
Summary: Pull Request resolved: #138765 todo Test Plan: ci Differential Revision: D64867504
ydwu4
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add another test case for test_lifted_constants of boolean?
Summary: todo Test Plan: ci Differential Revision: D64867504
fdc9f01 to
dbd5906
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64867504 |
my bad, added now |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Why?
I want the following code to work.
minimal repro:
error: AssertionError: Encountered an unsupported object of type <class 'torch.SymBool'> while writing the metadata for exported program
second error will be handled by #138760
Motivation
I could technically bypass it with a torch.int tensor. However, it doesn't work with torch.cond. I want the following to work. It would also require #138760 for aot compile to work.
Differential Revision: D64867504
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov