mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
error message for instantiating CUDA Stream if CUDA not available (#159868)
Fixes #159744 Summary: ``` import torch # Generate input data input_tensor = torch.randn(3, 3) stream = torch.cuda.Stream() # Call the API input_tensor.record_stream(stream) ``` ⚠️ will now show an error message `torch.cuda.Stream requires CUDA support` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159868 Approved by: https://github.com/malfet, https://github.com/isuruf
This commit is contained in:
committed by
PyTorch MergeBot
parent
8d49cd5b26
commit
df26c51478
@ -10495,7 +10495,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
def test_no_cuda_monkeypatch(self):
|
||||
# Note that this is not in test_cuda.py as this whole file is skipped when cuda
|
||||
# is not available.
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Stream"):
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.cuda.Stream requires CUDA support"):
|
||||
torch.cuda.Stream()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"):
|
||||
|
@ -32,6 +32,9 @@ class Stream(torch._C._CudaStreamBase):
|
||||
"""
|
||||
|
||||
def __new__(cls, device=None, priority=0, **kwargs):
|
||||
# Check CUDA availability
|
||||
if not torch.backends.cuda.is_built():
|
||||
raise RuntimeError("torch.cuda.Stream requires CUDA support")
|
||||
# setting device manager is expensive, so we avoid it unless necessary
|
||||
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
||||
return super().__new__(cls, priority=priority, **kwargs)
|
||||
|
Reference in New Issue
Block a user