Fix torch.accelerator api abort when passing invaild device (#143550)

# Motivation
Fix https://github.com/pytorch/pytorch/issues/143543

# Solution
We should raise python exception instead of aborting...

# Additional Context
without this PR:
```python
>>> import torch
>>> torch.accelerator.current_stream(torch.accelerator.device_count())
terminate called after throwing an instance of 'c10::Error'
  what():  device is out of range, device is 2, total number of device is 2.
Exception raised from check_device_index at /home/dvrogozh/git/pytorch/pytorch/c10/xpu/XPUFunctions.h:36 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xac (0x7f30707eb95c in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f307078fc57 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so)
frame #2: <unknown function> + 0x19a3e (0x7f3070c2ba3e in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so)
frame #3: c10::xpu::getCurrentXPUStream(signed char) + 0x2f (0x7f3070c2c83f in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so)
frame #4: <unknown function> + 0x1ca35 (0x7f3070c2ea35 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so)
frame #5: <unknown function> + 0x653f15 (0x7f3083391f15 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x39e5f2 (0x7f30830dc5f2 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so)
<omitting python frames>
frame #20: <unknown function> + 0x29d90 (0x7f308b19bd90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #21: __libc_start_main + 0x80 (0x7f308b19be40 in /lib/x86_64-linux-gnu/libc.so.6)

Aborted (core dumped)
```
with this PR:
```python
>>> import torch
>>> torch.accelerator.current_stream(torch.accelerator.device_count())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/pt-gpu/4T-4652/guangyey/stock-pytorch/torch/accelerator/__init__.py", line 123, in current_stream
    return torch._C._accelerator_getStream(device_index)
RuntimeError: The device index is out of range. It must be in [0, 2), but got 2.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143550
Approved by: https://github.com/EikanWang, https://github.com/dvrogozh, https://github.com/albanD
This commit is contained in:
Yu, Guangye
2024-12-19 13:44:10 +00:00
committed by PyTorch MergeBot
parent eebc93d41e
commit 07fa6e2c8b
9 changed files with 23 additions and 17 deletions

View File

@ -82,7 +82,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
void uncheckedSetDevice(Device d) const noexcept override { void uncheckedSetDevice(Device d) const noexcept override {
C10_HIP_CHECK_WARN(hipSetDevice(d.index())); C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
} }
Stream getStream(Device d) const noexcept override { Stream getStream(Device d) const override {
return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap(); return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
} }
Stream getDefaultStream(Device d) const override { Stream getDefaultStream(Device d) const override {
@ -94,7 +94,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override { Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index()); return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
} }
Stream exchangeStream(Stream s) const noexcept override { Stream exchangeStream(Stream s) const override {
HIPStreamMasqueradingAsCUDA cs(s); HIPStreamMasqueradingAsCUDA cs(s);
auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index()); auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
setCurrentHIPStreamMasqueradingAsCUDA(cs); setCurrentHIPStreamMasqueradingAsCUDA(cs);

View File

@ -64,7 +64,7 @@ struct TORCH_API MPSGuardImpl final
// TODO: Currently setting only device 0 // TODO: Currently setting only device 0
} }
Stream getStream(Device d) const noexcept override { Stream getStream(Device d) const override {
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
} }
@ -78,7 +78,7 @@ struct TORCH_API MPSGuardImpl final
} }
// NB: These do NOT set the current device // NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override { Stream exchangeStream(Stream s) const override {
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
} }
DeviceIndex deviceCount() const noexcept override { DeviceIndex deviceCount() const noexcept override {

View File

@ -105,7 +105,7 @@ struct C10_API DeviceGuardImplInterface {
/** /**
* Get the current stream for a given device. * Get the current stream for a given device.
*/ */
virtual Stream getStream(Device) const noexcept = 0; virtual Stream getStream(Device) const = 0;
/** /**
* Get the default stream for a given device. * Get the default stream for a given device.
@ -138,7 +138,7 @@ struct C10_API DeviceGuardImplInterface {
* Return the previous stream for that device. You are NOT required * Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream. * to set the current device to match the device of this stream.
*/ */
virtual Stream exchangeStream(Stream) const noexcept = 0; virtual Stream exchangeStream(Stream) const = 0;
/** /**
* Destroys the given event. * Destroys the given event.

View File

@ -37,7 +37,7 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
void uncheckedSetDevice(Device d) const noexcept override { void uncheckedSetDevice(Device d) const noexcept override {
impl_->uncheckedSetDevice(d); impl_->uncheckedSetDevice(d);
} }
Stream getStream(Device d) const noexcept override { Stream getStream(Device d) const override {
return impl_->getStream(d); return impl_->getStream(d);
} }
Stream getNewStream(Device d, int priority = 0) const override { Stream getNewStream(Device d, int priority = 0) const override {
@ -50,7 +50,7 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
const override { const override {
return impl_->getStreamFromGlobalPool(d, isHighPriority); return impl_->getStreamFromGlobalPool(d, isHighPriority);
} }
Stream exchangeStream(Stream s) const noexcept override { Stream exchangeStream(Stream s) const override {
return impl_->exchangeStream(s); return impl_->exchangeStream(s);
} }
DeviceIndex deviceCount() const noexcept override { DeviceIndex deviceCount() const noexcept override {

View File

@ -56,7 +56,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
void uncheckedSetDevice(Device d) const noexcept override { void uncheckedSetDevice(Device d) const noexcept override {
C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
} }
Stream getStream(Device d) const noexcept override { Stream getStream(Device d) const override {
return getCurrentCUDAStream(d.index()).unwrap(); return getCurrentCUDAStream(d.index()).unwrap();
} }
Stream getDefaultStream(Device d) const override { Stream getDefaultStream(Device d) const override {
@ -70,7 +70,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
return getStreamFromPool(isHighPriority, d.index()); return getStreamFromPool(isHighPriority, d.index());
} }
// NB: These do NOT set the current device // NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override { Stream exchangeStream(Stream s) const override {
CUDAStream cs(s); CUDAStream cs(s);
auto old_stream = getCurrentCUDAStream(s.device().index()); auto old_stream = getCurrentCUDAStream(s.device().index());
setCurrentCUDAStream(cs); setCurrentCUDAStream(cs);

View File

@ -32,13 +32,13 @@ C10_XPU_API void get_device_properties(
C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr); C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr);
static inline void check_device_index(DeviceIndex device) { static inline void check_device_index(DeviceIndex device_index) {
TORCH_CHECK( TORCH_CHECK(
device >= 0 && device < c10::xpu::device_count(), device_index >= 0 && device_index < c10::xpu::device_count(),
"device is out of range, device is ", "The device index is out of range. It must be in [0, ",
static_cast<int>(device),
", total number of device is ",
static_cast<int>(c10::xpu::device_count()), static_cast<int>(c10::xpu::device_count()),
"), but got ",
static_cast<int>(device_index),
"."); ".");
} }

View File

@ -44,7 +44,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
c10::xpu::set_device(d.index()); c10::xpu::set_device(d.index());
} }
Stream getStream(Device d) const noexcept override { Stream getStream(Device d) const override {
return getCurrentXPUStream(d.index()).unwrap(); return getCurrentXPUStream(d.index()).unwrap();
} }
@ -58,7 +58,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
} }
// NB: These do NOT set the current device // NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override { Stream exchangeStream(Stream s) const override {
const XPUStream stream(s); const XPUStream stream(s);
const auto old_stream = getCurrentXPUStream(s.device().index()); const auto old_stream = getCurrentXPUStream(s.device().index());
setCurrentXPUStream(stream); setCurrentXPUStream(stream);

View File

@ -766,6 +766,10 @@ class TestCuda(TestCase):
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id) self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
torch.accelerator.set_stream(s2) torch.accelerator.set_stream(s2)
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id) self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
with self.assertRaisesRegex(
RuntimeError, "device_index >= 0 && device_index < num_gpus"
):
torch.accelerator.current_stream(torch.accelerator.device_count())
def test_record_stream(self): def test_record_stream(self):
cycles_per_ms = get_cycles_per_ms() cycles_per_ms = get_cycles_per_ms()

View File

@ -306,6 +306,8 @@ print(torch.xpu.device_count())
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id) self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
torch.accelerator.set_stream(s2) torch.accelerator.set_stream(s2)
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id) self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.current_stream(torch.accelerator.device_count())
def test_generator(self): def test_generator(self):
torch.manual_seed(2024) torch.manual_seed(2024)