mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Migrate _cat from TH to ATen (CUDA) (#33237)
Summary: Fixes https://github.com/pytorch/pytorch/issues/24520 Benchmarks: Upstream: ``` $ python -m pt.cat_test --tag_filter all --device cuda --omp_num_threads 1 --mkl_num_threads 1 # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : all # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1,1,1)_N2_dim0_cuda # Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 17.355 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 30.718 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(128,1024,2)_N2_dim1_cuda # Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 17.329 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 30.176 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim0_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 74.417 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1025,1023,2)_N2_dim1_cuda # Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 75.728 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim2_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 190.165 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fa8876fcf28>,111,65]_N5_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fa8876fcf28>, 111, 65], N: 5, dim: 0, device: cuda Forward Execution Time (us) : 57.711 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[96,<function<lambda>at0x7fa886237048>,64]_N5_dim1_cuda # Input: sizes: [96, <function <lambda> at 0x7fa886237048>, 64], N: 5, dim: 1, device: cuda Forward Execution Time (us) : 49.903 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[128,64,<function<lambda>at0x7fa7b57bb840>]_N5_dim2_cuda # Input: sizes: [128, 64, <function <lambda> at 0x7fa7b57bb840>], N: 5, dim: 2, device: cuda Forward Execution Time (us) : 84.181 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fa7b57bba60>,32,64]_N50_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fa7b57bba60>, 32, 64], N: 50, dim: 0, device: cuda Forward Execution Time (us) : 82.339 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[32,<function<lambda>at0x7fa7b57bbae8>,64]_N50_dim1_cuda # Input: sizes: [32, <function <lambda> at 0x7fa7b57bbae8>, 64], N: 50, dim: 1, device: cuda Forward Execution Time (us) : 82.312 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[33,65,<function<lambda>at0x7fa7b57bbb70>]_N50_dim2_cuda # Input: sizes: [33, 65, <function <lambda> at 0x7fa7b57bbb70>], N: 50, dim: 2, device: cuda Forward Execution Time (us) : 90.715 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda # Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 129.021 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda # Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda Forward Execution Time (us) : 142.966 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda # Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda Forward Execution Time (us) : 387.023 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fa7b57bbbf8>]_N100_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fa7b57bbbf8>], N: 100, dim: 0, device: cuda Forward Execution Time (us) : 36.647 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fa7b57bbc80>]_N1000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fa7b57bbc80>], N: 1000, dim: 0, device: cuda Forward Execution Time (us) : 278.890 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fa7b57bbd08>]_N2000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fa7b57bbd08>], N: 2000, dim: 0, device: cuda Forward Execution Time (us) : 557.752 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fa7b57bbd90>]_N3000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fa7b57bbd90>], N: 3000, dim: 0, device: cuda Forward Execution Time (us) : 842.512 ``` New version: ``` $ python -m pt.cat_test --tag_filter all --device cuda --omp_num_threads 1 --mkl_num_threads 1 # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : all # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1,1,1)_N2_dim0_cuda # Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 24.419 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 25.025 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(128,1024,2)_N2_dim1_cuda # Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 24.247 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 25.098 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim0_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 74.441 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1025,1023,2)_N2_dim1_cuda # Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 74.866 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim2_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 189.280 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1c9b056048>,111,65]_N5_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1c9b056048>, 111, 65], N: 5, dim: 0, device: cuda Forward Execution Time (us) : 57.629 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[96,<function<lambda>at0x7f1c9b0560d0>,64]_N5_dim1_cuda # Input: sizes: [96, <function <lambda> at 0x7f1c9b0560d0>, 64], N: 5, dim: 1, device: cuda Forward Execution Time (us) : 49.975 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[128,64,<function<lambda>at0x7f1bce8f38c8>]_N5_dim2_cuda # Input: sizes: [128, 64, <function <lambda> at 0x7f1bce8f38c8>], N: 5, dim: 2, device: cuda Forward Execution Time (us) : 83.643 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bce8f3ae8>,32,64]_N50_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bce8f3ae8>, 32, 64], N: 50, dim: 0, device: cuda Forward Execution Time (us) : 82.307 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[32,<function<lambda>at0x7f1bce8f3b70>,64]_N50_dim1_cuda # Input: sizes: [32, <function <lambda> at 0x7f1bce8f3b70>, 64], N: 50, dim: 1, device: cuda Forward Execution Time (us) : 82.323 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[33,65,<function<lambda>at0x7f1bce8f3bf8>]_N50_dim2_cuda # Input: sizes: [33, 65, <function <lambda> at 0x7f1bce8f3bf8>], N: 50, dim: 2, device: cuda Forward Execution Time (us) : 90.549 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda # Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 129.022 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda # Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda Forward Execution Time (us) : 142.969 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda # Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda Forward Execution Time (us) : 386.973 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bce8f3c80>]_N100_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bce8f3c80>], N: 100, dim: 0, device: cuda Forward Execution Time (us) : 43.800 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bce8f3d08>]_N1000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bce8f3d08>], N: 1000, dim: 0, device: cuda Forward Execution Time (us) : 279.023 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bce8f3d90>]_N2000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bce8f3d90>], N: 2000, dim: 0, device: cuda Forward Execution Time (us) : 565.790 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bce8f3e18>]_N3000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bce8f3e18>], N: 3000, dim: 0, device: cuda Forward Execution Time (us) : 845.153 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/33237 Differential Revision: D20069181 Pulled By: ngimel fbshipit-source-id: b392e1ffd72c0d8df0c5a2d3ac96f59b37c84e32
This commit is contained in:
committed by
Facebook Github Bot
parent
97da60d511
commit
b10a39bb32
@ -1190,19 +1190,3 @@
|
||||
- THTensor* self
|
||||
- THTensor* src
|
||||
]]
|
||||
|
||||
[[
|
||||
name: _th_cat
|
||||
cname: catArray
|
||||
variants: [function]
|
||||
backends:
|
||||
- CUDA
|
||||
cuda_bool: True
|
||||
cuda_bfloat16: True
|
||||
return: self
|
||||
arguments:
|
||||
- arg: THTensor* self
|
||||
output: True
|
||||
- TensorList tensors
|
||||
- int64_t dim
|
||||
]]
|
||||
|
@ -74,6 +74,8 @@ bool canUse32BitIndexMath(const Tensor& t, int64_t max_elem) {
|
||||
int64_t elements = t.numel();
|
||||
if (elements >= max_elem) {
|
||||
return false;
|
||||
} else if (elements == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t offset = 0;
|
||||
|
360
aten/src/ATen/native/cuda/Shape.cu
Normal file
360
aten/src/ATen/native/cuda/Shape.cu
Normal file
@ -0,0 +1,360 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/UndefinedTensorImpl.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
|
||||
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;
|
||||
|
||||
namespace {
|
||||
|
||||
inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) {
|
||||
const int numSM = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
|
||||
//X dim of grid for cat array cooperates on a single tensor in the cat.
|
||||
//Given half of the GPU, full utilization will always occur.
|
||||
grid = dim3( 2LL * numSM, (long long) nTensors );
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Similar to any other IndexToOffset calculation for copying along a given
|
||||
// dimension.
|
||||
template <typename IndexType, int Dims>
|
||||
struct CatArrIndexToOffset {
|
||||
static inline __device__ IndexType compute(
|
||||
const IndexType outputSize[Dims],
|
||||
const IndexType outputStride[Dims],
|
||||
const IndexType dimSize,
|
||||
const unsigned int concatDim,
|
||||
IndexType linearIndex) {
|
||||
IndexType offset = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = Dims - 1; i >= 1; --i) {
|
||||
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
|
||||
IndexType nextDimIndex = linearIndex / curDimSize;
|
||||
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
|
||||
IndexType curDimOffset = curDimIndex * outputStride[i];
|
||||
offset += curDimOffset;
|
||||
linearIndex = nextDimIndex;
|
||||
}
|
||||
|
||||
return offset + linearIndex * outputStride[0];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IndexType>
|
||||
struct CatArrInputTensor {
|
||||
T* input;
|
||||
IndexType offset;
|
||||
IndexType dimSize;
|
||||
IndexType nElements;
|
||||
};
|
||||
|
||||
template<typename IndexType, unsigned int MaxDims>
|
||||
struct OutputTensorSizeStride {
|
||||
IndexType outputSize[MaxDims];
|
||||
IndexType outputStride[MaxDims];
|
||||
};
|
||||
|
||||
/**
|
||||
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a
|
||||
* grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to
|
||||
* copy each element from each input tensor into the output.
|
||||
*
|
||||
* output: base pointer to the storage associated with the output tensor
|
||||
* inputs: GPU-allocated array of input metadata for each input to concatenate
|
||||
* in the kernel
|
||||
* os: the size/stride vectors for the output tensor
|
||||
* concatDim: dimension along which we are concatenating
|
||||
* dimStride: the stride of the output tensor at the concatDim
|
||||
*
|
||||
* The most important assumption made is that the input tensors are contiguous.
|
||||
*/
|
||||
template <typename T, typename IndexType, int Dims>
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
C10_LAUNCH_BOUNDS_1(512)
|
||||
#endif
|
||||
__global__ void CatArrayBatchedCopy(
|
||||
T* output,
|
||||
CatArrInputTensor<T, IndexType>* inputs,
|
||||
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
|
||||
const int concatDim,
|
||||
IndexType dimStride) {
|
||||
|
||||
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
IndexType nElements = inputs[blockIdx.y].nElements;
|
||||
|
||||
if(tid >= nElements) return;
|
||||
|
||||
T* data = inputs[blockIdx.y].input;
|
||||
IndexType offset = inputs[blockIdx.y].offset;
|
||||
IndexType dimSize = inputs[blockIdx.y].dimSize;
|
||||
IndexType dataOffset = offset * dimStride;
|
||||
|
||||
IndexType stride = gridDim.x * blockDim.x;
|
||||
|
||||
while( tid < nElements){
|
||||
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
|
||||
os.outputSize, os.outputStride, dimSize, concatDim, tid);
|
||||
output[dataOffset + elementOffset] = data[tid];
|
||||
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
void check_shape_except_dim(const Tensor &first, const Tensor &second,
|
||||
int dimension)
|
||||
{
|
||||
int first_dims = first.dim();
|
||||
int second_dims = second.dim();
|
||||
TORCH_CHECK(first_dims == second_dims,
|
||||
"Tensors must have same number of dimensions: got ", first_dims,
|
||||
" and ", second_dims);
|
||||
for (int dim = 0; dim < first_dims; dim++) {
|
||||
if (dim == dimension) {
|
||||
continue;
|
||||
}
|
||||
int64_t first_dim_size = at::native::size(first, dim);
|
||||
int64_t second_dim_size = at::native::size(second, dim);
|
||||
TORCH_CHECK(first_dim_size == second_dim_size,
|
||||
"Sizes of tensors must match except in dimension ", dim, ". Got ",
|
||||
static_cast<long long>(first_dim_size), " and ",
|
||||
static_cast<long long>(second_dim_size));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
|
||||
int nDims) {
|
||||
// First, let's set up our kernel parameters. We start with a raw pointer to
|
||||
// the storage for the output Tensor.
|
||||
scalar_t *data = out.data_ptr<scalar_t>();
|
||||
|
||||
// Kernel Parameter
|
||||
long tensorMetadataSize =
|
||||
sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
|
||||
auto d_inputs_storage = at::empty(
|
||||
{tensorMetadataSize}, out.options().dtype(at::kByte));
|
||||
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
|
||||
d_inputs_storage.data_ptr());
|
||||
|
||||
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
|
||||
|
||||
// Next, let's initialize the size, stride arrays for the output Tensor.
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
param.outputSize[i] = at::native::size(out, i);
|
||||
param.outputStride[i] = out.stride(i);
|
||||
}
|
||||
|
||||
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Now we loop
|
||||
int batchCounter = 0;
|
||||
int64_t offset = 0;
|
||||
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
|
||||
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
|
||||
{
|
||||
auto stackInputs_storage = at::empty({tensorMetadataSize},
|
||||
out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true));
|
||||
auto stackInputs =
|
||||
static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
|
||||
stackInputs_storage.data_ptr());
|
||||
for (batchCounter = 0;
|
||||
batchCounter < CAT_ARRAY_BATCH_SIZE &&
|
||||
(i+batchCounter) < inputs.size();
|
||||
++batchCounter) {
|
||||
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);
|
||||
|
||||
stackInputs[batchCounter].input =
|
||||
inputs[i+batchCounter].data_ptr<scalar_t>();
|
||||
stackInputs[batchCounter].offset = offset;
|
||||
stackInputs[batchCounter].dimSize = dimSize;
|
||||
stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel();
|
||||
|
||||
// update offset
|
||||
offset += dimSize;
|
||||
}
|
||||
at::native::copy_(d_inputs_storage, stackInputs_storage,
|
||||
/* non_blocking= */ true);
|
||||
}
|
||||
|
||||
// Next, let's consider how we set our kernel launch parameters.
|
||||
// We borrow from THCApply, which the kernel's internal indexing
|
||||
// is based on.
|
||||
dim3 applyBlock = dim3(32*16);
|
||||
|
||||
//Get grid where x dim fills half gpu and y dim is number of tensors.
|
||||
//This will have cating two tensors fill the entire grid, but prevent
|
||||
//many threads from needlessly load meta data if their sizes is small.
|
||||
dim3 catGrid;
|
||||
getCatGrid(batchCounter, catGrid);
|
||||
|
||||
|
||||
// Template Declarations for dim = 1, 2, 3, 4
|
||||
#define HANDLE_CASE(DIMS) \
|
||||
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, d_inputs, param, dimension, param.outputStride[dimension]);
|
||||
switch (nDims) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
case 2:
|
||||
HANDLE_CASE(2);
|
||||
break;
|
||||
case 3:
|
||||
HANDLE_CASE(3);
|
||||
break;
|
||||
case 4:
|
||||
HANDLE_CASE(4);
|
||||
break;
|
||||
}
|
||||
#undef HANDLE_CASE
|
||||
THCudaCheck(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor cat_cuda(TensorList inputs, int64_t dimension) {
|
||||
Tensor out = at::empty({0}, inputs.front().options());
|
||||
cat_out_cuda(out, inputs, dimension);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
|
||||
|
||||
// previously, size [0] tensors were the only possible empty tensors; thus, it
|
||||
// wasn't possible to cat empty tensors unless all the other tensors were
|
||||
// 1-dimensional, so we allowed these tensors to be "skipped". We maintain
|
||||
// this behavior for backwards compatibility, but only for this specific size
|
||||
// (i.e. other empty sizes are not skipped).
|
||||
// FIXME: warn if this is the case
|
||||
auto should_skip = [](const Tensor &t) {
|
||||
return t.dim() == 1 && at::native::size(t, 0) == 0;
|
||||
};
|
||||
bool hasSkippedInput = false;
|
||||
|
||||
const Tensor *notSkippedTensor = NULL; // non-owning reference
|
||||
int nDims = 0;
|
||||
|
||||
// Inputs cannot alias the output tensor
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
auto lap = at::get_overlap_status(out, inputs[i]);
|
||||
TORCH_CHECK(lap != at::MemOverlapStatus::PARTIAL &&
|
||||
lap != at::MemOverlapStatus::FULL,
|
||||
"unsupported operation: the input tensors cannot refer to any "
|
||||
"of the output memory locations. Found overlap in input "
|
||||
"tensor ", i);
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputs.size(); i++)
|
||||
{
|
||||
if (should_skip(inputs[i])) {
|
||||
hasSkippedInput = true;
|
||||
continue;
|
||||
}
|
||||
nDims = inputs[i].dim();
|
||||
notSkippedTensor = &inputs[i];
|
||||
}
|
||||
|
||||
// If all inputs are empty tensors, return an empty tensor
|
||||
if (notSkippedTensor == NULL) {
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_CHECK(inputs.size() > 0, "invalid number of inputs ", inputs.size());
|
||||
TORCH_CHECK(dimension >= 0, "invalid dimension ", dimension);
|
||||
|
||||
std::vector<int64_t> size(notSkippedTensor->sizes().vec());
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
const Tensor &tensor = inputs[i];
|
||||
if (should_skip(tensor)) {
|
||||
continue;
|
||||
}
|
||||
check_shape_except_dim(*notSkippedTensor, tensor, dimension);
|
||||
cat_dim_size += at::native::size(tensor, dimension);
|
||||
}
|
||||
|
||||
// Compute the size of the result
|
||||
size[dimension] = cat_dim_size;
|
||||
out.resize_(size);
|
||||
if (out.numel() == 0) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// We parallelize the copy if all 6 conditions pass:
|
||||
//
|
||||
// 1. There is more than one input tensor
|
||||
// 2. No empty inputs
|
||||
// 3. The out tensor is 32-bit indexable
|
||||
// 4. The number of dimensions is <= 4
|
||||
// 5. All input tensors are contiguous (output tensor may be non-contig)
|
||||
// 6. All input tensors can use 32-bit indexing
|
||||
// 7. All input tensors are on the same device
|
||||
|
||||
const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(),
|
||||
[] (const Tensor& t) {
|
||||
return at::cuda::detail::canUse32BitIndexMath(t);
|
||||
});
|
||||
Device firstDevice = notSkippedTensor->device();
|
||||
const bool allSameDevice = std::all_of(inputs.begin(), inputs.end(),
|
||||
[firstDevice](const Tensor& t) {
|
||||
return t.device() == firstDevice;
|
||||
});
|
||||
const bool allContiguous = std::all_of(inputs.begin(), inputs.end(),
|
||||
[](const Tensor& t) {
|
||||
return !t.defined() || t.is_contiguous();
|
||||
});
|
||||
if (inputs.size() > 1 &&
|
||||
!hasSkippedInput &&
|
||||
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
|
||||
at::cuda::detail::canUse32BitIndexMath(out) &&
|
||||
allContiguous &&
|
||||
all32BitIndexable &&
|
||||
allSameDevice) {
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
|
||||
out.scalar_type(), "cat_cuda", [&]() {
|
||||
parallel_cat<scalar_t>(out, inputs, dimension, nDims);
|
||||
});
|
||||
|
||||
} else {
|
||||
int64_t offset = 0;
|
||||
for (int j = 0; j < inputs.size(); j++)
|
||||
{
|
||||
if (should_skip(inputs[j])) continue;
|
||||
int64_t dimSize = at::native::size(inputs[j], dimension);
|
||||
Tensor nt = at::narrow(out, dimension, offset, dimSize);
|
||||
copy_(nt, inputs[j]);
|
||||
offset += dimSize;
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -5290,13 +5290,13 @@
|
||||
- func: _cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
dispatch:
|
||||
CPU: _cat_cpu
|
||||
CUDA: legacy::cuda::_th_cat
|
||||
CUDA: cat_cuda
|
||||
QuantizedCPU: quantized_cat
|
||||
|
||||
- func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: _cat_out_cpu
|
||||
CUDA: legacy::cuda::_th_cat_out
|
||||
CUDA: cat_out_cuda
|
||||
QuantizedCPU: quantized_cat_out
|
||||
|
||||
- func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor)
|
||||
|
@ -31,111 +31,4 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t
|
||||
}
|
||||
}
|
||||
|
||||
#define CAT_ARRAY_BATCH_SIZE 1024
|
||||
#define CAT_ARRAY_MAX_INPUT_DIMS 4
|
||||
|
||||
inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
|
||||
if (curDevice == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Assume a reasonable number of SMs if no state is available
|
||||
int numSM =
|
||||
state ? at::cuda::getCurrentDeviceProperties()->multiProcessorCount : 15;
|
||||
//X dim of grid for cat array cooperates on a single tensor in the cat.
|
||||
//Given half of the GPU, full utilization will always occur.
|
||||
grid = dim3( 2LL * numSM, (long long) nTensors );
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Similar to any other IndexToOffset calculation for copying along a given dimension.
|
||||
template <typename IndexType, int Dims>
|
||||
struct CatArrIndexToOffset {
|
||||
static inline __device__ IndexType compute(
|
||||
const IndexType outputSize[Dims],
|
||||
const IndexType outputStride[Dims],
|
||||
const IndexType dimSize,
|
||||
const unsigned int concatDim,
|
||||
IndexType linearIndex) {
|
||||
IndexType offset = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = Dims - 1; i >= 1; --i) {
|
||||
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
|
||||
IndexType nextDimIndex = linearIndex / curDimSize;
|
||||
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
|
||||
IndexType curDimOffset = curDimIndex * outputStride[i];
|
||||
offset += curDimOffset;
|
||||
linearIndex = nextDimIndex;
|
||||
}
|
||||
|
||||
return offset + linearIndex * outputStride[0];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IndexType>
|
||||
struct CatArrInputTensor {
|
||||
T* input;
|
||||
IndexType offset;
|
||||
IndexType dimSize;
|
||||
IndexType nElements;
|
||||
};
|
||||
|
||||
template<typename IndexType, unsigned int MaxDims>
|
||||
struct OutputTensorSizeStride {
|
||||
IndexType outputSize[MaxDims];
|
||||
IndexType outputStride[MaxDims];
|
||||
};
|
||||
|
||||
/**
|
||||
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a grid-stride loop based off of
|
||||
* the blockIdx.x, threadIdx.x for each input to copy each element from each input tensor into the output.
|
||||
*
|
||||
* output: base pointer to the storage associated with the output tensor
|
||||
* inputs: GPU-allocated array of input metadata for each input to concatenate in the kernel
|
||||
* os: the size/stride vectors for the output tensor
|
||||
* concatDim: dimension along which we are concatenating
|
||||
* dimStride: the stride of the output tensor at the concatDim
|
||||
*
|
||||
* The most important assumption made is that the input tensors are contiguous.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
template <typename T, typename IndexType, int Dims>
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
C10_LAUNCH_BOUNDS_1(512)
|
||||
#endif
|
||||
__global__ void CatArrayBatchedCopy(
|
||||
T* output,
|
||||
CatArrInputTensor<T, IndexType>* inputs,
|
||||
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
|
||||
const int concatDim,
|
||||
IndexType dimStride) {
|
||||
|
||||
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
IndexType nElements = inputs[blockIdx.y].nElements;
|
||||
|
||||
if(tid >= nElements) return;
|
||||
|
||||
T* data = inputs[blockIdx.y].input;
|
||||
IndexType offset = inputs[blockIdx.y].offset;
|
||||
IndexType dimSize = inputs[blockIdx.y].dimSize;
|
||||
IndexType dataOffset = offset * dimStride;
|
||||
|
||||
IndexType stride = gridDim.x * blockDim.x;
|
||||
|
||||
while( tid < nElements){
|
||||
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
|
||||
os.outputSize, os.outputStride, dimSize, concatDim, tid);
|
||||
output[dataOffset + elementOffset] = data[tid];
|
||||
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -42,15 +42,6 @@ THCTensor_(numel)(THCState *state, THCTensor *t)
|
||||
return THCTensor_(nElement)(state, t);
|
||||
}
|
||||
|
||||
void THCTensor_(cat)(THCState *state, THCTensor *result,
|
||||
THCTensor *ta, THCTensor *tb, int dimension)
|
||||
{
|
||||
THCTensor* inputs[2];
|
||||
inputs[0] = ta;
|
||||
inputs[1] = tb;
|
||||
THCTensor_(catArray)(state, result, inputs, 2, dimension);
|
||||
}
|
||||
|
||||
void THCTensor_(check_shape_except_dim)(THCState *state,
|
||||
THCTensor *first, THCTensor *second, int dimension);
|
||||
inline void THCTensor_(check_shape_except_dim)(THCState *state,
|
||||
@ -73,185 +64,6 @@ inline void THCTensor_(check_shape_except_dim)(THCState *state,
|
||||
}
|
||||
}
|
||||
|
||||
void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
||||
THCTensor **inputs, int numInputs, int dimension)
|
||||
{
|
||||
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
|
||||
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
|
||||
// to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
|
||||
// size (i.e. other empty sizes are not skipped).
|
||||
// FIXME: warn if this is the case
|
||||
int i, j, cohortMax;
|
||||
int64_t offset;
|
||||
bool hasSkippedInput = false;
|
||||
THCTensor *notSkippedTensor = NULL; // non-owning reference
|
||||
auto should_skip = [](THCTensor *t) { return t->is_empty() && t->dim() == 1; };
|
||||
int nDims = 0;
|
||||
|
||||
// Inputs cannot alias the output tensor
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
auto lap = at::get_overlap_status(result, inputs[i]);
|
||||
THArgCheck(lap != at::MemOverlapStatus::PARTIAL &&
|
||||
lap != at::MemOverlapStatus::FULL, 0,
|
||||
"unsupported operation: the input tensors cannot refer to any of the "
|
||||
"output memory locations. Found overlap in input tensor %d.", i);
|
||||
}
|
||||
|
||||
for (i = 0; i < numInputs; i++)
|
||||
{
|
||||
if (should_skip(inputs[i])) {
|
||||
hasSkippedInput = true;
|
||||
continue;
|
||||
}
|
||||
nDims = inputs[i]->dim();
|
||||
notSkippedTensor = inputs[i];
|
||||
}
|
||||
|
||||
// If all inputs are empty tensors, return an empty tensor
|
||||
if (notSkippedTensor == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
|
||||
THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension);
|
||||
|
||||
std::vector<int64_t> size(nDims);
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
THCTensor *tensor = inputs[i];
|
||||
if (should_skip(tensor)) {
|
||||
continue;
|
||||
}
|
||||
THCTensor_(check_shape_except_dim)(state, notSkippedTensor, tensor, dimension);
|
||||
cat_dim_size += THCTensor_(size)(state, tensor, dimension);
|
||||
}
|
||||
|
||||
// Compute the size of the result
|
||||
for (int dim = 0; dim < nDims; dim++) {
|
||||
int64_t result_dim_size = THCTensor_(size)(state, notSkippedTensor, dim);
|
||||
if (dim == dimension) {
|
||||
result_dim_size = cat_dim_size;
|
||||
}
|
||||
size[dim] = result_dim_size;
|
||||
}
|
||||
THCTensor_(resize)(state, result, size, {});
|
||||
|
||||
// We parallelize the copy if all 6 conditions pass:
|
||||
//
|
||||
// 1. There is more than one input tensor
|
||||
// 2. No empty inputs
|
||||
// 3. The result tensor is 32-bit indexable
|
||||
// 4. The number of dimensions is <= 4
|
||||
// 5. All input tensors are contiguous (output tensor may be non-contig)
|
||||
// 6. All input tensors can use 32-bit indexing
|
||||
// 7. All input tensors are on the same device
|
||||
|
||||
if (numInputs > 1 &&
|
||||
!hasSkippedInput &&
|
||||
result->dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
|
||||
THCTensor_canUse32BitIndexMath(state, result) &&
|
||||
THCTensor_allContiguous(state, inputs, numInputs) &&
|
||||
THCTensor_all32BitIndexable(state, inputs, numInputs) &&
|
||||
THCTensor_allSameDevice(state, inputs, numInputs)) {
|
||||
|
||||
// First, let's set up our kernel parameters. We start with a raw pointer to the storage
|
||||
// for the output Tensor.
|
||||
scalar_t *data = THCTensor_(data)(state, result);
|
||||
|
||||
// Kernel Parameter
|
||||
size_t tensorMetadataSize = sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
|
||||
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(THCudaMalloc(state, tensorMetadataSize));
|
||||
|
||||
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
|
||||
|
||||
// Next, let's initialize the size, stride arrays for the output Tensor.
|
||||
for (i = 0; i < nDims; ++i) {
|
||||
param.outputSize[i] = THCTensor_(size)(state, result, i);
|
||||
param.outputStride[i] = THCTensor_(stride)(state, result, i);
|
||||
}
|
||||
|
||||
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Template Declarations for dim = 1, 2, 3, 4
|
||||
#define HANDLE_CASE(DIMS) \
|
||||
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<catGrid, applyBlock, 0, stream.stream()>>>(data, d_inputs, param, dimension, param.outputStride[dimension]);
|
||||
|
||||
// Now we loop
|
||||
offset = 0;
|
||||
for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) {
|
||||
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
|
||||
{
|
||||
auto stackInputs_owner = THCudaHostAlloc(state, tensorMetadataSize);
|
||||
CatArrInputTensor<scalar_t, unsigned int>* stackInputs = static_cast<CatArrInputTensor<scalar_t, unsigned int>*>(stackInputs_owner.get());
|
||||
cohortMax = 0;
|
||||
for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) {
|
||||
int64_t dimSize = THCTensor_(size)(state, inputs[i+j], dimension);
|
||||
|
||||
stackInputs[j].input = THCTensor_(data)(state, inputs[i+j]);
|
||||
stackInputs[j].offset = offset;
|
||||
stackInputs[j].dimSize = dimSize;
|
||||
stackInputs[j].nElements = THCTensor_(nElement)(state, inputs[i+j]);
|
||||
cohortMax = cohortMax > (int) stackInputs[j].nElements ? cohortMax : (int) stackInputs[j].nElements;
|
||||
|
||||
// update offset
|
||||
offset += dimSize;
|
||||
}
|
||||
THCudaCheck(cudaMemcpyAsync(
|
||||
d_inputs,
|
||||
stackInputs,
|
||||
j * sizeof(CatArrInputTensor<scalar_t, unsigned int>),
|
||||
cudaMemcpyHostToDevice,
|
||||
stream.stream()));
|
||||
THCudaHostRecord(state, stackInputs);
|
||||
}
|
||||
|
||||
// Next, let's consider how we set our kernel launch parameters.
|
||||
// We borrow from THCApply, which the kernel's internal indexing
|
||||
// is based on.
|
||||
dim3 applyBlock = getApplyBlock();
|
||||
|
||||
//Get grid where x dim fills half gpu and y dim is number of tensors.
|
||||
//This will have cating two tensors fill the entire grid, but prevent
|
||||
//many threads from needlessly load meta data if their sizes is small.
|
||||
dim3 catGrid;
|
||||
getCatGrid(state, j, catGrid);
|
||||
|
||||
|
||||
switch (nDims) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
case 2:
|
||||
HANDLE_CASE(2);
|
||||
break;
|
||||
case 3:
|
||||
HANDLE_CASE(3);
|
||||
break;
|
||||
case 4:
|
||||
HANDLE_CASE(4);
|
||||
break;
|
||||
}
|
||||
THCudaCheck(cudaGetLastError());
|
||||
}
|
||||
THCudaFree(state, d_inputs);
|
||||
#undef HANDLE_CASE
|
||||
} else {
|
||||
offset = 0;
|
||||
for (j = 0; j < numInputs; j++)
|
||||
{
|
||||
if (should_skip(inputs[j])) continue;
|
||||
int64_t dimSize = THCTensor_(size)(state, inputs[j], dimension);
|
||||
THCTensor *nt = THCTensor_(newWithTensor)(state, result);
|
||||
THCTensor_(narrow)(state, nt, NULL, dimension, offset, dimSize);
|
||||
THCTensor_(copy)(state, nt, inputs[j]);
|
||||
THCTensor_(free)(state, nt);
|
||||
offset += dimSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
|
||||
THCTensor *self)
|
||||
{
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value);
|
||||
THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
|
||||
THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
|
||||
THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
|
||||
THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
|
||||
THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
|
||||
|
||||
|
@ -5,6 +5,7 @@ from __future__ import unicode_literals
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
"""Microbenchmarks for Cat operator"""
|
||||
@ -12,11 +13,11 @@ import torch
|
||||
|
||||
# Configs for PT Cat operator
|
||||
cat_configs_short = op_bench.config_list(
|
||||
attr_names=['M', 'N', 'K', 'dim'],
|
||||
attr_names=['sizes', 'N', 'dim'],
|
||||
attrs=[
|
||||
[1, 1, 1, 0],
|
||||
[256, 512, 1, 0],
|
||||
[512, 512, 2, 1],
|
||||
[(1, 1, 1), 2, 0], # noqa
|
||||
[(512, 512, 2), 2, 1], # noqa
|
||||
[(128, 1024, 2), 2, 1], # noqa
|
||||
],
|
||||
cross_product_configs={
|
||||
'device': ['cpu', 'cuda'],
|
||||
@ -24,29 +25,81 @@ cat_configs_short = op_bench.config_list(
|
||||
tags=['short'],
|
||||
)
|
||||
|
||||
cat_configs_long = op_bench.cross_product_configs(
|
||||
M=[128],
|
||||
N=[128, 1024],
|
||||
K=[1, 2],
|
||||
dim=[0, 1, 2],
|
||||
device=['cpu', 'cuda'],
|
||||
tags=['long']
|
||||
cat_configs_long = op_bench.config_list(
|
||||
attr_names=['sizes', 'N', 'dim'],
|
||||
attrs=[
|
||||
[(2**10, 2**10, 2), 2, 0], # noqa
|
||||
[(2**10+1, 2**10-1, 2), 2, 1], # noqa
|
||||
[(2**10, 2**10, 2), 2, 2], # noqa
|
||||
|
||||
[[ lambda: random.randint(2**6, 2**7), 2**7-17, 2**6+1], # noqa
|
||||
5, 0],
|
||||
[[ 2**6+2**5, lambda: random.randint(2**6, 2**7), 2**6], # noqa
|
||||
5, 1],
|
||||
[[ 2**7, 2**6, lambda: random.randint(2**6, 2**7)], # noqa
|
||||
5, 2],
|
||||
|
||||
[[lambda: random.randint(2**5, 2**6), 2**5, 2**6], # noqa
|
||||
50, 0],
|
||||
[[2**5, lambda: random.randint(2**5, 2**6), 2**6], # noqa
|
||||
50, 1],
|
||||
[[2**5+1, 2**6+1, lambda: random.randint(2**5, 2**6)], # noqa
|
||||
50, 2],
|
||||
],
|
||||
cross_product_configs={
|
||||
'device': ['cpu', 'cuda'],
|
||||
},
|
||||
tags=['long'],
|
||||
)
|
||||
|
||||
# There is a different codepath on CUDA for >4 dimensions
|
||||
cat_configs_multidim = op_bench.config_list(
|
||||
attr_names=['sizes', 'N', 'dim'],
|
||||
attrs=[
|
||||
[(2**6, 2**5, 2**2, 2**4, 2**5), 2, 2], # noqa
|
||||
[(2**4, 2**5, 2**2, 2**4, 2**5), 8, 2], # noqa
|
||||
[(2**3+1, 2**5-1, 2**2+1, 2**4-1, 2**5+1), 17, 4], # noqa
|
||||
],
|
||||
cross_product_configs={
|
||||
'device': ['cpu', 'cuda'],
|
||||
},
|
||||
tags=['multidim'],
|
||||
)
|
||||
|
||||
cat_configs_manyinputs = op_bench.config_list(
|
||||
attr_names=['sizes', 'N', 'dim'],
|
||||
attrs=[
|
||||
[[lambda: random.randint(1, 10000)], 100, 0],
|
||||
[[lambda: random.randint(1, 1000)], 1000, 0],
|
||||
[[lambda: random.randint(1, 500)], 2000, 0],
|
||||
[[lambda: random.randint(1, 300)], 3000, 0],
|
||||
],
|
||||
cross_product_configs={
|
||||
'device': ['cpu', 'cuda'],
|
||||
},
|
||||
tags=['manyinputs'],
|
||||
)
|
||||
|
||||
class CatBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, dim, device):
|
||||
self.input_one = torch.rand(M, N, K, device=device)
|
||||
def init(self, sizes, N, dim, device):
|
||||
random.seed(42)
|
||||
self.inputs = []
|
||||
for i in range(N):
|
||||
current_sizes = [old_size() if callable(old_size) else old_size
|
||||
for old_size in sizes]
|
||||
self.inputs.append(torch.rand(current_sizes, device=device))
|
||||
self.dim = dim
|
||||
self.set_module_name('cat')
|
||||
|
||||
def forward(self):
|
||||
return torch.cat((self.input_one, self.input_one), dim=self.dim)
|
||||
return torch.cat(self.inputs, dim=self.dim)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(cat_configs_short + cat_configs_long,
|
||||
op_bench.generate_pt_test(cat_configs_short +
|
||||
cat_configs_long +
|
||||
cat_configs_multidim +
|
||||
cat_configs_manyinputs,
|
||||
CatBenchmark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
Reference in New Issue
Block a user