mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: `-Wextra-semi` or `-Wextra-semi-stmt` If the code compiles, this is safe to land. Test Plan: Sandcastle Reviewed By: palmje Differential Revision: D51995078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115761 Approved by: https://github.com/Skylion007
407 lines
12 KiB
C++
407 lines
12 KiB
C++
#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
|
|
#error \
|
|
"c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
|
|
#endif
|
|
|
|
namespace c10_complex_math {
|
|
|
|
// Exponential functions
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::exp(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::exp(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::log(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::log(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::log10(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::log10(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
|
|
const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
|
|
return c10_complex_math::log(x) / log2;
|
|
}
|
|
|
|
// Power functions
|
|
//
|
|
#if defined(_LIBCPP_VERSION) || \
|
|
(defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
|
|
namespace _detail {
|
|
C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
|
|
C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
|
|
C10_API c10::complex<float> acos(const c10::complex<float>& in);
|
|
C10_API c10::complex<double> acos(const c10::complex<double>& in);
|
|
} // namespace _detail
|
|
#endif
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::sqrt(static_cast<thrust::complex<T>>(x)));
|
|
#elif !( \
|
|
defined(_LIBCPP_VERSION) || \
|
|
(defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
|
|
return static_cast<c10::complex<T>>(
|
|
std::sqrt(static_cast<std::complex<T>>(x)));
|
|
#else
|
|
return _detail::sqrt(x);
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> pow(
|
|
const c10::complex<T>& x,
|
|
const c10::complex<T>& y) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(thrust::pow(
|
|
static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(std::pow(
|
|
static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> pow(
|
|
const c10::complex<T>& x,
|
|
const T& y) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::pow(static_cast<thrust::complex<T>>(x), y));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::pow(static_cast<std::complex<T>>(x), y));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> pow(
|
|
const T& x,
|
|
const c10::complex<T>& y) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::pow(x, static_cast<thrust::complex<T>>(y)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::pow(x, static_cast<std::complex<T>>(y)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
|
|
const c10::complex<T>& x,
|
|
const c10::complex<U>& y) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(thrust::pow(
|
|
static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(std::pow(
|
|
static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
|
|
const c10::complex<T>& x,
|
|
const U& y) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::pow(static_cast<thrust::complex<T>>(x), y));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::pow(static_cast<std::complex<T>>(x), y));
|
|
#endif
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
|
|
const T& x,
|
|
const c10::complex<U>& y) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::pow(x, static_cast<thrust::complex<T>>(y)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::pow(x, static_cast<std::complex<T>>(y)));
|
|
#endif
|
|
}
|
|
|
|
// Trigonometric functions
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::sin(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::sin(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::cos(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::cos(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::tan(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::tan(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::asin(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::asin(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::acos(static_cast<thrust::complex<T>>(x)));
|
|
#elif !defined(_LIBCPP_VERSION)
|
|
return static_cast<c10::complex<T>>(
|
|
std::acos(static_cast<std::complex<T>>(x)));
|
|
#else
|
|
return _detail::acos(x);
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::atan(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::atan(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
// Hyperbolic functions
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::sinh(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::sinh(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::cosh(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::cosh(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::tanh(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::tanh(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::asinh(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::asinh(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::acosh(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::acosh(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
|
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
return static_cast<c10::complex<T>>(
|
|
thrust::atanh(static_cast<thrust::complex<T>>(x)));
|
|
#else
|
|
return static_cast<c10::complex<T>>(
|
|
std::atanh(static_cast<std::complex<T>>(x)));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
|
|
#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \
|
|
defined(__HIPCC__)
|
|
// For Mac, the new implementation yielded a high relative error. Falling back
|
|
// to the old version for now.
|
|
// See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
|
|
// For CUDA we also use this one, as thrust::log(thrust::complex) takes
|
|
// *forever* to compile
|
|
|
|
// log1p(z) = log(1 + z)
|
|
// Let's define 1 + z = r * e ^ (i * a), then we have
|
|
// log(r * e ^ (i * a)) = log(r) + i * a
|
|
// With z = x + iy, the term r can be written as
|
|
// r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
|
|
// = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
|
|
// So, log(r) is
|
|
// log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
|
|
// = 0.5 * log1p(x * (x + 2) + y ^ 2)
|
|
// we need to use the expression only on certain condition to avoid overflow
|
|
// and underflow from `(x * (x + 2) + y ^ 2)`
|
|
T x = z.real();
|
|
T y = z.imag();
|
|
T zabs = std::abs(z);
|
|
T theta = std::atan2(y, x + T(1));
|
|
if (zabs < 0.5) {
|
|
T r = x * (T(2) + x) + y * y;
|
|
if (r == 0) { // handle underflow
|
|
return {x, theta};
|
|
}
|
|
return {T(0.5) * std::log1p(r), theta};
|
|
} else {
|
|
T z0 = std::hypot(x + 1, y);
|
|
return {std::log(z0), theta};
|
|
}
|
|
#else
|
|
// CPU path
|
|
// Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
|
|
c10::complex<T> u = z + T(1);
|
|
if (u == T(1)) {
|
|
return z;
|
|
} else {
|
|
auto log_u = log(u);
|
|
if (u - T(1) == z) {
|
|
return log_u;
|
|
}
|
|
return log_u * (z / (u - T(1)));
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
|
|
// expm1(z) = exp(z) - 1
|
|
// Define z = x + i * y
|
|
// f = e ^ (x + i * y) - 1
|
|
// = e ^ x * e ^ (i * y) - 1
|
|
// = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
|
|
// = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
|
|
// = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
|
|
T x = z.real();
|
|
T y = z.imag();
|
|
T a = std::sin(y / 2);
|
|
T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
|
|
T ei = std::exp(x) * std::sin(y);
|
|
return {er, ei};
|
|
}
|
|
|
|
} // namespace c10_complex_math
|
|
|
|
using c10_complex_math::acos;
|
|
using c10_complex_math::acosh;
|
|
using c10_complex_math::asin;
|
|
using c10_complex_math::asinh;
|
|
using c10_complex_math::atan;
|
|
using c10_complex_math::atanh;
|
|
using c10_complex_math::cos;
|
|
using c10_complex_math::cosh;
|
|
using c10_complex_math::exp;
|
|
using c10_complex_math::expm1;
|
|
using c10_complex_math::log;
|
|
using c10_complex_math::log10;
|
|
using c10_complex_math::log1p;
|
|
using c10_complex_math::log2;
|
|
using c10_complex_math::pow;
|
|
using c10_complex_math::sin;
|
|
using c10_complex_math::sinh;
|
|
using c10_complex_math::sqrt;
|
|
using c10_complex_math::tan;
|
|
using c10_complex_math::tanh;
|
|
|
|
namespace std {
|
|
|
|
using c10_complex_math::acos;
|
|
using c10_complex_math::acosh;
|
|
using c10_complex_math::asin;
|
|
using c10_complex_math::asinh;
|
|
using c10_complex_math::atan;
|
|
using c10_complex_math::atanh;
|
|
using c10_complex_math::cos;
|
|
using c10_complex_math::cosh;
|
|
using c10_complex_math::exp;
|
|
using c10_complex_math::expm1;
|
|
using c10_complex_math::log;
|
|
using c10_complex_math::log10;
|
|
using c10_complex_math::log1p;
|
|
using c10_complex_math::log2;
|
|
using c10_complex_math::pow;
|
|
using c10_complex_math::sin;
|
|
using c10_complex_math::sinh;
|
|
using c10_complex_math::sqrt;
|
|
using c10_complex_math::tan;
|
|
using c10_complex_math::tanh;
|
|
|
|
} // namespace std
|