Compare commits

...

1 Commits

Author SHA1 Message Date
161e1643ed Relocate channels-last cat regression test 2025-10-09 17:32:59 +00:00
2 changed files with 24 additions and 8 deletions

View File

@ -488,15 +488,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
}
}
int cat_dim = dimension;
if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
switch (cat_dim) {
case 0:
break;
case 1:
dimension = nDims - dimension;
cat_dim = nDims - cat_dim;
break;
default:
dimension--;
cat_dim--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
@ -505,23 +506,23 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
(char*)data, catMetaData, kernelOutputParam, cat_dim, trailingSize);\
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else if (isContig && isAligned && sizeof(scalar_t) == 2) { \
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_8><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else if (isContig) {\
CatArrayBatchedCopy_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else {\
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
}\
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) {

View File

@ -688,6 +688,21 @@ class TestTensorCreation(TestCase):
self.assertEqual(res1, res2)
self.assertTrue(res1.is_contiguous(memory_format=torch.channels_last))
@onlyCUDA
def test_cat_channels_last_large_inputs(self, device):
num_tensors = 130
inputs_cuda = [
torch.randn((2, 3, 4, 4), device=device).contiguous(memory_format=torch.channels_last)
for _ in range(num_tensors)
]
inputs_cpu = [t.cpu() for t in inputs_cuda]
result = torch.cat(inputs_cuda, dim=1)
expected = torch.cat(inputs_cpu, dim=1)
self.assertEqual(result.cpu(), expected)
self.assertTrue(result.is_contiguous(memory_format=torch.channels_last))
@onlyCUDA
def test_cat_out_memory_format(self, device):
inp_size = (4, 4, 4, 4)