mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPS] Expand fused forloop to bfloat16 (#141104)"
This reverts commit 9a729390420570cd2528ce2e9947e3eab209660b. Reverted https://github.com/pytorch/pytorch/pull/141104 on behalf of https://github.com/malfet due to Want to add test script to the commit message ([comment](https://github.com/pytorch/pytorch/pull/141104#issuecomment-2492659931))
This commit is contained in:
@ -1027,11 +1027,8 @@ class TestOptimRenewed(TestCase):
|
||||
if _get_device_type(device) == "mps" and dtype not in (
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
):
|
||||
self.skipTest(
|
||||
"MPS supports only torch.float16, torch.float32 and torch.bfloat16"
|
||||
)
|
||||
self.skipTest("MPS supports only torch.float16 and torch.float32")
|
||||
self._test_derived_optimizers(device, dtype, optim_info, "fused")
|
||||
|
||||
@optims(
|
||||
|
Reference in New Issue
Block a user