-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Improve the conversion of from dynamic axes to shapes
#140488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Improve the conversion of from dynamic axes to shapes
#140488
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140488
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 123e85d with merge base f98c601 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
from dynamic axes to shapes
|
I have tested this PR on Olive with llama, and it works: However, I would say (1) re.sub on torch.export.Dim naming and (2) max=99999 on constraints is kind of over customized. Alternative way is that we don't do these two for users, and ask them to provide the correct information (we would have to change code in Olive side.). But still, max and min are hard to be decided in automation I think. |
| raise ValueError("model has no forward method and is not callable") | ||
|
|
||
|
|
||
| def _from_dynamic_axes_to_dynamic_shapes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to unit test this function a little more? Just to ensure rubustness
justinchuby
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Nice use of pytree. LGTM overall; more unit tests should help us be more confident.
Co-authored-by: Justin Chu <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
|
For test, I suggest reproducing the huggingface model signature and make sure that works. |
|
@justinchuby I added some extra tests for pytree, and there are existing small model tests already in test_api.py (they were added when conversion was introduced.). I find it hard to include hf models in the test of |
|
We don't need to include the hf models. (And we should not.) I would just replicate its forward signature. The test model doesn't need to do anything except for having the same forward signature as hf models. |
| dynamic_axes=dynamic_axes, | ||
| ) | ||
|
|
||
| # NOTE: torch.export.Dim being an object makes it impossible to compare the objects directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we implement __eq__ on Dim? Just a thought
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a valid request?
pytorch/torch/export/dynamic_shapes.py
Line 56 in 33191bb
| class _Dim(type): |
| "input_ids": torch.randn(2, 16), | ||
| "attention_mask": torch.randn(2, 32), | ||
| "position_ids": torch.randn(2, 16), | ||
| "past_key_values": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since 16 entries and 3 entries are effectively the same, we can shorten this test case a little
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ch#140488) Features: (1) Add support for tree structure. (2) Add user warning before axes to shapes conversion (3) Add suggestion of providing `dynamic_shapes` when conversion fails Notes: (1) `input_names` is crucial to the conversion, as we don't know the ONNX graph inputs. (2) min and max are set as default, so LLM has higher chance to fail if users use `dynamic_axes` in terms of the min/max constraints dependency between `attention_mask` and `sequence_length`, etc. (Found in llama-3.2-1B_Instruct) Pull Request resolved: pytorch#140488 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <[email protected]>
Features:
(1) Add support for tree structure.
(2) Add user warning before axes to shapes conversion
(3) Add suggestion of providing
dynamic_shapeswhen conversion failsNotes:
(1)
input_namesis crucial to the conversion, as we don't know the ONNX graph inputs.(2) min and max are set as default, so LLM has higher chance to fail if users use
dynamic_axesin terms of the min/max constraints dependency betweenattention_maskandsequence_length, etc. (Found in llama-3.2-1B_Instruct)