Complex support for expm1 (#96644)

Fixes #92619

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96644
Approved by: https://github.com/soulitzer
This commit is contained in:
yhl48
2023-03-24 17:24:50 +00:00
committed by PyTorch MergeBot
parent 1b8b82f835
commit 6fcd671574
14 changed files with 75 additions and 20 deletions

View File

@ -241,7 +241,7 @@ public:
return scaled_values.exp();
}
Vectorized<c10::complex<double>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);

View File

@ -275,7 +275,7 @@ public:
return scaled_values.exp();
}
Vectorized<c10::complex<float>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<float>> sin() const {
return map(std::sin);

View File

@ -456,6 +456,9 @@ class Vectorized<ComplexDbl> {
Vectorized<ComplexDbl> exp2() const {
return map(exp2_impl);
}
Vectorized<ComplexDbl> expm1() const {
return map(std::expm1);
}
Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
__at_align__ ComplexDbl x_tmp[size()];
@ -498,10 +501,6 @@ class Vectorized<ComplexDbl> {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexDbl> expm1() const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexDbl> operator<(const Vectorized<ComplexDbl>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}

View File

@ -535,6 +535,9 @@ class Vectorized<ComplexFlt> {
Vectorized<ComplexFlt> exp2() const {
return map(exp2_impl);
}
Vectorized<ComplexFlt> expm1() const {
return map(std::expm1);
}
Vectorized<ComplexFlt> eq(const Vectorized<ComplexFlt>& other) const {
auto ret = (*this == other);
@ -575,10 +578,6 @@ class Vectorized<ComplexFlt> {
TORCH_CHECK(false,"not supported for complex numbers");
}
Vectorized<ComplexFlt> expm1() const {
TORCH_CHECK(false,"not supported for complex numbers");
}
Vectorized<ComplexFlt> operator<(const Vectorized<ComplexFlt>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}

View File

@ -304,7 +304,7 @@ public:
return scaled_values.exp();
}
Vectorized<c10::complex<double>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);

View File

@ -806,7 +806,7 @@ public:
return scaled_values.exp();
}
Vectorized<c10::complex<float>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<float>> sin() const {
return map(std::sin);

View File

@ -774,7 +774,7 @@ IMPLEMENT_FLOAT_KERNEL(erfinv)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(exp)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(expm1)
IMPLEMENT_COMPLEX_KERNEL(expm1)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(floor)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)

View File

@ -191,7 +191,6 @@ STD_FUNCTOR(op_name, functor_name); \
OP_CUSTOM_FUNCTOR(function, op_name, functor_name); \
OP(floating_half_bfloat16, erfc, Erfc);
OP(floating_half_bfloat16, expm1, Expm1);
OP(floating_half, lgamma, Lgamma);
OP(floating_half_bfloat16, trunc, Truncf);
OP(floating_half_bfloat16, floor, Floor);
@ -206,6 +205,7 @@ OP(floating_complex_half_bfloat16, sin, Sin);
OP(floating_complex_half_bfloat16, sinh, Sinh);
OP(floating_complex_half_bfloat16, exp, Exp);
OP(floating_complex_half_bfloat16, expm1, Expm1);
OP(floating_complex_half_bfloat16, tanh, Tanh);
OP(floating_complex_half_bfloat16, log, Log);
OP(floating_complex_half_bfloat16, log10, Log10);

View File

@ -69,7 +69,7 @@ void exp_kernel_cuda(TensorIteratorBase& iter) {
}
void expm1_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half,
iter.common_dtype(), "expm1_cuda",
[&]() {

View File

@ -74,6 +74,41 @@ C10_DEFINE_TEST(TestExponential, EulerFormula) {
}
}
C10_DEFINE_TEST(TestExpm1, Normal) {
// expm1(x) = exp(x) - 1
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l1 = std::expm1(x);
c10::complex<float> l2 = std::exp(x) - 1.0f;
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l1 = std::expm1(x);
c10::complex<double> l2 = std::exp(x) - 1.0;
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
}
}
C10_DEFINE_TEST(TestExpm1, Small) {
// expm1(x) = exp(x) - 1
// expm1(x) provides greater precision than exp(x) - 1 for small values of x
{
c10::complex<float> x(1e-30, 1e-30);
c10::complex<float> l1 = std::expm1(x);
C10_ASSERT_NEAR(l1.real(), 1e-30, tol);
C10_ASSERT_NEAR(l1.imag(), 1e-30, tol);
}
{
c10::complex<double> x(1e-100, 1e-100);
c10::complex<double> l1 = std::expm1(x);
C10_ASSERT_NEAR(l1.real(), 1e-30, tol);
C10_ASSERT_NEAR(l1.imag(), 1e-30, tol);
}
}
C10_DEFINE_TEST(TestLog, Definition) {
// log(x) = log(r) + i*theta
{

View File

@ -318,6 +318,23 @@ C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
}
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
// expm1(z) = exp(z) - 1
// Define z = x + i * y
// f = e ^ (x + i * y) - 1
// = e ^ x * e ^ (i * y) - 1
// = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
// = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
// = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
T x = z.real();
T y = z.imag();
T a = std::sin(y / 2);
T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
T ei = std::exp(x) * std::sin(y);
return {er, ei};
}
} // namespace c10_complex_math
using c10_complex_math::acos;
@ -329,6 +346,7 @@ using c10_complex_math::atanh;
using c10_complex_math::cos;
using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::expm1;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
@ -351,6 +369,7 @@ using c10_complex_math::atanh;
using c10_complex_math::cos;
using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::expm1;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;

View File

@ -621,7 +621,7 @@
result: auto_element_wise
- name: expm1(Tensor self) -> Tensor
self: grad * (result + 1)
self: grad * (result.conj() + 1)
result: auto_element_wise
# TODO: this derivative is not SymInt safe, need sum_to support

View File

@ -264,6 +264,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"fill_",
"exp",
"exp2",
"expm1",
"nonzero",
"mean",
"std_mean",

View File

@ -8324,8 +8324,8 @@ foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo(
'expm1',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
@ -14456,8 +14456,8 @@ op_db: List[OpInfo] = [
UnaryUfuncInfo('expm1',
aliases=('special.expm1', ),
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
@ -14472,6 +14472,8 @@ op_db: List[OpInfo] = [
device_type='cpu', dtypes=[torch.bfloat16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
device_type='cpu', dtypes=[torch.bfloat16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
device_type='cuda', dtypes=[torch.complex128]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
device_type='cpu', dtypes=[torch.bfloat16]),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),