-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
This is a proposal to bring C standard functions pointer APIs to exchange DLPack based python function calls to 0.1-0.5us level, and will enable more effective use in GPU settings
As of now, PyTorch provides exchange relies on python functions such as tensor.__dlpack__(). While they works well for common cases, the general overhead of such exchange is at the level of 0.2-1 us, depending on the optimization and machine. A typical conversion based API runs as follows:
import mylib
def my_op(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
a0 = mylib.from_dlpack(a)
b0 = mylib.from_dlpack(b)
c0 = mylib.from_dlpack(c)
run_kernel(a0, b0, c0)For a function that takes three arguments f(a, b, c), assume we run DLPack exchange for each argument, the general conversion overhead usually gets to around 0.7us - 3us. While such overhead can be acceptable in many settings, in GPU applications the extra 1-3us overhead can still be significant.
After some investigation, we find out that actually one main source of such overhead is python. Calling python functions generally incur 0.1us - 0.2us overhead, and from_dlpack and dlpack conversion involves several steps going across cpp and python boundary.
To address this issue. This PR proposes functions for speed exchange DLPack tensors without going through python interpreter. See also the proposal on DLPack RFC dmlc/dlpack#175
Proposed Functions
//----------------------------------------------------------------------
// DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions
//----------------------------------------------------------------------
/*!
* \brief Request a producer library to create a new tensor.
*
* Create a new `DLManagedTensorVersioned` within the context of the producer
* library. The allocation is defined via the prototype DLTensor.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param prototype The prototype DLTensor. Only the dtype, ndim, shape,
* and device fields are used.
* \param out The output DLManagedTensorVersioned.
* \param error_ctx Context for `SetError`.
* \param SetError The function to set the error.
* \return The owning DLManagedTensorVersioned* or NULL on failure.
* SetError is called exactly when NULL is returned (the implementor
* must ensure this).
* \note - As a C function, must not thrown C++ exceptions.
* - Error propagation via SetError to avoid any direct need
* of Python API. Due to this `SetError` may have to ensure the GIL is
* held since it will presumably set a Python error.
*
* \sa DLPackExchangeAPI
*/
typedef int (*DLPackManagedTensorAllocator)( //
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, //
void (*SetError)(void* error_ctx, const char* kind, const char* message) //
);
/*!
* \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned.
*
* This function does not perform any stream synchronization. The consumer should query
* DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param py_object The Python object to convert. Must have the same type
* as the one the `DLPackExchangeAPI` was discovered from.
* \return The owning DLManagedTensorVersioned* or NULL on failure with a
* Python exception set. If the data cannot be described using DLPack
* this should be a BufferError if possible.
* \note - As a C function, must not thrown C++ exceptions.
*
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
*/
typedef int (*DLPackManagedTensorFromPyObjectNoSync)( //
void* py_object, //
DLManagedTensorVersioned** out //
);
/*!
* \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor.
*
* This function provides a faster interface for temporary, non-owning, exchange.
* The producer (implementor) still owns the memory of data, strides, shape.
* The liveness of the DLTensor and the data it views is only guaranteed until
* control is returned.
*
* This function currently assumes that the producer (implementor) can fill
* in the DLTensor shape and strides without the need for temporary allocations.
*
* This function does not perform any stream synchronization. The consumer should query
* DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param py_object The Python object to convert. Must have the same type
* as the one the `DLPackExchangeAPI` was discovered from.
* \param out The output DLTensor, whose space is pre-allocated on stack.
* \return 0 on success, -1 on failure with a Python exception set.
* \note - As a C function, must not thrown C++ exceptions.
*
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
*/
typedef int (*DLPackDLTensorFromPyObjectNoSync)( //
void* py_object, //
DLTensor* out //
);
/*!
* \brief Obtain the current work stream of a device.
*
* Obtain the current work stream of a device from the producer framework.
* For example, it should map to torch.cuda.current_stream in PyTorch.
*
* When device_type is kDLCPU, the consumer do not have to query the stream
* and the producer can simply return NULL when queried.
* The consumer do not have to do anything on stream sync or setting.
* So CPU only framework can just provide a dummy implementation that
* always set out_current_stream[0] to NULL.
*
* \param device_type The device type.
* \param device_id The device id.
* \param out_current_stream The output current work stream.
*
* \return 0 on success, -1 on failure with a Python exception set.
* \note - As a C function, must not thrown C++ exceptions.
*
* \sa DLPackExchangeAPI
*/
typedef int (*DLPackCurrentWorkStream)( //
DLDeviceType device_type, //
int32_t device_id, //
void** out_current_stream //
);
/*!
* \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray.
*
* Convert an owning DLManagedTensorVersioned* to the Python tensor of the
* producer (implementor) library with the correct type.
*
* This function does not perform any stream synchronization.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param tensor The DLManagedTensorVersioned to convert the ownership of the
* tensor is stolen.
* \param out_py_object The output Python object.
* \return 0 on success, -1 on failure with a Python exception set.
*
* \sa DLPackExchangeAPI
*/
typedef int (*DLPackManagedTensorToPyObjectNoSync)( //
DLManagedTensorVersioned* tensor, //
void** out_py_object //
);
/*!
* \brief DLPackExchangeAPI stable header.
* \sa DLPackExchangeAPI
*/
typedef struct DLPackExchangeAPIHeader {
/*!
* \brief The provided DLPack version the consumer must check major version
* compatibility before using this struct.
*/
DLPackVersion version;
/*!
* \brief Optional pointer to an older DLPackExchangeAPI in the chain.
*
* It must be NULL if the framework does not support older versions.
* If the current major version is larger than the one supported by the
* consumer, the consumer may walk this to find an earlier supported version.
*
* \sa DLPackExchangeAPI
*/
struct DLPackExchangeAPIHeader* prev_api;
} DLPackExchangeAPIHeader;
/*!
* \brief Framework-specific function pointers table for DLPack exchange.
*
* Additionally to `__dlpack__()` we define a C function table sharable by
* Python implementations via `__c_dlpack_exchange_api__`.
* This attribute must be set on the type as a Python integer compatible
* with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`.
*
* A consumer library may use a pattern such as:
*
* \code
*
* PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code
* MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj);
* if (api == NULL && PyErr_Occurred()) { goto handle_error; }
*
* \endcode
*
* Note that this must be defined on the type. The consumer should look up the
* attribute on the type and may cache the result for each unique type.
*
* The precise API table is given by:
* \code
* struct MyDLPackExchangeAPI : public DLPackExchangeAPI {
* MyDLPackExchangeAPI() {
* header.version.major = DLPACK_MAJOR_VERSION;
* header.version.minor = DLPACK_MINOR_VERSION;
* header.prev_version_api = nullptr;
*
* managed_tensor_allocator = MyDLPackManagedTensorAllocator;
* managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync;
* managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync;
* dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync;
* current_work_stream = MyDLPackCurrentWorkStream;
* }
*
* static const DLPackExchangeAPI* Global() {
* static MyDLPackExchangeAPI inst;
* return &inst;
* }
* };
* \endcode
*
* Guidelines for leveraging DLPackExchangeAPI:
*
* There are generally two kinds of consumer needs for DLPack exchange:
* - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel
* with the data from x, y, z. The consumer is also expected to run the kernel with the same
* stream context as the producer. For example, when x, y, z is torch.Tensor,
* consumer should query exchange_api->current_work_stream to get the
* current stream and launch the kernel with the same stream.
* This setup is necessary for no synchronization in kernel launch and maximum compatibility
* with CUDA graph capture in the producer.
* This is the desirable behavior for library extension support for frameworks like PyTorch.
* - N1: data ingestion and retention
*
* Note that obj.__dlpack__() API should provide useful ways for N1.
* The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0
* with the support of the function pointer current_work_stream.
*
* Array/Tensor libraries should statically create and initialize this structure
* then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
* The DLPackExchangeAPI* must stay alive throughout the lifetime of the process.
*
* One simple way to do so is to create a static instance of DLPackExchangeAPI
* within the framework and return a pointer to it. The following code
* shows an example to do so in C++. It should also be reasonably easy
* to do so in other languages.
*/
typedef struct DLPackExchangeAPI {
/*!
* \brief The header that remains stable across versions.
*/
DLPackExchangeAPIHeader header;
/*!
* \brief Producer function pointer for DLPackManagedTensorAllocator
* This function must not be NULL.
* \sa DLPackManagedTensorAllocator
*/
DLPackManagedTensorAllocator managed_tensor_allocator;
/*!
* \brief Producer function pointer for DLPackManagedTensorFromPyObject
* This function must be not NULL.
* \sa DLPackManagedTensorFromPyObject
*/
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackManagedTensorToPyObject
* This function must be not NULL.
* \sa DLPackManagedTensorToPyObject
*/
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackDLTensorFromPyObject
* This function can be NULL when the producer does not support this function.
* \sa DLPackDLTensorFromPyObjectNoSync
*/
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackCurrentWorkStream
* This function must be not NULL.
* \sa DLPackCurrentWorkStream
*/
DLPackCurrentWorkStream current_work_stream;
} DLPackExchangeAPI;Bring up in the Context of PyTorch
In the context of PyTorch, it means supporting three functions (besides on top of current DLConverter.cpp)
#include <ATen/DLConvertor.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDAStream.h>
int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out) {
try {
py::handle handle(static_cast<PyObject*>(py_obj));
at::Tensor tensor = handle.cast<at::Tensor>();
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}
int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) {
try {
at::Tensor tensor = at::fromDLPackImpl<DLManagedTensorVersioned>(src, nullptr);
*py_obj_out = THPVariable_Wrap(tensor);
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}
int TorchDLPackTensorAllocator(
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) {
try {
at::IntArrayRef shape(prototype->shape, prototype->shape + prototype->ndim);
at::TensorOptions options = at::TensorOptions()
.dtype(at::toScalarType(prototype->dtype))
.device(at::getATenDevice(prototype->device.device_type, prototype->device.device_id));
at::Tensor tensor = at::empty(shape, options);
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
SetError(error_ctx, "TorchDLPackTensorAllocator", e.what());
return -1;
}
}
int TorchCurrentWorkStream(DLDeviceType device_type, int32_t device_id, void **out_stream) {
try {
if (device_type != kDLCPU) {
*out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
}
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}
struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
TorchDLPackExchangeAPI() {
version.major = DLPACK_MAJOR_VERSION;
version.minor = DLPACK_MINOR_VERSION;
prev_version_api = nullptr;
managed_tensor_allocator = TorchDLPackTensorAllocator;
managed_tensor_from_py_object_no_sync = TorchDLPackFromPyObject;
managed_tensor_to_py_object_no_sync = TorchDLPackToPyObject;
dltensor_from_py_object_no_sync = TorchDLTensorFromPyObject;
current_work_stream = TorchCurrentWorkStream;
}
const DLPackExchangeAPI* Global() {
static MyDLPackExchangeAPI inst;
return &inst;
}
};
// The following code should be part of the python binding, to expose these function pointers
// python binding, likely as update to torch._C.
int64_t TorchDLPackExchangeAPIPtr() {
return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
}
// ...The functions will be exposed through a static function table and they will be set as constant attribute as the torch.Tensor class:
torch.Tensor.__c_dlpack_exchange_api__
Then other DLPack based package can leverage these C based exchange to write cpp extensions that only depend on the three function handles, PyObject* and DLPack
cc @jbschlosser