mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
torch.special.gamma (#78904)
```Python
gamma(input, *, out=None) -> Tensor
```
Gamma function $\Gamma\left(\text{input}\right)$.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78904
Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
d67ce755ad
commit
f563f25efd
@ -3013,6 +3013,160 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
|
||||
return chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
|
||||
} // chebyshev_polynomial_w_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE
|
||||
typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
gamma_forward(T x) {
|
||||
static const T P[] = {
|
||||
+1.60119522476751861407e-4,
|
||||
+1.19135147006586384913e-3,
|
||||
+1.04213797561761569935e-2,
|
||||
+4.76367800457137231464e-2,
|
||||
+2.07448227648435975150e-1,
|
||||
+4.94214826801497100753e-1,
|
||||
+9.99999999999999996796e-1,
|
||||
};
|
||||
|
||||
static const T Q[] = {
|
||||
-2.31581873324120129819e-5,
|
||||
+5.39605580493303397842e-4,
|
||||
-4.45641913851797240494e-3,
|
||||
+1.18139785222060435552e-2,
|
||||
+3.58236398605498653373e-2,
|
||||
-2.34591795718243348568e-1,
|
||||
+7.14304917030273074085e-2,
|
||||
+1.00000000000000000320e+0,
|
||||
};
|
||||
|
||||
static const T R[] = {
|
||||
+7.87311395793093628397e-4,
|
||||
-2.29549961613378126380e-4,
|
||||
-2.68132617805781232825e-3,
|
||||
+3.47222221605458667310e-3,
|
||||
+8.33333333333482257126e-2,
|
||||
};
|
||||
|
||||
int sign_gamma = 1;
|
||||
|
||||
if (!std::isfinite(x)) {
|
||||
return x;
|
||||
}
|
||||
|
||||
if (std::abs(x) > T(33.0)) {
|
||||
if (x < T(0.0)) {
|
||||
T p = std::floor(std::abs(x));
|
||||
|
||||
if (p == std::abs(x)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
int previous_p = p;
|
||||
|
||||
if ((previous_p & 1) == 0) {
|
||||
sign_gamma = -1;
|
||||
}
|
||||
|
||||
T z = std::abs(x) - p;
|
||||
|
||||
if (z > T(0.5)) {
|
||||
z = std::abs(x) - (p + T(1.0));
|
||||
}
|
||||
|
||||
z = std::abs(x) * std::sin(c10::pi<T> * z);
|
||||
|
||||
if (z == T(0.0)) {
|
||||
return sign_gamma * std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
if (std::abs(x) >= T(171.624376956302725)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
T r = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 4; index++) {
|
||||
r = r * (T(1.0) / std::abs(x)) + R[index];
|
||||
}
|
||||
|
||||
if (std::abs(x) > T(143.01608)) {
|
||||
return sign_gamma * c10::pi<T> / (std::abs(z) * (T(2.50662827463100050242) * (std::pow(std::abs(x), T(0.5) * std::abs(x) - T(0.25)) * (std::pow(std::abs(x), T(0.5) * std::abs(x) - T(0.25)) / std::exp(std::abs(x)))) * (T(1.0) + T(1.0) / std::abs(x) * r)));
|
||||
}
|
||||
|
||||
return sign_gamma * c10::pi<T> / (std::abs(z) * (T(2.50662827463100050242) * (std::pow(std::abs(x), std::abs(x) - T(0.5)) / std::exp(std::abs(x))) * (T(1.0) + T(1.0) / std::abs(x) * r)));
|
||||
}
|
||||
|
||||
if (x >= T(171.624376956302725)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
T r = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 4; index++) {
|
||||
r = r * (T(1.0) / x) + R[index];
|
||||
}
|
||||
|
||||
if (x > T(143.01608)) {
|
||||
return sign_gamma * (T(2.50662827463100050242) * (std::pow(x, T(0.5) * x - T(0.25)) * (std::pow(x, T(0.5) * x - T(0.25)) / std::exp(x))) * (T(1.0) + T(1.0) / x * r));
|
||||
}
|
||||
|
||||
return sign_gamma * (T(2.50662827463100050242) * (std::pow(x, x - T(0.5)) / std::exp(x)) * (T(1.0) + T(1.0) / x * r));
|
||||
}
|
||||
|
||||
T z = 1.0;
|
||||
|
||||
while (x >= T(3.0)) {
|
||||
x = x - T(1.0);
|
||||
|
||||
z = z * x;
|
||||
}
|
||||
|
||||
while (x < T(0.0)) {
|
||||
if (x > -0.000000001) {
|
||||
if (x == T(0.0)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
return z / ((T(1.0) + c10::euler<T> * x) * x);
|
||||
}
|
||||
|
||||
z = z / x;
|
||||
|
||||
x = x + T(1.0);
|
||||
}
|
||||
|
||||
while (x < T(2.0)) {
|
||||
if (x < 0.000000001) {
|
||||
if (x == T(0.0)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
return z / ((T(1.0) + c10::euler<T> * x) * x);
|
||||
}
|
||||
|
||||
z = z / x;
|
||||
|
||||
x = x + T(1.0);
|
||||
}
|
||||
|
||||
if (x == T(2.0)) {
|
||||
return z;
|
||||
}
|
||||
|
||||
T p = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
p = p * (x - T(2.0)) + P[index];
|
||||
}
|
||||
|
||||
T q = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
q = q * (x - T(2.0)) + Q[index];
|
||||
}
|
||||
|
||||
return z * p / q;
|
||||
} // T gamma_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
|
||||
@ -76,6 +76,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_gamma)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
|
||||
@ -205,6 +206,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_gamma_out, special_gamma_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
|
||||
@ -888,6 +890,7 @@ DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-c
|
||||
DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_gamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
@ -75,6 +75,7 @@ DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_gamma_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
|
||||
|
||||
40
aten/src/ATen/native/cpu/gamma.cpp
Normal file
40
aten/src/ATen/native/cpu/gamma.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vml.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/cpu/zmath.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
#include <c10/util/math_compat.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
static void gamma_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "gamma_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return gamma_forward(x);
|
||||
});
|
||||
});
|
||||
} // gamma_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
REGISTER_DISPATCH(special_gamma_stub, &gamma_kernel);
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
@ -2119,6 +2119,162 @@ const auto chebyshev_polynomial_w_string = jiterator_stringify(
|
||||
} // chebyshev_polynomial_w_forward(T x, T n)
|
||||
); // chebyshev_polynomial_w_string
|
||||
|
||||
const auto gamma_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T gamma_forward(T x) {
|
||||
static const T P[] = {
|
||||
+1.60119522476751861407e-4,
|
||||
+1.19135147006586384913e-3,
|
||||
+1.04213797561761569935e-2,
|
||||
+4.76367800457137231464e-2,
|
||||
+2.07448227648435975150e-1,
|
||||
+4.94214826801497100753e-1,
|
||||
+9.99999999999999996796e-1,
|
||||
};
|
||||
|
||||
static const T Q[] = {
|
||||
-2.31581873324120129819e-5,
|
||||
+5.39605580493303397842e-4,
|
||||
-4.45641913851797240494e-3,
|
||||
+1.18139785222060435552e-2,
|
||||
+3.58236398605498653373e-2,
|
||||
-2.34591795718243348568e-1,
|
||||
+7.14304917030273074085e-2,
|
||||
+1.00000000000000000320e+0,
|
||||
};
|
||||
|
||||
static const T R[] = {
|
||||
+7.87311395793093628397e-4,
|
||||
-2.29549961613378126380e-4,
|
||||
-2.68132617805781232825e-3,
|
||||
+3.47222221605458667310e-3,
|
||||
+8.33333333333482257126e-2,
|
||||
};
|
||||
|
||||
constexpr T PI = 3.14159265358979323846;
|
||||
|
||||
int sign_gamma = 1;
|
||||
|
||||
if (!isfinite(x)) {
|
||||
return x;
|
||||
}
|
||||
|
||||
if (abs(x) > T(33.0)) {
|
||||
if (x < T(0.0)) {
|
||||
T p = floor(abs(x));
|
||||
|
||||
if (p == abs(x)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
int previous_p = p;
|
||||
|
||||
if ((previous_p & 1) == 0) {
|
||||
sign_gamma = -1;
|
||||
}
|
||||
|
||||
T z = abs(x) - p;
|
||||
|
||||
if (z > T(0.5)) {
|
||||
z = abs(x) - (p + T(1.0));
|
||||
}
|
||||
|
||||
z = abs(x) * sin(PI * z);
|
||||
|
||||
if (z == T(0.0)) {
|
||||
return sign_gamma * INFINITY;
|
||||
}
|
||||
|
||||
if (abs(x) >= T(171.624376956302725)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
T r = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 4; index++) {
|
||||
r = r * (T(1.0) / abs(x)) + R[index];
|
||||
}
|
||||
|
||||
if (abs(x) > T(143.01608)) {
|
||||
return sign_gamma * PI / (abs(z) * (T(2.50662827463100050242) * (pow(abs(x), T(0.5) * abs(x) - T(0.25)) * (pow(abs(x), T(0.5) * abs(x) - T(0.25)) / exp(abs(x)))) * (T(1.0) + T(1.0) / abs(x) * r)));
|
||||
}
|
||||
|
||||
return sign_gamma * PI / (abs(z) * (T(2.50662827463100050242) * (pow(abs(x), abs(x) - T(0.5)) / exp(abs(x))) * (T(1.0) + T(1.0) / abs(x) * r)));
|
||||
}
|
||||
|
||||
if (x >= T(171.624376956302725)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
T r = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 4; index++) {
|
||||
r = r * (T(1.0) / x) + R[index];
|
||||
}
|
||||
|
||||
if (x > T(143.01608)) {
|
||||
return sign_gamma * (T(2.50662827463100050242) * (pow(x, T(0.5) * x - T(0.25)) * (pow(x, T(0.5) * x - T(0.25)) / exp(x))) * (T(1.0) + T(1.0) / x * r));
|
||||
}
|
||||
|
||||
return sign_gamma * (T(2.50662827463100050242) * (pow(x, x - T(0.5)) / exp(x)) * (T(1.0) + T(1.0) / x * r));
|
||||
}
|
||||
|
||||
T z = 1.0;
|
||||
|
||||
while (x >= T(3.0)) {
|
||||
x = x - T(1.0);
|
||||
|
||||
z = z * x;
|
||||
}
|
||||
|
||||
while (x < T(0.0)) {
|
||||
if (x > -0.000000001) {
|
||||
if (x == T(0.0)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
return z / ((T(1.0) + T(0.5772156649015329) * x) * x);
|
||||
}
|
||||
|
||||
z = z / x;
|
||||
|
||||
x = x + T(1.0);
|
||||
}
|
||||
|
||||
while (x < T(2.0)) {
|
||||
if (x < 0.000000001) {
|
||||
if (x == T(0.0)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
return z / ((T(1.0) + T(0.5772156649015329) * x) * x);
|
||||
}
|
||||
|
||||
z = z / x;
|
||||
|
||||
x = x + T(1.0);
|
||||
}
|
||||
|
||||
if (x == T(2.0)) {
|
||||
return z;
|
||||
}
|
||||
|
||||
T p = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
p = p * (x - T(2.0)) + P[index];
|
||||
}
|
||||
|
||||
T q = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
q = q * (x - T(2.0)) + Q[index];
|
||||
}
|
||||
|
||||
return z * p / q;
|
||||
} // T gamma_forward(T x)
|
||||
); // gamma_string
|
||||
|
||||
const auto hermite_polynomial_h_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T hermite_polynomial_h_forward(T x, int64_t n) {
|
||||
|
||||
43
aten/src/ATen/native/cuda/gamma.cu
Normal file
43
aten/src/ATen/native/cuda/gamma.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char gamma_name[] = "gamma_forward";
|
||||
|
||||
void gamma_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "gamma_cuda", [&]() {
|
||||
jitted_gpu_kernel<gamma_name, scalar_t, scalar_t, 1>(iterator, gamma_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "gamma_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return gamma_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_gamma_stub, &gamma_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
@ -12625,6 +12625,19 @@
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_gamma(Tensor x) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_gamma.out
|
||||
variants: function
|
||||
|
||||
- func: special_gamma.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_gamma_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
|
||||
device_check: NoCheck
|
||||
python_module: special
|
||||
|
||||
@ -1150,6 +1150,7 @@ aten_native_source_codegen_list = [
|
||||
"aten/src/ATen/native/cpu/WeightNormKernel.cpp",
|
||||
"aten/src/ATen/native/cpu/airy_ai.cpp",
|
||||
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/gamma.cpp",
|
||||
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
|
||||
|
||||
@ -22,6 +22,7 @@ Functions
|
||||
.. autofunction:: exp2
|
||||
.. autofunction:: expit
|
||||
.. autofunction:: expm1
|
||||
.. autofunction:: gamma
|
||||
.. autofunction:: gammainc
|
||||
.. autofunction:: gammaincc
|
||||
.. autofunction:: gammaln
|
||||
|
||||
@ -2451,6 +2451,7 @@
|
||||
"exp2",
|
||||
"expit",
|
||||
"expm1",
|
||||
"gamma",
|
||||
"gammainc",
|
||||
"gammaincc",
|
||||
"gammaln",
|
||||
|
||||
@ -2853,6 +2853,9 @@
|
||||
- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
|
||||
x: non_differentiable
|
||||
|
||||
- name: special_gamma(Tensor x) -> Tensor
|
||||
x: non_differentiable
|
||||
|
||||
- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
|
||||
x: non_differentiable
|
||||
n: non_differentiable
|
||||
|
||||
@ -894,6 +894,25 @@ inline Tensor& chebyshev_polynomial_w_out(
|
||||
return torch::special_chebyshev_polynomial_w_out(output, x, n);
|
||||
}
|
||||
|
||||
/// Gamma function.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.gamma.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::gamma(x);
|
||||
/// ```
|
||||
inline Tensor gamma(const Tensor& x) {
|
||||
return torch::special_gamma(x);
|
||||
}
|
||||
|
||||
inline Tensor& gamma_out(Tensor& y, const Tensor& x) {
|
||||
return torch::special_gamma_out(y, x);
|
||||
}
|
||||
|
||||
/// Physicist’s Hermite polynomial.
|
||||
///
|
||||
/// See
|
||||
|
||||
@ -1008,6 +1008,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.exp2: lambda input: -1,
|
||||
torch.special.expit: lambda input: -1,
|
||||
torch.special.expm1: lambda input: -1,
|
||||
torch.special.gamma: lambda input: -1,
|
||||
torch.special.gammainc: lambda input, other, out=None: -1,
|
||||
torch.special.gammaincc: lambda input, other, out=None: -1,
|
||||
torch.special.gammaln: lambda input: -1,
|
||||
|
||||
@ -21,6 +21,7 @@ __all__ = [
|
||||
'exp2',
|
||||
'expit',
|
||||
'expm1',
|
||||
'gamma',
|
||||
'gammainc',
|
||||
'gammaincc',
|
||||
'gammaln',
|
||||
@ -1030,6 +1031,31 @@ Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
gamma = _add_docstr(_special.special_gamma,
|
||||
r"""
|
||||
gamma(input, *, out=None) -> Tensor
|
||||
|
||||
Gamma function :math:`\Gamma\left(\text{input}\right)` defined as the
|
||||
convergent improper integral:
|
||||
|
||||
.. math::
|
||||
\int_{0}^{\infty}x^{z - 1}e^{-x}dx
|
||||
|
||||
The gamma function is often referred to as the generalized factorial function
|
||||
since :math:`\Gamma\left(n + 1\right) = n!` for natural numbers
|
||||
:math:`n \in \mathbb{N}`. It satisfies the recurrence relation
|
||||
:math:`\Gamma\left(z + 1\right) = z \Gamma\left(z\right)` for complex
|
||||
:math:`z \in \mathbb{C}`, which, combined with the fact that
|
||||
:math:`\Gamma\left(1\right) = 1`, implies the above identity for :math:`z = n`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
hermite_polynomial_h = _add_docstr(_special.special_hermite_polynomial_h,
|
||||
r"""
|
||||
hermite_polynomial_h(input, n, *, out=None) -> Tensor
|
||||
|
||||
@ -19678,6 +19678,34 @@ op_db: List[OpInfo] = [
|
||||
supports_one_python_scalar=True,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.gamma',
|
||||
decorators=(
|
||||
toleranceOverride(
|
||||
{
|
||||
torch.float32: tol(atol=1e-03, rtol=1.3e-05),
|
||||
torch.float64: tol(atol=1e-05, rtol=1.3e-05)
|
||||
}
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.gamma if TEST_SCIPY else _NOTHING,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestUnaryUfuncs',
|
||||
'test_reference_numerics_large',
|
||||
dtypes=[torch.float32, torch.float64],
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestUnaryUfuncs',
|
||||
'test_reference_numerics_small',
|
||||
dtypes=[torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64],
|
||||
),
|
||||
),
|
||||
supports_autograd=False,
|
||||
),
|
||||
BinaryUfuncInfo(
|
||||
'special.hermite_polynomial_h',
|
||||
dtypes=all_types_and(torch.bool),
|
||||
|
||||
Reference in New Issue
Block a user