Revert "don't check memory format for empty tensors (#126593)"

This reverts commit 12dee4f2046d07db97cddc7b3c5bdf06fc304ae3.

Reverted https://github.com/pytorch/pytorch/pull/126593 on behalf of https://github.com/clee2000 due to broke tests on inductor? test_modules.py::TestModuleCUDA::test_cpu_gpu_parity_nn_CTCLoss_cuda_float64 43f2f43eb3 https://github.com/pytorch/pytorch/actions/runs/9200644034/job/25308511495 ([comment](https://github.com/pytorch/pytorch/pull/126586#issuecomment-2126228689))
This commit is contained in:
PyTorch MergeBot
2024-05-23 04:54:28 +00:00
parent df4b7cb5f7
commit b1e214ceb1

View File

@ -663,10 +663,10 @@ class TestModule(TestCase):
d = output.dim()
if (d == 4 and ((input_mem_format == torch.channels_last)
or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d))
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last_3d))
else:
self.assertTrue(output.is_contiguous())
return self._traverse_obj(output, inner_check_out_mem_format)