[MPS] Fix relu for 0-element input case (#133191)

Fixes #133182

Should already be tested by `test/test_mps.py::MPSReluTest::testNumbersGPU`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133191
Approved by: https://github.com/albanD
This commit is contained in:
Li-Huai (Allan) Lin
2024-08-11 15:24:16 -07:00
committed by PyTorch MergeBot
parent 666362865c
commit cc1cc71c46
2 changed files with 5 additions and 8 deletions

View File

@ -49,17 +49,16 @@ Tensor relu_mps(const Tensor& self) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
if (self.numel() == 0) {
return self;
}
MPSStream* stream = getCurrentMPSStream();
bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
if (output.numel() == 0) {
return output;
}
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "relu" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {

View File

@ -659,8 +659,6 @@ def mps_ops_modifier(ops):
UNIMPLEMENTED_XFAILLIST = {
# Failures due to lack of op implementation on MPS backend
'login': None,
'log_sigmoid': None,
'log_sigmoid_forward': None,
'linalg.eig': None,
'linalg.eigvals': None,
'put': None,