Skip to content

Conversation

@eee4017
Copy link
Collaborator

@eee4017 eee4017 commented Sep 18, 2025

Many extensions (including pybind helpers) call Tensor.__dlpack__() without a stream argument. Before #150217, stream=None behaved like “no cross-stream sync” and was safe inside CUDA Graph capture. After #150217, stream=None maps to the legacy default stream, adding a cross-stream wait that invalidates capture when running on a non-default stream.

See this example

import torch
s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()

with torch.cuda.stream(s):
    with torch.cuda.graph(g):
        _ = x + 1
        cap = x.__dlpack__()
        _ = torch.utils.dlpack.from_dlpack(cap)

This PR partially reverts #150217 that stream=None defaults to no sync.

cc @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163242

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f21bd9e with merge base 28c42cc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eee4017 eee4017 added module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: dlpack release notes: cuda release notes category labels Sep 18, 2025
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 18, 2025
@zou3519 zou3519 requested review from eellison and ngimel September 18, 2025 14:11
@ngimel
Copy link
Collaborator

ngimel commented Sep 18, 2025

Why don't we want to keep no sync behavior if stream is None? Having different behavior during capture and outside of it is nightmare

@eee4017
Copy link
Collaborator Author

eee4017 commented Sep 18, 2025

Why don't we want to keep no sync behavior if stream is None? Having different behavior during capture and outside of it is nightmare

The problem is that we are not allowed to sync the default stream during capture (cudaErrorStreamCaptureInvalidated). Before the #150217 we don't have this behavior.

@eqy
Copy link
Collaborator

eqy commented Sep 19, 2025

CC @tqchen as we are considering changing the behavior to use current_stream() when the passed stream is None, which effectively partially reverts #150217

Note that the with the PR which makes stream == None assume default_stream() it breaks graph capture due to the sync

@ngimel
Copy link
Collaborator

ngimel commented Sep 19, 2025

The problem is that we are not allowed to sync the default stream during capture (cudaErrorStreamCaptureInvalidated). Before the #150217 we don't have this behavior.

Exactly, and I'm questioning the changes introduced by #150217, as they effectively make capture impossible, and I'm strongly against introducing different behavior under capture and without capture.

@tqchen
Copy link
Contributor

tqchen commented Sep 19, 2025

Note that there is stream = -1 which indicate no sync, however, seems most impl still defaults to None (which is specified as sync to default stream)

On a related note, the new speed convert RFC #162845 adds a stream query, which will make the behavior more streamlined (and compatible with cudagraph)

Of course we are talking about default behavior of dlpack, the original reasoning is that None is a safe choice (without considering cudagraph), and the dlpack interface would require explicit sync to the default stream. Considering cudagrpah, i think indeed in that mode default to current seems more sensible, but the no sync default behavior in stream mode might indeed cause some issues as some consumer may expects sync of the stream passed.

cc @leofang

@eee4017 eee4017 added the topic: not user facing topic category label Sep 19, 2025
@ngimel
Copy link
Collaborator

ngimel commented Sep 19, 2025

Please fix PR description to match what PR does

@tqchen
Copy link
Contributor

tqchen commented Sep 20, 2025

Just want to send a followup note, it might worthwhile bring awareness and discuss DLPack spec, ideally the spec should get updated and reflecting the needs.

Note that the same behavior can be achieved by passing in stream = -1 per spec and we should implement that. Maybe we can start by first supporting stream = -1 behavior, and keep the old behavior for a bit. Hopefully that supports needs and package who want to have stream preserving can intentionally pass in -1 and can unblock most needs. For packages that intends to do from_dlpack as stream = -1, it is posisble to pass stream = -1.

Personally I am OK with either, and stream preserving call is an important path. If I had to pick initially, I personally might even prefer the preserving one. Just want to make sure some consistency with data-api spec, we don't have surprises when people calling __dlpack__ here.

@ezyang
Copy link
Contributor

ezyang commented Sep 21, 2025

@ysiraichi are you able to weigh in here?

@tqchen
Copy link
Contributor

tqchen commented Sep 21, 2025

Just to elaborate a bit further on the choices here for clarity

Canonical way of using dlpack is to call mylib.from_dlpack where the mylib decides on the stream parameter to pass in.

s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()

with torch.cuda.stream(s):
    with torch.cuda.graph(g):
        _ = x + 1
        mylib_tensor = mylib.from_dlpack(x)
        mylib_kernel(mylib_tensor)

We have two choices for mylib to implement cudagraph support

C0: simply call torch.__dlpack__() without passing in stream = None, and stream=None defaults to no sync in torch

This would requires an update in torch(this PR).

  • Pros: many of intended behavior of mylib.from_dlpack are indeed stream preserving, so mylib do not need update to support cuda graph.
  • Cons: divergence to the data-api spec, which some frameworks myframework may rely on, specifically those who may not expect stream preserving behavior (the usecase where another framework do not expect stream context and use their own internal stream) and would like to "play it safe" (although not as efficient).

Of course in C0(Cons) case mylib can also update their call to __dlpack__() to pass in stream, indicating legacy default. But the default behavior(no stream passing) will results in potential issue when calling from_dlpack under a torch stream context.

C1: simply call torch.__dlpack__() with stream = -1

This would requires updates to mylib in their implementation of from_dlpack. By explicitly passing in stream=-1.

  • Cons: mylib will needs an explicit update on their call to __dlpack__ to pass in stream=-1, although the original intention of many mylib are stream preserving.
    • In this case, we should also cross check torch support stream = -1 correctly
  • Pros: for other myframework who have a different stream handling mechanism, their behavior under torch stream context won't be broken, and for now it aligns with the data-api spec, although one could argue that the importance of stream passing would need an update of the spec.

Discussions and Possible Action Items

Note that regardless of C0 or C1, assuming mylib or myframework is aware of their intention in from_dlpack and pass in stream parameter accordingly, we will get cudagraph compact for cases of myliband correctness for the case of myframework. These updates are invisible to the user, and they can continue to use mylib.from_dlpack(torch_tensor).

The only difference is course of action and whether the legacy default option should be preserved as conservative, and whether we would need update in mylib.from_dlpack or myframework.from_dlpack. The same thing applies to __dlpack_versioned__ dunder in the latest post 1.0 version, which will be used for exchange onwards.

One action item that i think falls from the discussion is that data-api spec should be updated to discuss the stream preserving behavior and provide guidelines for mylib usecases to pass in stream correctly when stream preserving is the intention

@eee4017
Copy link
Collaborator Author

eee4017 commented Sep 22, 2025

How about we change Tensor.__dlpack__’s default stream to -1 (stream-preserving, no sync) so implicit calls are CUDA Graph–friendly and avoid default-stream traffic; this matches what most extensions intend, requires no library changes unless they explicitly want default-stream sync (which they can pass), and keeps behavior unchanged for the large majority of users. While it does slightly diverge from the data-apis guidance, we can document the deviation and note that anyone who needs default-stream synchronization should pass it explicitly.

@tqchen
Copy link
Contributor

tqchen commented Sep 22, 2025

opened data-apis/array-api#974 for discussion in data-api

@ysiraichi
Copy link
Collaborator

The motivation behind PR #150217 was to make PyTorch DLPack compliant to DLPack 1.0 spec, specifically the behavior on different values for stream parameter. It currently does support stream=-1, so that could be used as an work around (as mentioned by @tqchen).

@ngimel
Copy link
Collaborator

ngimel commented Sep 22, 2025

Flipping default to -1 sounds good to me

@ysiraichi
Copy link
Collaborator

Since this makes it so __dlpack__ is inconsistent with the data-api spec, I think it would be nice to call it out in the documentation.

@eee4017 eee4017 added ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm labels Sep 24, 2025
@eee4017
Copy link
Collaborator Author

eee4017 commented Sep 24, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jainapurva pushed a commit that referenced this pull request Sep 29, 2025
…capture (#163242)

Many extensions (including pybind helpers) call `Tensor.__dlpack__()` without a stream argument. Before #150217, `stream=None` behaved like “no cross-stream sync” and was safe inside CUDA Graph capture. After #150217, `stream=None` maps to the legacy default stream, adding a cross-stream wait that invalidates capture when running on a non-default stream.

See this example

```
import torch
s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()

with torch.cuda.stream(s):
    with torch.cuda.graph(g):
        _ = x + 1
        cap = x.__dlpack__()
        _ = torch.utils.dlpack.from_dlpack(cap)
```

This PR partially reverts #150217 that stream=None defaults to no sync.

Pull Request resolved: #163242
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: dlpack open source release notes: cuda release notes category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants