Revert D19825127: [pytorch][PR] Move where cuda implementation to TensorIterator

Test Plan: revert-hammer

Differential Revision:
D19825127

Original commit changeset: bbf4682349d9

fbshipit-source-id: 0c439b8c9a00a5aa46fd196396cf7cc83cddb1b4
This commit is contained in:
Saurabh Aggarwal
2020-02-11 19:46:48 -08:00
committed by Facebook Github Bot
parent 000a5e2b7f
commit 74c8a8f7bc
9 changed files with 457 additions and 308 deletions

View File

@ -6,12 +6,42 @@
#include <ATen/native/ReduceOpsUtils.h>
#include <c10/util/Exception.h>
#include <ATen/native/cpu/TensorCompareKernel.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/NamedTensorUtils.h>
namespace {
template <typename scalar_t>
void where_cpu(
at::Tensor& ret,
const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
auto iter = at::TensorIterator();
iter.set_check_mem_overlap(true);
iter.add_output(ret);
iter.add_input(condition);
iter.add_input(self);
iter.add_input(other);
iter.dont_compute_common_dtype();
iter.build();
if (condition.scalar_type() == at::ScalarType::Byte) {
at::native::cpu_kernel(
iter,
[=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
} else {
at::native::cpu_kernel(
iter,
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
}
}
} // namespace
namespace at { namespace native {
DEFINE_DISPATCH(where_kernel);
DEFINE_DISPATCH(max_kernel);
DEFINE_DISPATCH(min_kernel);
@ -118,18 +148,12 @@ std::vector<Tensor> where(const Tensor& condition) {
return condition.nonzero_numpy();
}
Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
Tensor ret = at::empty(self.sizes(), self.options());
auto iter = at::TensorIterator();
iter.set_check_mem_overlap(true);
iter.add_output(ret);
iter.add_input(condition);
iter.add_input(self);
iter.add_input(other);
iter.dont_compute_common_dtype();
iter.build();
where_kernel(iter.device_type(), iter, condition.scalar_type());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
where_cpu<scalar_t>(ret, condition, self, other);
});
return ret;
}

View File

@ -1,5 +1,4 @@
#include <ATen/native/cpu/TensorCompareKernel.h>
#include <ATen/native/cpu/Loops.h>
#include <numeric>
#include <iterator>
@ -102,28 +101,9 @@ static void min_kernel_impl(
});
}
static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "where_cpu", [&] {
if (condition_type == at::ScalarType::Byte) {
at::native::cpu_kernel(
iter,
[=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
} else {
at::native::cpu_kernel(
iter,
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
}
});
}
} // anonymous namespace
REGISTER_DISPATCH(max_kernel, &max_kernel_impl);
REGISTER_DISPATCH(min_kernel, &min_kernel_impl);
REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
}} // namespace at::native

View File

@ -3,7 +3,6 @@
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/Optional.h>
#include <ATen/native/TensorIterator.h>
namespace at { namespace native {
@ -13,7 +12,4 @@ using reduce_fn =
DECLARE_DISPATCH(reduce_fn, max_kernel);
DECLARE_DISPATCH(reduce_fn, min_kernel);
using where_fn = void (*)(TensorIterator &, ScalarType);
DECLARE_DISPATCH(where_fn, where_kernel);
}} // namespace at::native

View File

@ -255,7 +255,7 @@ __global__ void elementwise_kernel(int N, func_t f, array_t data) {
}
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
template<int nt, int vt, typename func_t, typename array_t>
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
@ -281,9 +281,191 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
AT_CUDA_CHECK(cudaGetLastError());
}
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
} // namespace modern
template<typename func_t, int nargs=function_traits<func_t>::arity>
struct needs_dynamic_casting {
static bool check(TensorIterator& iter) {
using traits = function_traits<func_t>;
if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
return true;
}
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
}
};
template<typename func_t>
struct needs_dynamic_casting<func_t, 0> {
static bool check(TensorIterator& iter) {
using traits = function_traits<func_t>;
return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
}
};
template <typename func_t>
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
at::detail::Array<ScalarType, ntensors> dtypes;
for (int i = 0; i < ntensors; i++) {
dtypes[i] = iter.tensor(i).scalar_type();
}
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
for (int i = 0; i < ntensors; i++) {
strides[i] = inner_strides[i];
}
if (needs_dynamic_casting<func_t>::check(iter)) {
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
void* out = data[0] + strides[0] * idx;
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else if (iter.has_contiguous_first_dim()) {
modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
} else {
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
});
}
} else {
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
if (needs_dynamic_casting<func_t>::check(iter)) {
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else {
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
}
}
template <typename func_t>
void gpu_kernel(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel(sub_iter, f);
}
return;
}
gpu_kernel_impl(iter, f);
}
template <typename func_t>
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
if (iter.is_cpu_scalar(1)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto a = iter.scalar_value<arg1_t>(1);
iter.remove_operand(1);
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
return f(a, b);
});
} else if (iter.is_cpu_scalar(2)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto b = iter.scalar_value<arg2_t>(2);
iter.remove_operand(2);
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
return f(a, b);
});
} else {
gpu_kernel(iter, f);
}
}
template <typename func_t>
void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
// Note:
// `gpu_kernel_with_index` was originally implemented in PR #28175 with support
// of having an arbitrary number of tensors as arguments. This support was removed
// during the process of refactoring Loops.cuh to support vectorized memory access
// in PR #32777 (See also issue #31975). The removal of this support is soly because
// at that time, there is no operator using that functionality. If you need this
// functionality, feel free to add it back.
static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
char* data = (char*)iter.data_ptr(0);
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
int stride = iter.get_inner_strides()[0];
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data + stride * idx);
*out = f(idx);
});
} else {
auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data + offsets[0]);
*out = f(idx);
});
}
}
template <typename func_t>
void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
if (iter.numel() == 0) {
return;
}
// Split will change index, thus is not supported
// The caller should handle the split and pass in different func
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
gpu_kernel_with_index_impl(iter, f);
}
}} // namespace at::native

View File

@ -1,31 +1,3 @@
#pragma once
#include <ATen/detail/FunctionTraits.h>
namespace at { namespace native { namespace modern { namespace detail {
template<typename func_t, int remaining=function_traits<func_t>::arity-1>
struct has_same_arg_types {
using traits = function_traits<func_t>;
static constexpr bool value = std::is_same<
typename traits::template arg<remaining>::type,
typename traits::template arg<remaining-1>::type
>::value && has_same_arg_types<func_t, remaining-1>::value;
};
template<typename func_t>
struct has_same_arg_types<func_t, 0> {
static constexpr bool value = true;
};
template<typename func_t>
struct has_same_arg_types<func_t, -1> {
static constexpr bool value = true;
};
}}}} // namespace at::native::modern::detail
// Note:
// CUDA and ROCm get diverged in this PR:
// https://github.com/pytorch/pytorch/pull/32383
@ -37,195 +9,3 @@ struct has_same_arg_types<func_t, -1> {
#else
#include <ATen/native/cuda/ROCmLoops.cuh>
#endif
namespace at { namespace native {
// `needs_dynamic_casting` compares the types expected by iterator
// (i.e. dtypes of the operands) with the actual type of the arguments
// of func_t
template<typename func_t, int nargs=function_traits<func_t>::arity>
struct needs_dynamic_casting {
static bool check(TensorIterator& iter) {
using traits = function_traits<func_t>;
if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
return true;
}
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
}
};
template<typename func_t>
struct needs_dynamic_casting<func_t, 0> {
static bool check(TensorIterator& iter) {
using traits = function_traits<func_t>;
return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
}
};
template <typename func_t>
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
at::detail::Array<ScalarType, ntensors> dtypes;
for (int i = 0; i < ntensors; i++) {
dtypes[i] = iter.tensor(i).scalar_type();
}
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
for (int i = 0; i < ntensors; i++) {
strides[i] = inner_strides[i];
}
if (needs_dynamic_casting<func_t>::check(iter)) {
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
void* out = data[0] + strides[0] * idx;
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
} else {
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
});
}
} else {
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
if (needs_dynamic_casting<func_t>::check(iter)) {
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else {
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
}
}
template <typename func_t>
void gpu_kernel(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel(sub_iter, f);
}
return;
}
gpu_kernel_impl(iter, f);
}
template <typename func_t>
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
if (iter.is_cpu_scalar(1)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto a = iter.scalar_value<arg1_t>(1);
iter.remove_operand(1);
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
return f(a, b);
});
} else if (iter.is_cpu_scalar(2)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto b = iter.scalar_value<arg2_t>(2);
iter.remove_operand(2);
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
return f(a, b);
});
} else {
gpu_kernel(iter, f);
}
}
template <typename func_t>
void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
// Note:
// `gpu_kernel_with_index` was originally implemented in PR #28175 with support
// of having an arbitrary number of tensors as arguments. This support was removed
// during the process of refactoring Loops.cuh to support vectorized memory access
// in PR #32777 (See also issue #31975). The removal of this support is soly because
// at that time, there is no operator using that functionality. If you need this
// functionality, feel free to add it back.
static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
char* data = (char*)iter.data_ptr(0);
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
int stride = iter.get_inner_strides()[0];
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data + stride * idx);
*out = f(idx);
});
} else {
auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data + offsets[0]);
*out = f(idx);
});
}
}
template <typename func_t>
void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
if (iter.numel() == 0) {
return;
}
// Split will change index, thus is not supported
// The caller should handle the split and pass in different func
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
gpu_kernel_with_index_impl(iter, f);
}
}} //namespace at::native

View File

@ -240,7 +240,7 @@ __global__ void elementwise_kernel(int N, func_t f, array_t data) {
}
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
template<int nt, int vt, typename func_t, typename array_t>
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
@ -253,9 +253,191 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
AT_CUDA_CHECK(cudaGetLastError());
}
template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
} // namespace modern
template<typename func_t, int nargs=function_traits<func_t>::arity>
struct needs_dynamic_casting {
static bool check(TensorIterator& iter) {
using traits = function_traits<func_t>;
if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
return true;
}
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
}
};
template<typename func_t>
struct needs_dynamic_casting<func_t, 0> {
static bool check(TensorIterator& iter) {
using traits = function_traits<func_t>;
return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
}
};
template <typename func_t>
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
at::detail::Array<ScalarType, ntensors> dtypes;
for (int i = 0; i < ntensors; i++) {
dtypes[i] = iter.tensor(i).scalar_type();
}
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
for (int i = 0; i < ntensors; i++) {
strides[i] = inner_strides[i];
}
if (needs_dynamic_casting<func_t>::check(iter)) {
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
void* out = data[0] + strides[0] * idx;
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else if (iter.has_contiguous_first_dim()) {
modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
} else {
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
});
}
} else {
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
if (needs_dynamic_casting<func_t>::check(iter)) {
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else {
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
}
}
template <typename func_t>
void gpu_kernel(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel(sub_iter, f);
}
return;
}
gpu_kernel_impl(iter, f);
}
template <typename func_t>
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
if (iter.is_cpu_scalar(1)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto a = iter.scalar_value<arg1_t>(1);
iter.remove_operand(1);
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
return f(a, b);
});
} else if (iter.is_cpu_scalar(2)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto b = iter.scalar_value<arg2_t>(2);
iter.remove_operand(2);
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
return f(a, b);
});
} else {
gpu_kernel(iter, f);
}
}
template <typename func_t>
void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
// Note:
// `gpu_kernel_with_index` was originally implemented in PR #28175 with support
// of having an arbitrary number of tensors as arguments. This support was removed
// during the process of refactoring Loops.cuh to support vectorized memory access
// in PR #32777 (See also issue #31975). The removal of this support is soly because
// at that time, there is no operator using that functionality. If you need this
// functionality, feel free to add it back.
static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
char* data = (char*)iter.data_ptr(0);
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
int stride = iter.get_inner_strides()[0];
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data + stride * idx);
*out = f(idx);
});
} else {
auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data + offsets[0]);
*out = f(idx);
});
}
}
template <typename func_t>
void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
if (iter.numel() == 0) {
return;
}
// Split will change index, thus is not supported
// The caller should handle the split and pass in different func
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
gpu_kernel_with_index_impl(iter, f);
}
}} // namespace at::native

View File

@ -1,38 +1,56 @@
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/cuda/CUDAApplyUtils.cuh>
namespace {
template <typename scalar_t>
void where_cuda(
at::Tensor& ret,
const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
if (condition.scalar_type() == at::ScalarType::Byte) {
// Yes this name is repetitive, but the CPU version is called
// CPU_tensor_apply4 and we don't have a CPU namespace or directory.
at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
ret,
condition,
self,
other,
[] __device__(
scalar_t & ret_val,
const uint8_t& cond_val,
const scalar_t& self_val,
const scalar_t& other_val) {
ret_val = cond_val ? self_val : other_val;
});
} else {
at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
ret,
condition,
self,
other,
[] __device__(
scalar_t & ret_val,
const bool& cond_val,
const scalar_t& self_val,
const scalar_t& other_val) {
ret_val = cond_val ? self_val : other_val;
});
}
}
} // namespace
namespace at { namespace native {
using where_fn = void (*)(TensorIterator &, ScalarType);
DECLARE_DISPATCH(where_fn, where_kernel);
namespace {
void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] {
if (condition_type == at::ScalarType::Byte) {
gpu_kernel(
iter,
[=] GPU_LAMBDA (uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
} else {
gpu_kernel(
iter,
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
}
Tensor _s_where_cuda(
const Tensor& condition,
const Tensor& self,
const Tensor& other) {
Tensor ret = at::empty(self.sizes(), self.options());
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
where_cuda<scalar_t>(ret, condition, self, other);
});
return ret;
}
} // anonymous namespace
REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
}} // namespace at::native

View File

@ -2942,6 +2942,9 @@
- func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
CPU: _s_where_cpu
CUDA: _s_where_cuda
- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
variants: function

View File

@ -1,7 +1,6 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/cuda/CUDAContext.h>
using namespace at::native::memory;
@ -22,21 +21,6 @@ void reset_buffers() {
}
}
TEST(TestLoops, HasSameArgTypes) {
// This is a compile-time unit test. If this file compiles without error,
// then the test passes and during runtime, we just need to return.
using namespace at::native::modern::detail;
using func1_t = int (*)(float, float);
using func2_t = int (*)(bool, float, float);
using func3_t = int (*)(float);
using func4_t = int (*)();
static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
static_assert(!has_same_arg_types<func2_t>::value, "func2_t does not have the same argument types");
static_assert(has_same_arg_types<func3_t>::value, "func3_t has the same argument types");
static_assert(has_same_arg_types<func4_t>::value, "func4_t has the same argument types");
return;
}
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast<char *>(buffer1);