[MPS] Add support for bf16 autocast (#139390)

This PR adds support for bf16 autocast. Most of the code and ideas are copied from #99272.

Most of the heavy lifting was done by AI.

Fixes #139386

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139390
Approved by: https://github.com/malfet

Co-authored-by: Kulin Seth <kulin_seth@apple.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Roy Hvaara
2024-11-20 19:52:26 +00:00
committed by PyTorch MergeBot
parent 808f0f656d
commit bc69a19139
2 changed files with 25 additions and 6 deletions

View File

@ -7,6 +7,7 @@ from torch.testing._internal.autocast_test_lists import (
AutocastCPUTestLists,
TestAutocast,
)
from torch.testing._internal.common_device_type import expectedFailureMPSPre14
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
@ -350,11 +351,21 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning, "MPS Autocast only supports dtype of torch.float16 currently."
UserWarning,
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
):
with torch.autocast(device_type="mps", dtype=torch.bfloat16):
with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10)
# torch.bfloat16 is only supported on macOS 14 and above.
@expectedFailureMPSPre14
def test_mps_autocast_bfloat16_supported(self):
with torch.amp.autocast(device_type="mps", dtype=torch.bfloat16):
x = torch.randn(2, 3, device="mps")
y = torch.randn(3, 3, device="mps")
result = torch.mm(x, y)
self.assertEqual(result.dtype, torch.bfloat16)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):

View File

@ -323,11 +323,19 @@ class autocast:
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.float16]
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"MPS Autocast only supports dtype of torch.float16 currently."
error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0):
error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast."
)
warnings.warn(error_message)
enabled = False