Implement torch.igamma (#46183)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/41637
This is regularized lower incomplete gamma function, equivalent to scipy's `gammainc` and tensorflow `igamma`.

cc fritzo mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46183

Reviewed By: gchanan

Differential Revision: D24479126

Pulled By: mruberry

fbshipit-source-id: fdf8ea289fe4ca1b408810732192411e948fcdfe
This commit is contained in:
mfkasim91
2020-10-29 11:38:18 -07:00
committed by Facebook GitHub Bot
parent dd95bf65b6
commit 6eaa324c9f
26 changed files with 1590 additions and 8 deletions

106
NOTICE
View File

@ -284,6 +284,112 @@ Apache License Version 2.0:
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
=======================================================================
Cephes's 3-Clause BSD License
=======================================================================
Code derived from implementations in the Cephes Math Library should mention
its derivation and reference the following license:
3-Clause BSD License for the Cephes Math Library
Copyright (c) 2018, Steven Moshier
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
=======================================================================
SciPy's 3-Clause BSD License
=======================================================================
Code derived from implementations in SciPy should mention its derivation
and reference the following license:
Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
=======================================================================
Boost's 1.0 Software License
=======================================================================
Code derived from implementations in Boost 1.0 should mention its
derivation and reference the following license:
Boost Software License - Version 1.0 - August 17th, 2003
Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:
The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.

View File

@ -210,6 +210,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("i0", CppFunction::makeFallthrough());
m.impl("i0.out", CppFunction::makeFallthrough());
m.impl("i0_", CppFunction::makeFallthrough());
m.impl("igamma", CppFunction::makeFallthrough());
m.impl("igamma.out", CppFunction::makeFallthrough());
m.impl("igamma_", CppFunction::makeFallthrough());
m.impl("imag", CppFunction::makeFallthrough());
m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough());
m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough());

View File

@ -371,6 +371,8 @@ _(aten, hstack) \
_(aten, hypot) \
_(aten, i0) \
_(aten, i0_) \
_(aten, igamma) \
_(aten, igamma_) \
_(aten, ifft) \
_(aten, index) \
_(aten, index_add) \

View File

@ -394,6 +394,13 @@ public:
Vec256<T> i0() const {
return map(calc_i0);
}
Vec256<T> igamma(const Vec256<T> &x) const {
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {
ret[i] = calc_igamma(values[i], x[i]);
}
return ret;
}
Vec256<T> neg() const {
// NB: the trailing return type is needed because we need to coerce the
// return value back to T in the case of unary operator- incuring a

View File

@ -290,6 +290,25 @@ public:
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> igamma(const Vec256<BFloat16> &x) const {
__m256 lo, hi;
__m256 xlo, xhi;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(x.values, xlo, xhi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
}
auto o1 = _mm256_loadu_ps(tmp1);
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> log() const {
return map(Sleef_logf8_u10);
}

View File

@ -252,6 +252,9 @@ public:
Vec256<c10::complex<double>> hypot(const Vec256<c10::complex<double>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<double>> igamma(const Vec256<c10::complex<double>> &x) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<double>> neg() const {
auto zero = _mm256_setzero_pd();
return _mm256_sub_pd(zero, values);

View File

@ -290,6 +290,9 @@ public:
Vec256<c10::complex<float>> hypot(const Vec256<c10::complex<float>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<float>> igamma(const Vec256<c10::complex<float>> &x) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<float>> neg() const {
auto zero = _mm256_setzero_ps();
return _mm256_sub_ps(zero, values);

View File

@ -155,6 +155,16 @@ public:
Vec256<double> i0() const {
return map(calc_i0);
}
Vec256<double> igamma(const Vec256<double> &x) const {
__at_align32__ double tmp[size()];
__at_align32__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vec256<double> log() const {
return Vec256<double>(Sleef_logd4_u10(values));
}

View File

@ -193,6 +193,16 @@ public:
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> igamma(const Vec256<float> &x) const {
__at_align32__ float tmp[size()];
__at_align32__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vec256<float> neg() const {
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
}

View File

@ -362,6 +362,16 @@ public:
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> igamma(const Vec256<float> &x) const {
__at_align32__ float tmp[size()];
__at_align32__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vec256<float> log() const {
return map(std::log);
}

View File

@ -46,6 +46,7 @@ DEFINE_DISPATCH(logaddexp2_stub);
DEFINE_DISPATCH(gcd_stub);
DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(igamma_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
@ -968,6 +969,23 @@ Tensor& hypot_(Tensor& self, const Tensor& other) {
return at::hypot_out(self, self, other);
}
Tensor& igamma_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
igamma_stub(iter.device_type(), iter);
return result;
}
Tensor igamma(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
igamma_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& igamma_(Tensor& self, const Tensor& other) {
return at::igamma_out(self, self, other);
}
Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
nextafter_stub(iter.device_type(), iter);

View File

@ -10,7 +10,7 @@ namespace at { namespace native {
inline void alpha_check(const ScalarType dtype, Scalar alpha) {
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
"Boolean alpha only supported for Boolean results.");
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
|| alpha.isIntegral(true),
"For integral input tensors, argument alpha must not be a floating point number.");
}
@ -68,6 +68,7 @@ DECLARE_DISPATCH(binary_fn, logaddexp2_stub);
DECLARE_DISPATCH(binary_fn, gcd_stub);
DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, igamma_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);

View File

@ -381,6 +381,716 @@ static inline float calc_polygamma(int64_t n, float x) {
zeta(double(n + 1), x);
}
// regularized lower incomplete gamma
// the regularized lower, upper incomplete gamma, as well as their
// helper functions follow SciPy's implementation
/* References
* [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov
* [igam2] Maddock et. al., "Incomplete Gamma Functions",
* https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html
*/
/*
* This implementation of the regularized incomplete gamma functions and
* their helper functions are derived from the implementation of SciPy's
* gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations.
* See NOTICE for the licenses.
*/
template <typename scalar_t>
static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
const scalar_t denom[], int64_t N) {
// evaluating rational function, i.e., the ratio of two polynomials
// the coefficients for numerator are given by `num` while coeffs for
// denumerator are given by `denom`
int64_t i, dir;
scalar_t y, num_ans, denom_ans;
scalar_t absx = std::fabs(x);
const scalar_t *p;
if (absx > 1) {
/* Evaluate as a polynomial in 1/x. */
dir = -1;
p = num + M;
y = 1 / x;
}
else {
dir = 1;
p = num;
y = x;
}
/* Evaluate the numerator */
num_ans = *p;
p += dir;
for (i = 1; i <= M; i++) {
num_ans = num_ans * y + *p;
p += dir;
}
/* Evaluate the denominator */
if (absx > 1) {
p = denom + N;
}
else {
p = denom;
}
denom_ans = *p;
p += dir;
for (i = 1; i <= N; i++) {
denom_ans = denom_ans * y + *p;
p += dir;
}
if (absx > 1) {
i = N - M;
return std::pow(x, i) * num_ans / denom_ans;
}
else {
return num_ans / denom_ans;
}
}
// SciPy's lanczos implementation is taken from Boost
/* (C) Copyright John Maddock 2006.
* Use, modification and distribution are subject to the
* Boost Software License, Version 1.0. See
* https://www.boost.org/LICENSE_1_0.txt or see NOTICE.
*/
template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
449.9445569063168119446858607650988409623,
6955.999602515376140356310115515198987526,
75999.29304014542649875303443598909137092,
601859.6171681098786670226533699352302507,
3481712.15498064590882071018964774556468,
14605578.08768506808414169982791359218571,
43338889.32467613834773723740590533316085,
86363131.28813859145546927288977868422342,
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
32670.,
357423.,
2637558.,
13339535.,
45995730.,
105258076.,
150917976.,
120543840.,
39916800.,
0.
};
return ratevl(x, lanczos_sum_expg_scaled_num,
sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1,
lanczos_sum_expg_scaled_denom,
sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1);
}
template <typename scalar_t>
static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
// compute x^a * exp(-a) / gamma(a)
// corrected from (15) and (16) in [igam2] by replacing exp(x - a) with
// exp(a - x).
scalar_t ax, fac, res, num, numfac;
static scalar_t MAXLOG = std::is_same<scalar_t,double>::value ?
7.09782712893383996843E2 : 88.72283905206835;
static scalar_t EXP1 = 2.718281828459045;
static scalar_t lanczos_g = 6.024680040776729583740234375;
if (std::fabs(a - x) > 0.4 * std::fabs(a)) {
ax = a * std::log(x) - x - std::lgamma(a);
if (ax < -MAXLOG) {
return 0.0;
}
return std::exp(ax);
}
fac = a + lanczos_g - 0.5;
res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a);
if ((a < 200) && (x < 200)) {
res *= std::exp(a - x) * std::pow(x / fac, a);
}
else {
num = x - a - lanczos_g + 0.5;
numfac = num / fac;
res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac);
}
return res;
}
template <typename scalar_t>
static scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
static scalar_t MACHEP = std::is_same<scalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static int MAXITER = 2000;
int i;
scalar_t ans, ax, c, r;
ax = _igam_helper_fac(a, x);
if (ax == 0.0) {
return 0.0;
}
/* power series */
r = a;
c = 1.0;
ans = 1.0;
for (i = 0; i < MAXITER; i++) {
r += 1.0;
c *= x / r;
ans += c;
if (c <= MACHEP * ans) {
break;
}
}
return (ans * ax / a);
}
template <typename scalar_t>
static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
// Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in
// _igam_helper_series but extra care is taken to avoid cancellation.
int n;
scalar_t fac = 1;
scalar_t sum = 0;
scalar_t term, logx;
static scalar_t MAXITER = 2000;
static scalar_t MACHEP = std::is_same<scalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
fac *= -x / n;
term = fac / (a + n);
sum += term;
if (std::fabs(term) <= MACHEP * std::fabs(sum)) {
break;
}
}
logx = std::log(x);
term = -std::expm1(a * logx - std::lgamma(1+a));
return term - std::exp(a * logx - std::lgamma(a)) * sum;
}
template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static const scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9,
1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10,
-2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11,
-5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13,
-1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16,
-1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3,
-9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7,
-1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6,
4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8,
1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9,
4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14,
7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13,
-2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14,
-4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4,
2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5,
-1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6,
-6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10,
-1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9,
9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11,
1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12,
4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17,
8.8592218725911273e-15},
{6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4,
2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7,
1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6,
-2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8,
-1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9,
-9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14,
-2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12,
6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14,
2.0453671226782849e-14},
{-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4,
-1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5,
1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6,
8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11,
2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9,
-2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10,
-4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12,
-2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18,
-5.0453320690800944e-14},
{-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4,
-1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7,
-1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6,
-3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7,
4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9,
3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15,
9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11,
-3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13,
-1.3249659916340829e-13},
{5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4,
7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5,
-1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6,
-2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13,
-8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8,
8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10,
2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11,
1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18,
3.6902800842763467e-13},
{3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4,
2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7,
2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6,
4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7,
-1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8,
-1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17,
-5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11,
2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12,
1.0865561947091654e-12},
{-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4,
-6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4,
4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5,
6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11,
3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8,
6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9,
-1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10,
-1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18,
-3.3721464474854592e-12},
{-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4,
-6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7,
-8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5,
-8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6,
8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7,
8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14,
3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10,
-1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11,
-1.1002224534207726e-11},
{1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3,
9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4,
-1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5,
-3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11,
-2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7,
-1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8,
1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9,
9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18,
3.7647749553543836e-11},
{1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3,
2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7,
3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4,
2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5,
-5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6,
-6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14,
-3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9,
-1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10,
1.3481607129399749e-10},
{-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3,
-2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3,
8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4,
1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10,
1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6,
7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7,
-1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8,
-1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20,
-5.0423112718105824e-10},
{-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3,
-9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6,
-2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4,
-6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4,
4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5,
6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13,
3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8,
8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9,
-1.9661464453856102e-9},
{1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2,
7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2,
-5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3,
-1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10,
-1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5,
-3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6,
1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7,
1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17,
7.9795091026746235e-9},
{3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2,
5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6,
1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3,
3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3,
-4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4,
-7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12,
-5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6,
-5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7,
3.3654425209171788e-8},
{-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1,
-3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2,
4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2,
1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9,
1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4,
1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5,
-2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6,
-2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16,
-1.4729737374018841e-7},
{-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1,
-4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5,
-1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2,
-1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2,
5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3,
1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12,
8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5,
3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6,
-6.6812849447625594e-7},
{7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968,
1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1,
-4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1,
-1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8,
-2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3,
-1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3,
3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5,
5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14,
3.1369106244517615e-6},
{1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906,
4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4,
1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1,
1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1,
-8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2,
-1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11,
-1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4,
9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5,
1.5227271505597605e-5},
{-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1,
-1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1,
5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816,
2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7,
3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1,
8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2,
-7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3,
-1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11,
-7.6340103696869031e-5},
{-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1,
-5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3,
-2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1,
-1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195,
1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1,
3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10,
3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3,
-1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3,
-3.9479941246822517e-4},
{7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2,
1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2,
-8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1,
-3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7,
-6.2716159907747034, 5.1168999071852637, -2.0319658112299095,
-4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1,
1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2,
2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6,
2.1250180774699461e-3},
{2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2,
7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2,
3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2,
1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1,
-2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373,
-7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7,
-8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1,
1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2,
1.5109265210467774e-2},
{-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3,
-1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3,
1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2,
7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6,
1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1,
-7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1,
-4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468,
-7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1,
4.8683443692930507e-1}};
int k, n, sgn;
int maxpow = 0;
static scalar_t MACHEP = std::is_same<scalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
scalar_t lambda = x / a;
scalar_t sigma = (x - a) / a;
scalar_t eta, res, ck, ckterm, term, absterm;
scalar_t absoldterm = INFINITY;
scalar_t etapow[25] = {1};
scalar_t sum = 0;
scalar_t afac = 1;
if (igam) {
sgn = -1;
}
else {
sgn = 1;
}
if (lambda > 1) {
eta = std::sqrt(-2 * (std::log1p(sigma) - sigma));
}
else if (lambda < 1) {
eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma));
}
else {
eta = 0;
}
res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2));
for (k = 0; k < 25; k++) {
ck = d[k][0];
for (n = 1; n < 25; n++) {
if (n > maxpow) {
etapow[n] = eta * etapow[n-1];
maxpow += 1;
}
ckterm = d[k][n]*etapow[n];
ck += ckterm;
if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) {
break;
}
}
term = ck * afac;
absterm = std::fabs(term);
if (absterm > absoldterm) {
break;
}
sum += term;
if (absterm < MACHEP * std::fabs(sum)) {
break;
}
absoldterm = absterm;
afac /= a;
}
res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * M_PIf * a);
return res;
}
template <typename scalar_t>
static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
// Compute igamc using DLMF 8.9.2. [igam1]
int i;
scalar_t ans, ax, c, yc, r, t, y, z;
scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
int MAXITER = 2000;
static scalar_t MACHEP = std::is_same<scalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static scalar_t BIG = std::is_same<scalar_t,double>::value ?
4.503599627370496e15 : 16777216.;
static scalar_t BIGINV = std::is_same<scalar_t,double>::value ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
if (ax == 0.0) {
return 0.0;
}
/* continued fraction */
y = 1.0 - a;
z = x + y + 1.0;
c = 0.0;
pkm2 = 1.0;
qkm2 = x;
pkm1 = x + 1.0;
qkm1 = z * x;
ans = pkm1 / qkm1;
for (i = 0; i < MAXITER; i++) {
c += 1.0;
y += 1.0;
z += 2.0;
yc = y * c;
pk = pkm1 * z - pkm2 * yc;
qk = qkm1 * z - qkm2 * yc;
if (qk != 0) {
r = pk / qk;
t = std::fabs((ans - r) / r);
ans = r;
}
else {
t = 1.0;
}
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
if (std::fabs(pk) > BIG) {
pkm2 *= BIGINV;
pkm1 *= BIGINV;
qkm2 *= BIGINV;
qkm1 *= BIGINV;
}
if (t <= MACHEP) {
break;
}
}
return ans * ax;
}
template <typename scalar_t>
static inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
/* the calculation of the regularized upper incomplete gamma function
* is done differently based on the values of a and x:
* - if x and/or a is at the boundary of defined region, then assign the
* result at the boundary
* - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
* Large Parameter (see DLMF 8.12.4 [igam1])
* - if x > 1.1 and x < a, using the substraction from the regularized lower
* incomplete gamma
* - otherwise, calculate the series from [igam2] eq (5)
*/
scalar_t absxma_a;
static scalar_t SMALL = 20.0;
static scalar_t LARGE = 200.0;
static scalar_t SMALLRATIO = 0.3;
static scalar_t LARGERATIO = 4.5;
// note that in SciPy, a and x are non-negative, with exclusive 0s (i.e.,
// at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0.
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
return std::numeric_limits<scalar_t>::quiet_NaN();
}
else if (a == 0) {
if (x > 0) {
return 0.0;
}
else {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
}
else if (x == 0) {
return 1.0;
}
else if (std::isinf(a)) {
if (std::isinf(x)) {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
return 1.0;
}
else if (std::isinf(x)) {
return 0.0;
}
absxma_a = std::fabs(x - a) / a;
if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
return _igam_helper_asymptotic_series(a, x, 0);
}
else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) {
return _igam_helper_asymptotic_series(a, x, 0);
}
if (x > 1.1) {
if (x < a) {
return 1.0 - _igam_helper_series(a, x);
}
else {
return _igamc_helper_continued_fraction(a, x);
}
}
else if (x <= 0.5) {
if (-0.4 / std::log(x) < a) {
return 1.0 - _igam_helper_series(a, x);
}
else {
return _igamc_helper_series(a, x);
}
}
else {
if (x * 1.1 < a) {
return 1.0 - _igam_helper_series(a, x);
}
else {
return _igamc_helper_series(a, x);
}
}
}
template <typename scalar_t>
static inline scalar_t calc_igamma(scalar_t a, scalar_t x) {
/* the calculation of the regularized lower incomplete gamma function
* is done differently based on the values of a and x:
* - if x and/or a is at the boundary of defined region, then assign the
* result at the boundary
* - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
* Large Parameter (see DLMF 8.12.3 [igam1])
* - if x > 1 and x > a, using the substraction from the regularized upper
* incomplete gamma
* - otherwise, calculate the series from [igam2] eq (4)
*/
scalar_t absxma_a;
static scalar_t SMALL = 20.0;
static scalar_t LARGE = 200.0;
static scalar_t SMALLRATIO = 0.3;
static scalar_t LARGERATIO = 4.5;
// boundary values following SciPy
// note that in SciPy, a and x are non-negative, with exclusive 0s (i.e.,
// at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0.
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
return std::numeric_limits<scalar_t>::quiet_NaN();
}
else if (a == 0) {
if (x > 0) {
return 1.0;
}
else {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
}
else if (x == 0) {
return 0.0; // zero integration limit
}
else if (std::isinf(a)) {
if (std::isinf(x)) {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
return 0.0;
}
else if (std::isinf(x)) {
return 1.0;
}
/* Asymptotic regime where a ~ x. See [igam2] */
absxma_a = std::fabs(x - a) / a;
if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
return _igam_helper_asymptotic_series(a, x, 1);
}
else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) {
return _igam_helper_asymptotic_series(a, x, 1);
}
if ((x > 1.0) && (x > a)) {
return 1.0 - calc_igammac(a, x);
}
return _igam_helper_series(a, x);
}
template <>
c10::BFloat16 calc_igamma<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
return calc_igamma<float>(float(a), float(x));
}
template <>
c10::Half calc_igamma<c10::Half>(c10::Half a, c10::Half x) {
return calc_igamma<float>(float(a), float(x));
}
inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); }
template <typename T>

View File

@ -766,6 +766,19 @@ void hypot_kernel(TensorIterator& iter) {
});
}
void igamma_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
return calc_igamma(a, b);
},
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a.igamma(b);
});
});
}
void nextafter_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() {
cpu_kernel_vec(
@ -824,6 +837,7 @@ REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel);
REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(igamma_stub, &igamma_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);

View File

@ -92,6 +92,14 @@ void hypot_kernel_cuda(TensorIterator& iter) {
});
}
void igamma_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return calc_igamma(a, b);
});
});
}
void nextafter_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
@ -116,6 +124,7 @@ REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda);
REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(igamma_stub, &igamma_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);

View File

@ -54,12 +54,12 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
a = q;
i = 0;
b = 0.0;
while((i < 9) || (a <= 9.0)){
while ((i < 9) || (a <= 9.0)) {
i += 1;
a += 1.0;
b = ::pow( a, -x );
s += b;
if((-MACHEP < (b / s)) && ((b / s) < MACHEP)) {
if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) {
return static_cast<scalar_t>(s);
}
};
@ -68,16 +68,16 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
s -= 0.5 * b;
a = 1.0;
k = 0.0;
for(int i=0; i < 12; i++) {
for (int i=0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = t / s;
if(t < 0){
if (t < 0){
t = -t;
}
if((-MACHEP <t) && (t < MACHEP)){
if ((-MACHEP <t) && (t < MACHEP)){
return static_cast<scalar_t>(s);
}
k += 1.0;
@ -174,6 +174,503 @@ static inline __host__ __device__ scalar_t calc_polygamma(int n, scalar_t x) {
return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * zeta(static_cast<scalar_t>(n + 1), x);
}
/*
* This implementation of the regularized incomplete gamma functions and
* their helper functions are derived from the implementation of SciPy's
* gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations.
* See NOTICE for the licenses.
*/
// regularized lower & upper incomplete gamma
template <typename scalar_t>
static __host__ __device__ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
const scalar_t denom[], int64_t N) {
// evaluating rational function, i.e., the ratio of two polynomials
// the coefficients for numerator are given by `num` while coeffs for
// denumerator are given by `denom`
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
int64_t i, dir;
accscalar_t y, num_ans, denom_ans;
accscalar_t absx = ::fabs(x);
const accscalar_t *p;
if (absx > 1) {
/* Evaluate as a polynomial in 1/x. */
dir = -1;
p = num + M;
y = 1 / x;
}
else {
dir = 1;
p = num;
y = x;
}
/* Evaluate the numerator */
num_ans = *p;
p += dir;
for (i = 1; i <= M; i++) {
num_ans = num_ans * y + *p;
p += dir;
}
/* Evaluate the denominator */
if (absx > 1) {
p = denom + N;
}
else {
p = denom;
}
denom_ans = *p;
p += dir;
for (i = 1; i <= N; i++) {
denom_ans = denom_ans * y + *p;
p += dir;
}
if (absx > 1) {
i = N - M;
return ::pow(x, static_cast<accscalar_t>(i)) * num_ans / denom_ans;
}
else {
return num_ans / denom_ans;
}
}
template <typename scalar_t>
static __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
449.9445569063168119446858607650988409623,
6955.999602515376140356310115515198987526,
75999.29304014542649875303443598909137092,
601859.6171681098786670226533699352302507,
3481712.15498064590882071018964774556468,
14605578.08768506808414169982791359218571,
43338889.32467613834773723740590533316085,
86363131.28813859145546927288977868422342,
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
32670.,
357423.,
2637558.,
13339535.,
45995730.,
105258076.,
150917976.,
120543840.,
39916800.,
0
};
return ratevl(static_cast<accscalar_t>(x), lanczos_sum_expg_scaled_num,
sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1,
lanczos_sum_expg_scaled_denom,
sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1);
}
template <typename scalar_t>
static __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
// compute x^a * exp(-a) / gamma(a)
// corrected from (15) and (16) in [igam2] by replacing exp(x - a) with
// exp(a - x).
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac;
static accscalar_t MAXLOG = std::is_same<accscalar_t,double>::value ?
7.09782712893383996843E2 : 88.72283905206835;
static accscalar_t EXP1 = 2.718281828459045;
static accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a);
if (ax < -MAXLOG) {
return 0.0;
}
return ::exp(ax);
}
fac = a + lanczos_g - 0.5;
res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a);
if ((a < 200) && (x < 200)) {
res *= ::exp(a - x) * ::pow(x / fac, a);
}
else {
num = x - a - lanczos_g + 0.5;
numfac = num / fac;
res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac);
}
return res;
}
template <typename scalar_t>
static __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static accscalar_t MACHEP = std::is_same<accscalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static int MAXITER = 2000;
int i;
accscalar_t ans, ax, c, r;
ax = _igam_helper_fac(a, x);
if (ax == 0.0) {
return 0.0;
}
/* power series */
r = a;
c = 1.0;
ans = 1.0;
for (i = 0; i < MAXITER; i++) {
r += 1.0;
c *= x / r;
ans += c;
if (c <= MACHEP * ans) {
break;
}
}
return (ans * ax / a);
}
template <typename scalar_t>
static __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
// Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in
// _igam_helper_series but extra care is taken to avoid cancellation.
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
int n;
accscalar_t fac = 1;
accscalar_t sum = 0;
accscalar_t term, logx;
static accscalar_t MAXITER = 2000;
static accscalar_t MACHEP = std::is_same<accscalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
fac *= -x / n;
term = fac / (a + n);
sum += term;
if (::fabs(term) <= MACHEP * ::fabs(sum)) {
break;
}
}
logx = ::log(x);
term = -::expm1(a * logx - ::lgamma(1+a));
return term - ::exp(a * logx - ::lgamma(a)) * sum;
}
template <typename scalar_t>
static __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
{6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, 2.0453671226782849e-14},
{-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, -5.0453320690800944e-14},
{-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, -1.3249659916340829e-13},
{5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, 3.6902800842763467e-13},
{3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, 1.0865561947091654e-12},
{-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, -3.3721464474854592e-12},
{-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, -1.1002224534207726e-11},
{1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, 3.7647749553543836e-11},
{1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, 1.3481607129399749e-10},
{-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, -5.0423112718105824e-10},
{-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, -1.9661464453856102e-9},
{1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, 7.9795091026746235e-9},
{3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, 3.3654425209171788e-8},
{-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, -1.4729737374018841e-7},
{-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, -6.6812849447625594e-7},
{7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, 3.1369106244517615e-6},
{1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, 1.5227271505597605e-5},
{-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, -7.6340103696869031e-5},
{-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, -3.9479941246822517e-4},
{7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, 2.1250180774699461e-3},
{2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, 1.5109265210467774e-2},
{-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, 4.8683443692930507e-1}};
int k, n, sgn;
int maxpow = 0;
static accscalar_t MACHEP = std::is_same<accscalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a;
accscalar_t eta, res, ck, ckterm, term, absterm;
accscalar_t absoldterm = INFINITY;
accscalar_t etapow[25] = {1};
accscalar_t sum = 0;
accscalar_t afac = 1;
if (igam) {
sgn = -1;
}
else {
sgn = 1;
}
if (lambda > 1) {
eta = ::sqrt(-2 * (::log1p(sigma) - sigma));
}
else if (lambda < 1) {
eta = -::sqrt(-2 * (::log1p(sigma) - sigma));
}
else {
eta = 0;
}
res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2));
for (k = 0; k < 25; k++) {
ck = d[k][0];
for (n = 1; n < 25; n++) {
if (n > maxpow) {
etapow[n] = eta * etapow[n-1];
maxpow += 1;
}
ckterm = d[k][n]*etapow[n];
ck += ckterm;
if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) {
break;
}
}
term = ck * afac;
absterm = std::fabs(term);
if (absterm > absoldterm) {
break;
}
sum += term;
if (absterm < MACHEP * std::fabs(sum)) {
break;
}
absoldterm = absterm;
afac /= a;
}
res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a);
return res;
}
template <typename scalar_t>
static __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
// Compute igamc using DLMF 8.9.2. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
int i;
accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
int MAXITER = 2000;
static accscalar_t MACHEP = std::is_same<accscalar_t, double>::value ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static accscalar_t BIG = std::is_same<accscalar_t,double>::value ?
4.503599627370496e15 : 16777216.;
static accscalar_t BIGINV = std::is_same<accscalar_t,double>::value ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
if (ax == 0.0) {
return 0.0;
}
/* continued fraction */
y = 1.0 - a;
z = x + y + 1.0;
c = 0.0;
pkm2 = 1.0;
qkm2 = x;
pkm1 = x + 1.0;
qkm1 = z * x;
ans = pkm1 / qkm1;
for (i = 0; i < MAXITER; i++) {
c += 1.0;
y += 1.0;
z += 2.0;
yc = y * c;
pk = pkm1 * z - pkm2 * yc;
qk = qkm1 * z - qkm2 * yc;
if (qk != 0) {
r = pk / qk;
t = ::fabs((ans - r) / r);
ans = r;
}
else {
t = 1.0;
}
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
if (::fabs(pk) > BIG) {
pkm2 *= BIGINV;
pkm1 *= BIGINV;
qkm2 *= BIGINV;
qkm1 *= BIGINV;
}
if (t <= MACHEP) {
break;
}
}
return ans * ax;
}
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
/* the calculation of the regularized upper incomplete gamma function
* is done differently based on the values of a and x:
* - if x and/or a is at the boundary of defined region, then assign the
* result at the boundary
* - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
* Large Parameter (see DLMF 8.12.4 [igam1])
* - if x > 1.1 and x < a, using the substraction from the regularized lower
* incomplete gamma
* - otherwise, calculate the series from [igam2] eq (5)
*/
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static accscalar_t SMALL = 20.0;
static accscalar_t LARGE = 200.0;
static accscalar_t SMALLRATIO = 0.3;
static accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
return std::numeric_limits<accscalar_t>::quiet_NaN();
}
else if (a == 0) {
if (x > 0) {
return 0.0;
}
else {
return std::numeric_limits<accscalar_t>::quiet_NaN();
}
}
else if (x == 0) {
return 1.0;
}
else if (::isinf(static_cast<accscalar_t>(a))) {
if (::isinf(static_cast<accscalar_t>(x))) {
return std::numeric_limits<accscalar_t>::quiet_NaN();
}
return 1.0;
}
else if (::isinf(static_cast<accscalar_t>(x))) {
return 0.0;
}
absxma_a = ::fabs(x - a) / a;
if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
return _igam_helper_asymptotic_series(a, x, 0);
}
else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) {
return _igam_helper_asymptotic_series(a, x, 0);
}
if (x > 1.1) {
if (x < a) {
return 1.0 - _igam_helper_series(a, x);
}
else {
return _igamc_helper_continued_fraction(a, x);
}
}
else if (x <= 0.5) {
if (-0.4 / ::log(x) < a) {
return 1.0 - _igam_helper_series(a, x);
}
else {
return _igamc_helper_series(a, x);
}
}
else {
if (x * 1.1 < a) {
return 1.0 - _igam_helper_series(a, x);
}
else {
return _igamc_helper_series(a, x);
}
}
}
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
/* the calculation of the regularized lower incomplete gamma function
* is done differently based on the values of a and x:
* - if x and/or a is at the boundary of defined region, then assign the
* result at the boundary
* - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
* Large Parameter (see DLMF 8.12.3 [igam1])
* - if x > 1 and x > a, using the substraction from the regularized upper
* incomplete gamma
* - otherwise, calculate the series from [igam2] eq (4)
*/
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static accscalar_t SMALL = 20.0;
static accscalar_t LARGE = 200.0;
static accscalar_t SMALLRATIO = 0.3;
static accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
return std::numeric_limits<accscalar_t>::quiet_NaN();
}
else if (a == 0) {
if (x > 0) {
return 1.0;
}
else {
return std::numeric_limits<accscalar_t>::quiet_NaN();
}
}
else if (x == 0) {
return 0.0; // zero integration limit
}
else if (::isinf(static_cast<accscalar_t>(a))) {
if (::isinf(static_cast<accscalar_t>(x))) {
return std::numeric_limits<accscalar_t>::quiet_NaN();
}
return 0.0;
}
else if (::isinf(static_cast<accscalar_t>(x))) {
return 1.0;
}
/* Asymptotic regime where a ~ x. */
absxma_a = ::fabs(x - a) / a;
if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
return _igam_helper_asymptotic_series(a, x, 1);
}
else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) {
return _igam_helper_asymptotic_series(a, x, 1);
}
if ((x > 1.0) && (x > a)) {
return 1.0 - calc_igammac(a, x);
}
return _igam_helper_series(a, x);
}
// end of regularized lower & upper incomplete gamma
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {

View File

@ -6496,6 +6496,21 @@
dispatch:
DefaultBackend: hypot_
- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: igamma_out
- func: igamma(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU, CUDA: igamma
- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
CPU, CUDA: igamma_
- func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: nextafter_out

View File

@ -59,6 +59,14 @@ namespace std {
throw std::runtime_error("std::hypot is not implemented on older Android");
}
// TODO: this function needs to be implemented and tested. Currently just throw an error.
inline float igamma(float x, float y) {
throw std::runtime_error("igamma is not implemented on older Android");
}
inline double igamma(double x, double y) {
throw std::runtime_error("igamma is not implemented on older Android");
}
// TODO: this function needs to be implemented and tested. Currently just throw an error.
inline float nextafter(float x, float y) {
throw std::runtime_error("std::nextafter is not implemented on older Android");
@ -66,7 +74,7 @@ namespace std {
inline double nextafter(double x, double y) {
throw std::runtime_error("std::nextafter is not implemented on older Android");
}
// TODO: this function needs to be implemented and tested. Currently just throw an error.
inline float exp2(float x) {
throw std::runtime_error("std::exp2 is not implemented on older Android");

View File

@ -349,6 +349,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: hypot_
.. automethod:: i0
.. automethod:: i0_
.. automethod:: igamma
.. automethod:: igamma_
.. automethod:: ifft
.. automethod:: index_add_
.. automethod:: index_add

View File

@ -310,6 +310,7 @@ Pointwise Ops
logit
hypot
i0
igamma
mul
multiply
mvlgamma

View File

@ -2881,6 +2881,14 @@ class TestAutograd(TestCase):
a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_()
gradcheck(lambda a: torch.pow(2, a), (a,))
def test_igamma(self):
# 1e-3 offset to avoid zeros
# NOTE: derivative for s is not implemented
s = (torch.rand(100, dtype=torch.double) + 1e-3)
x = (torch.rand(100, dtype=torch.double) + 1e-3).requires_grad_()
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))
@skipIfNoLapack
def test_pinverse(self):
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?

View File

@ -13690,6 +13690,8 @@ class TestTorchDeviceType(TestCase):
("atan2", True, True, 'cuda'),
("hypot", True, True, 'cpu'),
("hypot", True, True, 'cuda'),
("igamma", True, True, 'cpu'),
("igamma", True, True, 'cuda'),
("nextafter", True, True, 'cpu'),
("nextafter", True, True, 'cuda'),
("le", True, True, 'cpu'),
@ -17473,6 +17475,70 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
self.assertEqual(actual, expected)
def _helper_test_igamma(self, loglo, loghi, device, dtype):
exp1 = 2.71828182846
vec1 = torch.logspace(loglo, loghi, steps=500, base=exp1,
dtype=torch.float64, device=device).unsqueeze(-1)
vec1 = vec1.to(dtype)
inputs = [
(vec1, vec1.transpose(0, 1)),
(vec1, vec1), # for large number, it should approach 0.5
(vec1, 0.5 * vec1), # test for considerable ratio
(vec1, 2.0 * vec1),
(vec1[::2, :], vec1[::2, :]), # contiguous/discontiguous tests
(vec1[::2, :], vec1[:vec1.shape[0] // 2, :]),
(vec1[:vec1.shape[0] // 2, :], vec1[::2, :]),
]
half_prec = dtype in [torch.bfloat16, torch.float16]
for input0, input1 in inputs:
actual = torch.igamma(input0, input1)
if half_prec:
input0 = input0.to(torch.float)
input1 = input1.to(torch.float)
expected = scipy.special.gammainc(input0.cpu().numpy(), input1.cpu().numpy())
expected = torch.from_numpy(expected).to(dtype)
self.assertEqual(actual, expected)
@skipCUDAIfRocm # see issue https://github.com/pytorch/pytorch/issues/46531
@dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
@dtypes(torch.float32, torch.float64)
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
@onlyOnCPUAndCUDA
def test_igamma_common(self, device, dtype):
# test igamma for reasonable range of values
loglo = -4 # approx 0.018
loghi = 4 # approx 54.6
self._helper_test_igamma(loglo, loghi, device, dtype)
@dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
@dtypes(torch.float32, torch.float64)
@onlyOnCPUAndCUDA
def test_igamma_edge_cases(self, device, dtype):
tkwargs = {"dtype": dtype, "device": device}
infs = torch.zeros((3,), **tkwargs) + float("inf")
zeros = torch.zeros((3,), **tkwargs)
ones = torch.ones((3,), **tkwargs)
zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs)
small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs)
nans = torch.zeros((3,), **tkwargs) + float("nan")
inpouts = [
# (a , x), out
((zeros, small_to_inf), ones),
((small_to_inf, zeros), zeros),
((infs, zero_to_large), zeros),
((zero_to_large, infs), ones),
((zeros, zeros), nans),
((infs, infs), nans),
((-small_to_inf, small_to_inf), nans),
]
for inputs, output in inpouts:
input0, input1 = inputs
calc = torch.igamma(input0, input1)
if torch.all(torch.isnan(output)):
self.assertTrue(torch.all(torch.isnan(calc)))
else:
self.assertEqual(calc, output)
@dtypes(torch.int64, torch.float64)
def test_remainder_edge_cases(self, device, dtype):
# Test variations of negative values used as input
@ -17496,8 +17562,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
r = a.remainder(b)
r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2] * 10000, dtype=dtype, device=device)
self.assertEqual(r, r_expected)
# Test nan cases
a = torch.tensor([-34, 0, 34] * 20000, dtype=dtype, device=device)
b = torch.zeros(3 * 20000, dtype=dtype, device=device)
self.assertTrue(torch.isnan(a.remainder(b)).all())

View File

@ -540,6 +540,10 @@
- name: i0(Tensor self) -> Tensor
self: not_implemented("i0")
- name: igamma(Tensor self, Tensor other) -> Tensor
self: 'not_implemented("igamma: input")'
other: grad * exp((self - 1) * log(other) - other - lgamma(self))
- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
self: index_backward(zeros_like(self), indices, grad)
indices: TensorList()

View File

@ -1542,6 +1542,20 @@ i0_() -> Tensor
In-place version of :meth:`~Tensor.i0`
""")
add_docstr_all('igamma',
r"""
igamma(other) -> Tensor
See :func:`torch.igamma`
""")
add_docstr_all('igamma_',
r"""
igamma_(other) -> Tensor
In-place version of :meth:`~Tensor.igamma`
""")
add_docstr_all('indices',
r"""
indices() -> Tensor

View File

@ -3316,6 +3316,47 @@ Example::
""".format(**common_args))
add_docstr(torch.igamma,
r"""
igamma(input, other, *, out=None) -> Tensor
Computes the regularized lower incomplete gamma function:
.. math::
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt
where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
and at least one is strictly positive.
If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
:math:`\Gamma(\cdot)` in the equation above is the gamma function,
.. math::
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
See :func:`torch.lgamma` for a related function.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
and float inputs.
.. note::
The backward pass with respect to :attr:`input` is not yet supported.
Please open an issue on PyTorch's Github to request it.
""" + r"""
Args:
input (Tensor): the first non-negative input tensor
other (Tensor): the second non-negative input tensor
Keyword args:
{out}
Example::
>>> a = torch.igamma(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0]))
tensor([0.3528, 0.5665, 0.7350])
""".format(**common_args))
add_docstr(torch.index_select,
r"""
index_select(input, dim, index, *, out=None) -> Tensor

View File

@ -388,6 +388,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.hstack: lambda tensors, out=None: -1,
torch.hypot: lambda input, other, out=None: -1,
torch.ifft: lambda input, signal_ndim, normalized=False: -1,
torch.igamma: lambda input, other, out=None: -1,
torch.imag: lambda input, out=None: -1,
torch.index_add: lambda input, dim, index, source: -1,
torch.index_copy: lambda input, dim, index, source: -1,