mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fixes GELU, LeakyRELU and MISH on non-contiguous tensors (#123049)
Fixes GELU, LeakyRELU and MISH activation functions on non-contiguous tensors (for instance, when a transpose operation was applied on the tensors prior to the MPS operator), forward and backward passes. I also extended tests on the 3 activation functions to check: full-precision and half-precision, contiguous and non-contiguous, and several dims of tensors: scalars, 1D, empty, 2D, > 3D. I had issues with Mish and GELU activations when asserting the gradients vs. CPU with sum() on some cases, so I reverted to the previous setup by setting a gradient parameter on .backwards(). This PR also fixes an issue with LeakyRELU on empty tensors. Fixes #98212 huggingface/transformers#22468 huggingface/transformers#19353 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123049 Approved by: https://github.com/kulinseth
This commit is contained in:
committed by
PyTorch MergeBot
parent
98f3e0214b
commit
a6a3f2e06b
@ -131,8 +131,17 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
TORCH_CHECK(output.is_mps());
|
||||
|
||||
if (self.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
bool executeGatherOp =
|
||||
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
|
||||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
|
||||
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -152,13 +161,17 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);
|
||||
|
||||
// Create dictionary of inputs and outputs
|
||||
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
}
|
||||
if (executeGatherOp) {
|
||||
output.copy_(output_);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
|
||||
@ -171,8 +184,14 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
|
||||
using CachedGraph = MPSUnaryGradCachedGraph;
|
||||
TORCH_CHECK(output.is_mps());
|
||||
|
||||
if (self.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
Tensor output_ = at::empty_like(self, self.suggest_memory_format());
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
|
||||
@ -202,12 +221,13 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output_);
|
||||
|
||||
// Create dictionary of inputs and outputs
|
||||
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
}
|
||||
output.copy_(output_);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(log_softmax_mps_out)
|
||||
@ -656,6 +676,11 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate,
|
||||
auto approximate_type = get_gelutype_enum(approximate);
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
bool executeGatherOp =
|
||||
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
|
||||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
|
||||
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
const auto key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + gelutype_to_string(approximate_type);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -672,12 +697,17 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate,
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);
|
||||
|
||||
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
}
|
||||
|
||||
if (executeGatherOp) {
|
||||
output.copy_(output_);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_backward_out_mps)
|
||||
@ -686,8 +716,11 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps)
|
||||
using CachedGraph = MPSUnaryGradCachedGraph;
|
||||
|
||||
// Empty output
|
||||
if (grad_input.numel() == 0)
|
||||
if (self.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor grad_input_ = at::empty_like(self, self.suggest_memory_format());
|
||||
|
||||
auto approximate_type = get_gelutype_enum(approximate);
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@ -761,11 +794,12 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps)
|
||||
|
||||
Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad);
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_);
|
||||
|
||||
auto feeds = dictionaryFromPlaceholders(gradPlaceholder, selfPlaceholder);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
}
|
||||
grad_input.copy_(grad_input_);
|
||||
}
|
||||
|
||||
static void elu_variants_out_mps(const Tensor& self,
|
||||
@ -1241,6 +1275,11 @@ TORCH_IMPL_FUNC(mish_out_mps)
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
bool executeGatherOp =
|
||||
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
|
||||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
|
||||
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mish_out_mps:" + getTensorsStringKey({self});
|
||||
|
||||
@ -1257,12 +1296,16 @@ TORCH_IMPL_FUNC(mish_out_mps)
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false);
|
||||
|
||||
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
}
|
||||
if (executeGatherOp) {
|
||||
result.copy_(result_);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
|
||||
|
@ -1470,9 +1470,19 @@ class MPSLeakyReluTest(TestCaseMPS):
|
||||
0.9]]),
|
||||
negative_slope=0.1))
|
||||
|
||||
def _testLeakyRelu(self, np_features, negative_slope, device):
|
||||
cpu_x = torch.from_numpy(np_features).requires_grad_()
|
||||
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
|
||||
def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
|
||||
mps_x = cpu_x.detach().clone().to('mps')
|
||||
|
||||
if not contiguous and not (0 in shape or len(shape) < 2):
|
||||
# Tranposing will make the tensor non-contiguous
|
||||
cpu_x = cpu_x.transpose(0, 1)
|
||||
mps_x = mps_x.transpose(0, 1)
|
||||
assert not mps_x.is_contiguous()
|
||||
|
||||
cpu_x.requires_grad_()
|
||||
mps_x.requires_grad_()
|
||||
|
||||
relu_op = torch.nn.LeakyReLU(negative_slope)
|
||||
|
||||
cpu_leaky_relu = relu_op(cpu_x)
|
||||
@ -1480,19 +1490,24 @@ class MPSLeakyReluTest(TestCaseMPS):
|
||||
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
|
||||
|
||||
# test backward pass
|
||||
|
||||
cpu_grad = torch.ones_like(cpu_leaky_relu)
|
||||
mps_grad = cpu_grad.to('mps')
|
||||
cpu_leaky_relu.backward(gradient=cpu_grad)
|
||||
|
||||
mps_leaky_relu.backward(gradient=mps_grad)
|
||||
torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
|
||||
cpu_leaky_relu.backward(gradient=cpu_grad)
|
||||
|
||||
assert cpu_x.grad is not None # Check that the grad is well-populated
|
||||
self.assertEqual(cpu_x.grad, mps_x.grad)
|
||||
|
||||
def testNumbersCPU(self):
|
||||
for t in [np.float32]:
|
||||
self._testLeakyRelu(
|
||||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
|
||||
negative_slope=0.2,
|
||||
device="cpu")
|
||||
|
||||
for t in [torch.float, torch.half]:
|
||||
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
|
||||
for contiguous in [True, False]:
|
||||
self._testLeakyRelu(shape,
|
||||
dtype=t,
|
||||
negative_slope=0.2,
|
||||
contiguous=contiguous)
|
||||
|
||||
class TestAvgPool(TestCaseMPS):
|
||||
def _sum_pool2d(self, x, kernel_size):
|
||||
@ -6631,9 +6646,18 @@ class TestMPS(TestCaseMPS):
|
||||
helper((2, 16, 16), (4, 4), return_indices, dtype)
|
||||
|
||||
def test_gelu_simple(self):
|
||||
def helper(shape, dtype=torch.float):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
|
||||
x = cpu_x.detach().clone().to('mps').requires_grad_()
|
||||
def helper(shape, dtype=torch.float, contiguous=True):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
|
||||
x = cpu_x.detach().clone().to('mps')
|
||||
|
||||
if not contiguous and (0 not in shape and len(shape) >= 2):
|
||||
# Tranposing will make the tensor non-contiguous
|
||||
cpu_x = cpu_x.transpose(0, 1)
|
||||
x = x.transpose(0, 1)
|
||||
assert not x.is_contiguous()
|
||||
|
||||
cpu_x.requires_grad_()
|
||||
x.requires_grad_()
|
||||
|
||||
gelu_result = torch.nn.GELU()(x)
|
||||
# GELU is not supported on CPU, so cast it to float
|
||||
@ -6648,16 +6672,55 @@ class TestMPS(TestCaseMPS):
|
||||
atol = 1e-5 if dtype == torch.float else 1e-2
|
||||
rtol = 1e-3 if dtype == torch.float else 1e-2
|
||||
self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
|
||||
|
||||
assert x.grad is not None # Check that the grad is well-populated
|
||||
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
|
||||
|
||||
# Test empty shape too
|
||||
for dtype in [torch.float, torch.half]:
|
||||
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
|
||||
helper(shape, dtype)
|
||||
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
|
||||
for contiguous in [True, False]:
|
||||
helper(shape, dtype, contiguous)
|
||||
# Test that gelu would raise an assert for integral types
|
||||
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
|
||||
|
||||
def test_mish_simple(self):
|
||||
def helper(shape, dtype=torch.float, contiguous=True):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
|
||||
x = cpu_x.detach().clone().to('mps')
|
||||
|
||||
if not contiguous and (0 not in shape and len(shape) >= 2):
|
||||
# Tranposing will make the tensor non-contiguous
|
||||
cpu_x = cpu_x.transpose(0, 1)
|
||||
x = x.transpose(0, 1)
|
||||
assert not x.is_contiguous()
|
||||
|
||||
cpu_x.requires_grad_()
|
||||
x.requires_grad_()
|
||||
|
||||
mish_result = torch.nn.Mish()(x)
|
||||
mish_result_cpu = torch.nn.Mish()(cpu_x)
|
||||
|
||||
cpu_grad = torch.ones_like(mish_result_cpu)
|
||||
grad = cpu_grad.to('mps')
|
||||
|
||||
mish_result.backward(gradient=grad)
|
||||
mish_result_cpu.backward(gradient=cpu_grad)
|
||||
|
||||
atol = 1e-5 if dtype == torch.float else 1e-2
|
||||
rtol = 1e-3 if dtype == torch.float else 1e-2
|
||||
self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
|
||||
|
||||
assert x.grad is not None # Check that the grad is well-populated
|
||||
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
|
||||
|
||||
# Test empty shape too
|
||||
for dtype in [torch.float, torch.half]:
|
||||
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
|
||||
for contiguous in [True, False]:
|
||||
helper(shape, dtype, contiguous)
|
||||
|
||||
def test_gelu(self):
|
||||
def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
|
||||
numpy_dtype = {
|
||||
|
Reference in New Issue
Block a user