mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Upgrade to DLPack 1.0. (#145000)
This PR makes the necessary changes in order to upgrade PyTorch DLPack support to version 1.0. In summary, we add support for the following: - Support both `DLManagedTensor` and `DLManagedTensorVersioned` when producing and consuming DLPack capsules - New parameter for `__dlpack__` method: `max_version` - Version checks: - Fallback to old implementation if no `max_version` or if version lower than 1.0 - Check that the to-be-consumed capsule is of version up to 1.X In order to accommodate these new specifications, this PR adds the following main changes: - `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python API for creating a versioned DLPack capsule (called by `__dlpack__` method) - `DLPackTraits<T>` class (DLConvertor.h): select the correct traits (e.g. capsule name, conversion functions) depending on which DLPack tensor class is being used - `toDLPackImpl<T>` function (DLConvertor.cpp): populates the common fields of both classes - `fromDLPackImpl<T>` function (DLConvertor.cpp): constructs a tensor from a DLPAck capsule - `fillVersion<T>` function (DLConvertor.cpp): populates the version field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`) - `tensor_fromDLPackImpl<T>` function (tensor_new.cpp): outer function for constructing a tensor out of a DLPack capsule that also marks the capsule as used Pull Request resolved: https://github.com/pytorch/pytorch/pull/145000 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6eb6f198e1
commit
6e185c5312
@ -11,7 +11,12 @@ from torch.testing._internal.common_device_type import (
|
||||
skipMeta,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_JETSON,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
||||
|
||||
@ -164,7 +169,7 @@ class TestTorchDlPack(TestCase):
|
||||
# in the current stream to make sure that it was correctly populated.
|
||||
with torch.cuda.stream(stream_a):
|
||||
x = make_tensor((5,), dtype=dtype, device=device) + 1
|
||||
z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
|
||||
z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream))
|
||||
stream_a.synchronize()
|
||||
stream_b.synchronize()
|
||||
self.assertEqual(z, x)
|
||||
@ -201,7 +206,7 @@ class TestTorchDlPack(TestCase):
|
||||
assert stream == 1
|
||||
else:
|
||||
assert stream == 0
|
||||
capsule = self.tensor.__dlpack__(stream)
|
||||
capsule = self.tensor.__dlpack__(stream=stream)
|
||||
return capsule
|
||||
|
||||
# CUDA-based tests runs on non-default streams
|
||||
@ -224,7 +229,7 @@ class TestTorchDlPack(TestCase):
|
||||
x = torch.zeros(1, device=device)
|
||||
torch.cuda._sleep(2**20)
|
||||
self.assertTrue(torch.cuda.default_stream().query())
|
||||
x.__dlpack__(1)
|
||||
x.__dlpack__(stream=1)
|
||||
# check that the default stream has work (a pending cudaStreamWaitEvent)
|
||||
self.assertFalse(torch.cuda.default_stream().query())
|
||||
|
||||
@ -281,6 +286,37 @@ class TestTorchDlPack(TestCase):
|
||||
new_tensor = torch.tensor(wrap)
|
||||
self.assertEqual(tensor, new_tensor)
|
||||
|
||||
@skipMeta
|
||||
@skipIfTorchDynamo("__dlpack__ doesn't work with dynamo")
|
||||
@onlyNativeDeviceTypes
|
||||
def test_max_version(self, device):
|
||||
def capsule_name(kwargs):
|
||||
is_versioned = "max_version" in kwargs and kwargs["max_version"][0] >= 1
|
||||
return "dltensor_versioned" if is_versioned else "dltensor"
|
||||
|
||||
def test(device, **kwargs):
|
||||
inp = make_tensor((5,), dtype=torch.float32, device=device)
|
||||
|
||||
# Make sure we are actually using the (un)versioned DLPack tensor, based on the
|
||||
# informed keyword arguments.
|
||||
capsule = inp.__dlpack__(**kwargs)
|
||||
self.assertRegex(
|
||||
str(capsule), f"""capsule object "{capsule_name(kwargs)}" at"""
|
||||
)
|
||||
|
||||
out = torch.from_dlpack(capsule)
|
||||
self.assertEqual(inp, out)
|
||||
|
||||
# Use the DLPack 0.X version implementation, since max_version=None.
|
||||
test(device)
|
||||
# Use the DLPack 0.X version implementation.
|
||||
test(device, max_version=(0, 8))
|
||||
# Current highest DLPack version implemented.
|
||||
test(device, max_version=(1, 0))
|
||||
# Newer DLPack version.
|
||||
# Consumer should still be able to process a smaller version capsule.
|
||||
test(device, max_version=(2, 0))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTorchDlPack, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user