Skip to content

Conversation

@haifeng-jin
Copy link
Collaborator

@haifeng-jin haifeng-jin commented Oct 23, 2024

Summary

Support torch.compile() to trace through numpy.random ops when used with NumPy 2.

Details:

Fixes #136559
As we upgrade to NumPy 2, torch falsely filtered out numpy.random as unsupported in dynamo tracking.
This PR changes the filtering rules to include them while keeping behavior with numpy 1 unchanged.

Before this PR, the following tests failed:

PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_functions.py -k FunctionTests.test_numpy_random
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_unspec.py -k UnspecTests.test_to_tensor
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k FakeTensorTest.test_export_numpy
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k PropagateRealTensorsFakeTensorTest.test_export_numpy_propagate_real_tensors

With this PR, the supported/unsupported ops in NumPy 1 are not changed.
For NumPy 2, only the numpy.random ops that are already supported with NumPy 1 are added to the supported list.

I used the following scripts to check the differences before and after the change for both NumPy 1 & 2.
The output is empty for NumPy 1 since there is no change.
The output is a list of numpy.random that considered supported for NumPy 2.

from torch._dynamo import trace_rules
import numpy as np


def new_numpy_function_ids():
    unsupported_funcs = {"seed", "ranf", "get_bit_generator", "RandomState", "set_bit_generator", "sample"}

    def is_supported(k, v, mod):
        if not callable(v):
            return False
        if not getattr(v, "__module__", None):
            return True
        if v.__module__ == mod.__name__:
            return True
        if v.__module__ == "numpy.random.mtrand" and mod.__name__== "numpy.random" and k not in unsupported_funcs:
            return True
        return False
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        for k, v in mod.__dict__.items():
            if is_supported(k, v, mod):
                rv[id(v)] = f"{mod.__name__}.{k}"
    return rv

def old_numpy_function_ids():
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        rv.update(
            {
                id(v): f"{mod.__name__}.{k}"
                for k, v in mod.__dict__.items()
                if callable(v)
                and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
            }
        )
    return rv

rv1 = set(old_numpy_function_ids().values())
rv2 = set(new_numpy_function_ids().values())

for v in (rv1 - rv2):
    print(v)
print("****")
for v in (rv2 - rv1):
    print(v)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec @kiukchung @lezcano

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138686

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit abe1b9f with merge base 73fde0d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@colesbury colesbury requested a review from zou3519 October 23, 2024 20:09
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 23, 2024
@haifeng-jin
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 28, 2024
@zou3519
Copy link
Contributor

zou3519 commented Oct 29, 2024

I'm not the right person to review this, @anijain2305 @williamwen42 do one of you want to take this?

@zou3519 zou3519 removed their request for review October 29, 2024 20:54
@haifeng-jin
Copy link
Collaborator Author

Can we also request a review from @lezcano?
Thanks!

@williamwen42 williamwen42 requested a review from lezcano October 29, 2024 22:53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might not need this set - if we actually encounter an unsupported numpy function, I believe we would just graph break when attempting to trace the call in dynamo. You can try removing this set and seeing if the tests pass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have update the PR.
I did some quick local tests. They passed.
Thanks!

Copy link
Collaborator Author

@haifeng-jin haifeng-jin Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the list and allow them to graph break, it somehow created more graph breaks for the Hugging face BigBird model and failed the inductor tests.

So I added back the list and excluded those ops from being tracked. The tests then passed locally.

lezcano
lezcano previously approved these changes Oct 30, 2024
williamwen42
williamwen42 previously approved these changes Oct 30, 2024
@williamwen42
Copy link
Member

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased dynamo onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dynamo && git pull --rebase)

@williamwen42
Copy link
Member

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot dismissed stale reviews from lezcano and williamwen42 November 1, 2024 23:34

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@huydhn
Copy link
Contributor

huydhn commented Nov 1, 2024

@haifeng-jin
Copy link
Collaborator Author

haifeng-jin commented Nov 3, 2024

@huydhn Thank you!
Would you mind share how should I reproduce the errors that caused by the PR?
like the command for running the failed tests.

I will debug my changes and make sure they pass the tests before it is merged.

@huydhn
Copy link
Contributor

huydhn commented Nov 3, 2024

Here is an example failed job GH job link HUD commit link

I have added ciflow/inductor in your PR, so you can see they are failing on the PR too #138686 (comment)

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Fixes pytorch#136559
As we upgrade to NumPy 2, torch falsely filtered out `numpy.random` as unsupported in dynamo tracking.
This PR changes the filtering rules to include them while keeping behavior with numpy 1 unchanged.

Before this PR, the following tests failed:

```
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_functions.py -k FunctionTests.test_numpy_random
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_unspec.py -k UnspecTests.test_to_tensor
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k FakeTensorTest.test_export_numpy
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k PropagateRealTensorsFakeTensorTest.test_export_numpy_propagate_real_tensors
```

With this PR, the supported/unsupported ops in NumPy 1 are not changed.
For NumPy 2, only the `numpy.random` ops that are already supported with NumPy 1 are added to the supported list.

I used the following scripts to check the differences before and after the change for both NumPy 1 & 2.
The output is empty for NumPy 1 since there is no change.
The output is a list of `numpy.random` that considered supported for NumPy 2.

```py
from torch._dynamo import trace_rules
import numpy as np

def new_numpy_function_ids():
    unsupported_funcs = {"seed", "ranf", "get_bit_generator", "RandomState", "set_bit_generator", "sample"}

    def is_supported(k, v, mod):
        if not callable(v):
            return False
        if not getattr(v, "__module__", None):
            return True
        if v.__module__ == mod.__name__:
            return True
        if v.__module__ == "numpy.random.mtrand" and mod.__name__== "numpy.random" and k not in unsupported_funcs:
            return True
        return False
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        for k, v in mod.__dict__.items():
            if is_supported(k, v, mod):
                rv[id(v)] = f"{mod.__name__}.{k}"
    return rv

def old_numpy_function_ids():
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        rv.update(
            {
                id(v): f"{mod.__name__}.{k}"
                for k, v in mod.__dict__.items()
                if callable(v)
                and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
            }
        )
    return rv

rv1 = set(old_numpy_function_ids().values())
rv2 = set(new_numpy_function_ids().values())

for v in (rv1 - rv2):
    print(v)
print("****")
for v in (rv2 - rv1):
    print(v)
```

Pull Request resolved: pytorch#138686
Approved by: https://github.com/lezcano, https://github.com/williamwen42
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
This reverts commit 124eac2.

Reverted pytorch#138686 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but I am seeing inductor failure with hf_BigBird number of graph breaks after it lands ([comment](pytorch#138686 (comment)))
This reverts commit 3b87731350752f2c3c16543fa57ff51fe22a54b5.
@haifeng-jin
Copy link
Collaborator Author

I have locally reproduced the Hugging face BigBird graph breaks and fixed it.
See more about the cause and fix in this thread.

More details on reproducing it locally

I cloned the torch benchmark repo and followed the installation instructions.
Added it to the Python path and run the following command in the PyTorch root repo.

python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cpu --inference --float32 --total-partitions 2 --partition-id 0 --output inference_torchbench.csv -k hf_BigBird

@haifeng-jin
Copy link
Collaborator Author

@williamwen42 @huydhn Would you please help review and merge the PR? Thanks!

@williamwen42
Copy link
Member

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@haifeng-jin
Copy link
Collaborator Author

@pytorchbot label "topic: bug fixes"

@pytorch-bot pytorch-bot bot added the topic: bug fixes topic category label Nov 14, 2024
@atalman atalman removed the topic: not user facing topic category label Nov 15, 2024
@haifeng-jin
Copy link
Collaborator Author

@pytorchbot label "release notes: dynamo"

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#136559
As we upgrade to NumPy 2, torch falsely filtered out `numpy.random` as unsupported in dynamo tracking.
This PR changes the filtering rules to include them while keeping behavior with numpy 1 unchanged.

Before this PR, the following tests failed:

```
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_functions.py -k FunctionTests.test_numpy_random
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_unspec.py -k UnspecTests.test_to_tensor
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k FakeTensorTest.test_export_numpy
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k PropagateRealTensorsFakeTensorTest.test_export_numpy_propagate_real_tensors
```

With this PR, the supported/unsupported ops in NumPy 1 are not changed.
For NumPy 2, only the `numpy.random` ops that are already supported with NumPy 1 are added to the supported list.

I used the following scripts to check the differences before and after the change for both NumPy 1 & 2.
The output is empty for NumPy 1 since there is no change.
The output is a list of `numpy.random` that considered supported for NumPy 2.

```py
from torch._dynamo import trace_rules
import numpy as np

def new_numpy_function_ids():
    unsupported_funcs = {"seed", "ranf", "get_bit_generator", "RandomState", "set_bit_generator", "sample"}

    def is_supported(k, v, mod):
        if not callable(v):
            return False
        if not getattr(v, "__module__", None):
            return True
        if v.__module__ == mod.__name__:
            return True
        if v.__module__ == "numpy.random.mtrand" and mod.__name__== "numpy.random" and k not in unsupported_funcs:
            return True
        return False
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        for k, v in mod.__dict__.items():
            if is_supported(k, v, mod):
                rv[id(v)] = f"{mod.__name__}.{k}"
    return rv

def old_numpy_function_ids():
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        rv.update(
            {
                id(v): f"{mod.__name__}.{k}"
                for k, v in mod.__dict__.items()
                if callable(v)
                and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
            }
        )
    return rv

rv1 = set(old_numpy_function_ids().values())
rv2 = set(new_numpy_function_ids().values())

for v in (rv1 - rv2):
    print(v)
print("****")
for v in (rv2 - rv1):
    print(v)
```

Pull Request resolved: pytorch#138686
Approved by: https://github.com/williamwen42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: dynamo Reverted topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile errors when tracing numpy.random.uniform with numpy2

9 participants