Update channel shuffle to return alias instead of self as-is (#99745)

Partially addresses https://github.com/pytorch/pytorch/issues/99655
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99745
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2023-04-22 15:29:17 -04:00
committed by PyTorch MergeBot
parent ab0a8215bb
commit 5ee5afb82c
2 changed files with 4 additions and 4 deletions

View File

@ -47,12 +47,12 @@ Tensor channel_shuffle(const Tensor& self, int64_t groups) {
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
if (self.is_contiguous(MemoryFormat::ChannelsLast) &&
xnnpack::use_channel_shuffle(self, groups)) {
auto output = self.numel() == 0 ? self : xnnpack::channel_shuffle(self, groups);
auto output = self.numel() == 0 ? self.alias() : xnnpack::channel_shuffle(self, groups);
return output;
}
#endif
auto output = self.numel() == 0 ? self : at::native_channel_shuffle(self, groups);
auto output = self.numel() == 0 ? self.alias() : at::native_channel_shuffle(self, groups);
return namedinference::propagate_names_if_nonempty(
output,
self.has_names() ? self.names() : at::ArrayRef<Dimname>{});

View File

@ -6414,8 +6414,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
self.assertEqual(y, y_ref)
def test_channel_shuffle_return_self(self):
# gh-76616: nn.ChannelShuffle will return self with an empty input tensor
def test_channel_shuffle_return_alias_of_self(self):
# gh-76616: nn.ChannelShuffle will return alias of self with an empty input tensor
groups = 3
input_tensor = torch.rand([0, 9, 4, 4])
output = torch.nn.ChannelShuffle(groups)(input_tensor)