[reland][chalf] where(cpu and cuda), pow(cuda) (#78665)

Reland: https://github.com/pytorch/pytorch/pull/77640
Ref: #74537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78665
Approved by: https://github.com/ngimel
This commit is contained in:
Kshiteej K
2022-06-02 18:04:06 +00:00
committed by PyTorch MergeBot
parent d578197747
commit 849b08f14b
7 changed files with 137 additions and 34 deletions

View File

@ -205,7 +205,7 @@ static void aminmax_kernel(
}
static void where_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
iter.dtype(), "where_cpu", [&] {
cpu_kernel(
iter,

View File

@ -2,6 +2,7 @@
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/Pow.h>
@ -83,9 +84,69 @@ void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<value_t> base
});
}
/* complex<Half> support impl */
const char pow_scalar_base_name[] = "pow_scalar_base_kernel";
template <>
void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<at::Half> base) {
using scalar_t = c10::complex<at::Half>;
using opmath_t = at::opmath_type<scalar_t>;
// For complex, thrust::pow uses the identity
// pow(a, b) = exp(log(a) * b)
const auto fct = std::log(opmath_t{base});
#if AT_USE_JITERATOR()
static const auto pow_kernel_string =
jiterator_stringify(template <typename T> T pow_scalar_base_kernel(T exp, T fct) {
return std::exp(fct * exp);
});
jitted_gpu_kernel<pow_scalar_base_name, scalar_t, scalar_t, 1>(
iter,
pow_kernel_string,
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
/*scalar_val=*/0,
/*extra_args=*/std::make_tuple(fct));
#else
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t {
return std::exp(fct * opmath_t{exp});
});
#endif
}
namespace {
#if AT_USE_JITERATOR()
/* complex<Half> support impl */
const char pow_name[] = "pow_kernel";
static const auto pow_kernel_string =
jiterator_stringify(template <typename T> T pow_kernel(T base, T exp) {
return std::pow(base, exp);
});
#endif
/* complex<Half> support impl */
void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) {
using scalar_t = c10::complex<at::Half>;
using opmath_t = at::opmath_type<scalar_t>;
auto exp = exp_scalar.to<opmath_t>();
#if AT_USE_JITERATOR()
jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 1>(
iter,
pow_kernel_string,
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
/*scalar_val=*/0,
/*extra_args=*/std::make_tuple(exp));
#else
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t {
return std::pow(opmath_t{base}, exp);
});
#endif
}
} // anonymous namespace
void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] {
auto common_dtype = iter.common_dtype();
if (common_dtype == kComplexHalf) {
using scalar_t = c10::complex<at::Half>;
if (iter.is_cpu_scalar(1)) {
const auto base = iter.scalar_value<scalar_t>(1);
iter.remove_operand(1);
@ -93,13 +154,38 @@ void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
} else if (iter.is_cpu_scalar(2)) {
const auto exp = iter.scalar_value<scalar_t>(2);
iter.remove_operand(2);
pow_tensor_scalar_kernel(iter, exp);
pow_chalf_tensor_scalar_impl(iter, exp);
} else {
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
return pow_(base, exp);
});
using opmath_t = at::opmath_type<scalar_t>;
TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2));
#if AT_USE_JITERATOR()
jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 2>(
iter, pow_kernel_string);
#else
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
return pow_(opmath_t{base}, opmath_t{exp});
});
#endif
}
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] {
if (iter.is_cpu_scalar(1)) {
const auto base = iter.scalar_value<scalar_t>(1);
iter.remove_operand(1);
pow_scalar_tensor_impl(iter, base);
} else if (iter.is_cpu_scalar(2)) {
const auto exp = iter.scalar_value<scalar_t>(2);
iter.remove_operand(2);
pow_tensor_scalar_kernel(iter, exp);
} else {
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
return pow_(base, exp);
});
}
});
}
}
@ -140,6 +226,11 @@ void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar
}
}
if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) {
if (iter.common_dtype() == kComplexHalf) {
using scalar_t = c10::complex<at::Half>;
pow_chalf_tensor_scalar_impl(iter, exp_scalar);
return;
}
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() {
const auto exp = exp_scalar.to<scalar_t>();
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t {

View File

@ -12,7 +12,7 @@ namespace at { namespace native {
namespace {
void where_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
gpu_kernel(
iter,
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {

View File

@ -36,7 +36,7 @@ class C10_API Scalar {
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {}
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, DEFINE_IMPLICIT_CTOR)
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
#undef DEFINE_IMPLICIT_CTOR

View File

@ -1456,13 +1456,15 @@ class TestBinaryUfuncs(TestCase):
self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
else:
self._do_pow_for_exponents(m1, exponents, math.pow, None)
if dtype != torch.half:
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
else:
will_raise_error = dtype is torch.half and torch.device(device).type == 'cpu'
if will_raise_error:
# On CPU,
# Half Tensor with complex exponents leads to computation dtype
# of ComplexHalf for which this ops is not supported yet
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
else:
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
# base - number, exponent - tensor
# contiguous
@ -1751,11 +1753,14 @@ class TestBinaryUfuncs(TestCase):
first_exp[0] = first_exp[10] = first_exp[20] = 0
second_exp[0] = second_exp[10] = second_exp[20] = 0
for base in complexes:
# On CPU,
# Half Tensor with complex base leads to computation dtype
# of ComplexHalf for which this ops is not supported yet
# NOTE: pow has fast-path when base is 1 which supports
# ComplexHalf
if dtype is torch.half and base != (1 + 0j):
will_raise_error = torch.device(device).type == 'cpu' and \
dtype is torch.half and base != (1 + 0j)
if will_raise_error:
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
self._test_pow(base, first_exp)
self._test_pow(base, second_exp)

View File

@ -48,6 +48,7 @@ from torch.testing._internal.common_device_type import (
OpDTypes,
skipMeta,
)
from torch.utils._pytree import tree_map
import torch._prims as prims
from torch._prims.context import TorchRefsMode
@ -1106,6 +1107,16 @@ class TestCommon(TestCase):
*transformed_sample.args,
**transformed_sample.kwargs,
)
# Since range of chalf is much less compared to cfloat,
# we get `inf`s easily (eg. with `pow`, `exp`),
# so we cast `cfloat` back to `chalf`.
expected = tree_map(lambda x: x.to(torch.complex32) if isinstance(
x, torch.Tensor) and x.dtype is torch.complex64 else x, expected)
# `exact_dtype` is False because for ops like real, imag
# we get different dtypes for `actual` and `expected`
# `chalf` input -> `half` output
# `cfloat` input -> `float` output
self.assertEqual(actual, expected, exact_dtype=False)

View File

@ -12590,15 +12590,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
supports_out=False,
skips=(
# RuntimeError: "where_cpu" not implemented for 'ComplexHalf'
# RuntimeError: "where_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', dtypes=(torch.chalf,)),
# RuntimeError: "where_cpu" not implemented for 'ComplexHalf'
# RuntimeError: "where_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', dtypes=(torch.chalf,)),
)),
supports_out=False),
OpInfo('masked_scatter',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_scatter,
@ -14819,12 +14811,13 @@ op_db: List[OpInfo] = [
reference_inputs_func=reference_inputs_permute),
BinaryUfuncInfo('pow',
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
ref=np.power,
# Due to AVX2 curently not being fully supported for Float16, log_vml_cpu can't be enabled
# for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
# unsupported on CPU.
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
supports_inplace_autograd=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@ -14848,12 +14841,18 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
dtypes=[torch.int16, torch.int32, torch.int64]),
# FIXME Complex values error with: Greatest absolute difference: nan at index
# Ref: https://github.com/pytorch/pytorch/issues/76853
# For `chalf`, reference computation in `numpy` is computed in `cfloat`.
# Output of `chalf` saturates to `inf` quicker than reference due to its small range
# which leads to failure of this test.
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics',
dtypes=(torch.complex32,)),
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
dtypes=[torch.complex64, torch.complex128]),
dtypes=(torch.complex32, torch.complex64, torch.complex128)),
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
dtypes=[torch.complex64, torch.complex128]),
dtypes=(torch.complex32, torch.complex64, torch.complex128)),
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
dtypes=[torch.complex64, torch.complex128]),
dtypes=(torch.complex32, torch.complex64, torch.complex128)),
)),
BinaryUfuncInfo('float_power',
ref=np.float_power,
@ -15103,7 +15102,8 @@ op_db: List[OpInfo] = [
UnaryUfuncInfo('sgn',
ref=reference_sgn,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
@ -15360,7 +15360,6 @@ op_db: List[OpInfo] = [
ref=np.tan,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@ -15673,11 +15672,8 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool),
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2}),),
# TODO: add `torch.chalf` backward dtype support.
# AssertionError: The supported dtypes for angle on device type cuda are incorrect!
# The following dtypes did not work in backward but are listed by the OpInfo: {torch.complex32}.
backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
backward_dtypesIfCUDA=floating_and_complex_types(),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse_csr=True,
@ -17587,7 +17583,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)),
OpInfo('nonzero',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
sample_inputs_func=sample_inputs_nonzero,