mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update channel shuffle to return alias instead of self as-is (#99745)
Partially addresses https://github.com/pytorch/pytorch/issues/99655 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99745 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab0a8215bb
commit
5ee5afb82c
@ -47,12 +47,12 @@ Tensor channel_shuffle(const Tensor& self, int64_t groups) {
|
||||
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
|
||||
if (self.is_contiguous(MemoryFormat::ChannelsLast) &&
|
||||
xnnpack::use_channel_shuffle(self, groups)) {
|
||||
auto output = self.numel() == 0 ? self : xnnpack::channel_shuffle(self, groups);
|
||||
auto output = self.numel() == 0 ? self.alias() : xnnpack::channel_shuffle(self, groups);
|
||||
return output;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto output = self.numel() == 0 ? self : at::native_channel_shuffle(self, groups);
|
||||
auto output = self.numel() == 0 ? self.alias() : at::native_channel_shuffle(self, groups);
|
||||
return namedinference::propagate_names_if_nonempty(
|
||||
output,
|
||||
self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
|
||||
|
@ -6414,8 +6414,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
self.assertEqual(y, y_ref)
|
||||
|
||||
|
||||
def test_channel_shuffle_return_self(self):
|
||||
# gh-76616: nn.ChannelShuffle will return self with an empty input tensor
|
||||
def test_channel_shuffle_return_alias_of_self(self):
|
||||
# gh-76616: nn.ChannelShuffle will return alias of self with an empty input tensor
|
||||
groups = 3
|
||||
input_tensor = torch.rand([0, 9, 4, 4])
|
||||
output = torch.nn.ChannelShuffle(groups)(input_tensor)
|
||||
|
Reference in New Issue
Block a user