Files
pytorch/aten/src/ATen/native/cuda/Shape.cu
Nikita Shulga a229e78544 [BE] Enforce sign-compare (#96723)
Number of OSS PR were reverted, because new signed-unsigned comparison warnings, which are treated as errors in some internal builds.
Not sure how those selective rules are applied, but this PR removes `-Wno-sign-compare` from PyTorch codebase.

The only tricky part in this PR, as making sure that non-ASCII character detection works for both signed and unsigned chars  here:
6e3d51b08a/torch/csrc/jit/serialization/python_print.cpp (L926)

Exclude several files from sign-compare if flash attention is used, due to the violation in cutlass, to be fixed by https://github.com/NVIDIA/cutlass/pull/869
Do not try to fix sign compare violations in caffe2 codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96723
Approved by: https://github.com/albanD
2023-03-15 06:04:20 +00:00

324 lines
12 KiB
Plaintext

#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/Resize.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/TensorShape.h>
#include <ATen/Dispatch.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Optional.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cat_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/narrow.h>
#endif
namespace at::native {
constexpr int CAT_ARRAY_BATCH_SIZE = 128;
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;
namespace {
inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) {
const int numSM = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
//X dim of grid for cat array cooperates on a single tensor in the cat.
//Given half of the GPU, full utilization will always occur.
grid = dim3( 2LL * numSM, (long long) nTensors );
return true;
}
// Similar to any other IndexToOffset calculation for copying along a given
// dimension.
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
static inline __device__ IndexType compute(
const IndexType tensorSize[Dims],
const IndexType tensorStride[Dims],
const IndexType dimSize,
const unsigned int concatDim,
IndexType linearIndex) {
// linearIndex is not really linear index, but instead the offset in
// input tensor. If the input tensor is contiguous, then this offset
// is the linear index, but if the input tensor is channels last, then
// it is the linear index of the permuted contiguous tensor
IndexType offset = 0;
#pragma unroll
for (int i = Dims - 1; i >= 1; --i) {
IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i];
IndexType nextDimIndex = linearIndex / curDimSize;
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
IndexType curDimOffset = curDimIndex * tensorStride[i];
offset += curDimOffset;
linearIndex = nextDimIndex;
}
return offset + linearIndex * tensorStride[0];
}
};
template<typename IndexType, unsigned int MaxDims>
struct TensorSizeStride {
IndexType tensorSize[MaxDims];
IndexType tensorStride[MaxDims];
};
/**
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a
* grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to
* copy each element from each input tensor into the output.
*
* output: base pointer to the storage associated with the output tensor
* inputs: GPU-allocated array of input metadata for each input to concatenate
* in the kernel
* os: the size/stride vectors for the output tensor
* concatDim: dimension along which we are concatenating
* dimStride: the stride of the output tensor at the concatDim
*
* The most important assumption made is that the input tensors are contiguous.
*/
// pass meta data directly through kernel argument instead of pin memory
// In contiguous case, we will not need stride_size, setting it as 1 as placeholder
// to pass compile.
template <typename T, typename IndexType, int n, int stride_size>
struct CatArrInputTensorMetadata {
T* input[n];
IndexType offset[n];
IndexType dimSize[n];
IndexType nElements[n];
bool isContiguous[n];
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> tensorStride[stride_size];
};
template <typename T, typename IndexType, int Dims, int batch_size, int stride_size>
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs.nElements[blockIdx.y];
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0];
bool isContig = inputs.isContiguous[blockIdx.y];
if(tid >= nElements) return;
T* data = inputs.input[blockIdx.y];
IndexType offset = inputs.offset[blockIdx.y];
IndexType dimSize = inputs.dimSize[blockIdx.y];
IndexType dataOffset = offset * dimStride;
IndexType stride = gridDim.x * blockDim.x;
while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
if (isContig) {
output[dataOffset + elementOffset] = data[tid];
} else {
IndexType inElementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[inElementOffset];
}
tid += stride;
}
}
template <typename scalar_t, int batch_size, int stride_size>
void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;
// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
outputParam.tensorSize[i] = out.size(i);
outputParam.tensorStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
outputParam.tensorSize[0] = out.size(0);
outputParam.tensorStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
outputParam.tensorSize[i] = out.size(i + 1);
outputParam.tensorStride[i] = out.stride(i + 1);
}
outputParam.tensorSize[nDims - 1] = out.size(1);
outputParam.tensorStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (unsigned i = 0; i < inputs.size() ; i += batch_size) {
for (batchCounter = 0;
batchCounter < batch_size &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = 0;
// There is a legacy case where a 1-D empty tensor can be concat with
// high-dimensional tensor
if (inputs[i+batchCounter].get().numel() > 0) {
dimSize = inputs[i+batchCounter].get().size(dimension);
}
catMetaData.input[batchCounter] = inputs[i+batchCounter].get().data_ptr<scalar_t>();
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel();
if (stride_size > 1) {
auto strides = inputs[i+batchCounter].get().strides();
auto sizes = inputs[i+batchCounter].get().sizes();
for(int j = 0; j < nDims; j++){
catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j];
catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j];
}
catMetaData.isContiguous[batchCounter] = false;
} else {
catMetaData.isContiguous[batchCounter] = true;
}
// update offset
offset += dimSize;
}
// Next, let's consider how we set our kernel launch parameters.
// We borrow from THCApply, which the kernel's internal indexing
// is based on.
dim3 applyBlock = dim3(32*16);
//Get grid where x dim fills half gpu and y dim is number of tensors.
//This will have cating two tensors fill the entire grid, but prevent
//many threads from needlessly load meta data if their sizes is small.
dim3 catGrid;
getCatGrid(batchCounter, catGrid);
if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
case 0:
break;
case 1:
dimension = nDims - dimension;
break;
default:
dimension--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;
case 2:
HANDLE_CASE(2);
break;
case 3:
HANDLE_CASE(3);
break;
case 4:
HANDLE_CASE(4);
break;
}
#undef HANDLE_CASE
}
}
} // namespace
TORCH_IMPL_FUNC(cat_out_cuda)
(const ITensorListRef& tensors,
int64_t dim,
int64_t valid,
bool all_contiguous,
bool all_same_dtype,
bool all_same_sizes_and_stride,
MemoryFormat memory_format,
const Tensor& result) {
if (result.numel() == 0) {
return;
}
auto materialized = tensors.materialize();
// We parallelize the copy if all 6 conditions pass:
//
// 1. There is more than one input tensor
// 2. The out tensor is 32-bit indexable
// 3. The number of dimensions is <= 4
// 4. All input tensors are contiguous (output tensor may be non-contig)
// 5. All input tensors can use 32-bit indexing
const bool all32BitIndexable = std::all_of(materialized.begin(), materialized.end(),
[] (const Tensor& t) {
return at::cuda::detail::canUse32BitIndexMath(t);
});
int nDims = materialized[valid].get().dim();
// We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways
// For contiguous input, we don't need to pass stride meta data to cuda kernel through constant
// memory. Therefore, we could pass more inputs to cuda threads.
// For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
// of constant memory.
if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
all_contiguous &&
all32BitIndexable &&
all_same_dtype) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
} else if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
all32BitIndexable &&
all_same_dtype &&
memory_format == c10::MemoryFormat::Contiguous) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
} else {
int64_t offset = 0;
for (const Tensor& t : materialized) {
if (cat_should_skip_tensor(t)) continue;
int64_t dimSize = t.size(dim);
Tensor nt = at::narrow(result, dim, offset, dimSize);
copy_(nt, t);
offset += dimSize;
}
}
}
} // namespace at::native