[torch] Format repeat_interleave op files (#58313)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58313

Same as title.

I am planning to send a follow-up diff to this op, so sending formatting diff ahead to keep PR simple.

Test Plan: Rely on existing signals since this is simple formatting diff.

Reviewed By: ngimel

Differential Revision: D28447685

fbshipit-source-id: c7cd473b61e40e6f50178aca88b9af197a759099
This commit is contained in:
Serhat Yilmaz
2021-05-17 13:50:56 -07:00
committed by Facebook GitHub Bot
parent 06c1094ea0
commit d645088f2f
3 changed files with 71 additions and 40 deletions

View File

@ -1,9 +1,13 @@
#include <ATen/ATen.h>
#include <ATen/native/Repeat.h>
#include <ATen/Parallel.h>
#include <ATen/native/Repeat.h>
template <typename index_t>
static void compute_cpu(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *result_ptr, int64_t size) {
static void compute_cpu(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size) {
at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
for (int64_t i = i_begin; i < i_end; i++) {
int64_t end = cumsum_ptr[i];
@ -16,9 +20,10 @@ static void compute_cpu(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *resul
});
}
namespace at { namespace native {
namespace at {
namespace native {
Tensor repeat_interleave_cpu(const Tensor &repeat) {
Tensor repeat_interleave_cpu(const Tensor& repeat) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() {
output = repeat_interleave_common<index_t, compute_cpu<index_t>>(repeat);
@ -27,9 +32,12 @@ Tensor repeat_interleave_cpu(const Tensor &repeat) {
return output;
}
Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional<int64_t> dim) {
Tensor repeat_interleave(
const Tensor& self,
const Tensor& repeats,
c10::optional<int64_t> dim) {
Tensor input = self;
if(!dim) {
if (!dim) {
input = self.flatten();
dim = 0;
}
@ -48,8 +56,13 @@ Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optiona
return input.index_select(dim.value(), at::repeat_interleave(repeats_));
}
Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional<int64_t> dim) {
return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
Tensor repeat_interleave(
const Tensor& self,
int64_t repeats,
c10::optional<int64_t> dim) {
return at::native::repeat_interleave(
self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
}
}}
} // namespace native
} // namespace at

View File

@ -2,15 +2,18 @@
#include <ATen/ATen.h>
namespace at { namespace native {
namespace at {
namespace native {
template <typename index_t, void compute(index_t*, int64_t*, index_t*, int64_t)>
static inline Tensor repeat_interleave_common(const Tensor &repeats) {
TORCH_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
static inline Tensor repeat_interleave_common(const Tensor& repeats) {
TORCH_CHECK(
repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
TORCH_CHECK(
repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
"repeats has to be Long or Int tensor");
TORCH_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
TORCH_CHECK(
(repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
if (repeats.size(0) == 0) {
return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
@ -25,4 +28,5 @@ static inline Tensor repeat_interleave_common(const Tensor &repeats) {
return result;
}
}}
} // namespace native
} // namespace at

View File

@ -3,39 +3,53 @@
#include <ATen/native/Repeat.h>
template <typename index_t>
__global__ static void compute_cuda_kernel(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *result_ptr, int64_t size) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
int warp_id = idx / C10_WARP_SIZE;
int tid_in_warp = idx % C10_WARP_SIZE;
for (int64_t i = warp_id; i < size; i += stride) {
int64_t end = cumsum_ptr[i];
index_t repeat = repeat_ptr[i];
int64_t start = end - repeat;
for(int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) {
result_ptr[j] = i;
}
__global__ static void compute_cuda_kernel(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
int warp_id = idx / C10_WARP_SIZE;
int tid_in_warp = idx % C10_WARP_SIZE;
for (int64_t i = warp_id; i < size; i += stride) {
int64_t end = cumsum_ptr[i];
index_t repeat = repeat_ptr[i];
int64_t start = end - repeat;
for (int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) {
result_ptr[j] = i;
}
}
}
template <typename index_t>
static void compute_cuda(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *result_ptr, int64_t size) {
int64_t block = 512;
int64_t warps_per_block = block / C10_WARP_SIZE;
int64_t grid = std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
static void compute_cuda(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size) {
int64_t block = 512;
int64_t warps_per_block = block / C10_WARP_SIZE;
int64_t grid =
std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(repeat_ptr, cumsum_ptr, result_ptr, size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
repeat_ptr, cumsum_ptr, result_ptr, size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
namespace at { namespace native {
namespace at {
namespace native {
Tensor repeat_interleave_cuda(const Tensor &repeat) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cuda", [&]() {
output = repeat_interleave_common<index_t, compute_cuda<index_t>>(repeat);
});
return output;
Tensor repeat_interleave_cuda(const Tensor& repeat) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(
repeat.scalar_type(), "repeat_interleave_cuda", [&]() {
output =
repeat_interleave_common<index_t, compute_cuda<index_t>>(repeat);
});
return output;
}
}}
} // namespace native
} // namespace at