mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Update error message for supported autocast type (#139192)
Autocast in MPS currently only supports dtype of `torch.float16`. This PR updates the error message to reflect this. This PR was created using [Copilot Workspace](https://copilot-workspace.githubnext.com/pytorch/pytorch/issues/139190?shareId=5b510fda-380c-4e86-8e91-6b67a078f180) with no human input other than clicking buttons. Fixes #139190 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139192 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
996c40e85e
commit
4b83302585
@ -348,6 +348,13 @@ class TestAutocastMPS(TestCase):
|
||||
s.backward()
|
||||
self.assertEqual(weight_dtype_cast_counter, 2)
|
||||
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "MPS Autocast only supports dtype of torch.float16 currently."
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.bfloat16):
|
||||
_ = torch.ones(10)
|
||||
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
|
Reference in New Issue
Block a user