mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add support for bf16 autocast (#139390)
This PR adds support for bf16 autocast. Most of the code and ideas are copied from #99272. Most of the heavy lifting was done by AI. Fixes #139386 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139390 Approved by: https://github.com/malfet Co-authored-by: Kulin Seth <kulin_seth@apple.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
808f0f656d
commit
bc69a19139
@ -7,6 +7,7 @@ from torch.testing._internal.autocast_test_lists import (
|
||||
AutocastCPUTestLists,
|
||||
TestAutocast,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import expectedFailureMPSPre14
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
@ -350,11 +351,21 @@ class TestAutocastMPS(TestCase):
|
||||
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "MPS Autocast only supports dtype of torch.float16 currently."
|
||||
UserWarning,
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.bfloat16):
|
||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
||||
# torch.bfloat16 is only supported on macOS 14 and above.
|
||||
@expectedFailureMPSPre14
|
||||
def test_mps_autocast_bfloat16_supported(self):
|
||||
with torch.amp.autocast(device_type="mps", dtype=torch.bfloat16):
|
||||
x = torch.randn(2, 3, device="mps")
|
||||
y = torch.randn(3, 3, device="mps")
|
||||
result = torch.mm(x, y)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
|
||||
@ -323,11 +323,19 @@ class autocast:
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "mps":
|
||||
supported_dtype = [torch.float16]
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"MPS Autocast only supports dtype of torch.float16 currently."
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.fast_dtype == torch.bfloat16:
|
||||
if not torch.backends.mps.is_macos_or_newer(14, 0):
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
||||
"on macOS versions below 14. Disabling autocast."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
|
||||
Reference in New Issue
Block a user