[MPS] Fixes GELU, LeakyRELU and MISH on non-contiguous tensors (#123049)

Fixes GELU, LeakyRELU and MISH activation functions on non-contiguous tensors (for instance, when a transpose operation was applied on the tensors prior to the MPS operator), forward and backward passes.

I also extended tests on the 3 activation functions to check: full-precision and half-precision, contiguous and non-contiguous, and several dims of tensors: scalars, 1D, empty, 2D, > 3D.

I had issues with Mish and GELU activations when asserting the gradients vs. CPU with sum() on some cases, so I reverted to the previous setup by setting a gradient parameter on .backwards().
This PR also fixes an issue with LeakyRELU on empty tensors.

Fixes #98212 huggingface/transformers#22468 huggingface/transformers#19353
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123049
Approved by: https://github.com/kulinseth
This commit is contained in:
Joël Tang
2024-04-21 00:12:29 +00:00
committed by PyTorch MergeBot
parent 98f3e0214b
commit a6a3f2e06b
2 changed files with 131 additions and 25 deletions

View File

@ -131,8 +131,17 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
using CachedGraph = MPSUnaryCachedGraph;
TORCH_CHECK(output.is_mps());
if (self.numel() == 0) {
return;
}
MPSStream* stream = getCurrentMPSStream();
bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
@autoreleasepool {
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
@ -152,13 +161,17 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);
// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
output.copy_(output_);
}
}
TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
@ -171,8 +184,14 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
using CachedGraph = MPSUnaryGradCachedGraph;
TORCH_CHECK(output.is_mps());
if (self.numel() == 0) {
return;
}
MPSStream* stream = getCurrentMPSStream();
Tensor output_ = at::empty_like(self, self.suggest_memory_format());
@autoreleasepool {
string key =
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
@ -202,12 +221,13 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output_);
// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
output.copy_(output_);
}
TORCH_IMPL_FUNC(log_softmax_mps_out)
@ -656,6 +676,11 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate,
auto approximate_type = get_gelutype_enum(approximate);
MPSStream* stream = getCurrentMPSStream();
bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
@autoreleasepool {
const auto key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + gelutype_to_string(approximate_type);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
@ -672,12 +697,17 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate,
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
output.copy_(output_);
}
}
TORCH_IMPL_FUNC(gelu_backward_out_mps)
@ -686,8 +716,11 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps)
using CachedGraph = MPSUnaryGradCachedGraph;
// Empty output
if (grad_input.numel() == 0)
if (self.numel() == 0) {
return;
}
Tensor grad_input_ = at::empty_like(self, self.suggest_memory_format());
auto approximate_type = get_gelutype_enum(approximate);
MPSStream* stream = getCurrentMPSStream();
@ -761,11 +794,12 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps)
Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_);
auto feeds = dictionaryFromPlaceholders(gradPlaceholder, selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
grad_input.copy_(grad_input_);
}
static void elu_variants_out_mps(const Tensor& self,
@ -1241,6 +1275,11 @@ TORCH_IMPL_FUNC(mish_out_mps)
MPSStream* stream = getCurrentMPSStream();
bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
@autoreleasepool {
string key = "mish_out_mps:" + getTensorsStringKey({self});
@ -1257,12 +1296,16 @@ TORCH_IMPL_FUNC(mish_out_mps)
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false);
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
result.copy_(result_);
}
}
Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {

View File

@ -1470,9 +1470,19 @@ class MPSLeakyReluTest(TestCaseMPS):
0.9]]),
negative_slope=0.1))
def _testLeakyRelu(self, np_features, negative_slope, device):
cpu_x = torch.from_numpy(np_features).requires_grad_()
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
if not contiguous and not (0 in shape or len(shape) < 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
mps_x = mps_x.transpose(0, 1)
assert not mps_x.is_contiguous()
cpu_x.requires_grad_()
mps_x.requires_grad_()
relu_op = torch.nn.LeakyReLU(negative_slope)
cpu_leaky_relu = relu_op(cpu_x)
@ -1480,19 +1490,24 @@ class MPSLeakyReluTest(TestCaseMPS):
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
# test backward pass
cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
cpu_leaky_relu.backward(gradient=cpu_grad)
mps_leaky_relu.backward(gradient=mps_grad)
torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
cpu_leaky_relu.backward(gradient=cpu_grad)
assert cpu_x.grad is not None # Check that the grad is well-populated
self.assertEqual(cpu_x.grad, mps_x.grad)
def testNumbersCPU(self):
for t in [np.float32]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
for t in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
self._testLeakyRelu(shape,
dtype=t,
negative_slope=0.2,
device="cpu")
contiguous=contiguous)
class TestAvgPool(TestCaseMPS):
def _sum_pool2d(self, x, kernel_size):
@ -6631,9 +6646,18 @@ class TestMPS(TestCaseMPS):
helper((2, 16, 16), (4, 4), return_indices, dtype)
def test_gelu_simple(self):
def helper(shape, dtype=torch.float):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
gelu_result = torch.nn.GELU()(x)
# GELU is not supported on CPU, so cast it to float
@ -6648,16 +6672,55 @@ class TestMPS(TestCaseMPS):
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
helper(shape, dtype)
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
# Test that gelu would raise an assert for integral types
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
def test_mish_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
mish_result = torch.nn.Mish()(x)
mish_result_cpu = torch.nn.Mish()(cpu_x)
cpu_grad = torch.ones_like(mish_result_cpu)
grad = cpu_grad.to('mps')
mish_result.backward(gradient=grad)
mish_result_cpu.backward(gradient=cpu_grad)
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
def test_gelu(self):
def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
numpy_dtype = {