mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D19825127: [pytorch][PR] Move where cuda implementation to TensorIterator
Test Plan: revert-hammer Differential Revision: D19825127 Original commit changeset: bbf4682349d9 fbshipit-source-id: 0c439b8c9a00a5aa46fd196396cf7cc83cddb1b4
This commit is contained in:
committed by
Facebook Github Bot
parent
000a5e2b7f
commit
74c8a8f7bc
@ -6,12 +6,42 @@
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/native/cpu/TensorCompareKernel.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void where_cpu(
|
||||
at::Tensor& ret,
|
||||
const at::Tensor& condition,
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& other) {
|
||||
auto iter = at::TensorIterator();
|
||||
iter.set_check_mem_overlap(true);
|
||||
iter.add_output(ret);
|
||||
iter.add_input(condition);
|
||||
iter.add_input(self);
|
||||
iter.add_input(other);
|
||||
iter.dont_compute_common_dtype();
|
||||
iter.build();
|
||||
if (condition.scalar_type() == at::ScalarType::Byte) {
|
||||
at::native::cpu_kernel(
|
||||
iter,
|
||||
[=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
} else {
|
||||
at::native::cpu_kernel(
|
||||
iter,
|
||||
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
DEFINE_DISPATCH(where_kernel);
|
||||
DEFINE_DISPATCH(max_kernel);
|
||||
DEFINE_DISPATCH(min_kernel);
|
||||
|
||||
@ -118,18 +148,12 @@ std::vector<Tensor> where(const Tensor& condition) {
|
||||
return condition.nonzero_numpy();
|
||||
}
|
||||
|
||||
Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
|
||||
Tensor ret = at::empty(self.sizes(), self.options());
|
||||
auto iter = at::TensorIterator();
|
||||
iter.set_check_mem_overlap(true);
|
||||
iter.add_output(ret);
|
||||
iter.add_input(condition);
|
||||
iter.add_input(self);
|
||||
iter.add_input(other);
|
||||
iter.dont_compute_common_dtype();
|
||||
iter.build();
|
||||
where_kernel(iter.device_type(), iter, condition.scalar_type());
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
|
||||
where_cpu<scalar_t>(ret, condition, self, other);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <ATen/native/cpu/TensorCompareKernel.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
@ -102,28 +101,9 @@ static void min_kernel_impl(
|
||||
});
|
||||
}
|
||||
|
||||
static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "where_cpu", [&] {
|
||||
if (condition_type == at::ScalarType::Byte) {
|
||||
at::native::cpu_kernel(
|
||||
iter,
|
||||
[=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
} else {
|
||||
at::native::cpu_kernel(
|
||||
iter,
|
||||
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(max_kernel, &max_kernel_impl);
|
||||
REGISTER_DISPATCH(min_kernel, &min_kernel_impl);
|
||||
REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -13,7 +12,4 @@ using reduce_fn =
|
||||
DECLARE_DISPATCH(reduce_fn, max_kernel);
|
||||
DECLARE_DISPATCH(reduce_fn, min_kernel);
|
||||
|
||||
using where_fn = void (*)(TensorIterator &, ScalarType);
|
||||
DECLARE_DISPATCH(where_fn, where_kernel);
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -255,7 +255,7 @@ __global__ void elementwise_kernel(int N, func_t f, array_t data) {
|
||||
}
|
||||
|
||||
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
|
||||
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
|
||||
template<int nt, int vt, typename func_t, typename array_t>
|
||||
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
|
||||
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
|
||||
if (N == 0) {
|
||||
@ -281,9 +281,191 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
|
||||
static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
|
||||
|
||||
} // namespace modern
|
||||
|
||||
template<typename func_t, int nargs=function_traits<func_t>::arity>
|
||||
struct needs_dynamic_casting {
|
||||
static bool check(TensorIterator& iter) {
|
||||
using traits = function_traits<func_t>;
|
||||
if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
|
||||
return true;
|
||||
}
|
||||
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename func_t>
|
||||
struct needs_dynamic_casting<func_t, 0> {
|
||||
static bool check(TensorIterator& iter) {
|
||||
using traits = function_traits<func_t>;
|
||||
return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
constexpr int ntensors = traits::arity + 1;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
|
||||
|
||||
at::detail::Array<char*, ntensors> data;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
data[i] = (char*)iter.data_ptr(i);
|
||||
}
|
||||
|
||||
at::detail::Array<ScalarType, ntensors> dtypes;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
dtypes[i] = iter.tensor(i).scalar_type();
|
||||
}
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
if (iter.is_trivial_1d()) {
|
||||
auto inner_strides = iter.get_inner_strides();
|
||||
at::detail::Array<int, ntensors> strides;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
strides[i] = inner_strides[i];
|
||||
}
|
||||
|
||||
if (needs_dynamic_casting<func_t>::check(iter)) {
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
void* out = data[0] + strides[0] * idx;
|
||||
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
} else if (iter.has_contiguous_first_dim()) {
|
||||
modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
|
||||
} else {
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
|
||||
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
|
||||
if (needs_dynamic_casting<func_t>::check(iter)) {
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
} else {
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
|
||||
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
|
||||
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
||||
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
|
||||
}
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!iter.can_use_32bit_indexing()) {
|
||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||
gpu_kernel(sub_iter, f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
gpu_kernel_impl(iter, f);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
|
||||
|
||||
using traits = function_traits<func_t>;
|
||||
static_assert(
|
||||
traits::arity == 2,
|
||||
"gpu_kernel_with_scalars only supports two input arguments");
|
||||
|
||||
if (iter.is_cpu_scalar(1)) {
|
||||
using arg1_t = typename traits::template arg<0>::type;
|
||||
using arg2_t = typename traits::template arg<1>::type;
|
||||
auto a = iter.scalar_value<arg1_t>(1);
|
||||
iter.remove_operand(1);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
|
||||
return f(a, b);
|
||||
});
|
||||
} else if (iter.is_cpu_scalar(2)) {
|
||||
using arg1_t = typename traits::template arg<0>::type;
|
||||
using arg2_t = typename traits::template arg<1>::type;
|
||||
auto b = iter.scalar_value<arg2_t>(2);
|
||||
iter.remove_operand(2);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
|
||||
return f(a, b);
|
||||
});
|
||||
} else {
|
||||
gpu_kernel(iter, f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
|
||||
|
||||
// Note:
|
||||
// `gpu_kernel_with_index` was originally implemented in PR #28175 with support
|
||||
// of having an arbitrary number of tensors as arguments. This support was removed
|
||||
// during the process of refactoring Loops.cuh to support vectorized memory access
|
||||
// in PR #32777 (See also issue #31975). The removal of this support is soly because
|
||||
// at that time, there is no operator using that functionality. If you need this
|
||||
// functionality, feel free to add it back.
|
||||
static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
|
||||
|
||||
char* data = (char*)iter.data_ptr(0);
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
if (iter.is_trivial_1d()) {
|
||||
int stride = iter.get_inner_strides()[0];
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
arg0_t* out = (arg0_t*)(data + stride * idx);
|
||||
*out = f(idx);
|
||||
});
|
||||
} else {
|
||||
auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data + offsets[0]);
|
||||
*out = f(idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Split will change index, thus is not supported
|
||||
// The caller should handle the split and pass in different func
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
|
||||
|
||||
gpu_kernel_with_index_impl(iter, f);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -1,31 +1,3 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
|
||||
namespace at { namespace native { namespace modern { namespace detail {
|
||||
|
||||
template<typename func_t, int remaining=function_traits<func_t>::arity-1>
|
||||
struct has_same_arg_types {
|
||||
using traits = function_traits<func_t>;
|
||||
static constexpr bool value = std::is_same<
|
||||
typename traits::template arg<remaining>::type,
|
||||
typename traits::template arg<remaining-1>::type
|
||||
>::value && has_same_arg_types<func_t, remaining-1>::value;
|
||||
};
|
||||
|
||||
template<typename func_t>
|
||||
struct has_same_arg_types<func_t, 0> {
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template<typename func_t>
|
||||
struct has_same_arg_types<func_t, -1> {
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
}}}} // namespace at::native::modern::detail
|
||||
|
||||
// Note:
|
||||
// CUDA and ROCm get diverged in this PR:
|
||||
// https://github.com/pytorch/pytorch/pull/32383
|
||||
@ -37,195 +9,3 @@ struct has_same_arg_types<func_t, -1> {
|
||||
#else
|
||||
#include <ATen/native/cuda/ROCmLoops.cuh>
|
||||
#endif
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// `needs_dynamic_casting` compares the types expected by iterator
|
||||
// (i.e. dtypes of the operands) with the actual type of the arguments
|
||||
// of func_t
|
||||
template<typename func_t, int nargs=function_traits<func_t>::arity>
|
||||
struct needs_dynamic_casting {
|
||||
static bool check(TensorIterator& iter) {
|
||||
using traits = function_traits<func_t>;
|
||||
if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
|
||||
return true;
|
||||
}
|
||||
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename func_t>
|
||||
struct needs_dynamic_casting<func_t, 0> {
|
||||
static bool check(TensorIterator& iter) {
|
||||
using traits = function_traits<func_t>;
|
||||
return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
constexpr int ntensors = traits::arity + 1;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
|
||||
|
||||
at::detail::Array<char*, ntensors> data;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
data[i] = (char*)iter.data_ptr(i);
|
||||
}
|
||||
|
||||
at::detail::Array<ScalarType, ntensors> dtypes;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
dtypes[i] = iter.tensor(i).scalar_type();
|
||||
}
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
if (iter.is_trivial_1d()) {
|
||||
auto inner_strides = iter.get_inner_strides();
|
||||
at::detail::Array<int, ntensors> strides;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
strides[i] = inner_strides[i];
|
||||
}
|
||||
|
||||
if (needs_dynamic_casting<func_t>::check(iter)) {
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
void* out = data[0] + strides[0] * idx;
|
||||
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
} else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
|
||||
modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
|
||||
} else {
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
|
||||
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
|
||||
if (needs_dynamic_casting<func_t>::check(iter)) {
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
} else {
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
|
||||
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
|
||||
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
||||
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
|
||||
}
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!iter.can_use_32bit_indexing()) {
|
||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||
gpu_kernel(sub_iter, f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
gpu_kernel_impl(iter, f);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
|
||||
|
||||
using traits = function_traits<func_t>;
|
||||
static_assert(
|
||||
traits::arity == 2,
|
||||
"gpu_kernel_with_scalars only supports two input arguments");
|
||||
|
||||
if (iter.is_cpu_scalar(1)) {
|
||||
using arg1_t = typename traits::template arg<0>::type;
|
||||
using arg2_t = typename traits::template arg<1>::type;
|
||||
auto a = iter.scalar_value<arg1_t>(1);
|
||||
iter.remove_operand(1);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
|
||||
return f(a, b);
|
||||
});
|
||||
} else if (iter.is_cpu_scalar(2)) {
|
||||
using arg1_t = typename traits::template arg<0>::type;
|
||||
using arg2_t = typename traits::template arg<1>::type;
|
||||
auto b = iter.scalar_value<arg2_t>(2);
|
||||
iter.remove_operand(2);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
|
||||
return f(a, b);
|
||||
});
|
||||
} else {
|
||||
gpu_kernel(iter, f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
|
||||
|
||||
// Note:
|
||||
// `gpu_kernel_with_index` was originally implemented in PR #28175 with support
|
||||
// of having an arbitrary number of tensors as arguments. This support was removed
|
||||
// during the process of refactoring Loops.cuh to support vectorized memory access
|
||||
// in PR #32777 (See also issue #31975). The removal of this support is soly because
|
||||
// at that time, there is no operator using that functionality. If you need this
|
||||
// functionality, feel free to add it back.
|
||||
static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
|
||||
|
||||
char* data = (char*)iter.data_ptr(0);
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
if (iter.is_trivial_1d()) {
|
||||
int stride = iter.get_inner_strides()[0];
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
arg0_t* out = (arg0_t*)(data + stride * idx);
|
||||
*out = f(idx);
|
||||
});
|
||||
} else {
|
||||
auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data + offsets[0]);
|
||||
*out = f(idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Split will change index, thus is not supported
|
||||
// The caller should handle the split and pass in different func
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
|
||||
|
||||
gpu_kernel_with_index_impl(iter, f);
|
||||
}
|
||||
|
||||
}} //namespace at::native
|
@ -240,7 +240,7 @@ __global__ void elementwise_kernel(int N, func_t f, array_t data) {
|
||||
}
|
||||
|
||||
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
|
||||
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
|
||||
template<int nt, int vt, typename func_t, typename array_t>
|
||||
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
|
||||
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
|
||||
if (N == 0) {
|
||||
@ -253,9 +253,191 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
|
||||
static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
|
||||
|
||||
} // namespace modern
|
||||
|
||||
template<typename func_t, int nargs=function_traits<func_t>::arity>
|
||||
struct needs_dynamic_casting {
|
||||
static bool check(TensorIterator& iter) {
|
||||
using traits = function_traits<func_t>;
|
||||
if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
|
||||
return true;
|
||||
}
|
||||
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename func_t>
|
||||
struct needs_dynamic_casting<func_t, 0> {
|
||||
static bool check(TensorIterator& iter) {
|
||||
using traits = function_traits<func_t>;
|
||||
return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
constexpr int ntensors = traits::arity + 1;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
|
||||
|
||||
at::detail::Array<char*, ntensors> data;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
data[i] = (char*)iter.data_ptr(i);
|
||||
}
|
||||
|
||||
at::detail::Array<ScalarType, ntensors> dtypes;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
dtypes[i] = iter.tensor(i).scalar_type();
|
||||
}
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
if (iter.is_trivial_1d()) {
|
||||
auto inner_strides = iter.get_inner_strides();
|
||||
at::detail::Array<int, ntensors> strides;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
strides[i] = inner_strides[i];
|
||||
}
|
||||
|
||||
if (needs_dynamic_casting<func_t>::check(iter)) {
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
void* out = data[0] + strides[0] * idx;
|
||||
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
} else if (iter.has_contiguous_first_dim()) {
|
||||
modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
|
||||
} else {
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
|
||||
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
|
||||
if (needs_dynamic_casting<func_t>::check(iter)) {
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
} else {
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
|
||||
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
|
||||
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
||||
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
|
||||
}
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!iter.can_use_32bit_indexing()) {
|
||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||
gpu_kernel(sub_iter, f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
gpu_kernel_impl(iter, f);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
|
||||
|
||||
using traits = function_traits<func_t>;
|
||||
static_assert(
|
||||
traits::arity == 2,
|
||||
"gpu_kernel_with_scalars only supports two input arguments");
|
||||
|
||||
if (iter.is_cpu_scalar(1)) {
|
||||
using arg1_t = typename traits::template arg<0>::type;
|
||||
using arg2_t = typename traits::template arg<1>::type;
|
||||
auto a = iter.scalar_value<arg1_t>(1);
|
||||
iter.remove_operand(1);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
|
||||
return f(a, b);
|
||||
});
|
||||
} else if (iter.is_cpu_scalar(2)) {
|
||||
using arg1_t = typename traits::template arg<0>::type;
|
||||
using arg2_t = typename traits::template arg<1>::type;
|
||||
auto b = iter.scalar_value<arg2_t>(2);
|
||||
iter.remove_operand(2);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
|
||||
return f(a, b);
|
||||
});
|
||||
} else {
|
||||
gpu_kernel(iter, f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
|
||||
|
||||
// Note:
|
||||
// `gpu_kernel_with_index` was originally implemented in PR #28175 with support
|
||||
// of having an arbitrary number of tensors as arguments. This support was removed
|
||||
// during the process of refactoring Loops.cuh to support vectorized memory access
|
||||
// in PR #32777 (See also issue #31975). The removal of this support is soly because
|
||||
// at that time, there is no operator using that functionality. If you need this
|
||||
// functionality, feel free to add it back.
|
||||
static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
|
||||
|
||||
char* data = (char*)iter.data_ptr(0);
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
if (iter.is_trivial_1d()) {
|
||||
int stride = iter.get_inner_strides()[0];
|
||||
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
arg0_t* out = (arg0_t*)(data + stride * idx);
|
||||
*out = f(idx);
|
||||
});
|
||||
} else {
|
||||
auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
|
||||
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data + offsets[0]);
|
||||
*out = f(idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
|
||||
ASSERT_HOST_DEVICE_LAMBDA(func_t);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Split will change index, thus is not supported
|
||||
// The caller should handle the split and pass in different func
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
|
||||
|
||||
gpu_kernel_with_index_impl(iter, f);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -1,38 +1,56 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void where_cuda(
|
||||
at::Tensor& ret,
|
||||
const at::Tensor& condition,
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& other) {
|
||||
if (condition.scalar_type() == at::ScalarType::Byte) {
|
||||
// Yes this name is repetitive, but the CPU version is called
|
||||
// CPU_tensor_apply4 and we don't have a CPU namespace or directory.
|
||||
at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
|
||||
ret,
|
||||
condition,
|
||||
self,
|
||||
other,
|
||||
[] __device__(
|
||||
scalar_t & ret_val,
|
||||
const uint8_t& cond_val,
|
||||
const scalar_t& self_val,
|
||||
const scalar_t& other_val) {
|
||||
ret_val = cond_val ? self_val : other_val;
|
||||
});
|
||||
} else {
|
||||
at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
|
||||
ret,
|
||||
condition,
|
||||
self,
|
||||
other,
|
||||
[] __device__(
|
||||
scalar_t & ret_val,
|
||||
const bool& cond_val,
|
||||
const scalar_t& self_val,
|
||||
const scalar_t& other_val) {
|
||||
ret_val = cond_val ? self_val : other_val;
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
using where_fn = void (*)(TensorIterator &, ScalarType);
|
||||
DECLARE_DISPATCH(where_fn, where_kernel);
|
||||
|
||||
namespace {
|
||||
|
||||
void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] {
|
||||
if (condition_type == at::ScalarType::Byte) {
|
||||
gpu_kernel(
|
||||
iter,
|
||||
[=] GPU_LAMBDA (uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
} else {
|
||||
gpu_kernel(
|
||||
iter,
|
||||
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
}
|
||||
Tensor _s_where_cuda(
|
||||
const Tensor& condition,
|
||||
const Tensor& self,
|
||||
const Tensor& other) {
|
||||
Tensor ret = at::empty(self.sizes(), self.options());
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
|
||||
where_cuda<scalar_t>(ret, condition, self, other);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -2942,6 +2942,9 @@
|
||||
- func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _s_where_cpu
|
||||
CUDA: _s_where_cuda
|
||||
|
||||
- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
|
||||
variants: function
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
using namespace at::native::memory;
|
||||
@ -22,21 +21,6 @@ void reset_buffers() {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestLoops, HasSameArgTypes) {
|
||||
// This is a compile-time unit test. If this file compiles without error,
|
||||
// then the test passes and during runtime, we just need to return.
|
||||
using namespace at::native::modern::detail;
|
||||
using func1_t = int (*)(float, float);
|
||||
using func2_t = int (*)(bool, float, float);
|
||||
using func3_t = int (*)(float);
|
||||
using func4_t = int (*)();
|
||||
static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
|
||||
static_assert(!has_same_arg_types<func2_t>::value, "func2_t does not have the same argument types");
|
||||
static_assert(has_same_arg_types<func3_t>::value, "func3_t has the same argument types");
|
||||
static_assert(has_same_arg_types<func4_t>::value, "func4_t has the same argument types");
|
||||
return;
|
||||
}
|
||||
|
||||
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
|
||||
char *ptr = reinterpret_cast<char *>(buffer1);
|
||||
|
||||
|
Reference in New Issue
Block a user