From 07450e971348a03c5d4141bf56d1ca09abe5441c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 2 Jul 2024 12:29:50 +0000 Subject: [PATCH] Revert "[MPS] Add support for autocast in MPS (#99272)" This reverts commit 6240cfd5c751bea6ca91dc765085e1d871b22345. Reverted https://github.com/pytorch/pytorch/pull/99272 on behalf of https://github.com/jeanschmidt due to introduced breakages in trunk ([comment](https://github.com/pytorch/pytorch/pull/99272#issuecomment-2203033719)) --- aten/src/ATen/autocast_mode.cpp | 113 +------------------------- aten/src/ATen/autocast_mode.h | 27 +----- aten/src/ATen/core/interned_strings.h | 1 - c10/core/DispatchKey.cpp | 3 - c10/core/DispatchKey.h | 1 - c10/core/DispatchKeySet.h | 5 -- test/test_autocast.py | 49 ----------- test/test_mps.py | 23 ------ torch/amp/autocast_mode.py | 9 -- torch/csrc/jit/passes/autocast.cpp | 2 +- torch/csrc/utils/python_dispatch.cpp | 1 - torch/cuda/amp/common.py | 6 +- 12 files changed, 4 insertions(+), 236 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 9ec7d1373739..10fb72796fc6 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -69,7 +69,7 @@ thread_local std::array at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal at::kHalf, // XPU - at::kHalf, // MPS + at::ScalarType::Undefined, // MPS at::ScalarType::Undefined, // Meta (tensors with no data) at::kBFloat16, // HPU / HABANA at::ScalarType::Undefined, // SX-Aurora / NEC @@ -206,117 +206,6 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } -TORCH_LIBRARY_IMPL(_, AutocastMPS, m) { - m.fallback(torch::CppFunction::makeFallthrough()); -} - -TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { - // lower_precision_fp - KERNEL_MPS2(_convolution, deprecated, lower_precision_fp) - KERNEL_MPS(_convolution, lower_precision_fp) - KERNEL_MPS(conv1d, lower_precision_fp) - KERNEL_MPS(conv2d, lower_precision_fp) - KERNEL_MPS(conv_tbc, lower_precision_fp) - KERNEL_MPS(conv_transpose1d, lower_precision_fp) - KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp) - KERNEL_MPS(convolution, lower_precision_fp) - KERNEL_MPS(_mps_convolution, lower_precision_fp) - KERNEL_MPS(prelu, lower_precision_fp) - KERNEL_MPS(addmm, lower_precision_fp) - KERNEL_MPS(addmv, lower_precision_fp) - KERNEL_MPS(addr, lower_precision_fp) - KERNEL_MPS(matmul, lower_precision_fp) - KERNEL_MPS(einsum, lower_precision_fp) - KERNEL_MPS(mm, lower_precision_fp) - KERNEL_MPS(mv, lower_precision_fp) - KERNEL_MPS(linear, lower_precision_fp) - KERNEL_MPS(addbmm, lower_precision_fp) - KERNEL_MPS(baddbmm, lower_precision_fp) - KERNEL_MPS(bmm, lower_precision_fp) - KERNEL_MPS(chain_matmul, lower_precision_fp) - KERNEL_MPS(linalg_multi_dot, lower_precision_fp) - KERNEL_MPS(lstm_cell, lower_precision_fp) - - // fp32 - KERNEL_MPS(acos, fp32) - KERNEL_MPS(asin, fp32) - KERNEL_MPS(cosh, fp32) - KERNEL_MPS(erfinv, fp32) - KERNEL_MPS(exp, fp32) - KERNEL_MPS(expm1, fp32) - KERNEL_MPS(log, fp32) - KERNEL_MPS(log10, fp32) - KERNEL_MPS(log2, fp32) - KERNEL_MPS(log1p, fp32) - KERNEL_MPS(reciprocal, fp32) - KERNEL_MPS(rsqrt, fp32) - KERNEL_MPS(sinh, fp32) - KERNEL_MPS(tan, fp32) - KERNEL_MPS2(pow, Tensor_Scalar, fp32) - KERNEL_MPS2(pow, Tensor_Tensor, fp32) - KERNEL_MPS2(pow, Scalar, fp32) - KERNEL_MPS(softplus, fp32) - KERNEL_MPS(layer_norm, fp32) - KERNEL_MPS(native_layer_norm, fp32) - KERNEL_MPS(group_norm, fp32) - KERNEL_MPS2(frobenius_norm, dim, fp32) - KERNEL_MPS(nuclear_norm, fp32) - KERNEL_MPS2(nuclear_norm, dim, fp32) - KERNEL_MPS(cosine_similarity, fp32) - KERNEL_MPS(poisson_nll_loss, fp32) - KERNEL_MPS(cosine_embedding_loss, fp32) - KERNEL_MPS(nll_loss, fp32) - KERNEL_MPS(nll_loss2d, fp32) - KERNEL_MPS(hinge_embedding_loss, fp32) - KERNEL_MPS(kl_div, fp32) - KERNEL_MPS(l1_loss, fp32) - KERNEL_MPS(smooth_l1_loss, fp32) - KERNEL_MPS(huber_loss, fp32) - KERNEL_MPS(mse_loss, fp32) - KERNEL_MPS(margin_ranking_loss, fp32) - KERNEL_MPS(multilabel_margin_loss, fp32) - KERNEL_MPS(soft_margin_loss, fp32) - KERNEL_MPS(triplet_margin_loss, fp32) - KERNEL_MPS(multi_margin_loss, fp32) - KERNEL_MPS(binary_cross_entropy_with_logits, fp32) - KERNEL_MPS(dist, fp32) - KERNEL_MPS(pdist, fp32) - KERNEL_MPS(cdist, fp32) - KERNEL_MPS(renorm, fp32) - KERNEL_MPS(logsumexp, fp32) - - // fp32_set_opt_dtype - KERNEL_MPS(prod, fp32) - KERNEL_MPS2(prod, dim_int, fp32) - KERNEL_MPS2(prod, dim_Dimname, fp32) - KERNEL_MPS2(softmax, int, fp32) - KERNEL_MPS2(softmax, Dimname, fp32) - KERNEL_MPS2(log_softmax, int, fp32) - KERNEL_MPS2(log_softmax, Dimname, fp32) - KERNEL_MPS(cumprod, fp32) - KERNEL_MPS2(cumprod, dimname, fp32) - KERNEL_MPS(cumsum, fp32) - KERNEL_MPS2(cumsum, dimname, fp32) - KERNEL_MPS(linalg_vector_norm, fp32) - KERNEL_MPS(linalg_matrix_norm, fp32) - KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32) - KERNEL_MPS(sum, fp32) - KERNEL_MPS2(sum, dim_IntList, fp32) - KERNEL_MPS2(sum, dim_DimnameList, fp32) - // - // promote - KERNEL_MPS(addcdiv, promote) - KERNEL_MPS(addcmul, promote) - KERNEL_MPS(atan2, promote) - KERNEL_MPS(bilinear, promote) - KERNEL_MPS(cross, promote) - KERNEL_MPS(dot, promote) - KERNEL_MPS(grid_sampler, promote) - KERNEL_MPS(index_put, promote) - KERNEL_MPS(tensordot, promote) - KERNEL_MPS(scatter_add, promote) -} - TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 3cd2921c5057..c36030db5b04 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -145,8 +145,6 @@ inline bool is_autocast_eligible( return tensor.is_xla() && tensor.is_floating_point(); case c10::DeviceType::PrivateUse1: return tensor.is_privateuseone() && tensor.is_floating_point(); - case c10::DeviceType::MPS: - return tensor.is_mps() && tensor.is_floating_point(); default: return false; } @@ -170,8 +168,6 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type( return DispatchKey::AutocastXLA; case c10::DeviceType::PrivateUse1: return DispatchKey::AutocastPrivateUse1; - case c10::DeviceType::MPS: - return DispatchKey::AutocastMPS; default: throw std::runtime_error( "unknown device type for autocast in get_autocast_dispatch_key_from_device_type"); @@ -182,7 +178,7 @@ inline bool is_autocast_available(c10::DeviceType device_type) { if (device_type == at::kCPU || device_type == at::kCUDA || device_type == at::kXPU || device_type == at::kIPU || device_type == at::kHPU || device_type == at::kXLA || - device_type == at::kPrivateUse1 || device_type == at::kMPS) { + device_type == at::kPrivateUse1) { return true; } else { return false; @@ -749,27 +745,6 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) -// KERNEL_MPS registration for AutocastMPS -#define KERNEL_MPS(OP, POLICY) \ - m.impl( \ - TORCH_SELECTIVE_NAME("aten::" #OP), \ - &WrapFunction< \ - CastPolicy::POLICY, \ - DeviceType::MPS, \ - decltype(ATEN_FN(OP)), \ - decltype(ATEN_FN(OP)), \ - &ATEN_FN(OP)>::type::call); - -#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \ - m.impl( \ - TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ - &WrapFunction< \ - CastPolicy::POLICY, \ - DeviceType::MPS, \ - decltype(ATEN_FN2(OP, OVERLOAD)), \ - decltype(ATEN_FN2(OP, OVERLOAD)), \ - &ATEN_FN2(OP, OVERLOAD)>::type::call); - // Op lists for different policies. // To make sure other backends can reuse the policy op list. #define AT_FORALL_LOWER_PRECISION_FP(_) \ diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 38942031befc..4f6abd66cb88 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -228,7 +228,6 @@ namespace c10 { _(aten, is_autocast_cpu_enabled) \ _(aten, is_autocast_xla_enabled) \ _(aten, get_autocast_dtype) \ - _(aten, is_autocast_mps_enabled) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 526e7f079ee5..0388234efd5b 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -149,8 +149,6 @@ const char* toString(DispatchKey t) { return "AutocastXLA"; case DispatchKey::AutocastPrivateUse1: return "AutocastPrivateUse1"; - case DispatchKey::AutocastMPS: - return "AutocastMPS"; case DispatchKey::FuncTorchBatched: return "FuncTorchBatched"; @@ -299,7 +297,6 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"AutocastCUDA", c10::DispatchKey::AutocastCUDA}, {"AutocastXLA", c10::DispatchKey::AutocastXLA}, {"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1}, - {"AutocastMPS", c10::DispatchKey::AutocastMPS}, {"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched}, {"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor}, {"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index e08e8d2f2d01..71277ebfd891 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -359,7 +359,6 @@ enum class DispatchKey : uint16_t { AutocastXLA, // AutocastXLA is only being used for TPUs. XLA GPUs continue to use // AutocastCUDA. - AutocastMPS, AutocastCUDA, AutocastPrivateUse1, diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 1a9db6da19ed..4c391d60f2b0 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -655,7 +655,6 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ DispatchKey::AutocastCPU, - DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, @@ -672,7 +671,6 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ DispatchKey::AutocastCPU, - DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, @@ -865,7 +863,6 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA); constexpr auto autocast_privateuse1_ks = DispatchKeySet(DispatchKey::AutocastPrivateUse1); - constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS); switch (t) { case BackendComponent::CPUBit: return autocast_cpu_ks; @@ -881,8 +878,6 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { return autocast_xla_ks; case BackendComponent::PrivateUse1Bit: return autocast_privateuse1_ks; - case BackendComponent::MPSBit: - return autocast_mps_ks; default: return DispatchKeySet(); } diff --git a/test/test_autocast.py b/test/test_autocast.py index 9ff1cdfbc0ea..24f87944990d 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -344,55 +344,6 @@ class TestAutocastGPU(TestCase): torch._C._set_cached_tensors_enabled(False) -@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") -class TestAutocastMPS(TestCase): - def test_cast_cache_is_global(self): - class CustomLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, x, w_t): - ctx.save_for_backward(x, w_t) - return torch.nn.functional.linear(x, w_t) - - @staticmethod - def backward(ctx, grad_output): - x, w_t = ctx.saved_tensors - with torch.autocast(device_type="mps"): - dL_dX = torch.matmul(grad_output, w_t) - dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) - return dL_dX, dL_dW - - data = torch.randn(2, 3).to("mps") - weight = torch.nn.Parameter(torch.randn(4, 3).to("mps")) - weight_dtype_cast_counter = 0 - - class WeightDTypeCastCounterMode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if ( - func is torch.ops.aten._to_copy.default - and args[0] is weight - and kwargs["dtype"] is torch.float16 - ): - nonlocal weight_dtype_cast_counter - weight_dtype_cast_counter += 1 - return func(*args, **kwargs) - - def __enter__(self): - # self.old_clear_cache = torch.clear_autocast_cache - # torch.clear_autocast_cache = lambda: None - return super().__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - # torch.clear_autocast_cache = self.old_clear_cache - return super().__exit__(exc_type, exc_val, exc_tb) - - with WeightDTypeCastCounterMode(): - with torch.autocast(device_type="mps"): - output = CustomLinear.apply(data, weight) - s = output.sum() - s.backward() - self.assertEqual(weight_dtype_cast_counter, 2) - - class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): gpu_fast_dtype = torch.get_autocast_gpu_dtype() diff --git a/test/test_mps.py b/test/test_mps.py index c071a133f02a..804e78ef8419 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1199,29 +1199,6 @@ class MpsMemoryLeakCheck: raise RuntimeError(msg) -class TestAutocastMPS(TestCase): - - def test_matmul_autocast(self): - autocast_tensor_A = torch.rand((8, 8), device="mps") - autocast_tensor_B = torch.rand((8, 8), device="mps") - tensor_A = autocast_tensor_A.clone().detach() - tensor_B = autocast_tensor_B.clone().detach() - autocast_output_tensor = torch.empty(8, 8) - output_tensor = autocast_output_tensor.clone().detach() - - with torch.autocast(device_type="mps"): - autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B) - autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor) - - output_tensor = torch.mm(tensor_A, tensor_B) - output_tensor = torch.mm(tensor_A, output_tensor) - - self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16") - self.assertEqual(autocast_output_tensor, - output_tensor.to(torch.float16), - f"Autocast & non-autocast tensors did not match, \ - got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}") - # Expand TestCase class with Memory Leak Detection on MPS device class TestCaseMPS(TestCase): _do_mps_memory_leak_check = True diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index e27593910a2c..f5a50bbe2b3e 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -322,15 +322,6 @@ class autocast: raise RuntimeError( "Current CUDA Device does not support bfloat16. Please switch dtype to float16." ) - elif self.device == "mps": - supported_dtype = [torch.float16] - if self.fast_dtype not in supported_dtype: - error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += ( - "MPS Autocast only supports dtype of torch.float16 currently." - ) - warnings.warn(error_message) - enabled = False elif self.device == "xla": supported_dtype = [torch.float16, torch.bfloat16] if self.fast_dtype not in supported_dtype: diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 573261382e51..635162e04953 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -109,7 +109,7 @@ std::optional parseAutocast( TORCH_CHECK( dtype != c10::ScalarType::Undefined, "Autocast has invalid fast_dtype attribute"); - if (device == "cuda" || device == "mps") { + if (device == "cuda") { scope.context.gpu_enabled = enabled.value(); scope.context.gpu_scalar_type = dtype; } else if (device == "cpu") { diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index bab2e583c2c0..ec0af99842d2 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -689,7 +689,6 @@ void initDispatchBindings(PyObject* module) { DEF_ONE(PreDispatch) DEF_ONE(Functionalize) DEF_ONE(AutocastCPU) - DEF_ONE(AutocastMPS) DEF_ONE(AutocastXPU) DEF_ONE(AutocastHPU) DEF_ONE(AutocastIPU) diff --git a/torch/cuda/amp/common.py b/torch/cuda/amp/common.py index c81ca6f8d14c..30ccaeede8d9 100644 --- a/torch/cuda/amp/common.py +++ b/torch/cuda/amp/common.py @@ -7,8 +7,4 @@ __all__ = ["amp_definitely_not_available"] def amp_definitely_not_available(): - return not ( - torch.cuda.is_available() - or find_spec("torch_xla") - or torch.backends.mps.is_available() - ) + return not (torch.cuda.is_available() or find_spec("torch_xla"))