mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DLPack] Add support for missing keyword-arguments. (#150218)
This PR introduces the rest of the keyword-arguments added in DLPack version 2023.12: `dl_device` and `copy`. In summary, we handle these arguments in the C++ implementation of `to_dlpack(...)` at _torch/csrc/Module.cpp_, by calling the `maybeCopyTensor` function at _aten/src/ATen/DLConvertor.cpp_. It also introduces the following changes: - Add a new Python API `torchDeviceToDLDevice()`, which is simply a refactoring of the `getDLDevice()` function at _aten/src/ATen/DLConvertor.cpp_. - Add both keyword-arguments to the `from_dlpack()` function at _torch/utils/dlpack.py_ and to the `Tensor.__dlpack__()` dunder method. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150218 Approved by: https://github.com/albanD ghstack dependencies: #150216, #150217
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d526fe78f
commit
a10f15718d
@ -96,10 +96,14 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
return dtype;
|
||||
}
|
||||
|
||||
static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
|
||||
DLDevice torchDeviceToDLDevice(at::Device device) {
|
||||
DLDevice ctx;
|
||||
ctx.device_id = static_cast<int32_t>(static_cast<unsigned char>(device_id));
|
||||
switch (tensor.device().type()) {
|
||||
|
||||
ctx.device_id = (device.is_cuda() || device.is_privateuseone())
|
||||
? static_cast<int32_t>(static_cast<unsigned char>(device.index()))
|
||||
: 0;
|
||||
|
||||
switch (device.type()) {
|
||||
case DeviceType::CPU:
|
||||
ctx.device_type = DLDeviceType::kDLCPU;
|
||||
break;
|
||||
@ -120,8 +124,7 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
|
||||
break;
|
||||
case DeviceType::XPU:
|
||||
ctx.device_type = DLDeviceType::kDLOneAPI;
|
||||
ctx.device_id =
|
||||
at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device());
|
||||
ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device);
|
||||
break;
|
||||
case DeviceType::MAIA:
|
||||
ctx.device_type = DLDeviceType::kDLMAIA;
|
||||
@ -130,38 +133,40 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
|
||||
ctx.device_type = DLDeviceType::kDLExtDev;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
|
||||
TORCH_CHECK(false, "Cannot pack tensors on " + device.str());
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static Device getATenDevice(const DLDevice& ctx, void* data) {
|
||||
switch (ctx.device_type) {
|
||||
static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) {
|
||||
switch (type) {
|
||||
case DLDeviceType::kDLCPU:
|
||||
return at::Device(DeviceType::CPU);
|
||||
#ifndef USE_ROCM
|
||||
// if we are compiled under HIP, we cannot do cuda
|
||||
case DLDeviceType::kDLCUDA:
|
||||
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
return at::Device(DeviceType::CUDA, index);
|
||||
#endif
|
||||
case DLDeviceType::kDLOpenCL:
|
||||
return at::Device(DeviceType::OPENCL, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
return at::Device(DeviceType::OPENCL, index);
|
||||
case DLDeviceType::kDLROCM:
|
||||
#ifdef USE_ROCM
|
||||
// this looks funny, we need to return CUDA here to masquerade
|
||||
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
return at::Device(DeviceType::CUDA, index);
|
||||
#else
|
||||
return at::Device(DeviceType::HIP, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
return at::Device(DeviceType::HIP, index);
|
||||
#endif
|
||||
case DLDeviceType::kDLOneAPI:
|
||||
TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data.");
|
||||
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
||||
case DLDeviceType::kDLMAIA:
|
||||
return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
return at::Device(DeviceType::MAIA, index);
|
||||
case DLDeviceType::kDLExtDev:
|
||||
return at::Device(DeviceType::PrivateUse1, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
return at::Device(DeviceType::PrivateUse1, index);
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
|
||||
false, "Unsupported device_type: ", std::to_string(type));
|
||||
}
|
||||
}
|
||||
|
||||
@ -314,11 +319,7 @@ T* toDLPackImpl(const Tensor& src) {
|
||||
atDLMTensor->tensor.manager_ctx = atDLMTensor;
|
||||
atDLMTensor->tensor.deleter = &deleter<T>;
|
||||
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
|
||||
c10::DeviceIndex device_id = 0;
|
||||
if (src.is_cuda() || src.is_privateuseone()) {
|
||||
device_id = src.get_device();
|
||||
}
|
||||
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
|
||||
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
|
||||
@ -346,7 +347,7 @@ at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
|
||||
}
|
||||
|
||||
DLTensor& dl_tensor = src->dl_tensor;
|
||||
Device device = getATenDevice(dl_tensor.device, dl_tensor.data);
|
||||
Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data);
|
||||
ScalarType stype = toScalarType(dl_tensor.dtype);
|
||||
|
||||
if (!dl_tensor.strides) {
|
||||
@ -388,4 +389,35 @@ Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function<void(voi
|
||||
return fromDLPackImpl<DLManagedTensorVersioned>(src, std::move(deleter));
|
||||
}
|
||||
|
||||
Tensor maybeCopyTensor(
|
||||
const Tensor& data,
|
||||
std::optional<DLDevice> optional_dl_device,
|
||||
std::optional<bool> copy) {
|
||||
bool force_copy = copy.has_value() && *copy;
|
||||
bool force_move = copy.has_value() && !*copy;
|
||||
|
||||
if (optional_dl_device.has_value()) {
|
||||
auto device = at::getATenDevice(
|
||||
optional_dl_device->device_type,
|
||||
static_cast<c10::DeviceIndex>(optional_dl_device->device_id));
|
||||
|
||||
if (device != data.device()) {
|
||||
TORCH_CHECK_VALUE(
|
||||
!force_move,
|
||||
"cannot move (i.e. copy=False) tensor from ",
|
||||
data.device(),
|
||||
" to ",
|
||||
device,
|
||||
" without copying.");
|
||||
return data.to(device);
|
||||
}
|
||||
}
|
||||
|
||||
if (force_copy) {
|
||||
return data.clone();
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -21,6 +21,16 @@ TORCH_API Tensor fromDLPackVersioned(
|
||||
TORCH_API DLDataType getDLDataType(const Tensor& t);
|
||||
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
|
||||
|
||||
// Copies the Tensor if there's a device mismatch or copy is forced.
|
||||
// This should be used before actually creating the DLPack capsule.
|
||||
TORCH_API Tensor maybeCopyTensor(
|
||||
const Tensor& data,
|
||||
std::optional<DLDevice> optional_dl_device,
|
||||
std::optional<bool> copy);
|
||||
|
||||
// Converts the given at::Device into a DLDevice.
|
||||
TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
|
||||
|
||||
// This trait class is used for retrieving different attributes, such as the
|
||||
// PyCapsule names and conversion functions for both DLPack tensor classes:
|
||||
// `DLManagedTensor` and `DLManagedTensorVersioned`.
|
||||
|
@ -410,6 +410,55 @@ class TestTorchDlPack(TestCase):
|
||||
self.assertEqual(t, res)
|
||||
self.assertEqual(t.data_ptr(), res.data_ptr())
|
||||
|
||||
def _test_from_dlpack(self, device, out_device=None, copy=None):
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
inp = make_tensor((5,), dtype=torch.float32, device=device)
|
||||
out = torch.from_dlpack(inp, device=out_device, copy=copy)
|
||||
|
||||
if out_device is None:
|
||||
out_device = device
|
||||
if isinstance(out_device, str):
|
||||
out_device = torch.device(out_device)
|
||||
|
||||
self.assertEqual(inp, out)
|
||||
self.assertEqual(out.device, out_device)
|
||||
|
||||
# They should be moved (i.e. not copied) only if:
|
||||
# (a) we are forcing move, i.e. copy=False
|
||||
# (b) the output device is the same as the input one AND copy is None
|
||||
if copy is False or (copy is None and device == out_device):
|
||||
self.assertEqual(inp.data_ptr(), out.data_ptr())
|
||||
else:
|
||||
# Otherwise, inp should be copied.
|
||||
self.assertNotEqual(inp.data_ptr(), out.data_ptr())
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
def test_copy(self, device):
|
||||
# Force-copy same device tensor.
|
||||
self._test_from_dlpack(device, copy=True)
|
||||
self._test_from_dlpack(device, out_device=device, copy=True)
|
||||
# Output should be in a different device, i.e. should have been copied.
|
||||
self._test_from_dlpack(device, out_device="cpu")
|
||||
self._test_from_dlpack(device, out_device="cpu", copy=True)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
def test_no_copy(self, device):
|
||||
# No copy, since tensor lives in the same device.
|
||||
self._test_from_dlpack(device)
|
||||
self._test_from_dlpack(device, copy=False)
|
||||
self._test_from_dlpack(device, out_device=device)
|
||||
self._test_from_dlpack(device, out_device=device, copy=False)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
def test_needs_copy_error(self, device):
|
||||
with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"):
|
||||
self._test_from_dlpack(device, out_device="cpu", copy=False)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTorchDlPack, globals())
|
||||
|
||||
|
@ -1301,9 +1301,20 @@ def _initCrashHandler() -> None: ...
|
||||
|
||||
# NB: There is no Capsule type in typing, see
|
||||
# https://github.com/python/cpython/issues/109562
|
||||
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
|
||||
def _to_dlpack_versioned(data: Tensor) -> Any: ... # THPModule_toDLPackVersioned
|
||||
def _to_dlpack(
|
||||
data: Tensor,
|
||||
dl_device: tuple[IntEnum, _int] | None = None,
|
||||
copy: _bool | None = None,
|
||||
) -> Any: ... # THPModule_toDLPack
|
||||
def _to_dlpack_versioned(
|
||||
data: Tensor,
|
||||
dl_device: tuple[IntEnum, _int] | None = None,
|
||||
copy: _bool | None = None,
|
||||
) -> Any: ... # THPModule_toDLPackVersioned
|
||||
def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack
|
||||
def _torchDeviceToDLDevice(
|
||||
device: torch.device,
|
||||
) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice
|
||||
def _get_cpp_backtrace(
|
||||
frames_to_skip: _int,
|
||||
maximum_number_of_frames: _int,
|
||||
|
@ -1659,7 +1659,14 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
__torch_dispatch__ = _C._disabled_torch_dispatch_impl
|
||||
|
||||
def __dlpack__(self, *, stream=None, max_version=None):
|
||||
def __dlpack__(
|
||||
self,
|
||||
*,
|
||||
stream: Optional[Any] = None,
|
||||
max_version: Optional[tuple[int, int]] = None,
|
||||
dl_device: Optional[tuple[enum.IntEnum, int]] = None,
|
||||
copy: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
|
||||
of the current tensor to be exported to other libraries.
|
||||
@ -1670,22 +1677,31 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
Args:
|
||||
stream (integer or None): An optional Python integer representing a
|
||||
pointer to a CUDA stream. The current stream is synchronized with
|
||||
this stream before the capsule is created, and since the capsule
|
||||
shares its storage with the tensor this make it safe to access from
|
||||
both streams. If None or -1 is passed then no synchronization is performed.
|
||||
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
|
||||
synchronization.
|
||||
pointer to a CUDA stream. The current stream is synchronized with
|
||||
this stream before the capsule is created, and since the capsule
|
||||
shares its storage with the tensor this make it safe to access from
|
||||
both streams. If None or -1 is passed then no synchronization is performed.
|
||||
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
|
||||
synchronization.
|
||||
|
||||
max_version (tuple[int, int] or None): An optional Python tuple with
|
||||
2 integers, representing the maximum version the caller supports. If
|
||||
None (default), PyTorch will fallback to DLPack 0.8.
|
||||
2 integers, representing the maximum version the caller supports. If
|
||||
None (default), PyTorch will fallback to DLPack 0.8.
|
||||
|
||||
dl_device (tuple[DLDeviceType, int] or None): An optional tuple specifying
|
||||
in which device the exported DLPack capsule should be on. If None (default),
|
||||
the exported DLPack capsule will be on the same device as ``self``.
|
||||
|
||||
copy (bool or None): An optional boolean indicating whether or not to copy
|
||||
``self``. If None, PyTorch will copy only if necessary.
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
args = (self,)
|
||||
kwargs = {
|
||||
"stream": stream,
|
||||
"max_version": max_version,
|
||||
"dl_device": dl_device,
|
||||
"copy": copy,
|
||||
}
|
||||
return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs)
|
||||
|
||||
@ -1763,9 +1779,9 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
if max_version is None or max_version[0] < 1:
|
||||
# Fallback to the old, unversioned variant.
|
||||
return torch.to_dlpack(self)
|
||||
return _C._to_dlpack(self, dl_device=dl_device, copy=copy)
|
||||
|
||||
return _C._to_dlpack_versioned(self)
|
||||
return _C._to_dlpack_versioned(self, dl_device=dl_device, copy=copy)
|
||||
|
||||
def __dlpack_device__(self) -> tuple[enum.IntEnum, int]:
|
||||
if has_torch_function_unary(self):
|
||||
|
@ -607,25 +607,56 @@ void DLPack_Capsule_Destructor(PyObject* data) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) {
|
||||
PyObject* THPModule_toDLPackImpl(
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor");
|
||||
auto tensor = at::DLPackTraits<T>::toDLPack(THPVariable_Unpack(data));
|
||||
static torch::PythonArgParser parser(
|
||||
{"_to_dlpack(Tensor data, *, IntArrayRef? dl_device=None, bool? copy=None)"});
|
||||
torch::ParsedArgs<3> parsed_args{};
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(r.idx == 0);
|
||||
|
||||
auto data = r.tensor(0);
|
||||
auto dl_device = r.intlist(1);
|
||||
auto copy = r.toBoolOptional(2);
|
||||
|
||||
// Parse the int list into a tuple.
|
||||
std::optional<DLDevice> optional_dl_device;
|
||||
|
||||
if (!dl_device.empty()) {
|
||||
TORCH_CHECK(
|
||||
dl_device.size() == 2,
|
||||
"dl_device must be either None or a tuple of ints");
|
||||
optional_dl_device = DLDevice{
|
||||
static_cast<DLDeviceType>(dl_device[0]),
|
||||
static_cast<int32_t>(dl_device[1])};
|
||||
}
|
||||
|
||||
auto tensor = at::DLPackTraits<T>::toDLPack(
|
||||
at::maybeCopyTensor(data, optional_dl_device, copy));
|
||||
return PyCapsule_New(
|
||||
tensor, at::DLPackTraits<T>::capsule, DLPack_Capsule_Destructor<T>);
|
||||
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) {
|
||||
return THPModule_toDLPackImpl<DLManagedTensor>(_unused, data);
|
||||
static PyObject* THPModule_toDLPack(
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
return THPModule_toDLPackImpl<DLManagedTensor>(self, args, kwargs);
|
||||
}
|
||||
|
||||
static PyObject* THPModule_toDLPackVersioned(
|
||||
PyObject* _unused,
|
||||
PyObject* data) {
|
||||
return THPModule_toDLPackImpl<DLManagedTensorVersioned>(_unused, data);
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
return THPModule_toDLPackImpl<DLManagedTensorVersioned>(self, args, kwargs);
|
||||
}
|
||||
|
||||
static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
|
||||
@ -636,6 +667,28 @@ static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_torchDeviceToDLDevice(
|
||||
PyObject* _unused,
|
||||
PyObject* data) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
THPDevice_Check(data),
|
||||
"torchDeviceToDLDevice: expected torch.device argument.");
|
||||
auto device = reinterpret_cast<THPDevice*>(data)->device;
|
||||
auto dl_device = at::torchDeviceToDLDevice(device);
|
||||
|
||||
auto tuple = PyTuple_New(2);
|
||||
if (!tuple) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
PyTuple_SET_ITEM(tuple, 0, THPUtils_packInt64(dl_device.device_type));
|
||||
PyTuple_SET_ITEM(tuple, 1, THPUtils_packInt64(dl_device.device_id));
|
||||
|
||||
return tuple;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
size_t frames_to_skip = 0;
|
||||
@ -1687,9 +1740,19 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
THPModule_are_vmap_fallback_warnings_enabled,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_to_dlpack", THPModule_toDLPack, METH_O, nullptr},
|
||||
{"_to_dlpack_versioned", THPModule_toDLPackVersioned, METH_O, nullptr},
|
||||
{"_to_dlpack",
|
||||
castPyCFunctionWithKeywords(THPModule_toDLPack),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_to_dlpack_versioned",
|
||||
castPyCFunctionWithKeywords(THPModule_toDLPackVersioned),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr},
|
||||
{"_torchDeviceToDLDevice",
|
||||
THPModule_torchDeviceToDLDevice,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr},
|
||||
{"_rename_privateuse1_backend",
|
||||
THModule_rename_privateuse1_backend,
|
||||
|
@ -1512,7 +1512,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
Tensor.view: lambda self, shape: -1,
|
||||
Tensor.view_as: lambda self, other: -1,
|
||||
Tensor.zero_: lambda self: -1,
|
||||
Tensor.__dlpack__: lambda self, stream=None, max_version=None: -1,
|
||||
Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1,
|
||||
Tensor.__dlpack_device__: lambda self: -1,
|
||||
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
|
||||
} # fmt: skip
|
||||
|
@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import enum
|
||||
|
||||
from torch._C import _to_dlpack as to_dlpack
|
||||
from torch.types import Device as _Device
|
||||
|
||||
__all__ = [
|
||||
"DLDeviceType",
|
||||
@ -54,7 +55,12 @@ The DLPack capsule shares the tensor's memory.
|
||||
|
||||
# TODO: add a typing.Protocol to be able to tell Mypy that only objects with
|
||||
# __dlpack__ and __dlpack_device__ methods are accepted.
|
||||
def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
|
||||
def from_dlpack(
|
||||
ext_tensor: Any,
|
||||
*,
|
||||
device: Optional[_Device] = None,
|
||||
copy: Optional[bool] = None
|
||||
) -> 'torch.Tensor':
|
||||
"""from_dlpack(ext_tensor) -> Tensor
|
||||
|
||||
Converts a tensor from an external library into a ``torch.Tensor``.
|
||||
@ -76,6 +82,13 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
|
||||
an opaque ``PyCapsule`` instance, typically produced by a
|
||||
``to_dlpack`` function or method.
|
||||
|
||||
device (torch.device or str or None): An optional PyTorch device
|
||||
specifying where to place the new tensor. If None (default), the
|
||||
new tensor will be on the same device as ``ext_tensor``.
|
||||
|
||||
copy (bool or None): An optional boolean indicating whether or not to copy
|
||||
``self``. If None, PyTorch will copy only if necessary.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.utils.dlpack
|
||||
@ -106,20 +119,36 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
|
||||
|
||||
"""
|
||||
if hasattr(ext_tensor, '__dlpack__'):
|
||||
# Only populate kwargs if any of the optional arguments are, in fact, not None. Otherwise,
|
||||
# leave them out, since we might end up falling back to no-extra-kwargs __dlpack__ call.
|
||||
kwargs: dict[str, Any] = {}
|
||||
kwargs["max_version"] = (1, 0)
|
||||
|
||||
device = ext_tensor.__dlpack_device__()
|
||||
# device is either CUDA or ROCm, we need to pass the current
|
||||
if copy is not None:
|
||||
kwargs["copy"] = copy
|
||||
|
||||
# Parse the device parameter.
|
||||
# At this moment, it can either be a torch.device or a str representing
|
||||
# a torch.device, e.g. "cpu", "cuda", etc.
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert isinstance(device, torch.device), (
|
||||
f"from_dlpack: unsupported device type: {type(device)}"
|
||||
)
|
||||
kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device)
|
||||
|
||||
ext_device = ext_tensor.__dlpack_device__()
|
||||
# ext_device is either CUDA or ROCm, we need to pass the current
|
||||
# stream
|
||||
if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM):
|
||||
stream = torch.cuda.current_stream(f'cuda:{device[1]}')
|
||||
if ext_device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM):
|
||||
stream = torch.cuda.current_stream(f'cuda:{ext_device[1]}')
|
||||
# cuda_stream is the pointer to the stream and it is a public
|
||||
# attribute, but it is not documented
|
||||
# The array API specify that the default legacy stream must be passed
|
||||
# with a value of 1 for CUDA
|
||||
# https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none
|
||||
is_cuda = device[0] == DLDeviceType.kDLCUDA
|
||||
is_cuda = ext_device[0] == DLDeviceType.kDLCUDA
|
||||
# Since pytorch is not using PTDS by default, lets directly pass
|
||||
# the legacy stream
|
||||
stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream
|
||||
@ -134,6 +163,10 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
|
||||
dlpack = ext_tensor.__dlpack__(**kwargs)
|
||||
|
||||
else:
|
||||
assert device is None and copy is None, (
|
||||
"device and copy kwargs not supported when ext_tensor is "
|
||||
"already a DLPack capsule."
|
||||
)
|
||||
# Old versions just call the converter
|
||||
dlpack = ext_tensor
|
||||
return torch._C._from_dlpack(dlpack)
|
||||
|
Reference in New Issue
Block a user