Upgrade to DLPack 1.0. (#145000)

This PR makes the necessary changes in order to upgrade PyTorch DLPack
support to version 1.0. In summary, we add support for the following:

- Support both `DLManagedTensor` and `DLManagedTensorVersioned` when
  producing and consuming DLPack capsules
- New parameter for `__dlpack__` method: `max_version`
- Version checks:
    - Fallback to old implementation if no `max_version` or if version
      lower than 1.0
    - Check that the to-be-consumed capsule is of version up to 1.X

In order to accommodate these new specifications, this PR adds the
following main changes:

- `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python
API for creating a versioned DLPack capsule (called by `__dlpack__`
method)
- `DLPackTraits<T>` class (DLConvertor.h): select the correct
traits (e.g. capsule name, conversion functions) depending on which
DLPack tensor class is being used
- `toDLPackImpl<T>` function (DLConvertor.cpp): populates the
common fields of both classes
- `fromDLPackImpl<T>` function (DLConvertor.cpp): constructs a tensor
from a DLPAck capsule
- `fillVersion<T>` function (DLConvertor.cpp): populates the version
field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`)
- `tensor_fromDLPackImpl<T>` function (tensor_new.cpp): outer function
for constructing a tensor out of a DLPack capsule that also marks the
capsule as used

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145000
Approved by: https://github.com/albanD
This commit is contained in:
Yukio Siraichi
2025-05-30 17:54:58 -03:00
committed by PyTorch MergeBot
parent 6eb6f198e1
commit 6e185c5312
12 changed files with 460 additions and 89 deletions

View File

@ -11,7 +11,12 @@ from torch.testing._internal.common_device_type import (
skipMeta,
)
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
from torch.testing._internal.common_utils import (
IS_JETSON,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.utils.dlpack import from_dlpack, to_dlpack
@ -164,7 +169,7 @@ class TestTorchDlPack(TestCase):
# in the current stream to make sure that it was correctly populated.
with torch.cuda.stream(stream_a):
x = make_tensor((5,), dtype=dtype, device=device) + 1
z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream))
stream_a.synchronize()
stream_b.synchronize()
self.assertEqual(z, x)
@ -201,7 +206,7 @@ class TestTorchDlPack(TestCase):
assert stream == 1
else:
assert stream == 0
capsule = self.tensor.__dlpack__(stream)
capsule = self.tensor.__dlpack__(stream=stream)
return capsule
# CUDA-based tests runs on non-default streams
@ -224,7 +229,7 @@ class TestTorchDlPack(TestCase):
x = torch.zeros(1, device=device)
torch.cuda._sleep(2**20)
self.assertTrue(torch.cuda.default_stream().query())
x.__dlpack__(1)
x.__dlpack__(stream=1)
# check that the default stream has work (a pending cudaStreamWaitEvent)
self.assertFalse(torch.cuda.default_stream().query())
@ -281,6 +286,37 @@ class TestTorchDlPack(TestCase):
new_tensor = torch.tensor(wrap)
self.assertEqual(tensor, new_tensor)
@skipMeta
@skipIfTorchDynamo("__dlpack__ doesn't work with dynamo")
@onlyNativeDeviceTypes
def test_max_version(self, device):
def capsule_name(kwargs):
is_versioned = "max_version" in kwargs and kwargs["max_version"][0] >= 1
return "dltensor_versioned" if is_versioned else "dltensor"
def test(device, **kwargs):
inp = make_tensor((5,), dtype=torch.float32, device=device)
# Make sure we are actually using the (un)versioned DLPack tensor, based on the
# informed keyword arguments.
capsule = inp.__dlpack__(**kwargs)
self.assertRegex(
str(capsule), f"""capsule object "{capsule_name(kwargs)}" at"""
)
out = torch.from_dlpack(capsule)
self.assertEqual(inp, out)
# Use the DLPack 0.X version implementation, since max_version=None.
test(device)
# Use the DLPack 0.X version implementation.
test(device, max_version=(0, 8))
# Current highest DLPack version implemented.
test(device, max_version=(1, 0))
# Newer DLPack version.
# Consumer should still be able to process a smaller version capsule.
test(device, max_version=(2, 0))
instantiate_device_type_tests(TestTorchDlPack, globals())