Fix MPS interaction with autograd engine

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77644

Approved by: https://github.com/kulinseth, https://github.com/soulitzer, https://github.com/seemethere
This commit is contained in:
Alban Desmaison
2022-05-17 09:49:21 -04:00
committed by PyTorch MergeBot
parent f274558018
commit 090eddf1c7
3 changed files with 8 additions and 4 deletions

View File

@ -74,8 +74,12 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
return Stream(Stream::DEFAULT, Device(DeviceType::MPS, 0)); return Stream(Stream::DEFAULT, Device(DeviceType::MPS, 0));
} }
DeviceIndex deviceCount() const noexcept override { DeviceIndex deviceCount() const noexcept override {
//TODO: extend it for multi-device case if (at::hasMPS()) {
return 1; //TODO: extend it for multi-device case
return 1;
} else {
return 0;
}
} }
// Event-related functions // Event-related functions

View File

@ -4328,7 +4328,7 @@ for shape in [(1,), ()]:
# The autograd engine creates worker threads only when GPU devices are present. # The autograd engine creates worker threads only when GPU devices are present.
# So make sure that we do shutdown threads when we're testing cuda and make sure # So make sure that we do shutdown threads when we're testing cuda and make sure
# that there is no thread to shutdown when we're not using cuda. # that there is no thread to shutdown when we're not using cuda.
if TEST_CUDA: if TEST_CUDA or torch.backends.mps.is_available():
self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
else: else:
self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")

View File

@ -384,7 +384,7 @@ class TestMultiprocessing(TestCase):
ctx = mp.get_context('fork') ctx = mp.get_context('fork')
simple_autograd_function() simple_autograd_function()
# Autograd only uses thread when GPUs are involved # Autograd only uses thread when GPUs are involved
if torch.cuda.is_available(): if torch.cuda.is_available() or torch.backends.mps.is_available():
with self.assertRaisesRegex(RuntimeError, r'Unable to handle autograd'): with self.assertRaisesRegex(RuntimeError, r'Unable to handle autograd'):
with ctx.Pool(3) as pool: with ctx.Pool(3) as pool:
pool.map(simple_autograd_function, [1, 2, 3]) pool.map(simple_autograd_function, [1, 2, 3])