[MPS] Add native_dropout and native_dropout_backward (#162108)

Fixes #162002
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162108
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-09-08 15:22:37 -05:00
committed by PyTorch MergeBot
parent e025c0f459
commit 583bbf7761
7 changed files with 96 additions and 1 deletions

View File

@ -39,6 +39,13 @@ struct lerp_alpha_functor {
}
};
struct native_dropout_mask_and_scale_functor {
template <typename TI, typename TA>
inline TA operator()(const TI a, const TI b, const TA scale) {
return static_cast<TA>(a) * static_cast<TA>(b) * scale;
}
};
struct fmax_functor {
template <typename T>
inline T operator()(const T a, const T b) {
@ -427,6 +434,10 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, float, float, float);
REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, half, half, half);
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);

View File

@ -168,6 +168,10 @@ static void lerp_scalar_mps_kernel(at::TensorIteratorBase& iter, const Scalar& w
lib.exec_binary_kernel(iter, "lerp_alpha", weight);
}
static void native_dropout_mask_and_scale_mps_kernel(at::TensorIteratorBase& iter, const Scalar& scale) {
lib.exec_binary_kernel(iter, "native_dropout_mask_and_scale", scale);
}
static void mul_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "mul");
}

View File

@ -0,0 +1,45 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TensorOperators.h>
#include <ATen/mps/MPSGeneratorImpl.h>
#include <ATen/native/Distributions.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/operations/BinaryKernel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/bernoulli.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/native_dropout_backward_native.h>
#include <ATen/ops/native_dropout_native.h>
#include <ATen/ops/ones_like.h>
#endif
namespace at::native {
static Tensor native_dropout_mask_and_scale(const Tensor& input, const Tensor& mask, float scale) {
auto output = at::empty_like(input);
mps::binary_op_kernel("native_dropout_mask_and_scale", input, mask, output, scale);
return output;
}
std::tuple<Tensor, Tensor> native_dropout_mps(const Tensor& input, double p, std::optional<bool> train) {
if (input.numel() == 0 || !train.value_or(false) || p == 0) {
return {input.clone(), at::ones_like(input, input.options().dtype(c10::kBool))};
}
float p_comp = 1.0f - p;
Tensor mask = at::empty_like(input, input.options().dtype(c10::kBool));
mask.bernoulli_(p_comp);
auto scale = p_comp == 0 ? 0.0f : 1.0f / p_comp;
Tensor output = native_dropout_mask_and_scale(input, mask, scale);
return {std::move(output), std::move(mask)};
}
Tensor native_dropout_backward_mps(const Tensor& grad, const Tensor& mask, double scale) {
auto grad_float = isFloatingType(grad.scalar_type()) ? grad : grad.to(c10::kFloat);
return native_dropout_mask_and_scale(grad_float, mask, scale);
}
} // namespace at::native

View File

@ -288,6 +288,7 @@
dispatch:
CPU: native_dropout_cpu
CUDA: native_dropout_cuda
MPS: native_dropout_mps
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_nested
tags: [nondeterministic_seeded, core]
autogen: native_dropout.out
@ -296,6 +297,7 @@
dispatch:
CPU, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_backward
CUDA: native_dropout_backward_cuda
MPS: native_dropout_backward_mps
autogen: native_dropout_backward.out
tags: pointwise

View File

@ -7524,6 +7524,39 @@ class TestMPS(TestCaseMPS):
uniq = mps_out.unique()
self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_dropout(self, dtype):
shapes = [
(100_000,),
(100, 1000),
(10, 100, 100),
(10, 10, 10, 10, 10),
]
p_list = [0, 0.34, 0.78, 1]
for shape, p, train in itertools.product(shapes, p_list, [False, True]):
input = torch.randn(shape, device='mps', dtype=dtype, requires_grad=True)
output, mask = torch.native_dropout(input, p, train=train)
p_actual_mps = 1 - (mask.sum() / mask.numel())
if train:
self.assertEqual(p_actual_mps, p, atol=1e-2, rtol=1e-2)
self.assertTrue((output[mask.logical_not()] == 0).all())
self.assertEqual(output[mask], input[mask] / (1 - p))
else:
self.assertEqual(output, input)
self.assertTrue(mask.all())
output_grad = torch.randn_like(output)
output.backward(output_grad)
grad_scale = 0 if p == 1 else 1 / (1 - p)
if train:
self.assertEqual(input.grad, output_grad * mask * grad_scale)
else:
self.assertEqual(input.grad, output_grad)
def test_mps_generator(self):
# explicit manual seeding by creating an MPS Generator
g_mps = torch.Generator(device='mps')

View File

@ -78,6 +78,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);

View File

@ -340,7 +340,6 @@ if torch.backends.mps.is_available():
"masked.median": None,
"matrix_exp": None,
"mode": None,
"native_dropout_backward": None,
"normnuc": None,
"nn.functional.fractional_max_pool2d": None,
"nn.functional.fractional_max_pool3d": None,