diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index c5e610223215..e8491e565be7 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -131,8 +131,17 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(output.is_mps()); + if (self.numel() == 0) { + return; + } + MPSStream* stream = getCurrentMPSStream(); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); + @autoreleasepool { string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { @@ -152,13 +161,17 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s newCachedGraph->outputTensor_ = outputTensor; }); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false); // Create dictionary of inputs and outputs auto feeds = dictionaryFromPlaceholders(selfPlaceholder); runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } + if (executeGatherOp) { + output.copy_(output_); + } } TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) @@ -171,8 +184,14 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) using CachedGraph = MPSUnaryGradCachedGraph; TORCH_CHECK(output.is_mps()); + if (self.numel() == 0) { + return; + } + MPSStream* stream = getCurrentMPSStream(); + Tensor output_ = at::empty_like(self, self.suggest_memory_format()); + @autoreleasepool { string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to()); @@ -202,12 +221,13 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output_); // Create dictionary of inputs and outputs auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder); runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } + output.copy_(output_); } TORCH_IMPL_FUNC(log_softmax_mps_out) @@ -656,6 +676,11 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate, auto approximate_type = get_gelutype_enum(approximate); MPSStream* stream = getCurrentMPSStream(); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); + @autoreleasepool { const auto key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + gelutype_to_string(approximate_type); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { @@ -672,12 +697,17 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate, newCachedGraph->outputTensor_ = outputTensor; }); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false); auto feeds = dictionaryFromPlaceholders(selfPlaceholder); runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } + + if (executeGatherOp) { + output.copy_(output_); + } } TORCH_IMPL_FUNC(gelu_backward_out_mps) @@ -686,8 +716,11 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps) using CachedGraph = MPSUnaryGradCachedGraph; // Empty output - if (grad_input.numel() == 0) + if (self.numel() == 0) { return; + } + + Tensor grad_input_ = at::empty_like(self, self.suggest_memory_format()); auto approximate_type = get_gelutype_enum(approximate); MPSStream* stream = getCurrentMPSStream(); @@ -761,11 +794,12 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps) Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad); Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_); auto feeds = dictionaryFromPlaceholders(gradPlaceholder, selfPlaceholder); runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } + grad_input.copy_(grad_input_); } static void elu_variants_out_mps(const Tensor& self, @@ -1241,6 +1275,11 @@ TORCH_IMPL_FUNC(mish_out_mps) MPSStream* stream = getCurrentMPSStream(); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); + @autoreleasepool { string key = "mish_out_mps:" + getTensorsStringKey({self}); @@ -1257,12 +1296,16 @@ TORCH_IMPL_FUNC(mish_out_mps) newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; }); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false); auto feeds = dictionaryFromPlaceholders(selfPlaceholder); runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } + if (executeGatherOp) { + result.copy_(result_); + } } Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) { diff --git a/test/test_mps.py b/test/test_mps.py index 862bda96c729..eb5e45aa0fed 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1470,9 +1470,19 @@ class MPSLeakyReluTest(TestCaseMPS): 0.9]]), negative_slope=0.1)) - def _testLeakyRelu(self, np_features, negative_slope, device): - cpu_x = torch.from_numpy(np_features).requires_grad_() - mps_x = torch.from_numpy(np_features).to('mps').requires_grad_() + def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous): + cpu_x = torch.randn(shape, device='cpu', dtype=dtype) + mps_x = cpu_x.detach().clone().to('mps') + + if not contiguous and not (0 in shape or len(shape) < 2): + # Tranposing will make the tensor non-contiguous + cpu_x = cpu_x.transpose(0, 1) + mps_x = mps_x.transpose(0, 1) + assert not mps_x.is_contiguous() + + cpu_x.requires_grad_() + mps_x.requires_grad_() + relu_op = torch.nn.LeakyReLU(negative_slope) cpu_leaky_relu = relu_op(cpu_x) @@ -1480,19 +1490,24 @@ class MPSLeakyReluTest(TestCaseMPS): torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) # test backward pass + cpu_grad = torch.ones_like(cpu_leaky_relu) mps_grad = cpu_grad.to('mps') - cpu_leaky_relu.backward(gradient=cpu_grad) + mps_leaky_relu.backward(gradient=mps_grad) - torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu')) + cpu_leaky_relu.backward(gradient=cpu_grad) + + assert cpu_x.grad is not None # Check that the grad is well-populated + self.assertEqual(cpu_x.grad, mps_x.grad) def testNumbersCPU(self): - for t in [np.float32]: - self._testLeakyRelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - negative_slope=0.2, - device="cpu") - + for t in [torch.float, torch.half]: + for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: + for contiguous in [True, False]: + self._testLeakyRelu(shape, + dtype=t, + negative_slope=0.2, + contiguous=contiguous) class TestAvgPool(TestCaseMPS): def _sum_pool2d(self, x, kernel_size): @@ -6631,9 +6646,18 @@ class TestMPS(TestCaseMPS): helper((2, 16, 16), (4, 4), return_indices, dtype) def test_gelu_simple(self): - def helper(shape, dtype=torch.float): - cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) - x = cpu_x.detach().clone().to('mps').requires_grad_() + def helper(shape, dtype=torch.float, contiguous=True): + cpu_x = torch.randn(shape, device='cpu', dtype=dtype) + x = cpu_x.detach().clone().to('mps') + + if not contiguous and (0 not in shape and len(shape) >= 2): + # Tranposing will make the tensor non-contiguous + cpu_x = cpu_x.transpose(0, 1) + x = x.transpose(0, 1) + assert not x.is_contiguous() + + cpu_x.requires_grad_() + x.requires_grad_() gelu_result = torch.nn.GELU()(x) # GELU is not supported on CPU, so cast it to float @@ -6648,16 +6672,55 @@ class TestMPS(TestCaseMPS): atol = 1e-5 if dtype == torch.float else 1e-2 rtol = 1e-3 if dtype == torch.float else 1e-2 self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol) + + assert x.grad is not None # Check that the grad is well-populated self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol) # Test empty shape too for dtype in [torch.float, torch.half]: - for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]: - helper(shape, dtype) + for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: + for contiguous in [True, False]: + helper(shape, dtype, contiguous) # Test that gelu would raise an assert for integral types for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps"))) + def test_mish_simple(self): + def helper(shape, dtype=torch.float, contiguous=True): + cpu_x = torch.randn(shape, device='cpu', dtype=dtype) + x = cpu_x.detach().clone().to('mps') + + if not contiguous and (0 not in shape and len(shape) >= 2): + # Tranposing will make the tensor non-contiguous + cpu_x = cpu_x.transpose(0, 1) + x = x.transpose(0, 1) + assert not x.is_contiguous() + + cpu_x.requires_grad_() + x.requires_grad_() + + mish_result = torch.nn.Mish()(x) + mish_result_cpu = torch.nn.Mish()(cpu_x) + + cpu_grad = torch.ones_like(mish_result_cpu) + grad = cpu_grad.to('mps') + + mish_result.backward(gradient=grad) + mish_result_cpu.backward(gradient=cpu_grad) + + atol = 1e-5 if dtype == torch.float else 1e-2 + rtol = 1e-3 if dtype == torch.float else 1e-2 + self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol) + + assert x.grad is not None # Check that the grad is well-populated + self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol) + + # Test empty shape too + for dtype in [torch.float, torch.half]: + for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: + for contiguous in [True, False]: + helper(shape, dtype, contiguous) + def test_gelu(self): def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None): numpy_dtype = {