Fix DLPack stream logic. (#150217)

This PR fixes the logic for dealing with CUDA and ROCm streams whenever
we are trying to create a DLPack capsule from a tensor.

In summary, this PR:

- Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is
  called for a CUDA tensor.
- Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor:
  PyTorch doesn't support the per-thread default stream.
- Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or
  2, is called for a CUDA tensor using ROCm.

For more details, see [the documentation][1].

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150217
Approved by: https://github.com/msaroufim, https://github.com/albanD
ghstack dependencies: #150216
This commit is contained in:
Yukio Siraichi
2025-07-19 16:36:07 -03:00
committed by PyTorch MergeBot
parent b64f338da4
commit 1d526fe78f
2 changed files with 97 additions and 17 deletions

View File

@ -3,11 +3,13 @@
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
dtypes,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
skipCUDAIfNotRocm,
skipCUDAIfRocm,
skipMeta,
)
@ -242,6 +244,62 @@ class TestTorchDlPack(TestCase):
x = make_tensor((5,), dtype=dtype, device=device)
x.__dlpack__(stream=object())
@skipMeta
@onlyCUDA
@skipCUDAIfRocm
def test_dlpack_cuda_per_thread_stream(self, device):
# Test whether we raise an error if we are trying to use per-thread default
# stream, which is currently not supported by PyTorch.
x = make_tensor((5,), dtype=torch.float32, device=device)
with self.assertRaisesRegex(
BufferError, "per-thread default stream is not supported"
):
x.__dlpack__(stream=2)
@skipMeta
@onlyCUDA
@skipCUDAIfNotRocm
def test_dlpack_invalid_rocm_streams(self, device):
# Test that we correctly raise errors on unsupported ROCm streams.
def test(x, stream):
with self.assertRaisesRegex(
AssertionError, r"unsupported stream on ROCm: \d"
):
x.__dlpack__(stream=stream)
x = make_tensor((5,), dtype=torch.float32, device=device)
test(x, stream=1)
test(x, stream=2)
@skipMeta
@onlyCUDA
@skipCUDAIfRocm
def test_dlpack_invalid_cuda_streams(self, device):
x = make_tensor((5,), dtype=torch.float32, device=device)
with self.assertRaisesRegex(AssertionError, r"unsupported stream on CUDA: \d"):
x.__dlpack__(stream=0)
@skipMeta
def test_dlpack_invalid_cpu_stream(self):
x = make_tensor((5,), dtype=torch.float32, device="cpu")
with self.assertRaisesRegex(AssertionError, r"stream should be None on cpu."):
x.__dlpack__(stream=0)
@skipMeta
@onlyCUDA
@deviceCountAtLeast(2)
def test_dlpack_tensor_on_different_device(self, devices):
dev0, dev1 = devices[:2]
with torch.device(dev0):
x = make_tensor((5,), dtype=torch.float32, device=dev0)
with self.assertRaisesRegex(
BufferError, r"Can't export tensors on a different CUDA device"
):
with torch.device(dev1):
x.__dlpack__()
# TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
@skipMeta
def test_dlpack_export_requires_grad(self):

View File

@ -1703,27 +1703,49 @@ class Tensor(torch._C.TensorBase):
"Can't export tensors with layout other than torch.strided"
)
if (
self.device.type == "cuda"
and self.device.index != torch.cuda.current_device()
):
raise BufferError(
"Can't export tensors on a different CUDA device. "
f"Expected: {self.device}. "
f"Current device: {torch.cuda.current_device()}."
)
if stream is not None and type(stream) is not int:
# Stream pointers in CUDA/ROCm are uniquely numbered and can
# be retrieved from their integer value.
raise TypeError("stream must be ``int`` or ``none``")
elif stream is not None and stream != -1:
if self.device.type == "cuda":
# NB: This logic handles the special case values for default
# streams and must be kept in sync with from_dlpack in
# torch/utils/dlpack.py
if stream == 1 and torch.version.hip is None:
stream = torch.cuda.default_stream()
elif stream == 0 and torch.version.hip is not None:
stream = torch.cuda.default_stream()
else:
stream = torch.cuda.ExternalStream(stream)
# Only synchronize on different streams
sync_stream = torch.cuda.current_stream()
if stream != sync_stream:
event = torch.cuda.Event()
event.record(sync_stream)
stream.wait_event(event)
elif self.device.type == "cuda" and stream != -1:
# NB: This logic handles the special case values for default
# streams and must be kept in sync with from_dlpack in
# torch/utils/dlpack.py
is_rocm = torch.version.hip is not None
is_cuda = not is_rocm
if stream is None or (is_rocm and stream == 0) or (is_cuda and stream == 1):
stream = torch.cuda.default_stream()
else:
if is_cuda and stream == 2:
raise BufferError("per-thread default stream is not supported.")
device_str = "CUDA" if is_cuda else "ROCm"
assert (is_cuda and stream != 0) or (
is_rocm and stream not in (1, 2)
), f"unsupported stream on {device_str}: {stream}."
stream = torch.cuda.ExternalStream(stream)
# Only synchronize on different streams
current_stream = torch.cuda.current_stream()
if stream != current_stream:
event = torch.cuda.Event()
event.record(current_stream)
stream.wait_event(event)
elif self.device.type == "cpu":
assert stream is None, "stream should be None on cpu."
if self.device.type == "xla":
import torch_xla
import torch_xla.utils.dlpack as xla_dlpack