torch.special.gamma (#78904)

```Python
gamma(input, *, out=None) -> Tensor
```

Gamma function $\Gamma\left(\text{input}\right)$.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78904
Approved by: https://github.com/mruberry
This commit is contained in:
Allen Goodman
2022-06-27 19:36:17 +00:00
committed by PyTorch MergeBot
parent d67ce755ad
commit f563f25efd
15 changed files with 490 additions and 0 deletions

View File

@ -3013,6 +3013,160 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
return chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
} // chebyshev_polynomial_w_forward(T x, T n)
template<typename T>
static inline C10_HOST_DEVICE
typename std::enable_if<std::is_floating_point<T>::value, T>::type
gamma_forward(T x) {
static const T P[] = {
+1.60119522476751861407e-4,
+1.19135147006586384913e-3,
+1.04213797561761569935e-2,
+4.76367800457137231464e-2,
+2.07448227648435975150e-1,
+4.94214826801497100753e-1,
+9.99999999999999996796e-1,
};
static const T Q[] = {
-2.31581873324120129819e-5,
+5.39605580493303397842e-4,
-4.45641913851797240494e-3,
+1.18139785222060435552e-2,
+3.58236398605498653373e-2,
-2.34591795718243348568e-1,
+7.14304917030273074085e-2,
+1.00000000000000000320e+0,
};
static const T R[] = {
+7.87311395793093628397e-4,
-2.29549961613378126380e-4,
-2.68132617805781232825e-3,
+3.47222221605458667310e-3,
+8.33333333333482257126e-2,
};
int sign_gamma = 1;
if (!std::isfinite(x)) {
return x;
}
if (std::abs(x) > T(33.0)) {
if (x < T(0.0)) {
T p = std::floor(std::abs(x));
if (p == std::abs(x)) {
return std::numeric_limits<T>::infinity();
}
int previous_p = p;
if ((previous_p & 1) == 0) {
sign_gamma = -1;
}
T z = std::abs(x) - p;
if (z > T(0.5)) {
z = std::abs(x) - (p + T(1.0));
}
z = std::abs(x) * std::sin(c10::pi<T> * z);
if (z == T(0.0)) {
return sign_gamma * std::numeric_limits<T>::infinity();
}
if (std::abs(x) >= T(171.624376956302725)) {
return std::numeric_limits<T>::infinity();
}
T r = 0.0;
for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / std::abs(x)) + R[index];
}
if (std::abs(x) > T(143.01608)) {
return sign_gamma * c10::pi<T> / (std::abs(z) * (T(2.50662827463100050242) * (std::pow(std::abs(x), T(0.5) * std::abs(x) - T(0.25)) * (std::pow(std::abs(x), T(0.5) * std::abs(x) - T(0.25)) / std::exp(std::abs(x)))) * (T(1.0) + T(1.0) / std::abs(x) * r)));
}
return sign_gamma * c10::pi<T> / (std::abs(z) * (T(2.50662827463100050242) * (std::pow(std::abs(x), std::abs(x) - T(0.5)) / std::exp(std::abs(x))) * (T(1.0) + T(1.0) / std::abs(x) * r)));
}
if (x >= T(171.624376956302725)) {
return std::numeric_limits<T>::infinity();
}
T r = 0.0;
for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / x) + R[index];
}
if (x > T(143.01608)) {
return sign_gamma * (T(2.50662827463100050242) * (std::pow(x, T(0.5) * x - T(0.25)) * (std::pow(x, T(0.5) * x - T(0.25)) / std::exp(x))) * (T(1.0) + T(1.0) / x * r));
}
return sign_gamma * (T(2.50662827463100050242) * (std::pow(x, x - T(0.5)) / std::exp(x)) * (T(1.0) + T(1.0) / x * r));
}
T z = 1.0;
while (x >= T(3.0)) {
x = x - T(1.0);
z = z * x;
}
while (x < T(0.0)) {
if (x > -0.000000001) {
if (x == T(0.0)) {
return std::numeric_limits<T>::infinity();
}
return z / ((T(1.0) + c10::euler<T> * x) * x);
}
z = z / x;
x = x + T(1.0);
}
while (x < T(2.0)) {
if (x < 0.000000001) {
if (x == T(0.0)) {
return std::numeric_limits<T>::infinity();
}
return z / ((T(1.0) + c10::euler<T> * x) * x);
}
z = z / x;
x = x + T(1.0);
}
if (x == T(2.0)) {
return z;
}
T p = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
p = p * (x - T(2.0)) + P[index];
}
T q = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
q = q * (x - T(2.0)) + Q[index];
}
return z * p / q;
} // T gamma_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
if (n < 0) {

View File

@ -76,6 +76,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
CREATE_UNARY_FLOAT_META_FUNC(special_gamma)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
@ -205,6 +206,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_gamma_out, special_gamma_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
@ -888,6 +890,7 @@ DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-c
DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_gamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -75,6 +75,7 @@ DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
DECLARE_DISPATCH(unary_fn, special_gamma_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);

View File

@ -0,0 +1,40 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <cmath>
#include <limits>
#include <type_traits>
#include <ATen/Config.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vml.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/cpu/zmath.h>
#include <ATen/OpMathType.h>
#include <c10/util/math_compat.h>
#include <c10/util/MathConstants.h>
#include <c10/core/Scalar.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
static void gamma_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "gamma_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return gamma_forward(x);
});
});
} // gamma_kernel(TensorIteratorBase& iterator)
REGISTER_DISPATCH(special_gamma_stub, &gamma_kernel);
} // namespace native
} // namespace at

View File

@ -2119,6 +2119,162 @@ const auto chebyshev_polynomial_w_string = jiterator_stringify(
} // chebyshev_polynomial_w_forward(T x, T n)
); // chebyshev_polynomial_w_string
const auto gamma_string = jiterator_stringify(
template<typename T>
T gamma_forward(T x) {
static const T P[] = {
+1.60119522476751861407e-4,
+1.19135147006586384913e-3,
+1.04213797561761569935e-2,
+4.76367800457137231464e-2,
+2.07448227648435975150e-1,
+4.94214826801497100753e-1,
+9.99999999999999996796e-1,
};
static const T Q[] = {
-2.31581873324120129819e-5,
+5.39605580493303397842e-4,
-4.45641913851797240494e-3,
+1.18139785222060435552e-2,
+3.58236398605498653373e-2,
-2.34591795718243348568e-1,
+7.14304917030273074085e-2,
+1.00000000000000000320e+0,
};
static const T R[] = {
+7.87311395793093628397e-4,
-2.29549961613378126380e-4,
-2.68132617805781232825e-3,
+3.47222221605458667310e-3,
+8.33333333333482257126e-2,
};
constexpr T PI = 3.14159265358979323846;
int sign_gamma = 1;
if (!isfinite(x)) {
return x;
}
if (abs(x) > T(33.0)) {
if (x < T(0.0)) {
T p = floor(abs(x));
if (p == abs(x)) {
return INFINITY;
}
int previous_p = p;
if ((previous_p & 1) == 0) {
sign_gamma = -1;
}
T z = abs(x) - p;
if (z > T(0.5)) {
z = abs(x) - (p + T(1.0));
}
z = abs(x) * sin(PI * z);
if (z == T(0.0)) {
return sign_gamma * INFINITY;
}
if (abs(x) >= T(171.624376956302725)) {
return INFINITY;
}
T r = 0.0;
for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / abs(x)) + R[index];
}
if (abs(x) > T(143.01608)) {
return sign_gamma * PI / (abs(z) * (T(2.50662827463100050242) * (pow(abs(x), T(0.5) * abs(x) - T(0.25)) * (pow(abs(x), T(0.5) * abs(x) - T(0.25)) / exp(abs(x)))) * (T(1.0) + T(1.0) / abs(x) * r)));
}
return sign_gamma * PI / (abs(z) * (T(2.50662827463100050242) * (pow(abs(x), abs(x) - T(0.5)) / exp(abs(x))) * (T(1.0) + T(1.0) / abs(x) * r)));
}
if (x >= T(171.624376956302725)) {
return INFINITY;
}
T r = 0.0;
for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / x) + R[index];
}
if (x > T(143.01608)) {
return sign_gamma * (T(2.50662827463100050242) * (pow(x, T(0.5) * x - T(0.25)) * (pow(x, T(0.5) * x - T(0.25)) / exp(x))) * (T(1.0) + T(1.0) / x * r));
}
return sign_gamma * (T(2.50662827463100050242) * (pow(x, x - T(0.5)) / exp(x)) * (T(1.0) + T(1.0) / x * r));
}
T z = 1.0;
while (x >= T(3.0)) {
x = x - T(1.0);
z = z * x;
}
while (x < T(0.0)) {
if (x > -0.000000001) {
if (x == T(0.0)) {
return INFINITY;
}
return z / ((T(1.0) + T(0.5772156649015329) * x) * x);
}
z = z / x;
x = x + T(1.0);
}
while (x < T(2.0)) {
if (x < 0.000000001) {
if (x == T(0.0)) {
return INFINITY;
}
return z / ((T(1.0) + T(0.5772156649015329) * x) * x);
}
z = z / x;
x = x + T(1.0);
}
if (x == T(2.0)) {
return z;
}
T p = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
p = p * (x - T(2.0)) + P[index];
}
T q = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
q = q * (x - T(2.0)) + Q[index];
}
return z * p / q;
} // T gamma_forward(T x)
); // gamma_string
const auto hermite_polynomial_h_string = jiterator_stringify(
template<typename T>
T hermite_polynomial_h_forward(T x, int64_t n) {

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char gamma_name[] = "gamma_forward";
void gamma_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "gamma_cuda", [&]() {
jitted_gpu_kernel<gamma_name, scalar_t, scalar_t, 1>(iterator, gamma_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "gamma_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return gamma_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_gamma_stub, &gamma_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -12625,6 +12625,19 @@
python_module: special
variants: function
- func: special_gamma(Tensor x) -> Tensor
python_module: special
structured_delegate: special_gamma.out
variants: function
- func: special_gamma.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_gamma_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
device_check: NoCheck
python_module: special

View File

@ -1150,6 +1150,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/WeightNormKernel.cpp",
"aten/src/ATen/native/cpu/airy_ai.cpp",
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
"aten/src/ATen/native/cpu/gamma.cpp",
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",

View File

@ -22,6 +22,7 @@ Functions
.. autofunction:: exp2
.. autofunction:: expit
.. autofunction:: expm1
.. autofunction:: gamma
.. autofunction:: gammainc
.. autofunction:: gammaincc
.. autofunction:: gammaln

View File

@ -2451,6 +2451,7 @@
"exp2",
"expit",
"expm1",
"gamma",
"gammainc",
"gammaincc",
"gammaln",

View File

@ -2853,6 +2853,9 @@
- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_gamma(Tensor x) -> Tensor
x: non_differentiable
- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable

View File

@ -894,6 +894,25 @@ inline Tensor& chebyshev_polynomial_w_out(
return torch::special_chebyshev_polynomial_w_out(output, x, n);
}
/// Gamma function.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.gamma.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::gamma(x);
/// ```
inline Tensor gamma(const Tensor& x) {
return torch::special_gamma(x);
}
inline Tensor& gamma_out(Tensor& y, const Tensor& x) {
return torch::special_gamma_out(y, x);
}
/// Physicists Hermite polynomial.
///
/// See

View File

@ -1008,6 +1008,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.exp2: lambda input: -1,
torch.special.expit: lambda input: -1,
torch.special.expm1: lambda input: -1,
torch.special.gamma: lambda input: -1,
torch.special.gammainc: lambda input, other, out=None: -1,
torch.special.gammaincc: lambda input, other, out=None: -1,
torch.special.gammaln: lambda input: -1,

View File

@ -21,6 +21,7 @@ __all__ = [
'exp2',
'expit',
'expm1',
'gamma',
'gammainc',
'gammaincc',
'gammaln',
@ -1030,6 +1031,31 @@ Keyword args:
{out}
""".format(**common_args))
gamma = _add_docstr(_special.special_gamma,
r"""
gamma(input, *, out=None) -> Tensor
Gamma function :math:`\Gamma\left(\text{input}\right)` defined as the
convergent improper integral:
.. math::
\int_{0}^{\infty}x^{z - 1}e^{-x}dx
The gamma function is often referred to as the generalized factorial function
since :math:`\Gamma\left(n + 1\right) = n!` for natural numbers
:math:`n \in \mathbb{N}`. It satisfies the recurrence relation
:math:`\Gamma\left(z + 1\right) = z \Gamma\left(z\right)` for complex
:math:`z \in \mathbb{C}`, which, combined with the fact that
:math:`\Gamma\left(1\right) = 1`, implies the above identity for :math:`z = n`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
hermite_polynomial_h = _add_docstr(_special.special_hermite_polynomial_h,
r"""
hermite_polynomial_h(input, n, *, out=None) -> Tensor

View File

@ -19678,6 +19678,34 @@ op_db: List[OpInfo] = [
supports_one_python_scalar=True,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.gamma',
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=1e-03, rtol=1.3e-05),
torch.float64: tol(atol=1e-05, rtol=1.3e-05)
}
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.gamma if TEST_SCIPY else _NOTHING,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
'TestUnaryUfuncs',
'test_reference_numerics_large',
dtypes=[torch.float32, torch.float64],
),
DecorateInfo(
unittest.skip("Skipped!"),
'TestUnaryUfuncs',
'test_reference_numerics_small',
dtypes=[torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64],
),
),
supports_autograd=False,
),
BinaryUfuncInfo(
'special.hermite_polynomial_h',
dtypes=all_types_and(torch.bool),