From 583bbf7761d7a1e7dd28ac3f73f6bb7d98018c67 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Mon, 8 Sep 2025 15:22:37 -0500 Subject: [PATCH] [MPS] Add `native_dropout` and `native_dropout_backward` (#162108) Fixes #162002 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162108 Approved by: https://github.com/malfet --- .../native/mps/kernels/BinaryKernel.metal | 11 +++++ .../native/mps/operations/BinaryKernel.mm | 4 ++ .../src/ATen/native/mps/operations/Dropout.mm | 45 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 33 ++++++++++++++ .../aoti_torch/generated/c_shim_mps.h | 1 + torch/testing/_internal/common_mps.py | 1 - 7 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 aten/src/ATen/native/mps/operations/Dropout.mm diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 5b908e7b882f..0539eab79500 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -39,6 +39,13 @@ struct lerp_alpha_functor { } }; +struct native_dropout_mask_and_scale_functor { + template + inline TA operator()(const TI a, const TI b, const TA scale) { + return static_cast(a) * static_cast(b) * scale; + } +}; + struct fmax_functor { template inline T operator()(const T a, const T b) { @@ -427,6 +434,10 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar); REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char); REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool); +REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, float, float, float); +REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, half, half, half); + REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 7d812648e3bf..0b303f48028f 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -168,6 +168,10 @@ static void lerp_scalar_mps_kernel(at::TensorIteratorBase& iter, const Scalar& w lib.exec_binary_kernel(iter, "lerp_alpha", weight); } +static void native_dropout_mask_and_scale_mps_kernel(at::TensorIteratorBase& iter, const Scalar& scale) { + lib.exec_binary_kernel(iter, "native_dropout_mask_and_scale", scale); +} + static void mul_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "mul"); } diff --git a/aten/src/ATen/native/mps/operations/Dropout.mm b/aten/src/ATen/native/mps/operations/Dropout.mm new file mode 100644 index 000000000000..116367d809eb --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Dropout.mm @@ -0,0 +1,45 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +static Tensor native_dropout_mask_and_scale(const Tensor& input, const Tensor& mask, float scale) { + auto output = at::empty_like(input); + mps::binary_op_kernel("native_dropout_mask_and_scale", input, mask, output, scale); + return output; +} + +std::tuple native_dropout_mps(const Tensor& input, double p, std::optional train) { + if (input.numel() == 0 || !train.value_or(false) || p == 0) { + return {input.clone(), at::ones_like(input, input.options().dtype(c10::kBool))}; + } + + float p_comp = 1.0f - p; + Tensor mask = at::empty_like(input, input.options().dtype(c10::kBool)); + mask.bernoulli_(p_comp); + auto scale = p_comp == 0 ? 0.0f : 1.0f / p_comp; + Tensor output = native_dropout_mask_and_scale(input, mask, scale); + return {std::move(output), std::move(mask)}; +} + +Tensor native_dropout_backward_mps(const Tensor& grad, const Tensor& mask, double scale) { + auto grad_float = isFloatingType(grad.scalar_type()) ? grad : grad.to(c10::kFloat); + return native_dropout_mask_and_scale(grad_float, mask, scale); +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 027d3e1f2cb2..abb061afc5c9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -288,6 +288,7 @@ dispatch: CPU: native_dropout_cpu CUDA: native_dropout_cuda + MPS: native_dropout_mps NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_nested tags: [nondeterministic_seeded, core] autogen: native_dropout.out @@ -296,6 +297,7 @@ dispatch: CPU, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_backward CUDA: native_dropout_backward_cuda + MPS: native_dropout_backward_mps autogen: native_dropout_backward.out tags: pointwise diff --git a/test/test_mps.py b/test/test_mps.py index 36e9f96079ee..756b2cd20567 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7524,6 +7524,39 @@ class TestMPS(TestCaseMPS): uniq = mps_out.unique() self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype)) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_dropout(self, dtype): + shapes = [ + (100_000,), + (100, 1000), + (10, 100, 100), + (10, 10, 10, 10, 10), + ] + p_list = [0, 0.34, 0.78, 1] + + for shape, p, train in itertools.product(shapes, p_list, [False, True]): + input = torch.randn(shape, device='mps', dtype=dtype, requires_grad=True) + output, mask = torch.native_dropout(input, p, train=train) + + p_actual_mps = 1 - (mask.sum() / mask.numel()) + if train: + self.assertEqual(p_actual_mps, p, atol=1e-2, rtol=1e-2) + self.assertTrue((output[mask.logical_not()] == 0).all()) + self.assertEqual(output[mask], input[mask] / (1 - p)) + else: + self.assertEqual(output, input) + self.assertTrue(mask.all()) + + output_grad = torch.randn_like(output) + output.backward(output_grad) + + grad_scale = 0 if p == 1 else 1 / (1 - p) + if train: + self.assertEqual(input.grad, output_grad * mask * grad_scale) + else: + self.assertEqual(input.grad, output_grad) + + def test_mps_generator(self): # explicit manual seeding by creating an MPS Generator g_mps = torch.Generator(device='mps') diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index 56bd07115858..179c0074b3cd 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -78,6 +78,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index cc5d63582c69..ea07fd3c0514 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -340,7 +340,6 @@ if torch.backends.mps.is_available(): "masked.median": None, "matrix_exp": None, "mode": None, - "native_dropout_backward": None, "normnuc": None, "nn.functional.fractional_max_pool2d": None, "nn.functional.fractional_max_pool3d": None,