mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Upgrades dlpack to v1.1 to include fp8/fp4 (#162195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162195 Approved by: https://github.com/eqy, https://github.com/albanD, https://github.com/Skylion007, https://github.com/rgommers
This commit is contained in:
committed by
PyTorch MergeBot
parent
f2206b1ed8
commit
928ac57c2a
@ -93,15 +93,7 @@ class TestTorchDlPack(TestCase):
|
||||
z[0] = z[0] + 20.0
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
def test_dlpack_conversion_with_streams(self, device, dtype):
|
||||
# Create a stream where the tensor will reside
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# Do an operation in the actual stream
|
||||
x = make_tensor((5,), dtype=dtype, device=device) + 1
|
||||
def _dlpack_conversion_with_streams(self, stream, x):
|
||||
# DLPack protocol helps establish a correct stream order
|
||||
# (hence data dependency) at the exchange boundary.
|
||||
# DLPack manages this synchronization for us, so we don't need to
|
||||
@ -114,8 +106,38 @@ class TestTorchDlPack(TestCase):
|
||||
with torch.cuda.stream(stream):
|
||||
z = from_dlpack(x)
|
||||
stream.synchronize()
|
||||
return z
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
def test_dlpack_conversion_with_streams(self, device, dtype):
|
||||
# Create a stream where the tensor will reside
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# Do an operation in the actual stream
|
||||
x = make_tensor((5,), dtype=dtype, device=device) + 1
|
||||
z = self._dlpack_conversion_with_streams(stream, x)
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
torch.float4_e2m1fn_x2,
|
||||
)
|
||||
def test_dlpack_conversion_with_streams_narrow_precision(self, device, dtype):
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
|
||||
x = x.view(dtype)
|
||||
z = self._dlpack_conversion_with_streams(stream, x)
|
||||
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(
|
||||
@ -187,6 +209,27 @@ class TestTorchDlPack(TestCase):
|
||||
stream_b.synchronize()
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
torch.float4_e2m1fn_x2,
|
||||
)
|
||||
def test_dlpack_conversion_with_diff_streams_narrow_precision(self, device, dtype):
|
||||
stream_a = torch.cuda.Stream()
|
||||
stream_b = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream_a):
|
||||
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
|
||||
x = x.view(dtype)
|
||||
z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream))
|
||||
stream_a.synchronize()
|
||||
stream_b.synchronize()
|
||||
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(
|
||||
@ -484,9 +527,7 @@ class TestTorchDlPack(TestCase):
|
||||
@skipMeta
|
||||
@onlyCPU
|
||||
def test_dlpack_unsupported_dtype_error(self, device):
|
||||
inp = make_tensor((5,), dtype=torch.float32, device=device).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
inp = torch.quantize_per_tensor(torch.randn(()), 0.1, 10, torch.qint8)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
BufferError, ".* types are not supported by dlpack"
|
||||
|
Reference in New Issue
Block a user