From 377033757ae5ca524ea842f1b0a5f446ed3d8fe0 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 31 Aug 2025 05:42:41 +0000 Subject: [PATCH] Use vectorized stores for all dtypes in cat (#161649) resurrecting #151818 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161649 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cuda/Shape.cu | 115 ++++++++++++++++++++++++++--- test/test_tensor_creation_ops.py | 35 +++++++++ 2 files changed, 139 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index e2eb2226acf4..92029732b449 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -226,6 +226,38 @@ __global__ void CatArrayBatchedCopy_contig( } } + +template +__global__ void CatArrayBatchedCopy_vectorized( + char* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType trailingSize) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs.nElements[blockIdx.y] / elems_per_vec; + + if(tid >= nElements) return; + + const char * data = (char*)inputs.input[blockIdx.y]; + IndexType offset = inputs.offset[blockIdx.y] * trailingSize / elems_per_vec; + IndexType dimSize = inputs.dimSize[blockIdx.y] * trailingSize / elems_per_vec; + IndexType dataOffset = offset * alignment; // in bytes + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ + IndexType elementOffset = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, tid) * alignment; // in bytes + auto vec = at::native::memory::ld_vec(data + alignment * tid); + at::native::memory::st_vec(output + dataOffset + elementOffset, vec); + tid += stride; + } +} + + + /* Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads to improve memory bandwidth throughput. @@ -296,12 +328,27 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i scalar_t *data = (scalar_t *)(out.mutable_data_ptr()); CatArrInputTensorMetadata catMetaData; TensorSizeStride outputParam; + // If all batches are contiguous we can call a specialized implementation + // which requires the input tensor addresses to be aligned to a + // 16 Byte boundary. + + constexpr bool isContig = stride_size == 1; + bool isAligned = true; + constexpr int alignment = 16; // Next, let's initialize the size, stride arrays for the output Tensor. + // for contig case, we'll canonicalize output strides, so that + // we don't have arbitrary strides for dims of size 0 + size_t stride0 = 1; if (memory_format == c10::MemoryFormat::Contiguous) { - for (int i = 0; i < nDims; ++i) { + for (int i = nDims - 1; i >= 0; --i) { outputParam.tensorSize[i] = out.size(i); - outputParam.tensorStride[i] = out.stride(i); + if (isContig) { + outputParam.tensorStride[i] = stride0; + stride0 *= out.size(i); + } else { + outputParam.tensorStride[i] = out.stride(i); + } } } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) { // permute the semantics of dims from NCHW to NHWC so that the input @@ -320,12 +367,15 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - // If all batches are contiguous we can call a specialized implementation - // which requires the input tensor addresses to be aligned to a - // 16 Byte boundary. - bool isContig = true; - bool isAligned = true; + // for channels last computing slice size correctly is much more involved, so we never send it + // on the fully vectorized path + // we need output stride in cat dimension to be multiple of alignment, + // if we ever use it to compute offsets + // for catting in 0th dimension it doesn't matter + bool isInOutAligned = isContig && at::native::memory::get_alignment(data) >= alignment && + memory_format == c10::MemoryFormat::Contiguous && (dimension == 0 || + outputParam.tensorStride[dimension - 1] * sizeof(scalar_t) % alignment == 0); unsigned int max_elements_per_tensor = 0; // Now we loop @@ -341,6 +391,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i // high-dimensional tensor if (inputs[i+batchCounter].get().numel() > 0) { dimSize = inputs[i+batchCounter].get().size(dimension); + if (isInOutAligned) { + auto t = inputs[i+batchCounter].get(); + // similarly to output stride, we cannot trust stride value to + // determine slice size if the corresponding dimension is 1 + // we have to multiply all the subsequent sizes + int64_t slice_size = dimension == 0 ? t.numel() : t.sizes()[dimension - 1] != 1 ? + t.strides()[dimension - 1] : c10::multiply_integers(t.sizes().begin() + dimension, t.sizes().end()); + slice_size *= sizeof(scalar_t); + isInOutAligned &= (slice_size % alignment == 0); + } } catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr()); @@ -351,10 +411,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i #ifdef USE_ROCM // On ROCm, CatArrayBatchedCopy_contig is faster isAligned = false; + isInOutAligned = false; #else // If at least one of the inputs is not aligned, we can't call the // CatArrayBatchedCopy_alignedK_contig isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]); + isInOutAligned &= at::native::memory::get_alignment(catMetaData.input[batchCounter]) >= alignment; #endif if (stride_size > 1) { @@ -365,7 +427,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j]; } catMetaData.isContiguous[batchCounter] = false; - isContig = false; } else { catMetaData.isContiguous[batchCounter] = true; } @@ -388,10 +449,13 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i max_elements_per_tensor, batchCounter); #else dim3 applyBlock, catGrid; - if (isContig && sizeof(scalar_t) > 2) { + if (isInOutAligned) { + std::tie(catGrid, applyBlock) = getCatGridContig( + max_elements_per_tensor, batchCounter); + } else if (isContig && isAligned && sizeof(scalar_t) > 2) { std::tie(catGrid, applyBlock) = getCatGridContig( max_elements_per_tensor, batchCounter); - } else if (isContig && sizeof(scalar_t) == 2) { + } else if (isContig && isAligned && sizeof(scalar_t) == 2) { std::tie(catGrid, applyBlock) = getCatGridContig( max_elements_per_tensor, batchCounter); } else { @@ -399,6 +463,30 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i getCatGrid(batchCounter, catGrid); } #endif + int32_t trailingSize; + TensorSizeStride kernelOutputParam; + if (isInOutAligned) { + // in this case we can and should flatten the tensors after the cat dim + // we want to view the tensors as if consisting of `alignment`-sized elements + // however, we might not be able to cleanly divide just the last dim - + // it might not be the multiple of alignment. + // however, we know that the full concatted slice is multiple of alignment, + // so if we flatten all the dims after and including concat dim, + // it will be divisible by alignment + // then we need to divide last out size by elems_per_vec, + // and divide all strides except last by elems_per_vec (last stride is 1 always) + // for input, we will fix up the sizes and strides in the kernel directly + kernelOutputParam = outputParam; + nDims = dimension + 1; + constexpr auto elems_per_vec = alignment / sizeof(scalar_t); + auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1]; + kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec; + trailingSize = outputParam.tensorStride[dimension]; + kernelOutputParam.tensorStride[dimension] = 1; + for (int i = 0; i < dimension; ++i) { + kernelOutputParam.tensorStride[i] /= elems_per_vec; + } + } if (memory_format != c10::MemoryFormat::Contiguous) { switch (dimension) { @@ -413,7 +501,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i } // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\ + if (isInOutAligned) {\ + constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \ + CatArrayBatchedCopy_vectorized<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + (char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\ + } else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\ CatArrayBatchedCopy_alignedK_contig<<<\ catGrid, applyBlock, 0, stream.stream()>>>(\ data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 02cb1d31d563..0ff55c62ae1c 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -1151,6 +1151,41 @@ class TestTensorCreation(TestCase): z = torch.cat([x, y]) self.assertEqual(z.size(), (21, SIZE, SIZE)) + @dtypes(torch.float) + def test_cat_size1(self, device, dtype): + # create a tensor that has aligned stride along dim - 1 dimension + # but catted slice size is not aligned + x1 = torch.randn(16, 16, device=device, dtype=dtype)[:1, :1] + xref = x1.clone().view(-1).view(x1.shape) + # make sure output size is aligned, need at least 4 elements for this + res = torch.cat([x1, x1, x1, x1], dim=-1) + ref = torch.cat([xref, xref, xref, xref], dim=-1) + self.assertEqual(res, ref) + + @dtypes(torch.float) + def test_cat_trailing_dim(self, device, dtype): + x1 = torch.randn(16, 16, 23, device=device, dtype=dtype) + x2 = torch.rand_like(x1) + res = torch.cat([x1, x2], dim=1) + ref = torch.cat([x1.cpu(), x2.cpu()], dim=1) + self.assertEqual(res, ref) + + @dtypes(torch.float) + def test_cat_misaligned(self, device, dtype): + x1 = torch.randn(14, device=device, dtype=dtype)[2:] + x2 = torch.rand_like(x1) + res = torch.cat([x1, x2], dim=-1) + ref = torch.cat([x1.cpu(), x2.cpu()], dim=-1) + self.assertEqual(res, ref) + + @dtypes(torch.float) + def test_cat_multi_batch(self, device, dtype): + xs = [torch.randn(16, 16, device=device, dtype=dtype) for _ in range(130)] + xs_cpu = [x.cpu() for x in xs] + res = torch.cat(xs, dim=-1) + ref = torch.cat(xs_cpu, dim=-1) + self.assertEqual(res, ref) + # FIXME: Create an OpInfo-based tensor creation method test that verifies this for all tensor # creation methods and verify all dtypes and layouts @dtypes(torch.bool, torch.uint8, torch.int16, torch.int64, torch.float16, torch.float32, torch.complex64)