[MPS] Add support for autocast in MPS (#99272)

Fixes https://github.com/pytorch/pytorch/issues/88415

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet
This commit is contained in:
Kulin Seth
2024-08-05 17:02:30 +00:00
committed by PyTorch MergeBot
parent d532c00c81
commit 6919e8baab
11 changed files with 231 additions and 3 deletions

View File

@ -69,7 +69,7 @@ thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
at::ScalarType::Undefined, // Vulkan
at::ScalarType::Undefined, // Metal
at::kHalf, // XPU
at::ScalarType::Undefined, // MPS
at::kHalf, // MPS
at::ScalarType::Undefined, // Meta (tensors with no data)
at::kBFloat16, // HPU / HABANA
at::ScalarType::Undefined, // SX-Aurora / NEC
@ -206,6 +206,117 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
}
TORCH_LIBRARY_IMPL(_, AutocastMPS, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
// lower_precision_fp
KERNEL_MPS2(_convolution, deprecated, lower_precision_fp)
KERNEL_MPS(_convolution, lower_precision_fp)
KERNEL_MPS(conv1d, lower_precision_fp)
KERNEL_MPS(conv2d, lower_precision_fp)
KERNEL_MPS(conv_tbc, lower_precision_fp)
KERNEL_MPS(conv_transpose1d, lower_precision_fp)
KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp)
KERNEL_MPS(convolution, lower_precision_fp)
KERNEL_MPS(_mps_convolution, lower_precision_fp)
KERNEL_MPS(prelu, lower_precision_fp)
KERNEL_MPS(addmm, lower_precision_fp)
KERNEL_MPS(addmv, lower_precision_fp)
KERNEL_MPS(addr, lower_precision_fp)
KERNEL_MPS(matmul, lower_precision_fp)
KERNEL_MPS(einsum, lower_precision_fp)
KERNEL_MPS(mm, lower_precision_fp)
KERNEL_MPS(mv, lower_precision_fp)
KERNEL_MPS(linear, lower_precision_fp)
KERNEL_MPS(addbmm, lower_precision_fp)
KERNEL_MPS(baddbmm, lower_precision_fp)
KERNEL_MPS(bmm, lower_precision_fp)
KERNEL_MPS(chain_matmul, lower_precision_fp)
KERNEL_MPS(linalg_multi_dot, lower_precision_fp)
KERNEL_MPS(lstm_cell, lower_precision_fp)
// fp32
KERNEL_MPS(acos, fp32)
KERNEL_MPS(asin, fp32)
KERNEL_MPS(cosh, fp32)
KERNEL_MPS(erfinv, fp32)
KERNEL_MPS(exp, fp32)
KERNEL_MPS(expm1, fp32)
KERNEL_MPS(log, fp32)
KERNEL_MPS(log10, fp32)
KERNEL_MPS(log2, fp32)
KERNEL_MPS(log1p, fp32)
KERNEL_MPS(reciprocal, fp32)
KERNEL_MPS(rsqrt, fp32)
KERNEL_MPS(sinh, fp32)
KERNEL_MPS(tan, fp32)
KERNEL_MPS2(pow, Tensor_Scalar, fp32)
KERNEL_MPS2(pow, Tensor_Tensor, fp32)
KERNEL_MPS2(pow, Scalar, fp32)
KERNEL_MPS(softplus, fp32)
KERNEL_MPS(layer_norm, fp32)
KERNEL_MPS(native_layer_norm, fp32)
KERNEL_MPS(group_norm, fp32)
KERNEL_MPS2(frobenius_norm, dim, fp32)
KERNEL_MPS(nuclear_norm, fp32)
KERNEL_MPS2(nuclear_norm, dim, fp32)
KERNEL_MPS(cosine_similarity, fp32)
KERNEL_MPS(poisson_nll_loss, fp32)
KERNEL_MPS(cosine_embedding_loss, fp32)
KERNEL_MPS(nll_loss, fp32)
KERNEL_MPS(nll_loss2d, fp32)
KERNEL_MPS(hinge_embedding_loss, fp32)
KERNEL_MPS(kl_div, fp32)
KERNEL_MPS(l1_loss, fp32)
KERNEL_MPS(smooth_l1_loss, fp32)
KERNEL_MPS(huber_loss, fp32)
KERNEL_MPS(mse_loss, fp32)
KERNEL_MPS(margin_ranking_loss, fp32)
KERNEL_MPS(multilabel_margin_loss, fp32)
KERNEL_MPS(soft_margin_loss, fp32)
KERNEL_MPS(triplet_margin_loss, fp32)
KERNEL_MPS(multi_margin_loss, fp32)
KERNEL_MPS(binary_cross_entropy_with_logits, fp32)
KERNEL_MPS(dist, fp32)
KERNEL_MPS(pdist, fp32)
KERNEL_MPS(cdist, fp32)
KERNEL_MPS(renorm, fp32)
KERNEL_MPS(logsumexp, fp32)
// fp32_set_opt_dtype
KERNEL_MPS(prod, fp32)
KERNEL_MPS2(prod, dim_int, fp32)
KERNEL_MPS2(prod, dim_Dimname, fp32)
KERNEL_MPS2(softmax, int, fp32)
KERNEL_MPS2(softmax, Dimname, fp32)
KERNEL_MPS2(log_softmax, int, fp32)
KERNEL_MPS2(log_softmax, Dimname, fp32)
KERNEL_MPS(cumprod, fp32)
KERNEL_MPS2(cumprod, dimname, fp32)
KERNEL_MPS(cumsum, fp32)
KERNEL_MPS2(cumsum, dimname, fp32)
KERNEL_MPS(linalg_vector_norm, fp32)
KERNEL_MPS(linalg_matrix_norm, fp32)
KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32)
KERNEL_MPS(sum, fp32)
KERNEL_MPS2(sum, dim_IntList, fp32)
KERNEL_MPS2(sum, dim_DimnameList, fp32)
//
// promote
KERNEL_MPS(addcdiv, promote)
KERNEL_MPS(addcmul, promote)
KERNEL_MPS(atan2, promote)
KERNEL_MPS(bilinear, promote)
KERNEL_MPS(cross, promote)
KERNEL_MPS(dot, promote)
KERNEL_MPS(grid_sampler, promote)
KERNEL_MPS(index_put, promote)
KERNEL_MPS(tensordot, promote)
KERNEL_MPS(scatter_add, promote)
}
TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}

View File

@ -145,6 +145,8 @@ inline bool is_autocast_eligible(
return tensor.is_xla() && tensor.is_floating_point();
case c10::DeviceType::PrivateUse1:
return tensor.is_privateuseone() && tensor.is_floating_point();
case c10::DeviceType::MPS:
return tensor.is_mps() && tensor.is_floating_point();
default:
return false;
}
@ -168,6 +170,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
return DispatchKey::AutocastXLA;
case c10::DeviceType::PrivateUse1:
return DispatchKey::AutocastPrivateUse1;
case c10::DeviceType::MPS:
return DispatchKey::AutocastMPS;
default:
throw std::runtime_error(
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
@ -178,7 +182,7 @@ inline bool is_autocast_available(c10::DeviceType device_type) {
if (device_type == at::kCPU || device_type == at::kCUDA ||
device_type == at::kXPU || device_type == at::kIPU ||
device_type == at::kHPU || device_type == at::kXLA ||
device_type == at::kPrivateUse1) {
device_type == at::kPrivateUse1 || device_type == at::kMPS) {
return true;
} else {
return false;
@ -745,6 +749,27 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_MPS registration for AutocastMPS
#define KERNEL_MPS(OP, POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" #OP), \
&WrapFunction< \
CastPolicy::POLICY, \
DeviceType::MPS, \
decltype(ATEN_FN(OP)), \
decltype(ATEN_FN(OP)), \
&ATEN_FN(OP)>::type::call);
#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
&WrapFunction< \
CastPolicy::POLICY, \
DeviceType::MPS, \
decltype(ATEN_FN2(OP, OVERLOAD)), \
decltype(ATEN_FN2(OP, OVERLOAD)), \
&ATEN_FN2(OP, OVERLOAD)>::type::call);
// Op lists for different policies.
// To make sure other backends can reuse the policy op list.
#define AT_FORALL_LOWER_PRECISION_FP(_) \

View File

@ -228,6 +228,7 @@ namespace c10 {
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
_(aten, is_autocast_mps_enabled) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \

View File

@ -149,6 +149,8 @@ const char* toString(DispatchKey t) {
return "AutocastXLA";
case DispatchKey::AutocastPrivateUse1:
return "AutocastPrivateUse1";
case DispatchKey::AutocastMPS:
return "AutocastMPS";
case DispatchKey::FuncTorchBatched:
return "FuncTorchBatched";
@ -297,6 +299,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
{"AutocastXLA", c10::DispatchKey::AutocastXLA},
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
{"AutocastMPS", c10::DispatchKey::AutocastMPS},
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
{"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor},
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},

View File

@ -359,6 +359,7 @@ enum class DispatchKey : uint16_t {
AutocastXLA,
// AutocastXLA is only being used for TPUs. XLA GPUs continue to use
// AutocastCUDA.
AutocastMPS,
AutocastCUDA,
AutocastPrivateUse1,

View File

@ -655,6 +655,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastCPU,
DispatchKey::AutocastMPS,
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastIPU,
@ -671,6 +672,7 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastCPU,
DispatchKey::AutocastMPS,
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastIPU,
@ -863,6 +865,7 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
constexpr auto autocast_privateuse1_ks =
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
@ -878,6 +881,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
return autocast_xla_ks;
case BackendComponent::PrivateUse1Bit:
return autocast_privateuse1_ks;
case BackendComponent::MPSBit:
return autocast_mps_ks;
default:
return DispatchKeySet();
}

View File

@ -344,6 +344,55 @@ class TestAutocastGPU(TestCase):
torch._C._set_cached_tensors_enabled(False)
@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
class TestAutocastMPS(TestCase):
def test_cast_cache_is_global(self):
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type="mps"):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
data = torch.randn(2, 3).to("mps")
weight = torch.nn.Parameter(torch.randn(4, 3).to("mps"))
weight_dtype_cast_counter = 0
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default
and args[0] is weight
and kwargs["dtype"] is torch.float16
):
nonlocal weight_dtype_cast_counter
weight_dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
# self.old_clear_cache = torch.clear_autocast_cache
# torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
# torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
with WeightDTypeCastCounterMode():
with torch.autocast(device_type="mps"):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(weight_dtype_cast_counter, 2)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()

View File

@ -1203,6 +1203,29 @@ class MpsMemoryLeakCheck:
raise RuntimeError(msg)
class TestAutocastMPS(TestCase):
def test_matmul_autocast(self):
autocast_tensor_A = torch.rand((8, 8), device="mps")
autocast_tensor_B = torch.rand((8, 8), device="mps")
tensor_A = autocast_tensor_A.clone().detach()
tensor_B = autocast_tensor_B.clone().detach()
autocast_output_tensor = torch.empty(8, 8)
output_tensor = autocast_output_tensor.clone().detach()
with torch.autocast(device_type="mps"):
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
output_tensor = torch.mm(tensor_A, tensor_B)
output_tensor = torch.mm(tensor_A, output_tensor)
self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
self.assertEqual(autocast_output_tensor,
output_tensor.to(torch.float16),
f"Autocast & non-autocast tensors did not match, \
got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
# Expand TestCase class with Memory Leak Detection on MPS device
class TestCaseMPS(TestCase):
_do_mps_memory_leak_check = True

View File

@ -322,6 +322,15 @@ class autocast:
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"MPS Autocast only supports dtype of torch.float16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:

View File

@ -108,7 +108,7 @@ std::optional<AutocastScope> parseAutocast(
TORCH_CHECK(
dtype != c10::ScalarType::Undefined,
"Autocast has invalid fast_dtype attribute");
if (device == "cuda") {
if (device == "cuda" || device == "mps") {
scope.context.gpu_enabled = enabled.value();
scope.context.gpu_scalar_type = dtype;
} else if (device == "cpu") {

View File

@ -726,6 +726,7 @@ void initDispatchBindings(PyObject* module) {
DEF_ONE(PreDispatch)
DEF_ONE(Functionalize)
DEF_ONE(AutocastCPU)
DEF_ONE(AutocastMPS)
DEF_ONE(AutocastXPU)
DEF_ONE(AutocastHPU)
DEF_ONE(AutocastIPU)