mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DLPack] add NumPy exchange tests. (#150216)
This PR resolves an old TODO that requested NumPy DLPack exchange tests once version 1.22 was required. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150216 Approved by: https://github.com/msaroufim, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a1cfe7f1df
commit
b64f338da4
@ -5,6 +5,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
skipCUDAIfRocm,
|
||||
@ -317,6 +318,40 @@ class TestTorchDlPack(TestCase):
|
||||
# Consumer should still be able to process a smaller version capsule.
|
||||
test(device, max_version=(2, 0))
|
||||
|
||||
@skipMeta
|
||||
@onlyCPU
|
||||
@dtypes(
|
||||
# Note: NumPy DLPack bool support only landed in 1.25.
|
||||
*all_types_and_complex_and(
|
||||
torch.half,
|
||||
torch.uint16,
|
||||
torch.uint32,
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
def test_numpy_dlpack_protocol_conversion(self, device, dtype):
|
||||
import numpy as np
|
||||
|
||||
t = make_tensor((5,), dtype=dtype, device=device)
|
||||
|
||||
if hasattr(np, "from_dlpack"):
|
||||
# DLPack support only available from NumPy 1.22 onwards.
|
||||
# Here, we test having another framework (NumPy) calling our
|
||||
# Tensor.__dlpack__ implementation.
|
||||
arr = np.from_dlpack(t)
|
||||
self.assertEqual(t, arr)
|
||||
|
||||
# We can't use the array created above as input to from_dlpack.
|
||||
# That's because DLPack imported NumPy arrays are read-only.
|
||||
# Thus, we need to convert it to NumPy by using the numpy() method.
|
||||
t_arr = t.numpy()
|
||||
|
||||
# Transform the NumPy array back using DLPack.
|
||||
res = from_dlpack(t_arr)
|
||||
|
||||
self.assertEqual(t, res)
|
||||
self.assertEqual(t.data_ptr(), res.data_ptr())
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTorchDlPack, globals())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user