Skip to content

Conversation

@pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Sep 9, 2024

Previously we were accomodating torch._dynamo.mark_dynamic() for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for dynamic_shapes.

Note: there's 4 decorators relevant to export, mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked. User calls that involve export have only been mark_dynamic(), and we use maybe_mark_dynamic under the hood for Dim.AUTO, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135536

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4e0171d with merge base 525bec8 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pianpwk pianpwk changed the title dynamic shapes ignore mark dynamic and see what fails Sep 9, 2024
@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Sep 10, 2024
Summary:
Fixes #ISSUE_NUMBER


Differential Revision: D62404528

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/export_ignore_mark_dynamic branch from 9b747ca to 1ea23d1 Compare September 10, 2024 00:19
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62404528

facebook-github-bot pushed a commit that referenced this pull request Sep 10, 2024
Summary:
Fixes #ISSUE_NUMBER


Differential Revision: D62404528

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/export_ignore_mark_dynamic branch from 1ea23d1 to e88a7a9 Compare September 10, 2024 01:26
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62404528

@pianpwk pianpwk changed the title ignore mark dynamic and see what fails [export] ignore mark_dynamic() in export Sep 10, 2024
@pianpwk pianpwk marked this pull request as ready for review September 10, 2024 18:08
@pianpwk pianpwk requested a review from desertfire September 10, 2024 18:08
# see https://github.com/pytorch/pytorch/issues/113029
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)

def _produce_dynamic_shapes(path, x):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@desertfire tagging you for the AOTInductor changes in benchmarking

@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

# use this to produce dynamic_shapes spec instead.
if not isinstance(x, torch.Tensor):
return None
return {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You sure you don't have to handle other markers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In these tests I'm pretty sure it's only mark_dynamic that's used..

else:
_register_dataclass_output_as_pytree(example_outputs)

combined_args = tuple(example_args) + tuple(example_kwargs.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, no need to look at the model signature here?

Copy link
Contributor Author

@pianpwk pianpwk Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be better if the kwargs are out of order, but I don't think that's happening here, so it's fine to just feed the shapes spec as a tuple. Will check with tests though..

for i, val in enumerate(tensor.shape):
dim = shape.get(i, _DimHint.STATIC)
if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO:
if dim == _DimHint.AUTO:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol get rid of the dupe if

facebook-github-bot pushed a commit that referenced this pull request Sep 11, 2024
Summary:
Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec


Reviewed By: avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/export_ignore_mark_dynamic branch from 4a08a2f to e382f53 Compare September 11, 2024 02:40
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62404528

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Sep 11, 2024
Summary:
Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

X-link: pytorch/pytorch#135536

Reviewed By: avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
facebook-github-bot pushed a commit that referenced this pull request Sep 11, 2024
Summary:
X-link: pytorch/benchmark#2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec


Reviewed By: avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/export_ignore_mark_dynamic branch from e382f53 to d046d07 Compare September 11, 2024 03:04
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62404528

pianpwk added a commit to pytorch/benchmark that referenced this pull request Sep 11, 2024
Summary:
Pull Request resolved: #2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

X-link: pytorch/pytorch#135536

Reviewed By: avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
facebook-github-bot pushed a commit that referenced this pull request Sep 11, 2024
Summary:
X-link: pytorch/benchmark#2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec


Reviewed By: avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/export_ignore_mark_dynamic branch from d046d07 to 148ee2f Compare September 11, 2024 17:51
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62404528

pianpwk added a commit to pytorch/benchmark that referenced this pull request Sep 11, 2024
Summary:
Pull Request resolved: #2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

X-link: pytorch/pytorch#135536

Reviewed By: avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
pianpwk added a commit to pytorch/benchmark that referenced this pull request Sep 11, 2024
Summary:
Pull Request resolved: #2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

X-link: pytorch/pytorch#135536

Reviewed By: desertfire, avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
Summary:
X-link: pytorch/benchmark#2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec


Reviewed By: desertfire, avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/export_ignore_mark_dynamic branch from 148ee2f to 4e0171d Compare September 12, 2024 17:54
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62404528

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Sep 12, 2024
Summary:
Pull Request resolved: #2451

Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

X-link: pytorch/pytorch#135536

Reviewed By: desertfire, avikchaudhuri

Differential Revision: D62404528

Pulled By: pianpwk

fbshipit-source-id: 1da81cbc71c337b40cc336dfdb93292210ab66aa
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

Pull Request resolved: pytorch#135536
Approved by: https://github.com/avikchaudhuri
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`.

Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs.

Pull Request resolved: pytorch#135536
Approved by: https://github.com/avikchaudhuri
facebook-github-bot pushed a commit that referenced this pull request Sep 25, 2024
Summary:
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.


Differential Revision: D63358628

Pulled By: pianpwk
facebook-github-bot pushed a commit that referenced this pull request Sep 25, 2024
Summary:
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.


Differential Revision: D63358628

Pulled By: pianpwk
pianpwk added a commit that referenced this pull request Sep 25, 2024
Summary:
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.

Pull Request resolved: #136591

Differential Revision: D63358628

Pulled By: pianpwk
facebook-github-bot pushed a commit that referenced this pull request Sep 26, 2024
Summary:
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.


Differential Revision: D63358628

Pulled By: pianpwk
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2024
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in #133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following #135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. https://github.com/pytorch/pytorch/blob/7c6d543a5b3d65d8f49420e60cda150faaa5b8a0/torch/_dynamo/variables/builder.py#L2646-L2666

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.

Differential Revision: D63358628

Pull Request resolved: #136591
Approved by: https://github.com/avikchaudhuri
@github-actions github-actions bot deleted the pianpwk/export_ignore_mark_dynamic branch October 13, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants