Upgrades dlpack to v1.1 to include fp8/fp4 (#162195)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162195
Approved by: https://github.com/eqy, https://github.com/albanD, https://github.com/Skylion007, https://github.com/rgommers
This commit is contained in:
Syed Tousif Ahmed
2025-09-17 00:20:48 -07:00
committed by PyTorch MergeBot
parent f2206b1ed8
commit 928ac57c2a
3 changed files with 176 additions and 22 deletions

View File

@ -93,15 +93,7 @@ class TestTorchDlPack(TestCase):
z[0] = z[0] + 20.0
self.assertEqual(z, x)
@skipMeta
@onlyCUDA
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_dlpack_conversion_with_streams(self, device, dtype):
# Create a stream where the tensor will reside
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Do an operation in the actual stream
x = make_tensor((5,), dtype=dtype, device=device) + 1
def _dlpack_conversion_with_streams(self, stream, x):
# DLPack protocol helps establish a correct stream order
# (hence data dependency) at the exchange boundary.
# DLPack manages this synchronization for us, so we don't need to
@ -114,8 +106,38 @@ class TestTorchDlPack(TestCase):
with torch.cuda.stream(stream):
z = from_dlpack(x)
stream.synchronize()
return z
@skipMeta
@onlyCUDA
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_dlpack_conversion_with_streams(self, device, dtype):
# Create a stream where the tensor will reside
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Do an operation in the actual stream
x = make_tensor((5,), dtype=dtype, device=device) + 1
z = self._dlpack_conversion_with_streams(stream, x)
self.assertEqual(z, x)
@skipMeta
@onlyCUDA
@dtypes(
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.float4_e2m1fn_x2,
)
def test_dlpack_conversion_with_streams_narrow_precision(self, device, dtype):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
x = x.view(dtype)
z = self._dlpack_conversion_with_streams(stream, x)
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
@skipMeta
@onlyNativeDeviceTypes
@dtypes(
@ -187,6 +209,27 @@ class TestTorchDlPack(TestCase):
stream_b.synchronize()
self.assertEqual(z, x)
@skipMeta
@onlyCUDA
@dtypes(
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.float4_e2m1fn_x2,
)
def test_dlpack_conversion_with_diff_streams_narrow_precision(self, device, dtype):
stream_a = torch.cuda.Stream()
stream_b = torch.cuda.Stream()
with torch.cuda.stream(stream_a):
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
x = x.view(dtype)
z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream))
stream_a.synchronize()
stream_b.synchronize()
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
@skipMeta
@onlyNativeDeviceTypes
@dtypes(
@ -484,9 +527,7 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyCPU
def test_dlpack_unsupported_dtype_error(self, device):
inp = make_tensor((5,), dtype=torch.float32, device=device).to(
torch.float8_e4m3fn
)
inp = torch.quantize_per_tensor(torch.randn(()), 0.1, 10, torch.qint8)
with self.assertRaisesRegex(
BufferError, ".* types are not supported by dlpack"