mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "don't check memory format for empty tensors (#126593)"
This reverts commit 12dee4f2046d07db97cddc7b3c5bdf06fc304ae3.
Reverted https://github.com/pytorch/pytorch/pull/126593 on behalf of https://github.com/clee2000 due to broke tests on inductor? test_modules.py::TestModuleCUDA::test_cpu_gpu_parity_nn_CTCLoss_cuda_float64 43f2f43eb3
https://github.com/pytorch/pytorch/actions/runs/9200644034/job/25308511495 ([comment](https://github.com/pytorch/pytorch/pull/126586#issuecomment-2126228689))
This commit is contained in:
@ -663,10 +663,10 @@ class TestModule(TestCase):
|
||||
d = output.dim()
|
||||
if (d == 4 and ((input_mem_format == torch.channels_last)
|
||||
or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
|
||||
self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
|
||||
elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
|
||||
or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
|
||||
self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d))
|
||||
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last_3d))
|
||||
else:
|
||||
self.assertTrue(output.is_contiguous())
|
||||
return self._traverse_obj(output, inner_check_out_mem_format)
|
||||
|
Reference in New Issue
Block a user