[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)

## Description:

This PR refactors the autocast context manager in `autocast_mode.py` to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping `device_supported_dtypes` is used to associate device types with their supported dtypes, and the validation logic is unified.

In my view, this makes the code easier to maintain and extend for new devices.

Please share any suggestions and comments with me.

BTW, in the original `xla` branch, the `supported_dtype` are `[torch.float16, torch.bfloat16]`, 5d8a226e23/torch/amp/autocast_mode.py (L358-L363) but the warning message has only `torch.bfloat16`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163446
Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
KarhouTam
2025-10-10 12:30:02 +00:00
committed by PyTorch MergeBot
parent e0abcee3b5
commit 960b0d5f0d
2 changed files with 44 additions and 96 deletions

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning,
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
):
with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10)