mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use vectorized stores for all dtypes in cat (#161649)
resurrecting #151818 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161649 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
f612045ce1
commit
377033757a
@ -226,6 +226,38 @@ __global__ void CatArrayBatchedCopy_contig(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename IndexType, int Dims, int batch_size, int stride_size, int alignment, int elems_per_vec>
|
||||
__global__ void CatArrayBatchedCopy_vectorized(
|
||||
char* output,
|
||||
CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
|
||||
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
|
||||
const int concatDim,
|
||||
IndexType trailingSize) {
|
||||
|
||||
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
IndexType nElements = inputs.nElements[blockIdx.y] / elems_per_vec;
|
||||
|
||||
if(tid >= nElements) return;
|
||||
|
||||
const char * data = (char*)inputs.input[blockIdx.y];
|
||||
IndexType offset = inputs.offset[blockIdx.y] * trailingSize / elems_per_vec;
|
||||
IndexType dimSize = inputs.dimSize[blockIdx.y] * trailingSize / elems_per_vec;
|
||||
IndexType dataOffset = offset * alignment; // in bytes
|
||||
|
||||
IndexType stride = gridDim.x * blockDim.x;
|
||||
|
||||
while( tid < nElements){
|
||||
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
|
||||
os.tensorSize, os.tensorStride, dimSize, concatDim, tid) * alignment; // in bytes
|
||||
auto vec = at::native::memory::ld_vec<alignment>(data + alignment * tid);
|
||||
at::native::memory::st_vec<alignment>(output + dataOffset + elementOffset, vec);
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/*
|
||||
Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads
|
||||
to improve memory bandwidth throughput.
|
||||
@ -296,12 +328,27 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
scalar_t *data = (scalar_t *)(out.mutable_data_ptr());
|
||||
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
|
||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;
|
||||
// If all batches are contiguous we can call a specialized implementation
|
||||
// which requires the input tensor addresses to be aligned to a
|
||||
// 16 Byte boundary.
|
||||
|
||||
constexpr bool isContig = stride_size == 1;
|
||||
bool isAligned = true;
|
||||
constexpr int alignment = 16;
|
||||
|
||||
// Next, let's initialize the size, stride arrays for the output Tensor.
|
||||
// for contig case, we'll canonicalize output strides, so that
|
||||
// we don't have arbitrary strides for dims of size 0
|
||||
size_t stride0 = 1;
|
||||
if (memory_format == c10::MemoryFormat::Contiguous) {
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
outputParam.tensorSize[i] = out.size(i);
|
||||
outputParam.tensorStride[i] = out.stride(i);
|
||||
if (isContig) {
|
||||
outputParam.tensorStride[i] = stride0;
|
||||
stride0 *= out.size(i);
|
||||
} else {
|
||||
outputParam.tensorStride[i] = out.stride(i);
|
||||
}
|
||||
}
|
||||
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
|
||||
// permute the semantics of dims from NCHW to NHWC so that the input
|
||||
@ -320,12 +367,15 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
|
||||
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// If all batches are contiguous we can call a specialized implementation
|
||||
// which requires the input tensor addresses to be aligned to a
|
||||
// 16 Byte boundary.
|
||||
|
||||
bool isContig = true;
|
||||
bool isAligned = true;
|
||||
// for channels last computing slice size correctly is much more involved, so we never send it
|
||||
// on the fully vectorized path
|
||||
// we need output stride in cat dimension to be multiple of alignment,
|
||||
// if we ever use it to compute offsets
|
||||
// for catting in 0th dimension it doesn't matter
|
||||
bool isInOutAligned = isContig && at::native::memory::get_alignment(data) >= alignment &&
|
||||
memory_format == c10::MemoryFormat::Contiguous && (dimension == 0 ||
|
||||
outputParam.tensorStride[dimension - 1] * sizeof(scalar_t) % alignment == 0);
|
||||
unsigned int max_elements_per_tensor = 0;
|
||||
|
||||
// Now we loop
|
||||
@ -341,6 +391,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
// high-dimensional tensor
|
||||
if (inputs[i+batchCounter].get().numel() > 0) {
|
||||
dimSize = inputs[i+batchCounter].get().size(dimension);
|
||||
if (isInOutAligned) {
|
||||
auto t = inputs[i+batchCounter].get();
|
||||
// similarly to output stride, we cannot trust stride value to
|
||||
// determine slice size if the corresponding dimension is 1
|
||||
// we have to multiply all the subsequent sizes
|
||||
int64_t slice_size = dimension == 0 ? t.numel() : t.sizes()[dimension - 1] != 1 ?
|
||||
t.strides()[dimension - 1] : c10::multiply_integers(t.sizes().begin() + dimension, t.sizes().end());
|
||||
slice_size *= sizeof(scalar_t);
|
||||
isInOutAligned &= (slice_size % alignment == 0);
|
||||
}
|
||||
}
|
||||
|
||||
catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr());
|
||||
@ -351,10 +411,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
#ifdef USE_ROCM
|
||||
// On ROCm, CatArrayBatchedCopy_contig is faster
|
||||
isAligned = false;
|
||||
isInOutAligned = false;
|
||||
#else
|
||||
// If at least one of the inputs is not aligned, we can't call the
|
||||
// CatArrayBatchedCopy_alignedK_contig
|
||||
isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]);
|
||||
isInOutAligned &= at::native::memory::get_alignment(catMetaData.input[batchCounter]) >= alignment;
|
||||
#endif
|
||||
|
||||
if (stride_size > 1) {
|
||||
@ -365,7 +427,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j];
|
||||
}
|
||||
catMetaData.isContiguous[batchCounter] = false;
|
||||
isContig = false;
|
||||
} else {
|
||||
catMetaData.isContiguous[batchCounter] = true;
|
||||
}
|
||||
@ -388,10 +449,13 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
max_elements_per_tensor, batchCounter);
|
||||
#else
|
||||
dim3 applyBlock, catGrid;
|
||||
if (isContig && sizeof(scalar_t) > 2) {
|
||||
if (isInOutAligned) {
|
||||
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, alignment>(
|
||||
max_elements_per_tensor, batchCounter);
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) > 2) {
|
||||
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, ALIGNED_VEC_LOAD_BYTES_16>(
|
||||
max_elements_per_tensor, batchCounter);
|
||||
} else if (isContig && sizeof(scalar_t) == 2) {
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) == 2) {
|
||||
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, ALIGNED_VEC_LOAD_BYTES_8>(
|
||||
max_elements_per_tensor, batchCounter);
|
||||
} else {
|
||||
@ -399,6 +463,30 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
getCatGrid(batchCounter, catGrid);
|
||||
}
|
||||
#endif
|
||||
int32_t trailingSize;
|
||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
||||
if (isInOutAligned) {
|
||||
// in this case we can and should flatten the tensors after the cat dim
|
||||
// we want to view the tensors as if consisting of `alignment`-sized elements
|
||||
// however, we might not be able to cleanly divide just the last dim -
|
||||
// it might not be the multiple of alignment.
|
||||
// however, we know that the full concatted slice is multiple of alignment,
|
||||
// so if we flatten all the dims after and including concat dim,
|
||||
// it will be divisible by alignment
|
||||
// then we need to divide last out size by elems_per_vec,
|
||||
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
||||
// for input, we will fix up the sizes and strides in the kernel directly
|
||||
kernelOutputParam = outputParam;
|
||||
nDims = dimension + 1;
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
||||
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
||||
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
||||
trailingSize = outputParam.tensorStride[dimension];
|
||||
kernelOutputParam.tensorStride[dimension] = 1;
|
||||
for (int i = 0; i < dimension; ++i) {
|
||||
kernelOutputParam.tensorStride[i] /= elems_per_vec;
|
||||
}
|
||||
}
|
||||
|
||||
if (memory_format != c10::MemoryFormat::Contiguous) {
|
||||
switch (dimension) {
|
||||
@ -413,7 +501,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
}
|
||||
// Template Declarations for dim = 1, 2, 3, 4
|
||||
#define HANDLE_CASE(DIMS) \
|
||||
if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
|
||||
if (isInOutAligned) {\
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
|
||||
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
|
||||
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
|
@ -1151,6 +1151,41 @@ class TestTensorCreation(TestCase):
|
||||
z = torch.cat([x, y])
|
||||
self.assertEqual(z.size(), (21, SIZE, SIZE))
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_cat_size1(self, device, dtype):
|
||||
# create a tensor that has aligned stride along dim - 1 dimension
|
||||
# but catted slice size is not aligned
|
||||
x1 = torch.randn(16, 16, device=device, dtype=dtype)[:1, :1]
|
||||
xref = x1.clone().view(-1).view(x1.shape)
|
||||
# make sure output size is aligned, need at least 4 elements for this
|
||||
res = torch.cat([x1, x1, x1, x1], dim=-1)
|
||||
ref = torch.cat([xref, xref, xref, xref], dim=-1)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_cat_trailing_dim(self, device, dtype):
|
||||
x1 = torch.randn(16, 16, 23, device=device, dtype=dtype)
|
||||
x2 = torch.rand_like(x1)
|
||||
res = torch.cat([x1, x2], dim=1)
|
||||
ref = torch.cat([x1.cpu(), x2.cpu()], dim=1)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_cat_misaligned(self, device, dtype):
|
||||
x1 = torch.randn(14, device=device, dtype=dtype)[2:]
|
||||
x2 = torch.rand_like(x1)
|
||||
res = torch.cat([x1, x2], dim=-1)
|
||||
ref = torch.cat([x1.cpu(), x2.cpu()], dim=-1)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_cat_multi_batch(self, device, dtype):
|
||||
xs = [torch.randn(16, 16, device=device, dtype=dtype) for _ in range(130)]
|
||||
xs_cpu = [x.cpu() for x in xs]
|
||||
res = torch.cat(xs, dim=-1)
|
||||
ref = torch.cat(xs_cpu, dim=-1)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
# FIXME: Create an OpInfo-based tensor creation method test that verifies this for all tensor
|
||||
# creation methods and verify all dtypes and layouts
|
||||
@dtypes(torch.bool, torch.uint8, torch.int16, torch.int64, torch.float16, torch.float32, torch.complex64)
|
||||
|
Reference in New Issue
Block a user