diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 07f5667bec0e..6a305f8546e3 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -1190,19 +1190,3 @@ - THTensor* self - THTensor* src ]] - -[[ - name: _th_cat - cname: catArray - variants: [function] - backends: - - CUDA - cuda_bool: True - cuda_bfloat16: True - return: self - arguments: - - arg: THTensor* self - output: True - - TensorList tensors - - int64_t dim -]] diff --git a/aten/src/ATen/cuda/detail/IndexUtils.cu b/aten/src/ATen/cuda/detail/IndexUtils.cu index c5fbe9fa0846..31b6f385ba77 100644 --- a/aten/src/ATen/cuda/detail/IndexUtils.cu +++ b/aten/src/ATen/cuda/detail/IndexUtils.cu @@ -74,6 +74,8 @@ bool canUse32BitIndexMath(const Tensor& t, int64_t max_elem) { int64_t elements = t.numel(); if (elements >= max_elem) { return false; + } else if (elements == 0) { + return true; } int64_t offset = 0; diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu new file mode 100644 index 000000000000..6f4505323a0f --- /dev/null +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -0,0 +1,360 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +constexpr int CAT_ARRAY_BATCH_SIZE = 1024; +constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; + +namespace { + +inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) { + const int numSM = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + //X dim of grid for cat array cooperates on a single tensor in the cat. + //Given half of the GPU, full utilization will always occur. + grid = dim3( 2LL * numSM, (long long) nTensors ); + + return true; +} + +// Similar to any other IndexToOffset calculation for copying along a given +// dimension. +template +struct CatArrIndexToOffset { + static inline __device__ IndexType compute( + const IndexType outputSize[Dims], + const IndexType outputStride[Dims], + const IndexType dimSize, + const unsigned int concatDim, + IndexType linearIndex) { + IndexType offset = 0; + +#pragma unroll + for (int i = Dims - 1; i >= 1; --i) { + IndexType curDimSize = i == concatDim ? dimSize : outputSize[i]; + IndexType nextDimIndex = linearIndex / curDimSize; + IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex; + IndexType curDimOffset = curDimIndex * outputStride[i]; + offset += curDimOffset; + linearIndex = nextDimIndex; + } + + return offset + linearIndex * outputStride[0]; + } +}; + +template +struct CatArrInputTensor { + T* input; + IndexType offset; + IndexType dimSize; + IndexType nElements; +}; + +template +struct OutputTensorSizeStride { + IndexType outputSize[MaxDims]; + IndexType outputStride[MaxDims]; +}; + +/** + * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a + * grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to + * copy each element from each input tensor into the output. + * + * output: base pointer to the storage associated with the output tensor + * inputs: GPU-allocated array of input metadata for each input to concatenate + * in the kernel + * os: the size/stride vectors for the output tensor + * concatDim: dimension along which we are concatenating + * dimStride: the stride of the output tensor at the concatDim + * + * The most important assumption made is that the input tensors are contiguous. + */ +template +#ifdef __HIP_PLATFORM_HCC__ +C10_LAUNCH_BOUNDS_1(512) +#endif +__global__ void CatArrayBatchedCopy( + T* output, + CatArrInputTensor* inputs, + OutputTensorSizeStride os, + const int concatDim, + IndexType dimStride) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs[blockIdx.y].nElements; + + if(tid >= nElements) return; + + T* data = inputs[blockIdx.y].input; + IndexType offset = inputs[blockIdx.y].offset; + IndexType dimSize = inputs[blockIdx.y].dimSize; + IndexType dataOffset = offset * dimStride; + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ + IndexType elementOffset = CatArrIndexToOffset::compute( + os.outputSize, os.outputStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[tid]; + + tid += stride; + } +} + +void check_shape_except_dim(const Tensor &first, const Tensor &second, + int dimension) +{ + int first_dims = first.dim(); + int second_dims = second.dim(); + TORCH_CHECK(first_dims == second_dims, + "Tensors must have same number of dimensions: got ", first_dims, + " and ", second_dims); + for (int dim = 0; dim < first_dims; dim++) { + if (dim == dimension) { + continue; + } + int64_t first_dim_size = at::native::size(first, dim); + int64_t second_dim_size = at::native::size(second, dim); + TORCH_CHECK(first_dim_size == second_dim_size, + "Sizes of tensors must match except in dimension ", dim, ". Got ", + static_cast(first_dim_size), " and ", + static_cast(second_dim_size)); + } +} + +template +void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, + int nDims) { + // First, let's set up our kernel parameters. We start with a raw pointer to + // the storage for the output Tensor. + scalar_t *data = out.data_ptr(); + + // Kernel Parameter + long tensorMetadataSize = + sizeof(CatArrInputTensor) * CAT_ARRAY_BATCH_SIZE; + auto d_inputs_storage = at::empty( + {tensorMetadataSize}, out.options().dtype(at::kByte)); + auto d_inputs = static_cast *>( + d_inputs_storage.data_ptr()); + + OutputTensorSizeStride param; + + // Next, let's initialize the size, stride arrays for the output Tensor. + for (int i = 0; i < nDims; ++i) { + param.outputSize[i] = at::native::size(out, i); + param.outputStride[i] = out.stride(i); + } + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + // Now we loop + int batchCounter = 0; + int64_t offset = 0; + for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) { + // Re-allocate stackInputs every iteration to avoid read-after-write hazard + { + auto stackInputs_storage = at::empty({tensorMetadataSize}, + out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true)); + auto stackInputs = + static_cast *>( + stackInputs_storage.data_ptr()); + for (batchCounter = 0; + batchCounter < CAT_ARRAY_BATCH_SIZE && + (i+batchCounter) < inputs.size(); + ++batchCounter) { + int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + + stackInputs[batchCounter].input = + inputs[i+batchCounter].data_ptr(); + stackInputs[batchCounter].offset = offset; + stackInputs[batchCounter].dimSize = dimSize; + stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel(); + + // update offset + offset += dimSize; + } + at::native::copy_(d_inputs_storage, stackInputs_storage, + /* non_blocking= */ true); + } + + // Next, let's consider how we set our kernel launch parameters. + // We borrow from THCApply, which the kernel's internal indexing + // is based on. + dim3 applyBlock = dim3(32*16); + + //Get grid where x dim fills half gpu and y dim is number of tensors. + //This will have cating two tensors fill the entire grid, but prevent + //many threads from needlessly load meta data if their sizes is small. + dim3 catGrid; + getCatGrid(batchCounter, catGrid); + + + // Template Declarations for dim = 1, 2, 3, 4 +#define HANDLE_CASE(DIMS) \ + CatArrayBatchedCopy<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, d_inputs, param, dimension, param.outputStride[dimension]); + switch (nDims) { + case 1: + HANDLE_CASE(1); + break; + case 2: + HANDLE_CASE(2); + break; + case 3: + HANDLE_CASE(3); + break; + case 4: + HANDLE_CASE(4); + break; + } +#undef HANDLE_CASE + THCudaCheck(cudaGetLastError()); + } +} + +} // namespace + +Tensor cat_cuda(TensorList inputs, int64_t dimension) { + Tensor out = at::empty({0}, inputs.front().options()); + cat_out_cuda(out, inputs, dimension); + return out; +} + +Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { + + // previously, size [0] tensors were the only possible empty tensors; thus, it + // wasn't possible to cat empty tensors unless all the other tensors were + // 1-dimensional, so we allowed these tensors to be "skipped". We maintain + // this behavior for backwards compatibility, but only for this specific size + // (i.e. other empty sizes are not skipped). + // FIXME: warn if this is the case + auto should_skip = [](const Tensor &t) { + return t.dim() == 1 && at::native::size(t, 0) == 0; + }; + bool hasSkippedInput = false; + + const Tensor *notSkippedTensor = NULL; // non-owning reference + int nDims = 0; + + // Inputs cannot alias the output tensor + for (int i = 0; i < inputs.size(); i++) { + auto lap = at::get_overlap_status(out, inputs[i]); + TORCH_CHECK(lap != at::MemOverlapStatus::PARTIAL && + lap != at::MemOverlapStatus::FULL, + "unsupported operation: the input tensors cannot refer to any " + "of the output memory locations. Found overlap in input " + "tensor ", i); + } + + for (int i = 0; i < inputs.size(); i++) + { + if (should_skip(inputs[i])) { + hasSkippedInput = true; + continue; + } + nDims = inputs[i].dim(); + notSkippedTensor = &inputs[i]; + } + + // If all inputs are empty tensors, return an empty tensor + if (notSkippedTensor == NULL) { + return out; + } + + TORCH_CHECK(inputs.size() > 0, "invalid number of inputs ", inputs.size()); + TORCH_CHECK(dimension >= 0, "invalid dimension ", dimension); + + std::vector size(notSkippedTensor->sizes().vec()); + + // Compute size of the result in the cat dimension + int64_t cat_dim_size = 0; + for (int i = 0; i < inputs.size(); i++) { + const Tensor &tensor = inputs[i]; + if (should_skip(tensor)) { + continue; + } + check_shape_except_dim(*notSkippedTensor, tensor, dimension); + cat_dim_size += at::native::size(tensor, dimension); + } + + // Compute the size of the result + size[dimension] = cat_dim_size; + out.resize_(size); + if (out.numel() == 0) { + return out; + } + + // We parallelize the copy if all 6 conditions pass: + // + // 1. There is more than one input tensor + // 2. No empty inputs + // 3. The out tensor is 32-bit indexable + // 4. The number of dimensions is <= 4 + // 5. All input tensors are contiguous (output tensor may be non-contig) + // 6. All input tensors can use 32-bit indexing + // 7. All input tensors are on the same device + + const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(), + [] (const Tensor& t) { + return at::cuda::detail::canUse32BitIndexMath(t); + }); + Device firstDevice = notSkippedTensor->device(); + const bool allSameDevice = std::all_of(inputs.begin(), inputs.end(), + [firstDevice](const Tensor& t) { + return t.device() == firstDevice; + }); + const bool allContiguous = std::all_of(inputs.begin(), inputs.end(), + [](const Tensor& t) { + return !t.defined() || t.is_contiguous(); + }); + if (inputs.size() > 1 && + !hasSkippedInput && + out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + at::cuda::detail::canUse32BitIndexMath(out) && + allContiguous && + all32BitIndexable && + allSameDevice) { + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + out.scalar_type(), "cat_cuda", [&]() { + parallel_cat(out, inputs, dimension, nDims); + }); + + } else { + int64_t offset = 0; + for (int j = 0; j < inputs.size(); j++) + { + if (should_skip(inputs[j])) continue; + int64_t dimSize = at::native::size(inputs[j], dimension); + Tensor nt = at::narrow(out, dimension, offset, dimSize); + copy_(nt, inputs[j]); + offset += dimSize; + } + } + + return out; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 61027accce41..47b75afb97a4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5290,13 +5290,13 @@ - func: _cat(Tensor[] tensors, int dim=0) -> Tensor dispatch: CPU: _cat_cpu - CUDA: legacy::cuda::_th_cat + CUDA: cat_cuda QuantizedCPU: quantized_cat - func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _cat_out_cpu - CUDA: legacy::cuda::_th_cat_out + CUDA: cat_out_cuda QuantizedCPU: quantized_cat_out - func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) diff --git a/aten/src/THC/THCTensorMath.cuh b/aten/src/THC/THCTensorMath.cuh index 4613fbf2cd48..79b66e80ed5a 100644 --- a/aten/src/THC/THCTensorMath.cuh +++ b/aten/src/THC/THCTensorMath.cuh @@ -31,111 +31,4 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t } } -#define CAT_ARRAY_BATCH_SIZE 1024 -#define CAT_ARRAY_MAX_INPUT_DIMS 4 - -inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) { - int curDevice = -1; - cudaGetDevice(&curDevice); - - if (curDevice == -1) { - return false; - } - - // Assume a reasonable number of SMs if no state is available - int numSM = - state ? at::cuda::getCurrentDeviceProperties()->multiProcessorCount : 15; - //X dim of grid for cat array cooperates on a single tensor in the cat. - //Given half of the GPU, full utilization will always occur. - grid = dim3( 2LL * numSM, (long long) nTensors ); - - return true; -} - -// Similar to any other IndexToOffset calculation for copying along a given dimension. -template -struct CatArrIndexToOffset { - static inline __device__ IndexType compute( - const IndexType outputSize[Dims], - const IndexType outputStride[Dims], - const IndexType dimSize, - const unsigned int concatDim, - IndexType linearIndex) { - IndexType offset = 0; - -#pragma unroll - for (int i = Dims - 1; i >= 1; --i) { - IndexType curDimSize = i == concatDim ? dimSize : outputSize[i]; - IndexType nextDimIndex = linearIndex / curDimSize; - IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex; - IndexType curDimOffset = curDimIndex * outputStride[i]; - offset += curDimOffset; - linearIndex = nextDimIndex; - } - - return offset + linearIndex * outputStride[0]; - } -}; - -template -struct CatArrInputTensor { - T* input; - IndexType offset; - IndexType dimSize; - IndexType nElements; -}; - -template -struct OutputTensorSizeStride { - IndexType outputSize[MaxDims]; - IndexType outputStride[MaxDims]; -}; - -/** - * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a grid-stride loop based off of - * the blockIdx.x, threadIdx.x for each input to copy each element from each input tensor into the output. - * - * output: base pointer to the storage associated with the output tensor - * inputs: GPU-allocated array of input metadata for each input to concatenate in the kernel - * os: the size/stride vectors for the output tensor - * concatDim: dimension along which we are concatenating - * dimStride: the stride of the output tensor at the concatDim - * - * The most important assumption made is that the input tensors are contiguous. - */ - - - -template -#ifdef __HIP_PLATFORM_HCC__ -C10_LAUNCH_BOUNDS_1(512) -#endif -__global__ void CatArrayBatchedCopy( - T* output, - CatArrInputTensor* inputs, - OutputTensorSizeStride os, - const int concatDim, - IndexType dimStride) { - - IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; - IndexType nElements = inputs[blockIdx.y].nElements; - - if(tid >= nElements) return; - - T* data = inputs[blockIdx.y].input; - IndexType offset = inputs[blockIdx.y].offset; - IndexType dimSize = inputs[blockIdx.y].dimSize; - IndexType dataOffset = offset * dimStride; - - IndexType stride = gridDim.x * blockDim.x; - - while( tid < nElements){ - IndexType elementOffset = CatArrIndexToOffset::compute( - os.outputSize, os.outputStride, dimSize, concatDim, tid); - output[dataOffset + elementOffset] = data[tid]; - - tid += stride; - } -} - #endif diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu index af53fa0314eb..d86199129140 100644 --- a/aten/src/THC/generic/THCTensorMath.cu +++ b/aten/src/THC/generic/THCTensorMath.cu @@ -42,15 +42,6 @@ THCTensor_(numel)(THCState *state, THCTensor *t) return THCTensor_(nElement)(state, t); } -void THCTensor_(cat)(THCState *state, THCTensor *result, - THCTensor *ta, THCTensor *tb, int dimension) -{ - THCTensor* inputs[2]; - inputs[0] = ta; - inputs[1] = tb; - THCTensor_(catArray)(state, result, inputs, 2, dimension); -} - void THCTensor_(check_shape_except_dim)(THCState *state, THCTensor *first, THCTensor *second, int dimension); inline void THCTensor_(check_shape_except_dim)(THCState *state, @@ -73,185 +64,6 @@ inline void THCTensor_(check_shape_except_dim)(THCState *state, } } -void THCTensor_(catArray)(THCState *state, THCTensor *result, - THCTensor **inputs, int numInputs, int dimension) -{ - // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible - // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors - // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific - // size (i.e. other empty sizes are not skipped). - // FIXME: warn if this is the case - int i, j, cohortMax; - int64_t offset; - bool hasSkippedInput = false; - THCTensor *notSkippedTensor = NULL; // non-owning reference - auto should_skip = [](THCTensor *t) { return t->is_empty() && t->dim() == 1; }; - int nDims = 0; - - // Inputs cannot alias the output tensor - for (int i = 0; i < numInputs; i++) { - auto lap = at::get_overlap_status(result, inputs[i]); - THArgCheck(lap != at::MemOverlapStatus::PARTIAL && - lap != at::MemOverlapStatus::FULL, 0, - "unsupported operation: the input tensors cannot refer to any of the " - "output memory locations. Found overlap in input tensor %d.", i); - } - - for (i = 0; i < numInputs; i++) - { - if (should_skip(inputs[i])) { - hasSkippedInput = true; - continue; - } - nDims = inputs[i]->dim(); - notSkippedTensor = inputs[i]; - } - - // If all inputs are empty tensors, return an empty tensor - if (notSkippedTensor == NULL) { - return; - } - - THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); - THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension); - - std::vector size(nDims); - - // Compute size of the result in the cat dimension - int64_t cat_dim_size = 0; - for (int i = 0; i < numInputs; i++) { - THCTensor *tensor = inputs[i]; - if (should_skip(tensor)) { - continue; - } - THCTensor_(check_shape_except_dim)(state, notSkippedTensor, tensor, dimension); - cat_dim_size += THCTensor_(size)(state, tensor, dimension); - } - - // Compute the size of the result - for (int dim = 0; dim < nDims; dim++) { - int64_t result_dim_size = THCTensor_(size)(state, notSkippedTensor, dim); - if (dim == dimension) { - result_dim_size = cat_dim_size; - } - size[dim] = result_dim_size; - } - THCTensor_(resize)(state, result, size, {}); - - // We parallelize the copy if all 6 conditions pass: - // - // 1. There is more than one input tensor - // 2. No empty inputs - // 3. The result tensor is 32-bit indexable - // 4. The number of dimensions is <= 4 - // 5. All input tensors are contiguous (output tensor may be non-contig) - // 6. All input tensors can use 32-bit indexing - // 7. All input tensors are on the same device - - if (numInputs > 1 && - !hasSkippedInput && - result->dim() <= CAT_ARRAY_MAX_INPUT_DIMS && - THCTensor_canUse32BitIndexMath(state, result) && - THCTensor_allContiguous(state, inputs, numInputs) && - THCTensor_all32BitIndexable(state, inputs, numInputs) && - THCTensor_allSameDevice(state, inputs, numInputs)) { - - // First, let's set up our kernel parameters. We start with a raw pointer to the storage - // for the output Tensor. - scalar_t *data = THCTensor_(data)(state, result); - - // Kernel Parameter - size_t tensorMetadataSize = sizeof(CatArrInputTensor) * CAT_ARRAY_BATCH_SIZE; - auto d_inputs = static_cast *>(THCudaMalloc(state, tensorMetadataSize)); - - OutputTensorSizeStride param; - - // Next, let's initialize the size, stride arrays for the output Tensor. - for (i = 0; i < nDims; ++i) { - param.outputSize[i] = THCTensor_(size)(state, result, i); - param.outputStride[i] = THCTensor_(stride)(state, result, i); - } - - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - - // Template Declarations for dim = 1, 2, 3, 4 -#define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy<<>>(data, d_inputs, param, dimension, param.outputStride[dimension]); - - // Now we loop - offset = 0; - for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) { - // Re-allocate stackInputs every iteration to avoid read-after-write hazard - { - auto stackInputs_owner = THCudaHostAlloc(state, tensorMetadataSize); - CatArrInputTensor* stackInputs = static_cast*>(stackInputs_owner.get()); - cohortMax = 0; - for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) { - int64_t dimSize = THCTensor_(size)(state, inputs[i+j], dimension); - - stackInputs[j].input = THCTensor_(data)(state, inputs[i+j]); - stackInputs[j].offset = offset; - stackInputs[j].dimSize = dimSize; - stackInputs[j].nElements = THCTensor_(nElement)(state, inputs[i+j]); - cohortMax = cohortMax > (int) stackInputs[j].nElements ? cohortMax : (int) stackInputs[j].nElements; - - // update offset - offset += dimSize; - } - THCudaCheck(cudaMemcpyAsync( - d_inputs, - stackInputs, - j * sizeof(CatArrInputTensor), - cudaMemcpyHostToDevice, - stream.stream())); - THCudaHostRecord(state, stackInputs); - } - - // Next, let's consider how we set our kernel launch parameters. - // We borrow from THCApply, which the kernel's internal indexing - // is based on. - dim3 applyBlock = getApplyBlock(); - - //Get grid where x dim fills half gpu and y dim is number of tensors. - //This will have cating two tensors fill the entire grid, but prevent - //many threads from needlessly load meta data if their sizes is small. - dim3 catGrid; - getCatGrid(state, j, catGrid); - - - switch (nDims) { - case 1: - HANDLE_CASE(1); - break; - case 2: - HANDLE_CASE(2); - break; - case 3: - HANDLE_CASE(3); - break; - case 4: - HANDLE_CASE(4); - break; - } - THCudaCheck(cudaGetLastError()); - } - THCudaFree(state, d_inputs); -#undef HANDLE_CASE - } else { - offset = 0; - for (j = 0; j < numInputs; j++) - { - if (should_skip(inputs[j])) continue; - int64_t dimSize = THCTensor_(size)(state, inputs[j], dimension); - THCTensor *nt = THCTensor_(newWithTensor)(state, result); - THCTensor_(narrow)(state, nt, NULL, dimension, offset, dimSize); - THCTensor_(copy)(state, nt, inputs[j]); - THCTensor_(free)(state, nt); - offset += dimSize; - } - } -} - void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self) { diff --git a/aten/src/THC/generic/THCTensorMath.h b/aten/src/THC/generic/THCTensorMath.h index a565cec20872..3860b7177aa2 100644 --- a/aten/src/THC/generic/THCTensorMath.h +++ b/aten/src/THC/generic/THCTensorMath.h @@ -4,8 +4,6 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); -THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); -THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); diff --git a/benchmarks/operator_benchmark/pt/cat_test.py b/benchmarks/operator_benchmark/pt/cat_test.py index 3ba1426602b7..1f5eb9e5ce87 100644 --- a/benchmarks/operator_benchmark/pt/cat_test.py +++ b/benchmarks/operator_benchmark/pt/cat_test.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals import operator_benchmark as op_bench import torch +import random """Microbenchmarks for Cat operator""" @@ -12,11 +13,11 @@ import torch # Configs for PT Cat operator cat_configs_short = op_bench.config_list( - attr_names=['M', 'N', 'K', 'dim'], + attr_names=['sizes', 'N', 'dim'], attrs=[ - [1, 1, 1, 0], - [256, 512, 1, 0], - [512, 512, 2, 1], + [(1, 1, 1), 2, 0], # noqa + [(512, 512, 2), 2, 1], # noqa + [(128, 1024, 2), 2, 1], # noqa ], cross_product_configs={ 'device': ['cpu', 'cuda'], @@ -24,29 +25,81 @@ cat_configs_short = op_bench.config_list( tags=['short'], ) -cat_configs_long = op_bench.cross_product_configs( - M=[128], - N=[128, 1024], - K=[1, 2], - dim=[0, 1, 2], - device=['cpu', 'cuda'], - tags=['long'] +cat_configs_long = op_bench.config_list( + attr_names=['sizes', 'N', 'dim'], + attrs=[ + [(2**10, 2**10, 2), 2, 0], # noqa + [(2**10+1, 2**10-1, 2), 2, 1], # noqa + [(2**10, 2**10, 2), 2, 2], # noqa + + [[ lambda: random.randint(2**6, 2**7), 2**7-17, 2**6+1], # noqa + 5, 0], + [[ 2**6+2**5, lambda: random.randint(2**6, 2**7), 2**6], # noqa + 5, 1], + [[ 2**7, 2**6, lambda: random.randint(2**6, 2**7)], # noqa + 5, 2], + + [[lambda: random.randint(2**5, 2**6), 2**5, 2**6], # noqa + 50, 0], + [[2**5, lambda: random.randint(2**5, 2**6), 2**6], # noqa + 50, 1], + [[2**5+1, 2**6+1, lambda: random.randint(2**5, 2**6)], # noqa + 50, 2], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=['long'], ) +# There is a different codepath on CUDA for >4 dimensions +cat_configs_multidim = op_bench.config_list( + attr_names=['sizes', 'N', 'dim'], + attrs=[ + [(2**6, 2**5, 2**2, 2**4, 2**5), 2, 2], # noqa + [(2**4, 2**5, 2**2, 2**4, 2**5), 8, 2], # noqa + [(2**3+1, 2**5-1, 2**2+1, 2**4-1, 2**5+1), 17, 4], # noqa + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=['multidim'], +) + +cat_configs_manyinputs = op_bench.config_list( + attr_names=['sizes', 'N', 'dim'], + attrs=[ + [[lambda: random.randint(1, 10000)], 100, 0], + [[lambda: random.randint(1, 1000)], 1000, 0], + [[lambda: random.randint(1, 500)], 2000, 0], + [[lambda: random.randint(1, 300)], 3000, 0], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=['manyinputs'], +) class CatBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K, dim, device): - self.input_one = torch.rand(M, N, K, device=device) + def init(self, sizes, N, dim, device): + random.seed(42) + self.inputs = [] + for i in range(N): + current_sizes = [old_size() if callable(old_size) else old_size + for old_size in sizes] + self.inputs.append(torch.rand(current_sizes, device=device)) self.dim = dim self.set_module_name('cat') def forward(self): - return torch.cat((self.input_one, self.input_one), dim=self.dim) + return torch.cat(self.inputs, dim=self.dim) -op_bench.generate_pt_test(cat_configs_short + cat_configs_long, +op_bench.generate_pt_test(cat_configs_short + + cat_configs_long + + cat_configs_multidim + + cat_configs_manyinputs, CatBenchmark) - if __name__ == "__main__": op_bench.benchmark_runner.main()