mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)
## Description:
This PR refactors the autocast context manager in `autocast_mode.py` to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping `device_supported_dtypes` is used to associate device types with their supported dtypes, and the validation logic is unified.
In my view, this makes the code easier to maintain and extend for new devices.
Please share any suggestions and comments with me.
BTW, in the original `xla` branch, the `supported_dtype` are `[torch.float16, torch.bfloat16]`, 5d8a226e23/torch/amp/autocast_mode.py (L358-L363)
but the warning message has only `torch.bfloat16`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163446
Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e0abcee3b5
commit
960b0d5f0d
@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
||||
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
Reference in New Issue
Block a user