mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland][chalf] where(cpu and cuda), pow(cuda) (#78665)
Reland: https://github.com/pytorch/pytorch/pull/77640 Ref: #74537 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78665 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d578197747
commit
849b08f14b
@ -205,7 +205,7 @@ static void aminmax_kernel(
|
||||
}
|
||||
|
||||
static void where_kernel_impl(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
|
||||
iter.dtype(), "where_cpu", [&] {
|
||||
cpu_kernel(
|
||||
iter,
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/Pow.h>
|
||||
@ -83,9 +84,69 @@ void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<value_t> base
|
||||
});
|
||||
}
|
||||
|
||||
/* complex<Half> support impl */
|
||||
const char pow_scalar_base_name[] = "pow_scalar_base_kernel";
|
||||
template <>
|
||||
void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<at::Half> base) {
|
||||
using scalar_t = c10::complex<at::Half>;
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
// For complex, thrust::pow uses the identity
|
||||
// pow(a, b) = exp(log(a) * b)
|
||||
const auto fct = std::log(opmath_t{base});
|
||||
#if AT_USE_JITERATOR()
|
||||
static const auto pow_kernel_string =
|
||||
jiterator_stringify(template <typename T> T pow_scalar_base_kernel(T exp, T fct) {
|
||||
return std::exp(fct * exp);
|
||||
});
|
||||
jitted_gpu_kernel<pow_scalar_base_name, scalar_t, scalar_t, 1>(
|
||||
iter,
|
||||
pow_kernel_string,
|
||||
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
|
||||
/*scalar_val=*/0,
|
||||
/*extra_args=*/std::make_tuple(fct));
|
||||
#else
|
||||
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t {
|
||||
return std::exp(fct * opmath_t{exp});
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
/* complex<Half> support impl */
|
||||
const char pow_name[] = "pow_kernel";
|
||||
static const auto pow_kernel_string =
|
||||
jiterator_stringify(template <typename T> T pow_kernel(T base, T exp) {
|
||||
return std::pow(base, exp);
|
||||
});
|
||||
#endif
|
||||
|
||||
/* complex<Half> support impl */
|
||||
void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) {
|
||||
using scalar_t = c10::complex<at::Half>;
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
auto exp = exp_scalar.to<opmath_t>();
|
||||
#if AT_USE_JITERATOR()
|
||||
jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 1>(
|
||||
iter,
|
||||
pow_kernel_string,
|
||||
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
|
||||
/*scalar_val=*/0,
|
||||
/*extra_args=*/std::make_tuple(exp));
|
||||
#else
|
||||
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t {
|
||||
return std::pow(opmath_t{base}, exp);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] {
|
||||
auto common_dtype = iter.common_dtype();
|
||||
if (common_dtype == kComplexHalf) {
|
||||
using scalar_t = c10::complex<at::Half>;
|
||||
if (iter.is_cpu_scalar(1)) {
|
||||
const auto base = iter.scalar_value<scalar_t>(1);
|
||||
iter.remove_operand(1);
|
||||
@ -93,13 +154,38 @@ void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
|
||||
} else if (iter.is_cpu_scalar(2)) {
|
||||
const auto exp = iter.scalar_value<scalar_t>(2);
|
||||
iter.remove_operand(2);
|
||||
pow_tensor_scalar_kernel(iter, exp);
|
||||
pow_chalf_tensor_scalar_impl(iter, exp);
|
||||
} else {
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
|
||||
return pow_(base, exp);
|
||||
});
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2));
|
||||
#if AT_USE_JITERATOR()
|
||||
jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 2>(
|
||||
iter, pow_kernel_string);
|
||||
#else
|
||||
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
return pow_(opmath_t{base}, opmath_t{exp});
|
||||
});
|
||||
#endif
|
||||
}
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] {
|
||||
if (iter.is_cpu_scalar(1)) {
|
||||
const auto base = iter.scalar_value<scalar_t>(1);
|
||||
iter.remove_operand(1);
|
||||
pow_scalar_tensor_impl(iter, base);
|
||||
} else if (iter.is_cpu_scalar(2)) {
|
||||
const auto exp = iter.scalar_value<scalar_t>(2);
|
||||
iter.remove_operand(2);
|
||||
pow_tensor_scalar_kernel(iter, exp);
|
||||
} else {
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
|
||||
return pow_(base, exp);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -140,6 +226,11 @@ void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar
|
||||
}
|
||||
}
|
||||
if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) {
|
||||
if (iter.common_dtype() == kComplexHalf) {
|
||||
using scalar_t = c10::complex<at::Half>;
|
||||
pow_chalf_tensor_scalar_impl(iter, exp_scalar);
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() {
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t {
|
||||
|
@ -12,7 +12,7 @@ namespace at { namespace native {
|
||||
namespace {
|
||||
|
||||
void where_kernel_impl(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
|
||||
gpu_kernel(
|
||||
iter,
|
||||
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
|
@ -36,7 +36,7 @@ class C10_API Scalar {
|
||||
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
||||
Scalar(type vv) : Scalar(vv, true) {}
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
|
||||
|
||||
#undef DEFINE_IMPLICIT_CTOR
|
||||
|
@ -1456,13 +1456,15 @@ class TestBinaryUfuncs(TestCase):
|
||||
self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
|
||||
else:
|
||||
self._do_pow_for_exponents(m1, exponents, math.pow, None)
|
||||
if dtype != torch.half:
|
||||
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
|
||||
else:
|
||||
will_raise_error = dtype is torch.half and torch.device(device).type == 'cpu'
|
||||
if will_raise_error:
|
||||
# On CPU,
|
||||
# Half Tensor with complex exponents leads to computation dtype
|
||||
# of ComplexHalf for which this ops is not supported yet
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
|
||||
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
|
||||
else:
|
||||
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
|
||||
|
||||
# base - number, exponent - tensor
|
||||
# contiguous
|
||||
@ -1751,11 +1753,14 @@ class TestBinaryUfuncs(TestCase):
|
||||
first_exp[0] = first_exp[10] = first_exp[20] = 0
|
||||
second_exp[0] = second_exp[10] = second_exp[20] = 0
|
||||
for base in complexes:
|
||||
# On CPU,
|
||||
# Half Tensor with complex base leads to computation dtype
|
||||
# of ComplexHalf for which this ops is not supported yet
|
||||
# NOTE: pow has fast-path when base is 1 which supports
|
||||
# ComplexHalf
|
||||
if dtype is torch.half and base != (1 + 0j):
|
||||
will_raise_error = torch.device(device).type == 'cpu' and \
|
||||
dtype is torch.half and base != (1 + 0j)
|
||||
if will_raise_error:
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
|
||||
self._test_pow(base, first_exp)
|
||||
self._test_pow(base, second_exp)
|
||||
|
@ -48,6 +48,7 @@ from torch.testing._internal.common_device_type import (
|
||||
OpDTypes,
|
||||
skipMeta,
|
||||
)
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch._prims as prims
|
||||
from torch._prims.context import TorchRefsMode
|
||||
|
||||
@ -1106,6 +1107,16 @@ class TestCommon(TestCase):
|
||||
*transformed_sample.args,
|
||||
**transformed_sample.kwargs,
|
||||
)
|
||||
# Since range of chalf is much less compared to cfloat,
|
||||
# we get `inf`s easily (eg. with `pow`, `exp`),
|
||||
# so we cast `cfloat` back to `chalf`.
|
||||
expected = tree_map(lambda x: x.to(torch.complex32) if isinstance(
|
||||
x, torch.Tensor) and x.dtype is torch.complex64 else x, expected)
|
||||
|
||||
# `exact_dtype` is False because for ops like real, imag
|
||||
# we get different dtypes for `actual` and `expected`
|
||||
# `chalf` input -> `half` output
|
||||
# `cfloat` input -> `float` output
|
||||
self.assertEqual(actual, expected, exact_dtype=False)
|
||||
|
||||
|
||||
|
@ -12590,15 +12590,7 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_forward_grad=False,
|
||||
supports_out=False,
|
||||
skips=(
|
||||
# RuntimeError: "where_cpu" not implemented for 'ComplexHalf'
|
||||
# RuntimeError: "where_cuda" not implemented for 'ComplexHalf'
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', dtypes=(torch.chalf,)),
|
||||
# RuntimeError: "where_cpu" not implemented for 'ComplexHalf'
|
||||
# RuntimeError: "where_cuda" not implemented for 'ComplexHalf'
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', dtypes=(torch.chalf,)),
|
||||
)),
|
||||
supports_out=False),
|
||||
OpInfo('masked_scatter',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_masked_scatter,
|
||||
@ -14819,12 +14811,13 @@ op_db: List[OpInfo] = [
|
||||
reference_inputs_func=reference_inputs_permute),
|
||||
BinaryUfuncInfo('pow',
|
||||
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
|
||||
ref=np.power,
|
||||
# Due to AVX2 curently not being fully supported for Float16, log_vml_cpu can't be enabled
|
||||
# for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
|
||||
# unsupported on CPU.
|
||||
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half),
|
||||
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
|
||||
supports_inplace_autograd=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
@ -14848,12 +14841,18 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
|
||||
dtypes=[torch.int16, torch.int32, torch.int64]),
|
||||
# FIXME Complex values error with: Greatest absolute difference: nan at index
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/76853
|
||||
# For `chalf`, reference computation in `numpy` is computed in `cfloat`.
|
||||
# Output of `chalf` saturates to `inf` quicker than reference due to its small range
|
||||
# which leads to failure of this test.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics',
|
||||
dtypes=(torch.complex32,)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
|
||||
dtypes=[torch.complex64, torch.complex128]),
|
||||
dtypes=(torch.complex32, torch.complex64, torch.complex128)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
|
||||
dtypes=[torch.complex64, torch.complex128]),
|
||||
dtypes=(torch.complex32, torch.complex64, torch.complex128)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
|
||||
dtypes=[torch.complex64, torch.complex128]),
|
||||
dtypes=(torch.complex32, torch.complex64, torch.complex128)),
|
||||
)),
|
||||
BinaryUfuncInfo('float_power',
|
||||
ref=np.float_power,
|
||||
@ -15103,7 +15102,8 @@ op_db: List[OpInfo] = [
|
||||
UnaryUfuncInfo('sgn',
|
||||
ref=reference_sgn,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
|
||||
backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
|
||||
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_sparse=True,
|
||||
@ -15360,7 +15360,6 @@ op_db: List[OpInfo] = [
|
||||
ref=np.tan,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
|
||||
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
@ -15673,11 +15672,8 @@ op_db: List[OpInfo] = [
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool),
|
||||
decorators=(precisionOverride({torch.float16: 1e-2,
|
||||
torch.bfloat16: 1e-2}),),
|
||||
# TODO: add `torch.chalf` backward dtype support.
|
||||
# AssertionError: The supported dtypes for angle on device type cuda are incorrect!
|
||||
# The following dtypes did not work in backward but are listed by the OpInfo: {torch.complex32}.
|
||||
backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
|
||||
backward_dtypesIfCUDA=floating_and_complex_types(),
|
||||
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_sparse_csr=True,
|
||||
@ -17587,7 +17583,7 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)),
|
||||
OpInfo('nonzero',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
||||
sample_inputs_func=sample_inputs_nonzero,
|
||||
|
Reference in New Issue
Block a user