Skip to content

Conversation

@pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Sep 25, 2024

Removing _transform_shapes_for_default_dynamic and assume_static_by_default=False as added in #133620.

This reverts back to assume_static_by_default=True with the use of dynamo decorators (e.g. maybe_mark_dynamic, mark_static, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (mark_dynamic, mark_unbacked).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the dynamic_shapes specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g.

if marked_unbacked:
dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED
elif (
constraint_size is not None
or marked_dynamic
or marked_weak_dynamic
or is_nested_int(e.size()[i])
):
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override symbolic_context in this
# case
dynamic_size = DimDynamic.DYNAMIC
elif static_shapes or config.assume_static_by_default or marked_static:
dynamic_size = DimDynamic.STATIC
else:
dynamic_size = DimDynamic.DUCK
if constraint_stride is not None:
dynamic_stride = DimDynamic.DYNAMIC
else:
dynamic_stride = DimDynamic.INFER_STRIDE

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.

Differential Revision: D63358628

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136591

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit db71c3d with merge base c878ea2 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

facebook-github-bot pushed a commit that referenced this pull request Sep 25, 2024
Summary: Pull Request resolved: #136591

Differential Revision: D63358628
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

Copy link
Contributor

@avikchaudhuri avikchaudhuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine as long as tests pass, but I'm not entirely sure what the effect of the change is in this PR, could you write a summary please?

What I'd like is the some kind of invariant to hold around STATIC, constant dims, DYNAMIC, AUTO, and None here. Looks like we're not transforming specs any more, but using a marking scheme to get what we want. Who consumes these markings, is it just dynamo, or is it make_fx as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, trying to understand...so you start off with mostly STATIC here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for now everything here can be static except those marked with Dim.AUTO. This doesn't interfere with the old Dim() specs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does AUTO come in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DimHints now get handled in _process_dynamic_shapes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this buy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing functionality wise, we just don't have to maintain this specs inverting which is already pretty confusing to me today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, so we'll only allow specs to dictate which ones are marked? Should we warn if any of these were set, asking to use specs instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah added a comment in the last PR description line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else what? Maybe handle that explicitly with a pass and then assert / give a useful error if anything unexpected happens?

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 25, 2024
@pianpwk
Copy link
Contributor Author

pianpwk commented Sep 25, 2024

Looks fine as long as tests pass, but I'm not entirely sure what the effect of the change is in this PR, could you write a summary please?

What I'd like is the some kind of invariant to hold around STATIC, constant dims, DYNAMIC, AUTO, and None here. Looks like we're not transforming specs any more, but using a marking scheme to get what we want. Who consumes these markings, is it just dynamo, or is it make_fx as well?

Ah sorry, just realized the diff description didn't update the PR description too

@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

pianpwk added a commit that referenced this pull request Sep 25, 2024
Summary:
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.

Pull Request resolved: #136591

Differential Revision: D63358628

Pulled By: pianpwk
Summary:
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.


Differential Revision: D63358628

Pulled By: pianpwk
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63358628

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@pianpwk
Copy link
Contributor Author

pianpwk commented Sep 27, 2024

@pytorchbot merge -f "unrelated lint"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

desertfire added a commit that referenced this pull request Sep 30, 2024
Summary: sam_fast changes from timeout to fail_to_run after #136591, which "regressed" in a good way. Update the expected result file and continue investigating.

[ghstack-poisoned]
desertfire added a commit that referenced this pull request Sep 30, 2024
Summary: sam_fast changes from timeout to fail_to_run after #136591, which "regressed" in a good way. Update the expected result file and continue investigating.

ghstack-source-id: b627ce5
Pull Request resolved: #136996
pytorchmergebot pushed a commit that referenced this pull request Sep 30, 2024
Summary: sam_fast changes from timeout to fail_to_run after #136591, which "regressed" in a good way. Update the expected result file and continue investigating.

Pull Request resolved: #136996
Approved by: https://github.com/ezyang
@desertfire
Copy link
Contributor

@pytorchbot merge -f "unrelated lint"

@pianpwk , as a retrospect, the sam_fast failure was really and shouldn't been force merged.

AnantGulati pushed a commit to AnantGulati/pytorch that referenced this pull request Oct 2, 2024
Summary: sam_fast changes from timeout to fail_to_run after pytorch#136591, which "regressed" in a good way. Update the expected result file and continue investigating.

Pull Request resolved: pytorch#136996
Approved by: https://github.com/ezyang
@pianpwk
Copy link
Contributor Author

pianpwk commented Oct 2, 2024

@pytorchbot merge -f "unrelated lint"

@pianpwk , as a retrospect, the sam_fast failure was really and shouldn't been force merged.

@desertfire Ah sorry. I should've realized the test no longer timing out was related

@github-actions github-actions bot deleted the export-D63358628 branch November 3, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants