mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix relu for 0-element input case (#133191)
Fixes #133182 Should already be tested by `test/test_mps.py::MPSReluTest::testNumbersGPU`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133191 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
666362865c
commit
cc1cc71c46
@ -49,17 +49,16 @@ Tensor relu_mps(const Tensor& self) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
bool executeGatherOp =
|
||||
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
|
||||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
|
||||
Tensor output = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@autoreleasepool {
|
||||
string key = "relu" + getTensorsStringKey({self});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -659,8 +659,6 @@ def mps_ops_modifier(ops):
|
||||
UNIMPLEMENTED_XFAILLIST = {
|
||||
# Failures due to lack of op implementation on MPS backend
|
||||
'login': None,
|
||||
'log_sigmoid': None,
|
||||
'log_sigmoid_forward': None,
|
||||
'linalg.eig': None,
|
||||
'linalg.eigvals': None,
|
||||
'put': None,
|
||||
|
Reference in New Issue
Block a user