mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix DLPack stream logic. (#150217)
This PR fixes the logic for dealing with CUDA and ROCm streams whenever we are trying to create a DLPack capsule from a tensor. In summary, this PR: - Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is called for a CUDA tensor. - Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor: PyTorch doesn't support the per-thread default stream. - Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or 2, is called for a CUDA tensor using ROCm. For more details, see [the documentation][1]. [1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/150217 Approved by: https://github.com/msaroufim, https://github.com/albanD ghstack dependencies: #150216
This commit is contained in:
committed by
PyTorch MergeBot
parent
b64f338da4
commit
1d526fe78f
@ -3,11 +3,13 @@
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
deviceCountAtLeast,
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
skipCUDAIfNotRocm,
|
||||
skipCUDAIfRocm,
|
||||
skipMeta,
|
||||
)
|
||||
@ -242,6 +244,62 @@ class TestTorchDlPack(TestCase):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
x.__dlpack__(stream=object())
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_dlpack_cuda_per_thread_stream(self, device):
|
||||
# Test whether we raise an error if we are trying to use per-thread default
|
||||
# stream, which is currently not supported by PyTorch.
|
||||
x = make_tensor((5,), dtype=torch.float32, device=device)
|
||||
with self.assertRaisesRegex(
|
||||
BufferError, "per-thread default stream is not supported"
|
||||
):
|
||||
x.__dlpack__(stream=2)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
def test_dlpack_invalid_rocm_streams(self, device):
|
||||
# Test that we correctly raise errors on unsupported ROCm streams.
|
||||
def test(x, stream):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, r"unsupported stream on ROCm: \d"
|
||||
):
|
||||
x.__dlpack__(stream=stream)
|
||||
|
||||
x = make_tensor((5,), dtype=torch.float32, device=device)
|
||||
test(x, stream=1)
|
||||
test(x, stream=2)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_dlpack_invalid_cuda_streams(self, device):
|
||||
x = make_tensor((5,), dtype=torch.float32, device=device)
|
||||
with self.assertRaisesRegex(AssertionError, r"unsupported stream on CUDA: \d"):
|
||||
x.__dlpack__(stream=0)
|
||||
|
||||
@skipMeta
|
||||
def test_dlpack_invalid_cpu_stream(self):
|
||||
x = make_tensor((5,), dtype=torch.float32, device="cpu")
|
||||
with self.assertRaisesRegex(AssertionError, r"stream should be None on cpu."):
|
||||
x.__dlpack__(stream=0)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
def test_dlpack_tensor_on_different_device(self, devices):
|
||||
dev0, dev1 = devices[:2]
|
||||
|
||||
with torch.device(dev0):
|
||||
x = make_tensor((5,), dtype=torch.float32, device=dev0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
BufferError, r"Can't export tensors on a different CUDA device"
|
||||
):
|
||||
with torch.device(dev1):
|
||||
x.__dlpack__()
|
||||
|
||||
# TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
|
||||
@skipMeta
|
||||
def test_dlpack_export_requires_grad(self):
|
||||
|
@ -1703,27 +1703,49 @@ class Tensor(torch._C.TensorBase):
|
||||
"Can't export tensors with layout other than torch.strided"
|
||||
)
|
||||
|
||||
if (
|
||||
self.device.type == "cuda"
|
||||
and self.device.index != torch.cuda.current_device()
|
||||
):
|
||||
raise BufferError(
|
||||
"Can't export tensors on a different CUDA device. "
|
||||
f"Expected: {self.device}. "
|
||||
f"Current device: {torch.cuda.current_device()}."
|
||||
)
|
||||
|
||||
if stream is not None and type(stream) is not int:
|
||||
# Stream pointers in CUDA/ROCm are uniquely numbered and can
|
||||
# be retrieved from their integer value.
|
||||
raise TypeError("stream must be ``int`` or ``none``")
|
||||
elif stream is not None and stream != -1:
|
||||
if self.device.type == "cuda":
|
||||
# NB: This logic handles the special case values for default
|
||||
# streams and must be kept in sync with from_dlpack in
|
||||
# torch/utils/dlpack.py
|
||||
if stream == 1 and torch.version.hip is None:
|
||||
stream = torch.cuda.default_stream()
|
||||
elif stream == 0 and torch.version.hip is not None:
|
||||
stream = torch.cuda.default_stream()
|
||||
else:
|
||||
stream = torch.cuda.ExternalStream(stream)
|
||||
# Only synchronize on different streams
|
||||
sync_stream = torch.cuda.current_stream()
|
||||
if stream != sync_stream:
|
||||
event = torch.cuda.Event()
|
||||
event.record(sync_stream)
|
||||
stream.wait_event(event)
|
||||
elif self.device.type == "cuda" and stream != -1:
|
||||
# NB: This logic handles the special case values for default
|
||||
# streams and must be kept in sync with from_dlpack in
|
||||
# torch/utils/dlpack.py
|
||||
is_rocm = torch.version.hip is not None
|
||||
is_cuda = not is_rocm
|
||||
|
||||
if stream is None or (is_rocm and stream == 0) or (is_cuda and stream == 1):
|
||||
stream = torch.cuda.default_stream()
|
||||
else:
|
||||
if is_cuda and stream == 2:
|
||||
raise BufferError("per-thread default stream is not supported.")
|
||||
|
||||
device_str = "CUDA" if is_cuda else "ROCm"
|
||||
assert (is_cuda and stream != 0) or (
|
||||
is_rocm and stream not in (1, 2)
|
||||
), f"unsupported stream on {device_str}: {stream}."
|
||||
|
||||
stream = torch.cuda.ExternalStream(stream)
|
||||
|
||||
# Only synchronize on different streams
|
||||
current_stream = torch.cuda.current_stream()
|
||||
if stream != current_stream:
|
||||
event = torch.cuda.Event()
|
||||
event.record(current_stream)
|
||||
stream.wait_event(event)
|
||||
elif self.device.type == "cpu":
|
||||
assert stream is None, "stream should be None on cpu."
|
||||
|
||||
if self.device.type == "xla":
|
||||
import torch_xla
|
||||
import torch_xla.utils.dlpack as xla_dlpack
|
||||
|
Reference in New Issue
Block a user