mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix MPS interaction with autograd engine
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77644 Approved by: https://github.com/kulinseth, https://github.com/soulitzer, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
f274558018
commit
090eddf1c7
@ -74,8 +74,12 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
|
|||||||
return Stream(Stream::DEFAULT, Device(DeviceType::MPS, 0));
|
return Stream(Stream::DEFAULT, Device(DeviceType::MPS, 0));
|
||||||
}
|
}
|
||||||
DeviceIndex deviceCount() const noexcept override {
|
DeviceIndex deviceCount() const noexcept override {
|
||||||
//TODO: extend it for multi-device case
|
if (at::hasMPS()) {
|
||||||
return 1;
|
//TODO: extend it for multi-device case
|
||||||
|
return 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Event-related functions
|
// Event-related functions
|
||||||
|
@ -4328,7 +4328,7 @@ for shape in [(1,), ()]:
|
|||||||
# The autograd engine creates worker threads only when GPU devices are present.
|
# The autograd engine creates worker threads only when GPU devices are present.
|
||||||
# So make sure that we do shutdown threads when we're testing cuda and make sure
|
# So make sure that we do shutdown threads when we're testing cuda and make sure
|
||||||
# that there is no thread to shutdown when we're not using cuda.
|
# that there is no thread to shutdown when we're not using cuda.
|
||||||
if TEST_CUDA:
|
if TEST_CUDA or torch.backends.mps.is_available():
|
||||||
self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
|
self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
|
||||||
else:
|
else:
|
||||||
self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
|
self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
|
||||||
|
@ -384,7 +384,7 @@ class TestMultiprocessing(TestCase):
|
|||||||
ctx = mp.get_context('fork')
|
ctx = mp.get_context('fork')
|
||||||
simple_autograd_function()
|
simple_autograd_function()
|
||||||
# Autograd only uses thread when GPUs are involved
|
# Autograd only uses thread when GPUs are involved
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() or torch.backends.mps.is_available():
|
||||||
with self.assertRaisesRegex(RuntimeError, r'Unable to handle autograd'):
|
with self.assertRaisesRegex(RuntimeError, r'Unable to handle autograd'):
|
||||||
with ctx.Pool(3) as pool:
|
with ctx.Pool(3) as pool:
|
||||||
pool.map(simple_autograd_function, [1, 2, 3])
|
pool.map(simple_autograd_function, [1, 2, 3])
|
||||||
|
Reference in New Issue
Block a user