Skip to content

Conversation

@anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Feb 10, 2025

Stack from ghstack (oldest at bottom):

This PR adds support for list subclasses. Among other things are

  1. Tracking the mutations on internal vts like _dict_vt and _list_vt using sources. This helps identify if there was a mutation in the underlying data structures, and we need to reconstruct it.
  2. UserDefinedObjectVariable now has a new method - is_modified which side_effect infra relies upon to check mutations in the underlying vts (like _dict_vt).
  3. reconstruction logic ensures that we use dict.__getitem__ and list.__getitem__ methods. This is super important because we don't want to call the overridden __getitem__ methods.

If this PR is hard to review, please let me know. I can break it into several small PRs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146819

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 5 Pending, 5 Unrelated Failures

As of commit 3e734ab with merge base 6105b6f (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Feb 10, 2025
ghstack-source-id: a0de941
Pull Request resolved: #146819

tuple_new = tuple.__new__
tuple_methods = {method for method in tuple.__dict__.values() if callable(method)}
list_methods = {method for method in list.__dict__.values() if callable(method)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, probably better for another PR, but these global sets are good candidates to become frozen sets

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Feb 11, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Feb 11, 2025
ghstack-source-id: 5d63e95
Pull Request resolved: #146819
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 changed the title [dynamo][lists] Support list subclasses [dynamo] Support list subclasses and fix dict subclasses mutation bugs Feb 11, 2025
…utation bugs"


This PR adds support for list subclasses. Among other things are

1) Tracking the mutations on internal vts like `_dict_vt` and `_list_vt` using sources. This helps identify if there was a mutation in the underlying data structures, and we need to reconstruct it.
2) `UserDefinedObjectVariable` now has a new method - `is_modified` which `side_effect` infra relies upon to check mutations in the underlying vts (like `_dict_vt`).
3) `reconstruction` logic ensures that we use `dict.__getitem__` and `list.__getitem__` methods. This is super important because we don't want to call the overridden `__getitem__` methods.

If this PR is hard to review, please let me know. I can break it into several small PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305
Copy link
Contributor Author

CI failures are not relevant to this PR

Copy link
Contributor

@StrongerXi StrongerXi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Dropped some questions.

x = x * sd.attr
sd.attr = 10
x = x * sd.attr
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the removed portion cause a failure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it uncovered cases where there is mutation on the dict but no attribute mutation. Earlier, this was missed. We have other test that test for attr mutation, so we are covered in general.

d["baz"] = 4
return x * d["foo"] * d["bar"]

fn(torch.randn(4), d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we call this twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintentional, I was debugging an issue. Let me remove it.

Comment on lines 2411 to 2420
return next(itertools.islice(iter(d), n, n + 1))
# Call dict(d) to prevent calling overridden __iter__
dict_class = dict
if isinstance(d, OrderedDict):
dict_class = OrderedDict
return next(itertools.islice(iter(dict_class(d)), n, n + 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qqs:

  1. do we have a test that forced this change?
  2. can we use dict/OrderedDict.keys(d) (to avoid reconstructing the entire object via dict_class(d)), given that's how we accessed the keys in builder?

    pytorch/torch/_dynamo/utils.py

    Lines 2367 to 2375 in 6f15a60

    def get_items_from_dict(obj):
    # Get items without calling the user defined __getitem__ or keys method.
    assert isinstance(obj, dict)
    if istype(obj, (dict, OrderedDict)):
    return obj.items()
    elif isinstance(obj, OrderedDict):
    return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
    else:
    return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]

Copy link
Contributor Author

@anijain2305 anijain2305 Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good idea. Let me try it out.

Yes, there was a failure in one of the tests in test_dicts.py.

Comment on lines 1482 to 1484
return side_effects.is_attribute_mutation(self) or side_effects.is_modified(
self._dict_vt
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think side_effects.is_attribute_mutation(self) would always return true for UserDefinedListVariable etc.?
Doesn't that cause redundant codegen calls if the underlying user-defined list was never modified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. yes, this is wrong. I wanted to use something like

self in side_effects.store_attr_mutations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; in that case the better abstraction to use might be just side_effects.is_modified(self).

dict_class.__setitem__(dict_to, k, v)


def _manual_list_setitem(list_from, list_to):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't setitem, it is more of list_replace.

]
)

list_update_insts = bytecode_from_template(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work if the user does:

list = None

in global scope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is an existing problem at other places as well, like dict = None would also cause similar issues. I will merge the PR and send another after auditing other places.

…utation bugs"


This PR adds support for list subclasses. Among other things are

1) Tracking the mutations on internal vts like `_dict_vt` and `_list_vt` using sources. This helps identify if there was a mutation in the underlying data structures, and we need to reconstruct it.
2) `UserDefinedObjectVariable` now has a new method - `is_modified` which `side_effect` infra relies upon to check mutations in the underlying vts (like `_dict_vt`).
3) `reconstruction` logic ensures that we use `dict.__getitem__` and `list.__getitem__` methods. This is super important because we don't want to call the overridden `__getitem__` methods.

If this PR is hard to review, please let me know. I can break it into several small PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Feb 12, 2025
ghstack-source-id: 1228ee6
Pull Request resolved: #146819
@anijain2305
Copy link
Contributor Author

@pytorcchnot merge

@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / unit-test / cuda12.4-py3.12-gcc9-sm86 / test (inductor, 1, 2, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@anijain2305
Copy link
Contributor Author

@pytorchbot merge -f "flaky CI"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Feb 12, 2025
Summary:
This PR adds support for list subclasses. Among other things are

1) Tracking the mutations on internal vts like `_dict_vt` and `_list_vt` using sources. This helps identify if there was a mutation in the underlying data structures, and we need to reconstruct it.
2) `UserDefinedObjectVariable` now has a new method - `is_modified` which `side_effect` infra relies upon to check mutations in the underlying vts (like `_dict_vt`).
3) `reconstruction` logic ensures that we use `dict.__getitem__` and `list.__getitem__` methods. This is super important because we don't want to call the overridden `__getitem__` methods.

If this PR is hard to review, please let me know. I can break it into several small PRs.

X-link: pytorch/pytorch#146819
Approved by: https://github.com/StrongerXi, https://github.com/jansel

Reviewed By: huydhn

Differential Revision: D69537369

fbshipit-source-id: 9c20f4ee84c91639c320a3a04a1a153859623ab6
desai0007 pushed a commit to desai0007/test-repo-pytorch that referenced this pull request Feb 26, 2025
@github-actions github-actions bot deleted the gh/anijain2305/677/head branch March 23, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants