Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

Features:
(1) Add support for tree structure.
(2) Add user warning before axes to shapes conversion
(3) Add suggestion of providing dynamic_shapes when conversion fails

Notes:
(1) input_names is crucial to the conversion, as we don't know the ONNX graph inputs.
(2) min and max are set as default, so LLM has higher chance to fail if users use dynamic_axes in terms of the min/max constraints dependency between attention_mask and sequence_length, etc. (Found in llama-3.2-1B_Instruct)

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140488

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 123e85d with merge base f98c601 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Nov 13, 2024
@titaiwangms titaiwangms changed the title [ONNX] Improve from dynamic axes to shapes [ONNX] Improve the conversion of from dynamic axes to shapes Nov 13, 2024
@titaiwangms titaiwangms added the topic: improvements topic category label Nov 13, 2024
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 13, 2024
@titaiwangms
Copy link
Collaborator Author

titaiwangms commented Nov 13, 2024

I have tested this PR on Olive with llama, and it works:

olive capture-onnx-graph -m meta-llama/Llama-3.2-1B-Instruct --torch_dtype float32 -o models/dynamo_export --log_level 0 --use_dynamo_exporter

However, I would say (1) re.sub on torch.export.Dim naming and (2) max=99999 on constraints is kind of over customized. Alternative way is that we don't do these two for users, and ask them to provide the correct information (we would have to change code in Olive side.). But still, max and min are hard to be decided in automation I think.

raise ValueError("model has no forward method and is not callable")


def _from_dynamic_axes_to_dynamic_shapes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to unit test this function a little more? Just to ensure rubustness

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Nice use of pytree. LGTM overall; more unit tests should help us be more confident.

@justinchuby
Copy link
Collaborator

For test, I suggest reproducing the huggingface model signature and make sure that works.

@titaiwangms
Copy link
Collaborator Author

@justinchuby I added some extra tests for pytree, and there are existing small model tests already in test_api.py (they were added when conversion was introduced.). I find it hard to include hf models in the test of _from_dynamic_axes_to_dynamic_shapes, because the function basically asks everything: model, args, kwargs, input_names, etc. Do you have a better way to test this? Otherwise, we might need to build up llm testing...

@justinchuby
Copy link
Collaborator

justinchuby commented Nov 14, 2024

We don't need to include the hf models. (And we should not.) I would just replicate its forward signature. The test model doesn't need to do anything except for having the same forward signature as hf models.

dynamic_axes=dynamic_axes,
)

# NOTE: torch.export.Dim being an object makes it impossible to compare the objects directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we implement __eq__ on Dim? Just a thought

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a valid request?

class _Dim(type):

"input_ids": torch.randn(2, 16),
"attention_mask": torch.randn(2, 32),
"position_ids": torch.randn(2, 16),
"past_key_values": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since 16 entries and 3 entries are effectively the same, we can shorten this test case a little

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ch#140488)

Features:
(1) Add support for tree structure.
(2) Add user warning before axes to shapes conversion
(3) Add suggestion of providing `dynamic_shapes` when conversion fails

Notes:
(1) `input_names` is crucial to the conversion, as we don't know the ONNX graph inputs.
(2) min and max are set as default, so LLM has higher chance to fail if users use `dynamic_axes` in terms of the min/max constraints dependency between `attention_mask` and `sequence_length`, etc. (Found in llama-3.2-1B_Instruct)
Pull Request resolved: pytorch#140488
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

onnx.export(dynamo=True): input_names processing is broken when dynamic_axes and list inputs are used

5 participants