mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add native_dropout
and native_dropout_backward
(#162108)
Fixes #162002 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162108 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
e025c0f459
commit
583bbf7761
@ -39,6 +39,13 @@ struct lerp_alpha_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct native_dropout_mask_and_scale_functor {
|
||||
template <typename TI, typename TA>
|
||||
inline TA operator()(const TI a, const TI b, const TA scale) {
|
||||
return static_cast<TA>(a) * static_cast<TA>(b) * scale;
|
||||
}
|
||||
};
|
||||
|
||||
struct fmax_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
@ -427,6 +434,10 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, half, half, half);
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
|
||||
|
@ -168,6 +168,10 @@ static void lerp_scalar_mps_kernel(at::TensorIteratorBase& iter, const Scalar& w
|
||||
lib.exec_binary_kernel(iter, "lerp_alpha", weight);
|
||||
}
|
||||
|
||||
static void native_dropout_mask_and_scale_mps_kernel(at::TensorIteratorBase& iter, const Scalar& scale) {
|
||||
lib.exec_binary_kernel(iter, "native_dropout_mask_and_scale", scale);
|
||||
}
|
||||
|
||||
static void mul_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "mul");
|
||||
}
|
||||
|
45
aten/src/ATen/native/mps/operations/Dropout.mm
Normal file
45
aten/src/ATen/native/mps/operations/Dropout.mm
Normal file
@ -0,0 +1,45 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#include <ATen/native/Distributions.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/operations/BinaryKernel.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/bernoulli.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/native_dropout_backward_native.h>
|
||||
#include <ATen/ops/native_dropout_native.h>
|
||||
#include <ATen/ops/ones_like.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
static Tensor native_dropout_mask_and_scale(const Tensor& input, const Tensor& mask, float scale) {
|
||||
auto output = at::empty_like(input);
|
||||
mps::binary_op_kernel("native_dropout_mask_and_scale", input, mask, output, scale);
|
||||
return output;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> native_dropout_mps(const Tensor& input, double p, std::optional<bool> train) {
|
||||
if (input.numel() == 0 || !train.value_or(false) || p == 0) {
|
||||
return {input.clone(), at::ones_like(input, input.options().dtype(c10::kBool))};
|
||||
}
|
||||
|
||||
float p_comp = 1.0f - p;
|
||||
Tensor mask = at::empty_like(input, input.options().dtype(c10::kBool));
|
||||
mask.bernoulli_(p_comp);
|
||||
auto scale = p_comp == 0 ? 0.0f : 1.0f / p_comp;
|
||||
Tensor output = native_dropout_mask_and_scale(input, mask, scale);
|
||||
return {std::move(output), std::move(mask)};
|
||||
}
|
||||
|
||||
Tensor native_dropout_backward_mps(const Tensor& grad, const Tensor& mask, double scale) {
|
||||
auto grad_float = isFloatingType(grad.scalar_type()) ? grad : grad.to(c10::kFloat);
|
||||
return native_dropout_mask_and_scale(grad_float, mask, scale);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -288,6 +288,7 @@
|
||||
dispatch:
|
||||
CPU: native_dropout_cpu
|
||||
CUDA: native_dropout_cuda
|
||||
MPS: native_dropout_mps
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_nested
|
||||
tags: [nondeterministic_seeded, core]
|
||||
autogen: native_dropout.out
|
||||
@ -296,6 +297,7 @@
|
||||
dispatch:
|
||||
CPU, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_backward
|
||||
CUDA: native_dropout_backward_cuda
|
||||
MPS: native_dropout_backward_mps
|
||||
autogen: native_dropout_backward.out
|
||||
tags: pointwise
|
||||
|
||||
|
@ -7524,6 +7524,39 @@ class TestMPS(TestCaseMPS):
|
||||
uniq = mps_out.unique()
|
||||
self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
|
||||
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_dropout(self, dtype):
|
||||
shapes = [
|
||||
(100_000,),
|
||||
(100, 1000),
|
||||
(10, 100, 100),
|
||||
(10, 10, 10, 10, 10),
|
||||
]
|
||||
p_list = [0, 0.34, 0.78, 1]
|
||||
|
||||
for shape, p, train in itertools.product(shapes, p_list, [False, True]):
|
||||
input = torch.randn(shape, device='mps', dtype=dtype, requires_grad=True)
|
||||
output, mask = torch.native_dropout(input, p, train=train)
|
||||
|
||||
p_actual_mps = 1 - (mask.sum() / mask.numel())
|
||||
if train:
|
||||
self.assertEqual(p_actual_mps, p, atol=1e-2, rtol=1e-2)
|
||||
self.assertTrue((output[mask.logical_not()] == 0).all())
|
||||
self.assertEqual(output[mask], input[mask] / (1 - p))
|
||||
else:
|
||||
self.assertEqual(output, input)
|
||||
self.assertTrue(mask.all())
|
||||
|
||||
output_grad = torch.randn_like(output)
|
||||
output.backward(output_grad)
|
||||
|
||||
grad_scale = 0 if p == 1 else 1 / (1 - p)
|
||||
if train:
|
||||
self.assertEqual(input.grad, output_grad * mask * grad_scale)
|
||||
else:
|
||||
self.assertEqual(input.grad, output_grad)
|
||||
|
||||
|
||||
def test_mps_generator(self):
|
||||
# explicit manual seeding by creating an MPS Generator
|
||||
g_mps = torch.Generator(device='mps')
|
||||
|
@ -78,6 +78,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
|
||||
|
@ -340,7 +340,6 @@ if torch.backends.mps.is_available():
|
||||
"masked.median": None,
|
||||
"matrix_exp": None,
|
||||
"mode": None,
|
||||
"native_dropout_backward": None,
|
||||
"normnuc": None,
|
||||
"nn.functional.fractional_max_pool2d": None,
|
||||
"nn.functional.fractional_max_pool3d": None,
|
||||
|
Reference in New Issue
Block a user