Fix MPS interaction with autograd engine

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77644

Approved by: https://github.com/kulinseth, https://github.com/soulitzer, https://github.com/seemethere
This commit is contained in:
Alban Desmaison
2022-05-17 09:49:21 -04:00
committed by PyTorch MergeBot
parent f274558018
commit 090eddf1c7
3 changed files with 8 additions and 4 deletions

View File

@ -384,7 +384,7 @@ class TestMultiprocessing(TestCase):
ctx = mp.get_context('fork')
simple_autograd_function()
# Autograd only uses thread when GPUs are involved
if torch.cuda.is_available():
if torch.cuda.is_available() or torch.backends.mps.is_available():
with self.assertRaisesRegex(RuntimeError, r'Unable to handle autograd'):
with ctx.Pool(3) as pool:
pool.map(simple_autograd_function, [1, 2, 3])