Skip to content

Conversation

@shaoyuyoung
Copy link
Contributor

@shaoyuyoung shaoyuyoung commented Jan 7, 2025

Fixes #144310

We just need to add a check in lowering

updated: we add the error checking in meta registration

UT

 pytest -s -v test/inductor/test_torchinductor.py -k test_avg_pool_errors_with_uint

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144313

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 467c37f with merge base 69b883d (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@shaoyuyoung
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 7, 2025
@cpuhrsch cpuhrsch requested review from jgong5 and yanboliang January 7, 2025 06:39
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 7, 2025
model(x)

def test_avg_pool_errors_with_uint(self):
torch._dynamo.config.recompile_limit = 12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my local debugging, I found that the default is recompile_limit=8. But in this case, we need to compile the model 12 times (3 different dims * 4 different uint).
Not sure how the UT in test_torchinductor.py set the recompile_limit, but if I understand correctly, the different UTs should be independent of each other and not affect each other, so I added this config here.
Feel free to correct me if I am wrong

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @torch._dynamo.config.patch(recompile_limit=12)

assert len(stride) == dim
assert len(padding) == dim
assert len(x.get_size()) in (dim + 1, dim + 2)
if x.get_dtype() in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shaoyuyoung, the general contract is that the operator meta function should handle error checking. can we move this to

@register_meta(aten.avg_pool2d.default)
def meta_avg_pool2d(
? not sure where the avg_pool1d registration is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great, I'm currently looking for avg_pool1d registration.
@jansel , any comment? Because this PR follows the previous PR #143762

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this check in the meta function is a better fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for reviewing, I will do this work in my spare time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? not sure where the avg_pool1d registration is

It seems that avg_pool1d shares the same registration with 2d

model(x)

def test_avg_pool_errors_with_uint(self):
torch._dynamo.config.recompile_limit = 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @torch._dynamo.config.patch(recompile_limit=12)

assert len(stride) == dim
assert len(padding) == dim
assert len(x.get_size()) in (dim + 1, dim + 2)
if x.get_dtype() in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this check in the meta function is a better fix.

@shaoyuyoung
Copy link
Contributor Author

have updated, mind helping me review again? :)

@jansel
Copy link
Contributor

jansel commented Jan 15, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 15, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@shaoyuyoung
Copy link
Contributor Author

shaoyuyoung commented Jan 16, 2025

CI seems broken just now (?)
merge failed

@eellison
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shaoyuyoung shaoyuyoung deleted the fx_avg_pool_uint8 branch January 17, 2025 02:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[inductor] [dtype propogation] avg_pool1d,2d,3d pass the check when handling uint8,16,32,64 while eager throws the error

7 participants