[DLPack] Add support for missing keyword-arguments. (#150218)

This PR introduces the rest of the keyword-arguments added in DLPack
version 2023.12: `dl_device` and `copy`.

In summary, we handle these arguments in the C++ implementation of
`to_dlpack(...)` at _torch/csrc/Module.cpp_, by calling the
`maybeCopyTensor` function at _aten/src/ATen/DLConvertor.cpp_. It also
introduces the following changes:

- Add a new Python API `torchDeviceToDLDevice()`, which is simply a
  refactoring of the `getDLDevice()` function at
  _aten/src/ATen/DLConvertor.cpp_.
- Add both keyword-arguments to the `from_dlpack()` function at
  _torch/utils/dlpack.py_ and to the `Tensor.__dlpack__()` dunder
  method.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150218
Approved by: https://github.com/albanD
ghstack dependencies: #150216, #150217
This commit is contained in:
Yukio Siraichi
2025-07-19 16:36:07 -03:00
committed by PyTorch MergeBot
parent 1d526fe78f
commit a10f15718d
8 changed files with 266 additions and 52 deletions

View File

@ -96,10 +96,14 @@ DLDataType getDLDataType(const Tensor& t) {
return dtype;
}
static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
DLDevice torchDeviceToDLDevice(at::Device device) {
DLDevice ctx;
ctx.device_id = static_cast<int32_t>(static_cast<unsigned char>(device_id));
switch (tensor.device().type()) {
ctx.device_id = (device.is_cuda() || device.is_privateuseone())
? static_cast<int32_t>(static_cast<unsigned char>(device.index()))
: 0;
switch (device.type()) {
case DeviceType::CPU:
ctx.device_type = DLDeviceType::kDLCPU;
break;
@ -120,8 +124,7 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
break;
case DeviceType::XPU:
ctx.device_type = DLDeviceType::kDLOneAPI;
ctx.device_id =
at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device());
ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device);
break;
case DeviceType::MAIA:
ctx.device_type = DLDeviceType::kDLMAIA;
@ -130,38 +133,40 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
ctx.device_type = DLDeviceType::kDLExtDev;
break;
default:
TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
TORCH_CHECK(false, "Cannot pack tensors on " + device.str());
}
return ctx;
}
static Device getATenDevice(const DLDevice& ctx, void* data) {
switch (ctx.device_type) {
static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) {
switch (type) {
case DLDeviceType::kDLCPU:
return at::Device(DeviceType::CPU);
#ifndef USE_ROCM
// if we are compiled under HIP, we cannot do cuda
case DLDeviceType::kDLCUDA:
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
return at::Device(DeviceType::CUDA, index);
#endif
case DLDeviceType::kDLOpenCL:
return at::Device(DeviceType::OPENCL, static_cast<c10::DeviceIndex>(ctx.device_id));
return at::Device(DeviceType::OPENCL, index);
case DLDeviceType::kDLROCM:
#ifdef USE_ROCM
// this looks funny, we need to return CUDA here to masquerade
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
return at::Device(DeviceType::CUDA, index);
#else
return at::Device(DeviceType::HIP, static_cast<c10::DeviceIndex>(ctx.device_id));
return at::Device(DeviceType::HIP, index);
#endif
case DLDeviceType::kDLOneAPI:
TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data.");
return at::detail::getXPUHooks().getDeviceFromPtr(data);
case DLDeviceType::kDLMAIA:
return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
return at::Device(DeviceType::MAIA, index);
case DLDeviceType::kDLExtDev:
return at::Device(DeviceType::PrivateUse1, static_cast<c10::DeviceIndex>(ctx.device_id));
return at::Device(DeviceType::PrivateUse1, index);
default:
TORCH_CHECK(
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
false, "Unsupported device_type: ", std::to_string(type));
}
}
@ -314,11 +319,7 @@ T* toDLPackImpl(const Tensor& src) {
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter<T>;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
c10::DeviceIndex device_id = 0;
if (src.is_cuda() || src.is_privateuseone()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
@ -346,7 +347,7 @@ at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
}
DLTensor& dl_tensor = src->dl_tensor;
Device device = getATenDevice(dl_tensor.device, dl_tensor.data);
Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data);
ScalarType stype = toScalarType(dl_tensor.dtype);
if (!dl_tensor.strides) {
@ -388,4 +389,35 @@ Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function<void(voi
return fromDLPackImpl<DLManagedTensorVersioned>(src, std::move(deleter));
}
Tensor maybeCopyTensor(
const Tensor& data,
std::optional<DLDevice> optional_dl_device,
std::optional<bool> copy) {
bool force_copy = copy.has_value() && *copy;
bool force_move = copy.has_value() && !*copy;
if (optional_dl_device.has_value()) {
auto device = at::getATenDevice(
optional_dl_device->device_type,
static_cast<c10::DeviceIndex>(optional_dl_device->device_id));
if (device != data.device()) {
TORCH_CHECK_VALUE(
!force_move,
"cannot move (i.e. copy=False) tensor from ",
data.device(),
" to ",
device,
" without copying.");
return data.to(device);
}
}
if (force_copy) {
return data.clone();
}
return data;
}
} // namespace at

View File

@ -21,6 +21,16 @@ TORCH_API Tensor fromDLPackVersioned(
TORCH_API DLDataType getDLDataType(const Tensor& t);
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
// Copies the Tensor if there's a device mismatch or copy is forced.
// This should be used before actually creating the DLPack capsule.
TORCH_API Tensor maybeCopyTensor(
const Tensor& data,
std::optional<DLDevice> optional_dl_device,
std::optional<bool> copy);
// Converts the given at::Device into a DLDevice.
TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
// This trait class is used for retrieving different attributes, such as the
// PyCapsule names and conversion functions for both DLPack tensor classes:
// `DLManagedTensor` and `DLManagedTensorVersioned`.

View File

@ -410,6 +410,55 @@ class TestTorchDlPack(TestCase):
self.assertEqual(t, res)
self.assertEqual(t.data_ptr(), res.data_ptr())
def _test_from_dlpack(self, device, out_device=None, copy=None):
if isinstance(device, str):
device = torch.device(device)
inp = make_tensor((5,), dtype=torch.float32, device=device)
out = torch.from_dlpack(inp, device=out_device, copy=copy)
if out_device is None:
out_device = device
if isinstance(out_device, str):
out_device = torch.device(out_device)
self.assertEqual(inp, out)
self.assertEqual(out.device, out_device)
# They should be moved (i.e. not copied) only if:
# (a) we are forcing move, i.e. copy=False
# (b) the output device is the same as the input one AND copy is None
if copy is False or (copy is None and device == out_device):
self.assertEqual(inp.data_ptr(), out.data_ptr())
else:
# Otherwise, inp should be copied.
self.assertNotEqual(inp.data_ptr(), out.data_ptr())
@skipMeta
@onlyCUDA
def test_copy(self, device):
# Force-copy same device tensor.
self._test_from_dlpack(device, copy=True)
self._test_from_dlpack(device, out_device=device, copy=True)
# Output should be in a different device, i.e. should have been copied.
self._test_from_dlpack(device, out_device="cpu")
self._test_from_dlpack(device, out_device="cpu", copy=True)
@skipMeta
@onlyCUDA
def test_no_copy(self, device):
# No copy, since tensor lives in the same device.
self._test_from_dlpack(device)
self._test_from_dlpack(device, copy=False)
self._test_from_dlpack(device, out_device=device)
self._test_from_dlpack(device, out_device=device, copy=False)
@skipMeta
@onlyCUDA
def test_needs_copy_error(self, device):
with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"):
self._test_from_dlpack(device, out_device="cpu", copy=False)
instantiate_device_type_tests(TestTorchDlPack, globals())

View File

@ -1301,9 +1301,20 @@ def _initCrashHandler() -> None: ...
# NB: There is no Capsule type in typing, see
# https://github.com/python/cpython/issues/109562
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
def _to_dlpack_versioned(data: Tensor) -> Any: ... # THPModule_toDLPackVersioned
def _to_dlpack(
data: Tensor,
dl_device: tuple[IntEnum, _int] | None = None,
copy: _bool | None = None,
) -> Any: ... # THPModule_toDLPack
def _to_dlpack_versioned(
data: Tensor,
dl_device: tuple[IntEnum, _int] | None = None,
copy: _bool | None = None,
) -> Any: ... # THPModule_toDLPackVersioned
def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack
def _torchDeviceToDLDevice(
device: torch.device,
) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice
def _get_cpp_backtrace(
frames_to_skip: _int,
maximum_number_of_frames: _int,

View File

@ -1659,7 +1659,14 @@ class Tensor(torch._C.TensorBase):
__torch_dispatch__ = _C._disabled_torch_dispatch_impl
def __dlpack__(self, *, stream=None, max_version=None):
def __dlpack__(
self,
*,
stream: Optional[Any] = None,
max_version: Optional[tuple[int, int]] = None,
dl_device: Optional[tuple[enum.IntEnum, int]] = None,
copy: Optional[bool] = None,
):
"""
Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
of the current tensor to be exported to other libraries.
@ -1670,22 +1677,31 @@ class Tensor(torch._C.TensorBase):
Args:
stream (integer or None): An optional Python integer representing a
pointer to a CUDA stream. The current stream is synchronized with
this stream before the capsule is created, and since the capsule
shares its storage with the tensor this make it safe to access from
both streams. If None or -1 is passed then no synchronization is performed.
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
synchronization.
pointer to a CUDA stream. The current stream is synchronized with
this stream before the capsule is created, and since the capsule
shares its storage with the tensor this make it safe to access from
both streams. If None or -1 is passed then no synchronization is performed.
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
synchronization.
max_version (tuple[int, int] or None): An optional Python tuple with
2 integers, representing the maximum version the caller supports. If
None (default), PyTorch will fallback to DLPack 0.8.
2 integers, representing the maximum version the caller supports. If
None (default), PyTorch will fallback to DLPack 0.8.
dl_device (tuple[DLDeviceType, int] or None): An optional tuple specifying
in which device the exported DLPack capsule should be on. If None (default),
the exported DLPack capsule will be on the same device as ``self``.
copy (bool or None): An optional boolean indicating whether or not to copy
``self``. If None, PyTorch will copy only if necessary.
"""
if has_torch_function_unary(self):
args = (self,)
kwargs = {
"stream": stream,
"max_version": max_version,
"dl_device": dl_device,
"copy": copy,
}
return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs)
@ -1763,9 +1779,9 @@ class Tensor(torch._C.TensorBase):
if max_version is None or max_version[0] < 1:
# Fallback to the old, unversioned variant.
return torch.to_dlpack(self)
return _C._to_dlpack(self, dl_device=dl_device, copy=copy)
return _C._to_dlpack_versioned(self)
return _C._to_dlpack_versioned(self, dl_device=dl_device, copy=copy)
def __dlpack_device__(self) -> tuple[enum.IntEnum, int]:
if has_torch_function_unary(self):

View File

@ -607,25 +607,56 @@ void DLPack_Capsule_Destructor(PyObject* data) {
}
template <class T>
PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) {
PyObject* THPModule_toDLPackImpl(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor");
auto tensor = at::DLPackTraits<T>::toDLPack(THPVariable_Unpack(data));
static torch::PythonArgParser parser(
{"_to_dlpack(Tensor data, *, IntArrayRef? dl_device=None, bool? copy=None)"});
torch::ParsedArgs<3> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
TORCH_INTERNAL_ASSERT(r.idx == 0);
auto data = r.tensor(0);
auto dl_device = r.intlist(1);
auto copy = r.toBoolOptional(2);
// Parse the int list into a tuple.
std::optional<DLDevice> optional_dl_device;
if (!dl_device.empty()) {
TORCH_CHECK(
dl_device.size() == 2,
"dl_device must be either None or a tuple of ints");
optional_dl_device = DLDevice{
static_cast<DLDeviceType>(dl_device[0]),
static_cast<int32_t>(dl_device[1])};
}
auto tensor = at::DLPackTraits<T>::toDLPack(
at::maybeCopyTensor(data, optional_dl_device, copy));
return PyCapsule_New(
tensor, at::DLPackTraits<T>::capsule, DLPack_Capsule_Destructor<T>);
END_HANDLE_TH_ERRORS
}
} // namespace
static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) {
return THPModule_toDLPackImpl<DLManagedTensor>(_unused, data);
static PyObject* THPModule_toDLPack(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
return THPModule_toDLPackImpl<DLManagedTensor>(self, args, kwargs);
}
static PyObject* THPModule_toDLPackVersioned(
PyObject* _unused,
PyObject* data) {
return THPModule_toDLPackImpl<DLManagedTensorVersioned>(_unused, data);
PyObject* self,
PyObject* args,
PyObject* kwargs) {
return THPModule_toDLPackImpl<DLManagedTensorVersioned>(self, args, kwargs);
}
static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
@ -636,6 +667,28 @@ static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_torchDeviceToDLDevice(
PyObject* _unused,
PyObject* data) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPDevice_Check(data),
"torchDeviceToDLDevice: expected torch.device argument.");
auto device = reinterpret_cast<THPDevice*>(data)->device;
auto dl_device = at::torchDeviceToDLDevice(device);
auto tuple = PyTuple_New(2);
if (!tuple) {
throw python_error();
}
PyTuple_SET_ITEM(tuple, 0, THPUtils_packInt64(dl_device.device_type));
PyTuple_SET_ITEM(tuple, 1, THPUtils_packInt64(dl_device.device_id));
return tuple;
END_HANDLE_TH_ERRORS
}
static PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
HANDLE_TH_ERRORS
size_t frames_to_skip = 0;
@ -1687,9 +1740,19 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
THPModule_are_vmap_fallback_warnings_enabled,
METH_NOARGS,
nullptr},
{"_to_dlpack", THPModule_toDLPack, METH_O, nullptr},
{"_to_dlpack_versioned", THPModule_toDLPackVersioned, METH_O, nullptr},
{"_to_dlpack",
castPyCFunctionWithKeywords(THPModule_toDLPack),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_to_dlpack_versioned",
castPyCFunctionWithKeywords(THPModule_toDLPackVersioned),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr},
{"_torchDeviceToDLDevice",
THPModule_torchDeviceToDLDevice,
METH_O,
nullptr},
{"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr},
{"_rename_privateuse1_backend",
THModule_rename_privateuse1_backend,

View File

@ -1512,7 +1512,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
Tensor.view: lambda self, shape: -1,
Tensor.view_as: lambda self, other: -1,
Tensor.zero_: lambda self: -1,
Tensor.__dlpack__: lambda self, stream=None, max_version=None: -1,
Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1,
Tensor.__dlpack_device__: lambda self: -1,
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
} # fmt: skip

View File

@ -1,9 +1,10 @@
from typing import Any
from typing import Any, Optional
import torch
import enum
from torch._C import _to_dlpack as to_dlpack
from torch.types import Device as _Device
__all__ = [
"DLDeviceType",
@ -54,7 +55,12 @@ The DLPack capsule shares the tensor's memory.
# TODO: add a typing.Protocol to be able to tell Mypy that only objects with
# __dlpack__ and __dlpack_device__ methods are accepted.
def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
def from_dlpack(
ext_tensor: Any,
*,
device: Optional[_Device] = None,
copy: Optional[bool] = None
) -> 'torch.Tensor':
"""from_dlpack(ext_tensor) -> Tensor
Converts a tensor from an external library into a ``torch.Tensor``.
@ -76,6 +82,13 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
an opaque ``PyCapsule`` instance, typically produced by a
``to_dlpack`` function or method.
device (torch.device or str or None): An optional PyTorch device
specifying where to place the new tensor. If None (default), the
new tensor will be on the same device as ``ext_tensor``.
copy (bool or None): An optional boolean indicating whether or not to copy
``self``. If None, PyTorch will copy only if necessary.
Examples::
>>> import torch.utils.dlpack
@ -106,20 +119,36 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
"""
if hasattr(ext_tensor, '__dlpack__'):
# Only populate kwargs if any of the optional arguments are, in fact, not None. Otherwise,
# leave them out, since we might end up falling back to no-extra-kwargs __dlpack__ call.
kwargs: dict[str, Any] = {}
kwargs["max_version"] = (1, 0)
device = ext_tensor.__dlpack_device__()
# device is either CUDA or ROCm, we need to pass the current
if copy is not None:
kwargs["copy"] = copy
# Parse the device parameter.
# At this moment, it can either be a torch.device or a str representing
# a torch.device, e.g. "cpu", "cuda", etc.
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device), (
f"from_dlpack: unsupported device type: {type(device)}"
)
kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device)
ext_device = ext_tensor.__dlpack_device__()
# ext_device is either CUDA or ROCm, we need to pass the current
# stream
if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM):
stream = torch.cuda.current_stream(f'cuda:{device[1]}')
if ext_device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM):
stream = torch.cuda.current_stream(f'cuda:{ext_device[1]}')
# cuda_stream is the pointer to the stream and it is a public
# attribute, but it is not documented
# The array API specify that the default legacy stream must be passed
# with a value of 1 for CUDA
# https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none
is_cuda = device[0] == DLDeviceType.kDLCUDA
is_cuda = ext_device[0] == DLDeviceType.kDLCUDA
# Since pytorch is not using PTDS by default, lets directly pass
# the legacy stream
stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream
@ -134,6 +163,10 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
dlpack = ext_tensor.__dlpack__(**kwargs)
else:
assert device is None and copy is None, (
"device and copy kwargs not supported when ext_tensor is "
"already a DLPack capsule."
)
# Old versions just call the converter
dlpack = ext_tensor
return torch._C._from_dlpack(dlpack)