mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[MPS] Update error message for supported autocast type (#139192)
Autocast in MPS currently only supports dtype of `torch.float16`. This PR updates the error message to reflect this. This PR was created using [Copilot Workspace](https://copilot-workspace.githubnext.com/pytorch/pytorch/issues/139190?shareId=5b510fda-380c-4e86-8e91-6b67a078f180) with no human input other than clicking buttons. Fixes #139190 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139192 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
996c40e85e
commit
4b83302585
@ -348,6 +348,13 @@ class TestAutocastMPS(TestCase):
|
|||||||
s.backward()
|
s.backward()
|
||||||
self.assertEqual(weight_dtype_cast_counter, 2)
|
self.assertEqual(weight_dtype_cast_counter, 2)
|
||||||
|
|
||||||
|
def test_mps_autocast_error_message(self):
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning, "MPS Autocast only supports dtype of torch.float16 currently."
|
||||||
|
):
|
||||||
|
with torch.autocast(device_type="mps", dtype=torch.bfloat16):
|
||||||
|
_ = torch.ones(10)
|
||||||
|
|
||||||
|
|
||||||
class TestTorchAutocast(TestCase):
|
class TestTorchAutocast(TestCase):
|
||||||
def test_autocast_fast_dtype(self):
|
def test_autocast_fast_dtype(self):
|
||||||
|
@ -327,7 +327,7 @@ class autocast:
|
|||||||
if self.fast_dtype not in supported_dtype:
|
if self.fast_dtype not in supported_dtype:
|
||||||
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||||
error_message += (
|
error_message += (
|
||||||
"MPS Autocast only supports dtype of torch.bfloat16 currently."
|
"MPS Autocast only supports dtype of torch.float16 currently."
|
||||||
)
|
)
|
||||||
warnings.warn(error_message)
|
warnings.warn(error_message)
|
||||||
enabled = False
|
enabled = False
|
||||||
|
Reference in New Issue
Block a user