Skip to content

Conversation

@guangyey
Copy link
Collaborator

@guangyey guangyey commented Sep 10, 2024

Stack from ghstack (oldest at bottom):

Motivation

fix #135550
In PyTorch, tensor.data_ptr() is reinterpreted by a signed int64 data type, which could result in an overflow issue, like below:

import torch
a = torch.randn(2).to('xpu')
a.data_ptr()
# one possible output is
-23453392437248
# this is inconsistent with storage.data_ptr()
a.untyped_storage().data_ptr()
# one possible output is
18446720620317114368

This PR aims to fix this representation overflow issue to make tensor.data_ptr() consistent with tensor.untyped_storage().data_ptr(). With this PR, the output will become:

import torch
a = torch.randn(2).to('xpu')
a.data_ptr()
# one possible output is
18446720620317114368
# this is consistent with storage.data_ptr()
a.untyped_storage().data_ptr()
# one possible output is
18446720620317114368

Solution

Use PyLong_FromVoidPtr to prevent the overflow issue and fit the semantic of wrap.

Additional Context

This PR has been reverted (in place, no more change, and revert commit 2e8d431) due to the change of tensor.data_ptr(), which needs to sync up to intel xpu triton side, see #2192. So we have to update xpu triton commit pin with this PR together.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135567

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Unrelated Failures

As of commit 4ae47a0 with merge base 9b89fa4 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@guangyey guangyey changed the title fix tensor data_ptr overflow Fix tensor data_ptr overflow Sep 10, 2024
@guangyey guangyey added topic: bug fixes topic category topic: not user facing topic category labels Sep 10, 2024
@guangyey guangyey added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 10, 2024
@guangyey guangyey changed the title Fix tensor data_ptr overflow Fix tensor.data_ptr() overflow Sep 10, 2024
@guangyey guangyey changed the title Fix tensor.data_ptr() overflow Fix tensor.data_ptr() representation overflow Sep 10, 2024
guangyey added a commit that referenced this pull request Sep 10, 2024
ghstack-source-id: 5d776a5
Pull Request resolved: #135567

static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
// PyLong_FromVoidPtr should not need to mutate the pointer in order
Copy link
Collaborator Author

@guangyey guangyey Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

@dvrogozh dvrogozh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch works for me and fixes #135550. I also checked that this does not regress for me running the same test on cuda.

Copy link
Contributor

@dvrogozh dvrogozh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually wait. This patch regresses the following torch.compile related case in HF (note - it still passes for CUDA). I am on intel/intel-xpu-backend-for-triton@1111a28 (latest as of today) for xpu triton.

$ python3 -m pytest --pspec tests/test_utils.py::UtilsTester::test_dynamo
========================================================================================= test session starts =========================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/dvrogozh/git/huggingface/accelerate
plugins: pspec-0.0.4, dash-2.17.1, timeout-2.3.1, subtests-0.13.1, hypothesis-6.108.4, xdist-3.6.1, rich-0.1.1, typeguard-4.3.0
collected 1 item

tests/test_utils.py                                                                                                                                                                                    
Utils Tester
 ✗ dynamo
                                                                                                                                                                                                [100%]

============================================================================================== FAILURES ===============================================================================================
_______________________________________________________________________________________ UtilsTester.test_dynamo _______________________________________________________________________________________

self = <test_utils.UtilsTester testMethod=test_dynamo>

    @require_triton
    @require_non_cpu
    @require_torch_min_version(version="2.0")
    def test_dynamo(self):
        model = RegressionModel()
        model._original_forward = model.forward
        model.forward = torch.autocast(device_type=torch_device, dtype=torch.float16)(model.forward)
        model.forward = convert_outputs_to_fp32(model.forward)
        model.forward = torch.compile(model.forward, backend="inductor")
        inputs = torch.randn(4, 10).to(torch_device)
>       _ = model(inputs)

tests/test_utils.py:203:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/_dynamo/eval_frame.py:465: in _fn
    return fn(*args, **kwargs)
src/accelerate/utils/operations.py:820: in forward
    return model_forward(*args, **kwargs)
src/accelerate/utils/operations.py:808: in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
../../pytorch/pytorch/torch/amp/autocast_mode.py:44: in decorate_autocast
    return func(*args, **kwargs)
src/accelerate/test_utils/training.py:59: in forward
    print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
src/accelerate/test_utils/training.py:59: in torch_dynamo_resume_in_forward_at_59
    print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
../../pytorch/pytorch/torch/_dynamo/eval_frame.py:632: in _fn
    return fn(*args, **kwargs)
../../pytorch/pytorch/torch/_functorch/aot_autograd.py:1100: in forward
    return compiled_fn(full_args)
../../pytorch/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:308: in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
../../pytorch/pytorch/torch/_functorch/_aot_autograd/utils.py:124: in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
../../pytorch/pytorch/torch/_functorch/_aot_autograd/utils.py:98: in g
    return f(*args)
../../pytorch/pytorch/torch/autograd/function.py:575: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
../../pytorch/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:1525: in forward
    fw_outs = call_func_at_runtime_with_args(
../../pytorch/pytorch/torch/_functorch/_aot_autograd/utils.py:124: in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
../../pytorch/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:488: in wrapper
    return compiled_fn(runtime_args)
../../pytorch/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:667: in inner_fn
    outs = compiled_fn(args)
../../pytorch/pytorch/torch/_inductor/codecache.py:1478: in __call__
    return self.current_callable(inputs)
../../pytorch/pytorch/torch/_inductor/utils.py:1977: in run
    return model(new_inputs)
/tmp/torchinductor_dvrogozh/ay/caym5gxlqa26bbtaw5s2fw4cjejcdgrzilczkobnzwl5vfmctfzz.py:87: in call
    triton_poi_fused_add_mul_0.run(primals_2, primals_1.item(), primals_3.item(), buf0, 40, grid=grid(40), stream=stream0)
../../pytorch/pytorch/torch/_inductor/runtime/triton_heuristics.py:879: in run
    return launcher(
<string>:13: in launcher
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <intel.XPULauncher object at 0x7f5428db82e0>
args = (1, 1, 1, 94355215467936, <capsule object "kernel" at 0x7f5428db86f0>, KernelMetadata(allow_fp8e4b15=True, allow_fp8e4...39607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, warp_size=32), threads_per_warp=32), ...)
kwargs = {}

    def __call__(self, *args, **kwargs):
>       self.launch(*args, **kwargs)
E       ValueError: Pointer argument (at 0) doesn't reference XPU device memory (cpu tensor?)

../../intel-xpu-backend-for-triton/python/triton/backends/intel/driver.py:391: ValueError
---------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------
Model dtype: torch.float32, torch.float32. Input dtype: torch.float32
======================================================================================= short test summary info =======================================================================================
FAILED tests/test_utils.py::UtilsTester::test_dynamo - ValueError: Pointer argument (at 0) doesn't reference XPU device memory (cpu tensor?)
========================================================================================== 1 failed in 3.99s =========================================================================

cc: @vlad-penkin, I filed intel/intel-xpu-backend-for-triton#2188 on your side as we discussed offline.

@guangyey
Copy link
Collaborator Author

guangyey commented Sep 11, 2024

@dvrogozh I think the failure is another issue that Intel triton needs to change the API that accesses data_ptr. This PR only aims to fix the overflow issue.

@dvrogozh
Copy link
Contributor

@dvrogozh I think the failure is another issue. This PR only aims to fix the overflow issue.

@guangyey : Yes, you are right. The failure is due to wrong casting of device memory pointer in Triton. I believe I have fixed it in intel/intel-xpu-backend-for-triton#2192.

@guangyey : I will still preserve my -1 on your PR for the reason that the same new test I provided unveils one more issue in pytorch XPU. Specifically, if built in DEBUG=1 this assert fails:

void decrease(size_t amount) {
current -= static_cast<int64_t>(amount);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
current >= 0,
"Negative tracked stat in device allocator (likely logic error).");

Can you, please, help fix it? To reproduce run:

python3 -m pytest --pspec tests/test_utils.py::UtilsTester::test_dynamo

@guangyey
Copy link
Collaborator Author

@dvrogozh IMO we shouldn't leave the message unrelated to this PR to prevent distracting the PyTorch code reviewer.
Could you please file a new issue about the assertion failure? Thanks very much.

[ghstack-poisoned]
@EikanWang
Copy link
Collaborator

@guangyey , have you rebased the PR?

@guangyey
Copy link
Collaborator Author

@guangyey , have you rebased the PR?

Done.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@malfet
Copy link
Contributor

malfet commented Nov 19, 2024

If this is a reland, can you please reference original PR that were reverted in the description, and what have changed since the revert

@guangyey
Copy link
Collaborator Author

Thanks. This PR is reverted in place and the revert commit is 2e8d431. This PR has no more code change than the original commit. The revert reason is due to the change of tensor.data_ptr(), which needs to sync up to intel xpu triton side, see #2192. So we have to update xpu triton commit pin with this PR together.
This PR doesn't impact CUDA side as CUDA device data_ptr doesn't overflow.
I updated it in Additional Context as well.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the ci-no-td Do not run TD on this PR label Nov 26, 2024
@pytorchmergebot
Copy link
Collaborator

Rebased gh/guangyey/76/orig onto refs/remotes/origin/main because #137886 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/135567)

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Rebased gh/guangyey/76/orig onto refs/remotes/origin/viable/strict because #137886 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/135567)

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Rebased gh/guangyey/76/orig onto refs/remotes/origin/viable/strict because #137886 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/135567)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 28, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
# Motivation
fix pytorch#135550
In PyTorch, [`tensor.data_ptr()`](https://github.com/pytorch/pytorch/blob/e889252493558a56263618faae9a9ef421c2a47d/tools/autograd/templates/python_variable_methods.cpp#L204) is reinterpreted by a [signed int64](https://github.com/pytorch/pytorch/blob/e889252493558a56263618faae9a9ef421c2a47d/torch/csrc/autograd/utils/wrap_outputs.h#L50) data type, which could result in an **overflow issue**, like below:
```python
import torch
a = torch.randn(2).to('xpu')
a.data_ptr()
# one possible output is
-23453392437248
# this is inconsistent with storage.data_ptr()
a.untyped_storage().data_ptr()
# one possible output is
18446720620317114368
```
This PR aims to fix this representation overflow issue to make `tensor.data_ptr()` consistent with [`tensor.untyped_storage().data_ptr()`](https://github.com/pytorch/pytorch/blob/c0d2f991b14d50f8081d788d4a3dc6584ee15502/torch/csrc/StorageMethods.cpp#L62). With this PR, the output will become:
```python
import torch
a = torch.randn(2).to('xpu')
a.data_ptr()
# one possible output is
18446720620317114368
# this is consistent with storage.data_ptr()
a.untyped_storage().data_ptr()
# one possible output is
18446720620317114368
```

# Solution
Use `PyLong_FromVoidPtr` to prevent the overflow issue and fit the semantic of `wrap`.

# Additional Context
This PR has been reverted (in place, no more change, and revert commit pytorch@2e8d431) due to the change of `tensor.data_ptr()`, which needs to sync up to intel xpu triton side, see [pytorch#2192](intel/intel-xpu-backend-for-triton#2192). So we have to update xpu triton commit pin with this PR together.
Pull Request resolved: pytorch#135567
Approved by: https://github.com/dvrogozh, https://github.com/EikanWang, https://github.com/albanD
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
@github-actions github-actions bot deleted the gh/guangyey/68/head branch December 30, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks Merged open source Reverted topic: bug fixes topic category topic: not user facing topic category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

9 participants