mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
error message for instantiating CUDA Stream if CUDA not available (#159868)
Fixes #159744 Summary: ``` import torch # Generate input data input_tensor = torch.randn(3, 3) stream = torch.cuda.Stream() # Call the API input_tensor.record_stream(stream) ``` ⚠️ will now show an error message `torch.cuda.Stream requires CUDA support` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159868 Approved by: https://github.com/malfet, https://github.com/isuruf
This commit is contained in:
committed by
PyTorch MergeBot
parent
8d49cd5b26
commit
df26c51478
@ -10495,7 +10495,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
def test_no_cuda_monkeypatch(self):
|
def test_no_cuda_monkeypatch(self):
|
||||||
# Note that this is not in test_cuda.py as this whole file is skipped when cuda
|
# Note that this is not in test_cuda.py as this whole file is skipped when cuda
|
||||||
# is not available.
|
# is not available.
|
||||||
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Stream"):
|
with self.assertRaisesRegex(RuntimeError, "torch.cuda.Stream requires CUDA support"):
|
||||||
torch.cuda.Stream()
|
torch.cuda.Stream()
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"):
|
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"):
|
||||||
|
@ -32,6 +32,9 @@ class Stream(torch._C._CudaStreamBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, device=None, priority=0, **kwargs):
|
def __new__(cls, device=None, priority=0, **kwargs):
|
||||||
|
# Check CUDA availability
|
||||||
|
if not torch.backends.cuda.is_built():
|
||||||
|
raise RuntimeError("torch.cuda.Stream requires CUDA support")
|
||||||
# setting device manager is expensive, so we avoid it unless necessary
|
# setting device manager is expensive, so we avoid it unless necessary
|
||||||
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
||||||
return super().__new__(cls, priority=priority, **kwargs)
|
return super().__new__(cls, priority=priority, **kwargs)
|
||||||
|
Reference in New Issue
Block a user