[MPS] Update error message for supported autocast type (#139192)

Autocast in MPS currently only supports dtype of `torch.float16`. This PR updates the error message to reflect this.

This PR was created using [Copilot Workspace](https://copilot-workspace.githubnext.com/pytorch/pytorch/issues/139190?shareId=5b510fda-380c-4e86-8e91-6b67a078f180) with no human input other than clicking buttons.

Fixes #139190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139192
Approved by: https://github.com/malfet
This commit is contained in:
Roy Hvaara
2024-10-30 16:48:29 +00:00
committed by PyTorch MergeBot
parent 996c40e85e
commit 4b83302585
2 changed files with 8 additions and 1 deletions

View File

@ -348,6 +348,13 @@ class TestAutocastMPS(TestCase):
s.backward()
self.assertEqual(weight_dtype_cast_counter, 2)
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning, "MPS Autocast only supports dtype of torch.float16 currently."
):
with torch.autocast(device_type="mps", dtype=torch.bfloat16):
_ = torch.ones(10)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):

View File

@ -327,7 +327,7 @@ class autocast:
if self.fast_dtype not in supported_dtype:
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"MPS Autocast only supports dtype of torch.bfloat16 currently."
"MPS Autocast only supports dtype of torch.float16 currently."
)
warnings.warn(error_message)
enabled = False