mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPS] Add support for autocast in MPS (#99272)"
This reverts commit 6240cfd5c751bea6ca91dc765085e1d871b22345. Reverted https://github.com/pytorch/pytorch/pull/99272 on behalf of https://github.com/jeanschmidt due to introduced breakages in trunk ([comment](https://github.com/pytorch/pytorch/pull/99272#issuecomment-2203033719))
This commit is contained in:
@ -344,55 +344,6 @@ class TestAutocastGPU(TestCase):
|
||||
torch._C._set_cached_tensors_enabled(False)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
|
||||
class TestAutocastMPS(TestCase):
|
||||
def test_cast_cache_is_global(self):
|
||||
class CustomLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w_t):
|
||||
ctx.save_for_backward(x, w_t)
|
||||
return torch.nn.functional.linear(x, w_t)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, w_t = ctx.saved_tensors
|
||||
with torch.autocast(device_type="mps"):
|
||||
dL_dX = torch.matmul(grad_output, w_t)
|
||||
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
|
||||
return dL_dX, dL_dW
|
||||
|
||||
data = torch.randn(2, 3).to("mps")
|
||||
weight = torch.nn.Parameter(torch.randn(4, 3).to("mps"))
|
||||
weight_dtype_cast_counter = 0
|
||||
|
||||
class WeightDTypeCastCounterMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if (
|
||||
func is torch.ops.aten._to_copy.default
|
||||
and args[0] is weight
|
||||
and kwargs["dtype"] is torch.float16
|
||||
):
|
||||
nonlocal weight_dtype_cast_counter
|
||||
weight_dtype_cast_counter += 1
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
# self.old_clear_cache = torch.clear_autocast_cache
|
||||
# torch.clear_autocast_cache = lambda: None
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# torch.clear_autocast_cache = self.old_clear_cache
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
with WeightDTypeCastCounterMode():
|
||||
with torch.autocast(device_type="mps"):
|
||||
output = CustomLinear.apply(data, weight)
|
||||
s = output.sum()
|
||||
s.backward()
|
||||
self.assertEqual(weight_dtype_cast_counter, 2)
|
||||
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
|
Reference in New Issue
Block a user