Skip to content

[feature request] Implement __cuda_array_interface__ for numba compatibility. #11914

@asford

Description

@asford

Issue description

numba provides generic support for jit-optimized operations over tensor data structures, with support for cpu and cuda based operations. This is an ideal match for accelerating pytorch operations. Numba has proposed a standardized "__cuda_array_interface__", inspired by the numpy __array_interface__ standard, to provide interop between cuda-based tensors.

Implementation of this standard would dramatically ease PyTorch/Numba integration. This can be enabled via a monkey-patch of the current torch.Tensor type, however PyTorch should provide an implementation.

Code example

import torch


def torch_cuda_array_interface(tensor):
    """__cuda_array_interface__ getter implementation for torch.Tensor."""

    if not tensor.device.type == "cuda":
        # raise AttributeError for non-cuda tensors, so that
        # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
        raise AttributeError("Tensor is not on cuda device: %r" % tensor.device)

    if tensor.requires_grad:
        # RuntimeError, matching existing tensor.__array__() behavior.
        raise RuntimeError(
            "Can't get __cuda_array_interface__ on Variable that requires grad. "
            "Use var.detach().__cuda_array_interface__ instead."
        )

    typestr = {
        torch.float16: "f2",
        torch.float32: "f4",
        torch.float64: "f8",
        torch.uint8: "u1",
        torch.int8: "i1",
        torch.int16: "i2",
        torch.int32: "i4",
        torch.int64: "i8",
    }[tensor.dtype]

    itemsize = tensor.storage().element_size()

    shape = tensor.shape
    strides = tuple(s * itemsize for s in tensor.stride())
    data = (tensor.data_ptr(), False)

    return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=0)


torch.Tensor.__cuda_array_interface__ = property(torch_cuda_array_interface)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions