[MPS] Add support for autocast in MPS (#99272)

Fixes https://github.com/pytorch/pytorch/issues/88415

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet
This commit is contained in:
Kulin Seth
2024-08-05 17:02:30 +00:00
committed by PyTorch MergeBot
parent d532c00c81
commit 6919e8baab
11 changed files with 231 additions and 3 deletions

View File

@ -344,6 +344,55 @@ class TestAutocastGPU(TestCase):
torch._C._set_cached_tensors_enabled(False)
@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
class TestAutocastMPS(TestCase):
def test_cast_cache_is_global(self):
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type="mps"):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
data = torch.randn(2, 3).to("mps")
weight = torch.nn.Parameter(torch.randn(4, 3).to("mps"))
weight_dtype_cast_counter = 0
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default
and args[0] is weight
and kwargs["dtype"] is torch.float16
):
nonlocal weight_dtype_cast_counter
weight_dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
# self.old_clear_cache = torch.clear_autocast_cache
# torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
# torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
with WeightDTypeCastCounterMode():
with torch.autocast(device_type="mps"):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(weight_dtype_cast_counter, 2)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()