Skip to content

Conversation

@madsbk
Copy link
Contributor

@madsbk madsbk commented May 16, 2019

This PR implements auto-conversion of GPU arrays that support the __cuda_array_interface__ protocol (fixes #15601).

If an object exposes the __cuda_array_interface__ attribute, touch.as_tensor() and touch.tensor() will use the exposed device memory.

Zero-copy

When using touch.as_tensor(...,device=D) where D is the same device as the one used in __cuda_array_interface__.

Implicit copy

When using touch.as_tensor(...,device=D) where D is the CPU or another non-CUDA device.

Explicit copy

When using torch.tensor().

Exception

When using touch.as_tensor(...,device=D) where D is a CUDA device not used in __cuda_array_interface__.

Lifetime

torch.as_tensor(obj) tensor grabs a reference to obj so that the lifetime of obj exceeds the tensor

@pytorchbot pytorchbot added module: internals Related to internal abstractions in c10 and ATen module: numpy Related to numpy support, and also numpy compatibility of our operators module: numba labels May 16, 2019
@li-roy li-roy added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 16, 2019
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically, the code looks great! However, I would like the memory leaks to be fixed before merge. If you insist, just manually inserting the necessary decrefs is acceptable, however, I think using an RAII class from pybind11 will be much safer.

@madsbk
Copy link
Contributor Author

madsbk commented May 17, 2019

Thanks, I will use pybind11 to fix the memory leaks and also address the other issues.

if(!PyTuple_Check(py_data) || PyTuple_GET_SIZE(py_data) != 2) {
throw TypeError("`data` must be a 2-tuple of (int, bool)");
}
PyTuple_GET_ITEM(py_data, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you intended to have this line?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will merge when tests pass.

@ezyang ezyang self-requested a review May 20, 2019 14:04
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 5d8879c.

@mrocklin
Copy link
Contributor

FYI @seibert, thought you'd like to know that PyTorch now supports __cuda_array_interface__ in both directions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: internals Related to internal abstractions in c10 and ATen module: numba module: numpy Related to numpy support, and also numpy compatibility of our operators open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Auto-convert GPU arrays that support the __cuda_array_interface__ protocol

7 participants