mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Complex support for expm1 (#96644)
Fixes #92619 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96644 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -241,7 +241,7 @@ public:
|
||||
return scaled_values.exp();
|
||||
}
|
||||
Vectorized<c10::complex<double>> expm1() const {
|
||||
AT_ERROR("not supported for complex numbers");
|
||||
return map(std::expm1);
|
||||
}
|
||||
Vectorized<c10::complex<double>> sin() const {
|
||||
return map(std::sin);
|
||||
|
@ -275,7 +275,7 @@ public:
|
||||
return scaled_values.exp();
|
||||
}
|
||||
Vectorized<c10::complex<float>> expm1() const {
|
||||
AT_ERROR("not supported for complex numbers");
|
||||
return map(std::expm1);
|
||||
}
|
||||
Vectorized<c10::complex<float>> sin() const {
|
||||
return map(std::sin);
|
||||
|
@ -456,6 +456,9 @@ class Vectorized<ComplexDbl> {
|
||||
Vectorized<ComplexDbl> exp2() const {
|
||||
return map(exp2_impl);
|
||||
}
|
||||
Vectorized<ComplexDbl> expm1() const {
|
||||
return map(std::expm1);
|
||||
}
|
||||
|
||||
Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
|
||||
__at_align__ ComplexDbl x_tmp[size()];
|
||||
@ -498,10 +501,6 @@ class Vectorized<ComplexDbl> {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<ComplexDbl> expm1() const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<ComplexDbl> operator<(const Vectorized<ComplexDbl>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
@ -535,6 +535,9 @@ class Vectorized<ComplexFlt> {
|
||||
Vectorized<ComplexFlt> exp2() const {
|
||||
return map(exp2_impl);
|
||||
}
|
||||
Vectorized<ComplexFlt> expm1() const {
|
||||
return map(std::expm1);
|
||||
}
|
||||
|
||||
Vectorized<ComplexFlt> eq(const Vectorized<ComplexFlt>& other) const {
|
||||
auto ret = (*this == other);
|
||||
@ -575,10 +578,6 @@ class Vectorized<ComplexFlt> {
|
||||
TORCH_CHECK(false,"not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<ComplexFlt> expm1() const {
|
||||
TORCH_CHECK(false,"not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<ComplexFlt> operator<(const Vectorized<ComplexFlt>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
@ -304,7 +304,7 @@ public:
|
||||
return scaled_values.exp();
|
||||
}
|
||||
Vectorized<c10::complex<double>> expm1() const {
|
||||
AT_ERROR("not supported for complex numbers");
|
||||
return map(std::expm1);
|
||||
}
|
||||
Vectorized<c10::complex<double>> sin() const {
|
||||
return map(std::sin);
|
||||
|
@ -806,7 +806,7 @@ public:
|
||||
return scaled_values.exp();
|
||||
}
|
||||
Vectorized<c10::complex<float>> expm1() const {
|
||||
AT_ERROR("not supported for complex numbers");
|
||||
return map(std::expm1);
|
||||
}
|
||||
Vectorized<c10::complex<float>> sin() const {
|
||||
return map(std::sin);
|
||||
|
@ -774,7 +774,7 @@ IMPLEMENT_FLOAT_KERNEL(erfinv)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
IMPLEMENT_COMPLEX_KERNEL(exp)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
IMPLEMENT_FLOAT_KERNEL(expm1)
|
||||
IMPLEMENT_COMPLEX_KERNEL(expm1)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
IMPLEMENT_FLOAT_KERNEL(floor)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
|
@ -191,7 +191,6 @@ STD_FUNCTOR(op_name, functor_name); \
|
||||
OP_CUSTOM_FUNCTOR(function, op_name, functor_name); \
|
||||
|
||||
OP(floating_half_bfloat16, erfc, Erfc);
|
||||
OP(floating_half_bfloat16, expm1, Expm1);
|
||||
OP(floating_half, lgamma, Lgamma);
|
||||
OP(floating_half_bfloat16, trunc, Truncf);
|
||||
OP(floating_half_bfloat16, floor, Floor);
|
||||
@ -206,6 +205,7 @@ OP(floating_complex_half_bfloat16, sin, Sin);
|
||||
OP(floating_complex_half_bfloat16, sinh, Sinh);
|
||||
|
||||
OP(floating_complex_half_bfloat16, exp, Exp);
|
||||
OP(floating_complex_half_bfloat16, expm1, Expm1);
|
||||
OP(floating_complex_half_bfloat16, tanh, Tanh);
|
||||
OP(floating_complex_half_bfloat16, log, Log);
|
||||
OP(floating_complex_half_bfloat16, log10, Log10);
|
||||
|
@ -69,7 +69,7 @@ void exp_kernel_cuda(TensorIteratorBase& iter) {
|
||||
}
|
||||
|
||||
void expm1_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
ScalarType::BFloat16, ScalarType::Half,
|
||||
iter.common_dtype(), "expm1_cuda",
|
||||
[&]() {
|
||||
|
@ -74,6 +74,41 @@ C10_DEFINE_TEST(TestExponential, EulerFormula) {
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestExpm1, Normal) {
|
||||
// expm1(x) = exp(x) - 1
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l1 = std::expm1(x);
|
||||
c10::complex<float> l2 = std::exp(x) - 1.0f;
|
||||
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
|
||||
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l1 = std::expm1(x);
|
||||
c10::complex<double> l2 = std::exp(x) - 1.0;
|
||||
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
|
||||
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestExpm1, Small) {
|
||||
// expm1(x) = exp(x) - 1
|
||||
// expm1(x) provides greater precision than exp(x) - 1 for small values of x
|
||||
{
|
||||
c10::complex<float> x(1e-30, 1e-30);
|
||||
c10::complex<float> l1 = std::expm1(x);
|
||||
C10_ASSERT_NEAR(l1.real(), 1e-30, tol);
|
||||
C10_ASSERT_NEAR(l1.imag(), 1e-30, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(1e-100, 1e-100);
|
||||
c10::complex<double> l1 = std::expm1(x);
|
||||
C10_ASSERT_NEAR(l1.real(), 1e-30, tol);
|
||||
C10_ASSERT_NEAR(l1.imag(), 1e-30, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestLog, Definition) {
|
||||
// log(x) = log(r) + i*theta
|
||||
{
|
||||
|
@ -318,6 +318,23 @@ C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
|
||||
// expm1(z) = exp(z) - 1
|
||||
// Define z = x + i * y
|
||||
// f = e ^ (x + i * y) - 1
|
||||
// = e ^ x * e ^ (i * y) - 1
|
||||
// = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
|
||||
// = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
|
||||
// = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
|
||||
T x = z.real();
|
||||
T y = z.imag();
|
||||
T a = std::sin(y / 2);
|
||||
T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
|
||||
T ei = std::exp(x) * std::sin(y);
|
||||
return {er, ei};
|
||||
}
|
||||
|
||||
} // namespace c10_complex_math
|
||||
|
||||
using c10_complex_math::acos;
|
||||
@ -329,6 +346,7 @@ using c10_complex_math::atanh;
|
||||
using c10_complex_math::cos;
|
||||
using c10_complex_math::cosh;
|
||||
using c10_complex_math::exp;
|
||||
using c10_complex_math::expm1;
|
||||
using c10_complex_math::log;
|
||||
using c10_complex_math::log10;
|
||||
using c10_complex_math::log1p;
|
||||
@ -351,6 +369,7 @@ using c10_complex_math::atanh;
|
||||
using c10_complex_math::cos;
|
||||
using c10_complex_math::cosh;
|
||||
using c10_complex_math::exp;
|
||||
using c10_complex_math::expm1;
|
||||
using c10_complex_math::log;
|
||||
using c10_complex_math::log10;
|
||||
using c10_complex_math::log1p;
|
||||
|
@ -621,7 +621,7 @@
|
||||
result: auto_element_wise
|
||||
|
||||
- name: expm1(Tensor self) -> Tensor
|
||||
self: grad * (result + 1)
|
||||
self: grad * (result.conj() + 1)
|
||||
result: auto_element_wise
|
||||
|
||||
# TODO: this derivative is not SymInt safe, need sum_to support
|
||||
|
@ -264,6 +264,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"fill_",
|
||||
"exp",
|
||||
"exp2",
|
||||
"expm1",
|
||||
"nonzero",
|
||||
"mean",
|
||||
"std_mean",
|
||||
|
@ -8324,8 +8324,8 @@ foreach_unary_op_db: List[OpInfo] = [
|
||||
|
||||
ForeachFuncInfo(
|
||||
'expm1',
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
@ -14456,8 +14456,8 @@ op_db: List[OpInfo] = [
|
||||
UnaryUfuncInfo('expm1',
|
||||
aliases=('special.expm1', ),
|
||||
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_sparse=True,
|
||||
@ -14472,6 +14472,8 @@ op_db: List[OpInfo] = [
|
||||
device_type='cpu', dtypes=[torch.bfloat16]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
device_type='cpu', dtypes=[torch.bfloat16]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
device_type='cuda', dtypes=[torch.complex128]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
device_type='cpu', dtypes=[torch.bfloat16]),
|
||||
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
|
||||
|
Reference in New Issue
Block a user