mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torch] Format repeat_interleave op files (#58313)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58313 Same as title. I am planning to send a follow-up diff to this op, so sending formatting diff ahead to keep PR simple. Test Plan: Rely on existing signals since this is simple formatting diff. Reviewed By: ngimel Differential Revision: D28447685 fbshipit-source-id: c7cd473b61e40e6f50178aca88b9af197a759099
This commit is contained in:
committed by
Facebook GitHub Bot
parent
06c1094ea0
commit
d645088f2f
@ -1,9 +1,13 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/Repeat.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/native/Repeat.h>
|
||||
|
||||
template <typename index_t>
|
||||
static void compute_cpu(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *result_ptr, int64_t size) {
|
||||
static void compute_cpu(
|
||||
index_t* repeat_ptr,
|
||||
int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size) {
|
||||
at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
|
||||
for (int64_t i = i_begin; i < i_end; i++) {
|
||||
int64_t end = cumsum_ptr[i];
|
||||
@ -16,9 +20,10 @@ static void compute_cpu(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *resul
|
||||
});
|
||||
}
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor repeat_interleave_cpu(const Tensor &repeat) {
|
||||
Tensor repeat_interleave_cpu(const Tensor& repeat) {
|
||||
Tensor output;
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() {
|
||||
output = repeat_interleave_common<index_t, compute_cpu<index_t>>(repeat);
|
||||
@ -27,9 +32,12 @@ Tensor repeat_interleave_cpu(const Tensor &repeat) {
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional<int64_t> dim) {
|
||||
Tensor repeat_interleave(
|
||||
const Tensor& self,
|
||||
const Tensor& repeats,
|
||||
c10::optional<int64_t> dim) {
|
||||
Tensor input = self;
|
||||
if(!dim) {
|
||||
if (!dim) {
|
||||
input = self.flatten();
|
||||
dim = 0;
|
||||
}
|
||||
@ -48,8 +56,13 @@ Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optiona
|
||||
return input.index_select(dim.value(), at::repeat_interleave(repeats_));
|
||||
}
|
||||
|
||||
Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional<int64_t> dim) {
|
||||
return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
|
||||
Tensor repeat_interleave(
|
||||
const Tensor& self,
|
||||
int64_t repeats,
|
||||
c10::optional<int64_t> dim) {
|
||||
return at::native::repeat_interleave(
|
||||
self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
|
||||
}
|
||||
|
||||
}}
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -2,15 +2,18 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
template <typename index_t, void compute(index_t*, int64_t*, index_t*, int64_t)>
|
||||
static inline Tensor repeat_interleave_common(const Tensor &repeats) {
|
||||
TORCH_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
||||
static inline Tensor repeat_interleave_common(const Tensor& repeats) {
|
||||
TORCH_CHECK(
|
||||
repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
||||
TORCH_CHECK(
|
||||
repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
|
||||
"repeats has to be Long or Int tensor");
|
||||
TORCH_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
||||
TORCH_CHECK(
|
||||
(repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
||||
if (repeats.size(0) == 0) {
|
||||
return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
}
|
||||
@ -25,4 +28,5 @@ static inline Tensor repeat_interleave_common(const Tensor &repeats) {
|
||||
return result;
|
||||
}
|
||||
|
||||
}}
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -3,39 +3,53 @@
|
||||
#include <ATen/native/Repeat.h>
|
||||
|
||||
template <typename index_t>
|
||||
__global__ static void compute_cuda_kernel(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *result_ptr, int64_t size) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
|
||||
int warp_id = idx / C10_WARP_SIZE;
|
||||
int tid_in_warp = idx % C10_WARP_SIZE;
|
||||
for (int64_t i = warp_id; i < size; i += stride) {
|
||||
int64_t end = cumsum_ptr[i];
|
||||
index_t repeat = repeat_ptr[i];
|
||||
int64_t start = end - repeat;
|
||||
for(int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) {
|
||||
result_ptr[j] = i;
|
||||
}
|
||||
__global__ static void compute_cuda_kernel(
|
||||
index_t* repeat_ptr,
|
||||
int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
|
||||
int warp_id = idx / C10_WARP_SIZE;
|
||||
int tid_in_warp = idx % C10_WARP_SIZE;
|
||||
for (int64_t i = warp_id; i < size; i += stride) {
|
||||
int64_t end = cumsum_ptr[i];
|
||||
index_t repeat = repeat_ptr[i];
|
||||
int64_t start = end - repeat;
|
||||
for (int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) {
|
||||
result_ptr[j] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
static void compute_cuda(index_t *repeat_ptr, int64_t *cumsum_ptr, index_t *result_ptr, int64_t size) {
|
||||
int64_t block = 512;
|
||||
int64_t warps_per_block = block / C10_WARP_SIZE;
|
||||
int64_t grid = std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
|
||||
static void compute_cuda(
|
||||
index_t* repeat_ptr,
|
||||
int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size) {
|
||||
int64_t block = 512;
|
||||
int64_t warps_per_block = block / C10_WARP_SIZE;
|
||||
int64_t grid =
|
||||
std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
|
||||
|
||||
compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(repeat_ptr, cumsum_ptr, result_ptr, size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
repeat_ptr, cumsum_ptr, result_ptr, size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor repeat_interleave_cuda(const Tensor &repeat) {
|
||||
Tensor output;
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cuda", [&]() {
|
||||
output = repeat_interleave_common<index_t, compute_cuda<index_t>>(repeat);
|
||||
});
|
||||
return output;
|
||||
Tensor repeat_interleave_cuda(const Tensor& repeat) {
|
||||
Tensor output;
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
repeat.scalar_type(), "repeat_interleave_cuda", [&]() {
|
||||
output =
|
||||
repeat_interleave_common<index_t, compute_cuda<index_t>>(repeat);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
}}
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
Reference in New Issue
Block a user