mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPS] Add support for autocast in MPS (#99272)"
This reverts commit 6240cfd5c751bea6ca91dc765085e1d871b22345. Reverted https://github.com/pytorch/pytorch/pull/99272 on behalf of https://github.com/jeanschmidt due to introduced breakages in trunk ([comment](https://github.com/pytorch/pytorch/pull/99272#issuecomment-2203033719))
This commit is contained in:
@ -69,7 +69,7 @@ thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
at::ScalarType::Undefined, // Vulkan
|
||||
at::ScalarType::Undefined, // Metal
|
||||
at::kHalf, // XPU
|
||||
at::kHalf, // MPS
|
||||
at::ScalarType::Undefined, // MPS
|
||||
at::ScalarType::Undefined, // Meta (tensors with no data)
|
||||
at::kBFloat16, // HPU / HABANA
|
||||
at::ScalarType::Undefined, // SX-Aurora / NEC
|
||||
@ -206,117 +206,6 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutocastMPS, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
|
||||
// lower_precision_fp
|
||||
KERNEL_MPS2(_convolution, deprecated, lower_precision_fp)
|
||||
KERNEL_MPS(_convolution, lower_precision_fp)
|
||||
KERNEL_MPS(conv1d, lower_precision_fp)
|
||||
KERNEL_MPS(conv2d, lower_precision_fp)
|
||||
KERNEL_MPS(conv_tbc, lower_precision_fp)
|
||||
KERNEL_MPS(conv_transpose1d, lower_precision_fp)
|
||||
KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp)
|
||||
KERNEL_MPS(convolution, lower_precision_fp)
|
||||
KERNEL_MPS(_mps_convolution, lower_precision_fp)
|
||||
KERNEL_MPS(prelu, lower_precision_fp)
|
||||
KERNEL_MPS(addmm, lower_precision_fp)
|
||||
KERNEL_MPS(addmv, lower_precision_fp)
|
||||
KERNEL_MPS(addr, lower_precision_fp)
|
||||
KERNEL_MPS(matmul, lower_precision_fp)
|
||||
KERNEL_MPS(einsum, lower_precision_fp)
|
||||
KERNEL_MPS(mm, lower_precision_fp)
|
||||
KERNEL_MPS(mv, lower_precision_fp)
|
||||
KERNEL_MPS(linear, lower_precision_fp)
|
||||
KERNEL_MPS(addbmm, lower_precision_fp)
|
||||
KERNEL_MPS(baddbmm, lower_precision_fp)
|
||||
KERNEL_MPS(bmm, lower_precision_fp)
|
||||
KERNEL_MPS(chain_matmul, lower_precision_fp)
|
||||
KERNEL_MPS(linalg_multi_dot, lower_precision_fp)
|
||||
KERNEL_MPS(lstm_cell, lower_precision_fp)
|
||||
|
||||
// fp32
|
||||
KERNEL_MPS(acos, fp32)
|
||||
KERNEL_MPS(asin, fp32)
|
||||
KERNEL_MPS(cosh, fp32)
|
||||
KERNEL_MPS(erfinv, fp32)
|
||||
KERNEL_MPS(exp, fp32)
|
||||
KERNEL_MPS(expm1, fp32)
|
||||
KERNEL_MPS(log, fp32)
|
||||
KERNEL_MPS(log10, fp32)
|
||||
KERNEL_MPS(log2, fp32)
|
||||
KERNEL_MPS(log1p, fp32)
|
||||
KERNEL_MPS(reciprocal, fp32)
|
||||
KERNEL_MPS(rsqrt, fp32)
|
||||
KERNEL_MPS(sinh, fp32)
|
||||
KERNEL_MPS(tan, fp32)
|
||||
KERNEL_MPS2(pow, Tensor_Scalar, fp32)
|
||||
KERNEL_MPS2(pow, Tensor_Tensor, fp32)
|
||||
KERNEL_MPS2(pow, Scalar, fp32)
|
||||
KERNEL_MPS(softplus, fp32)
|
||||
KERNEL_MPS(layer_norm, fp32)
|
||||
KERNEL_MPS(native_layer_norm, fp32)
|
||||
KERNEL_MPS(group_norm, fp32)
|
||||
KERNEL_MPS2(frobenius_norm, dim, fp32)
|
||||
KERNEL_MPS(nuclear_norm, fp32)
|
||||
KERNEL_MPS2(nuclear_norm, dim, fp32)
|
||||
KERNEL_MPS(cosine_similarity, fp32)
|
||||
KERNEL_MPS(poisson_nll_loss, fp32)
|
||||
KERNEL_MPS(cosine_embedding_loss, fp32)
|
||||
KERNEL_MPS(nll_loss, fp32)
|
||||
KERNEL_MPS(nll_loss2d, fp32)
|
||||
KERNEL_MPS(hinge_embedding_loss, fp32)
|
||||
KERNEL_MPS(kl_div, fp32)
|
||||
KERNEL_MPS(l1_loss, fp32)
|
||||
KERNEL_MPS(smooth_l1_loss, fp32)
|
||||
KERNEL_MPS(huber_loss, fp32)
|
||||
KERNEL_MPS(mse_loss, fp32)
|
||||
KERNEL_MPS(margin_ranking_loss, fp32)
|
||||
KERNEL_MPS(multilabel_margin_loss, fp32)
|
||||
KERNEL_MPS(soft_margin_loss, fp32)
|
||||
KERNEL_MPS(triplet_margin_loss, fp32)
|
||||
KERNEL_MPS(multi_margin_loss, fp32)
|
||||
KERNEL_MPS(binary_cross_entropy_with_logits, fp32)
|
||||
KERNEL_MPS(dist, fp32)
|
||||
KERNEL_MPS(pdist, fp32)
|
||||
KERNEL_MPS(cdist, fp32)
|
||||
KERNEL_MPS(renorm, fp32)
|
||||
KERNEL_MPS(logsumexp, fp32)
|
||||
|
||||
// fp32_set_opt_dtype
|
||||
KERNEL_MPS(prod, fp32)
|
||||
KERNEL_MPS2(prod, dim_int, fp32)
|
||||
KERNEL_MPS2(prod, dim_Dimname, fp32)
|
||||
KERNEL_MPS2(softmax, int, fp32)
|
||||
KERNEL_MPS2(softmax, Dimname, fp32)
|
||||
KERNEL_MPS2(log_softmax, int, fp32)
|
||||
KERNEL_MPS2(log_softmax, Dimname, fp32)
|
||||
KERNEL_MPS(cumprod, fp32)
|
||||
KERNEL_MPS2(cumprod, dimname, fp32)
|
||||
KERNEL_MPS(cumsum, fp32)
|
||||
KERNEL_MPS2(cumsum, dimname, fp32)
|
||||
KERNEL_MPS(linalg_vector_norm, fp32)
|
||||
KERNEL_MPS(linalg_matrix_norm, fp32)
|
||||
KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32)
|
||||
KERNEL_MPS(sum, fp32)
|
||||
KERNEL_MPS2(sum, dim_IntList, fp32)
|
||||
KERNEL_MPS2(sum, dim_DimnameList, fp32)
|
||||
//
|
||||
// promote
|
||||
KERNEL_MPS(addcdiv, promote)
|
||||
KERNEL_MPS(addcmul, promote)
|
||||
KERNEL_MPS(atan2, promote)
|
||||
KERNEL_MPS(bilinear, promote)
|
||||
KERNEL_MPS(cross, promote)
|
||||
KERNEL_MPS(dot, promote)
|
||||
KERNEL_MPS(grid_sampler, promote)
|
||||
KERNEL_MPS(index_put, promote)
|
||||
KERNEL_MPS(tensordot, promote)
|
||||
KERNEL_MPS(scatter_add, promote)
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
@ -145,8 +145,6 @@ inline bool is_autocast_eligible(
|
||||
return tensor.is_xla() && tensor.is_floating_point();
|
||||
case c10::DeviceType::PrivateUse1:
|
||||
return tensor.is_privateuseone() && tensor.is_floating_point();
|
||||
case c10::DeviceType::MPS:
|
||||
return tensor.is_mps() && tensor.is_floating_point();
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -170,8 +168,6 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
||||
return DispatchKey::AutocastXLA;
|
||||
case c10::DeviceType::PrivateUse1:
|
||||
return DispatchKey::AutocastPrivateUse1;
|
||||
case c10::DeviceType::MPS:
|
||||
return DispatchKey::AutocastMPS;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
|
||||
@ -182,7 +178,7 @@ inline bool is_autocast_available(c10::DeviceType device_type) {
|
||||
if (device_type == at::kCPU || device_type == at::kCUDA ||
|
||||
device_type == at::kXPU || device_type == at::kIPU ||
|
||||
device_type == at::kHPU || device_type == at::kXLA ||
|
||||
device_type == at::kPrivateUse1 || device_type == at::kMPS) {
|
||||
device_type == at::kPrivateUse1) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
@ -749,27 +745,6 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
|
||||
REDISPATCH_SIGNATURE, \
|
||||
POLICY)
|
||||
|
||||
// KERNEL_MPS registration for AutocastMPS
|
||||
#define KERNEL_MPS(OP, POLICY) \
|
||||
m.impl( \
|
||||
TORCH_SELECTIVE_NAME("aten::" #OP), \
|
||||
&WrapFunction< \
|
||||
CastPolicy::POLICY, \
|
||||
DeviceType::MPS, \
|
||||
decltype(ATEN_FN(OP)), \
|
||||
decltype(ATEN_FN(OP)), \
|
||||
&ATEN_FN(OP)>::type::call);
|
||||
|
||||
#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \
|
||||
m.impl( \
|
||||
TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
|
||||
&WrapFunction< \
|
||||
CastPolicy::POLICY, \
|
||||
DeviceType::MPS, \
|
||||
decltype(ATEN_FN2(OP, OVERLOAD)), \
|
||||
decltype(ATEN_FN2(OP, OVERLOAD)), \
|
||||
&ATEN_FN2(OP, OVERLOAD)>::type::call);
|
||||
|
||||
// Op lists for different policies.
|
||||
// To make sure other backends can reuse the policy op list.
|
||||
#define AT_FORALL_LOWER_PRECISION_FP(_) \
|
||||
|
@ -228,7 +228,6 @@ namespace c10 {
|
||||
_(aten, is_autocast_cpu_enabled) \
|
||||
_(aten, is_autocast_xla_enabled) \
|
||||
_(aten, get_autocast_dtype) \
|
||||
_(aten, is_autocast_mps_enabled) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
_(onnx, Concat) \
|
||||
|
@ -149,8 +149,6 @@ const char* toString(DispatchKey t) {
|
||||
return "AutocastXLA";
|
||||
case DispatchKey::AutocastPrivateUse1:
|
||||
return "AutocastPrivateUse1";
|
||||
case DispatchKey::AutocastMPS:
|
||||
return "AutocastMPS";
|
||||
|
||||
case DispatchKey::FuncTorchBatched:
|
||||
return "FuncTorchBatched";
|
||||
@ -299,7 +297,6 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
|
||||
{"AutocastXLA", c10::DispatchKey::AutocastXLA},
|
||||
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
|
||||
{"AutocastMPS", c10::DispatchKey::AutocastMPS},
|
||||
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
|
||||
{"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor},
|
||||
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
|
||||
|
@ -359,7 +359,6 @@ enum class DispatchKey : uint16_t {
|
||||
AutocastXLA,
|
||||
// AutocastXLA is only being used for TPUs. XLA GPUs continue to use
|
||||
// AutocastCUDA.
|
||||
AutocastMPS,
|
||||
AutocastCUDA,
|
||||
AutocastPrivateUse1,
|
||||
|
||||
|
@ -655,7 +655,6 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
|
||||
|
||||
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||||
DispatchKey::AutocastCPU,
|
||||
DispatchKey::AutocastMPS,
|
||||
DispatchKey::AutocastCUDA,
|
||||
DispatchKey::AutocastXPU,
|
||||
DispatchKey::AutocastIPU,
|
||||
@ -672,7 +671,6 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({
|
||||
|
||||
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||||
DispatchKey::AutocastCPU,
|
||||
DispatchKey::AutocastMPS,
|
||||
DispatchKey::AutocastCUDA,
|
||||
DispatchKey::AutocastXPU,
|
||||
DispatchKey::AutocastIPU,
|
||||
@ -865,7 +863,6 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
|
||||
constexpr auto autocast_privateuse1_ks =
|
||||
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
|
||||
constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
|
||||
switch (t) {
|
||||
case BackendComponent::CPUBit:
|
||||
return autocast_cpu_ks;
|
||||
@ -881,8 +878,6 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
return autocast_xla_ks;
|
||||
case BackendComponent::PrivateUse1Bit:
|
||||
return autocast_privateuse1_ks;
|
||||
case BackendComponent::MPSBit:
|
||||
return autocast_mps_ks;
|
||||
default:
|
||||
return DispatchKeySet();
|
||||
}
|
||||
|
@ -344,55 +344,6 @@ class TestAutocastGPU(TestCase):
|
||||
torch._C._set_cached_tensors_enabled(False)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
|
||||
class TestAutocastMPS(TestCase):
|
||||
def test_cast_cache_is_global(self):
|
||||
class CustomLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w_t):
|
||||
ctx.save_for_backward(x, w_t)
|
||||
return torch.nn.functional.linear(x, w_t)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, w_t = ctx.saved_tensors
|
||||
with torch.autocast(device_type="mps"):
|
||||
dL_dX = torch.matmul(grad_output, w_t)
|
||||
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
|
||||
return dL_dX, dL_dW
|
||||
|
||||
data = torch.randn(2, 3).to("mps")
|
||||
weight = torch.nn.Parameter(torch.randn(4, 3).to("mps"))
|
||||
weight_dtype_cast_counter = 0
|
||||
|
||||
class WeightDTypeCastCounterMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if (
|
||||
func is torch.ops.aten._to_copy.default
|
||||
and args[0] is weight
|
||||
and kwargs["dtype"] is torch.float16
|
||||
):
|
||||
nonlocal weight_dtype_cast_counter
|
||||
weight_dtype_cast_counter += 1
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
# self.old_clear_cache = torch.clear_autocast_cache
|
||||
# torch.clear_autocast_cache = lambda: None
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# torch.clear_autocast_cache = self.old_clear_cache
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
with WeightDTypeCastCounterMode():
|
||||
with torch.autocast(device_type="mps"):
|
||||
output = CustomLinear.apply(data, weight)
|
||||
s = output.sum()
|
||||
s.backward()
|
||||
self.assertEqual(weight_dtype_cast_counter, 2)
|
||||
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
|
@ -1199,29 +1199,6 @@ class MpsMemoryLeakCheck:
|
||||
|
||||
raise RuntimeError(msg)
|
||||
|
||||
class TestAutocastMPS(TestCase):
|
||||
|
||||
def test_matmul_autocast(self):
|
||||
autocast_tensor_A = torch.rand((8, 8), device="mps")
|
||||
autocast_tensor_B = torch.rand((8, 8), device="mps")
|
||||
tensor_A = autocast_tensor_A.clone().detach()
|
||||
tensor_B = autocast_tensor_B.clone().detach()
|
||||
autocast_output_tensor = torch.empty(8, 8)
|
||||
output_tensor = autocast_output_tensor.clone().detach()
|
||||
|
||||
with torch.autocast(device_type="mps"):
|
||||
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
|
||||
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
|
||||
|
||||
output_tensor = torch.mm(tensor_A, tensor_B)
|
||||
output_tensor = torch.mm(tensor_A, output_tensor)
|
||||
|
||||
self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
|
||||
self.assertEqual(autocast_output_tensor,
|
||||
output_tensor.to(torch.float16),
|
||||
f"Autocast & non-autocast tensors did not match, \
|
||||
got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
|
||||
|
||||
# Expand TestCase class with Memory Leak Detection on MPS device
|
||||
class TestCaseMPS(TestCase):
|
||||
_do_mps_memory_leak_check = True
|
||||
|
@ -322,15 +322,6 @@ class autocast:
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "mps":
|
||||
supported_dtype = [torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"MPS Autocast only supports dtype of torch.float16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xla":
|
||||
supported_dtype = [torch.float16, torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
|
@ -109,7 +109,7 @@ std::optional<AutocastScope> parseAutocast(
|
||||
TORCH_CHECK(
|
||||
dtype != c10::ScalarType::Undefined,
|
||||
"Autocast has invalid fast_dtype attribute");
|
||||
if (device == "cuda" || device == "mps") {
|
||||
if (device == "cuda") {
|
||||
scope.context.gpu_enabled = enabled.value();
|
||||
scope.context.gpu_scalar_type = dtype;
|
||||
} else if (device == "cpu") {
|
||||
|
@ -689,7 +689,6 @@ void initDispatchBindings(PyObject* module) {
|
||||
DEF_ONE(PreDispatch)
|
||||
DEF_ONE(Functionalize)
|
||||
DEF_ONE(AutocastCPU)
|
||||
DEF_ONE(AutocastMPS)
|
||||
DEF_ONE(AutocastXPU)
|
||||
DEF_ONE(AutocastHPU)
|
||||
DEF_ONE(AutocastIPU)
|
||||
|
@ -7,8 +7,4 @@ __all__ = ["amp_definitely_not_available"]
|
||||
|
||||
|
||||
def amp_definitely_not_available():
|
||||
return not (
|
||||
torch.cuda.is_available()
|
||||
or find_spec("torch_xla")
|
||||
or torch.backends.mps.is_available()
|
||||
)
|
||||
return not (torch.cuda.is_available() or find_spec("torch_xla"))
|
||||
|
Reference in New Issue
Block a user