[special] Add special.ndtri (#58650)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345

TODO
* [x] Add docs https://13865352-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.ndtri
* [x] Add comments on implementation
* [x] Clean-up

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58650

Reviewed By: H-Huang

Differential Revision: D29160170

Pulled By: mruberry

fbshipit-source-id: 50e4ea663920e97b8437d03d5b52bcd9dedc1a8d
This commit is contained in:
kshitij12345
2021-06-19 18:35:11 -07:00
committed by Facebook GitHub Bot
parent 5824a866b7
commit 5ec4ad7f54
14 changed files with 246 additions and 30 deletions

View File

@ -533,6 +533,7 @@ _(aten, native_tensor) \
_(aten, native_zero) \
_(aten, special_ndtr) \
_(aten, nextafter) \
_(aten, special_ndtri) \
_(aten, bitwise_and) \
_(aten, bitwise_not) \
_(aten, bitwise_or) \

View File

@ -2,6 +2,7 @@
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/Math.h>
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
@ -118,15 +119,6 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
}
}
template <typename scalar_t>
C10_DEVICE static inline scalar_t polevl(const scalar_t x, const scalar_t A[], size_t len) {
scalar_t result = 0;
for (size_t i = 0; i <= len; i++) {
result = result * x + A[i];
}
return result;
}
/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
* from TensorFlow's random_binomial_op.cc implementation. That code is under
* copyright: 2019 The TensorFlow Authors.

View File

@ -220,16 +220,24 @@ static inline double zeta(double x, double q) {
return s;
}
static inline double polevl(double x, double *A, size_t len) {
double result = 0;
for (size_t i = 0; i <= len; i++) {
result = result * x + A[i];
}
return result;
}
static inline float polevlf(float x, float *A, size_t len) {
float result = 0;
/*
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*
* Evaluates polynomial of degree N:
*
* 2 N
* y = C + C x + C x +...+ C x
* 0 1 2 N
*
* Coefficients are stored in reverse order:
*
* coef[0] = C , ..., coef[N] = C .
* N 0
*/
template <typename T>
C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) {
T result = 0;
for (size_t i = 0; i <= len; i++) {
result = result * x + A[i];
}
@ -312,7 +320,7 @@ static inline double calc_digamma(double x) {
}
// Compute asymptotic digamma
static double A[] = {
static const double A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,
@ -371,7 +379,7 @@ static inline float calc_digamma(float x) {
}
// Compute asymptotic digamma
static float A[] = {
static const float A[] = {
8.33333333333333333333E-2f,
-2.10927960927960927961E-2f,
7.57575757575757575758E-3f,
@ -384,7 +392,7 @@ static inline float calc_digamma(float x) {
float y = 0;
if (x < 1.0e17f) {
float z = 1 / (x * x);
y = z * polevlf(z, A, 6);
y = z * polevl(z, A, 6);
}
return result + logf(x) - (0.5f / x) - y;
}
@ -1196,7 +1204,7 @@ chbevl(const T x, const T array[], size_t len) {
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
@ -1222,7 +1230,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
};
template <typename T>
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
@ -1247,7 +1255,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
};
template <typename T>
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
@ -1274,7 +1282,7 @@ chebyshev_coefficients_i1e_A() {
};
template <typename T>
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
@ -1303,7 +1311,7 @@ chebyshev_coefficients_i1e_A() {
};
template <typename T>
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
@ -1329,7 +1337,7 @@ chebyshev_coefficients_i1e_B() {
};
template <typename T>
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
@ -1368,7 +1376,7 @@ calc_i0(T _x) {
}
// Upcast bfloat16 input to float for numerical accuracy purposes
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
/*
* This function is derived from the implementation of the i0e function in the Cephes Math Library.
@ -1400,7 +1408,7 @@ calc_i0e(T _x) {
}
// Upcast bfloat16 input to float for numerical accuracy purposes
inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); }
static inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); }
/*
* This function is derived from the implementation of the i1 function in the Cephes Math Library.
@ -1461,3 +1469,135 @@ calc_i1e(T _x) {
static_cast<T>(chbevl(static_cast<T>(32.0 / x - 2.0), B, len) / std::sqrt(x));
return (_x < 0.0) ? -out : out;
}
/*
* This function is derived from the implementation of the i1e function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*
* Computes the argument, x, for which the area under the Gaussian probability density function
* (integrated from minus infinity to x) is equal to y.
*/
template <typename T>
static inline C10_HOST_DEVICE T calc_ndtri(T y0) {
/* sqrt(2pi) */
constexpr T s2pi = 2.50662827463100050242E0;
constexpr T one = 1;
constexpr T zero = 0;
/* approximation for 0 <= |y - 0.5| <= 3/8 */
static const T P0[5] = {
-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0,
};
static const T Q0[9] = {
1.00000000000000000000E0,
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0,
};
/* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
* i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
*/
static const T P1[9] = {
4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4,
};
static const T Q1[9] = {
1.00000000000000000000E0,
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4,
};
/* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
* i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
*/
static const T P2[9] = {
3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9,
};
static const T Q2[9] = {
1.00000000000000000000E0,
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9,
};
if (y0 == zero) {
return -std::numeric_limits<T>::infinity();
}
if (y0 == one) {
return std::numeric_limits<T>::infinity();
}
if (y0 < zero || y0 > one) {
return std::numeric_limits<T>::quiet_NaN();
}
bool code = true;
T y = y0;
if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */
y = one - y;
code = false;
}
if (y > T{0.13533528323661269189}) {
y = y - T{0.5};
const T y2 = y * y;
T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8));
return (x * s2pi);
}
T x = ::sqrt(T{-2.0} * ::log(y));
const T x0 = x - ::log(x) / x;
const T z = one / x;
T x1;
if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */
{
x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8);
} else {
x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8);
}
x = x0 - x1;
if (code) {
x = -x;
}
return x;
}

View File

@ -66,6 +66,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_entr)
CREATE_UNARY_FLOAT_META_FUNC(special_i0e)
CREATE_UNARY_FLOAT_META_FUNC(special_i1)
CREATE_UNARY_FLOAT_META_FUNC(special_i1e)
CREATE_UNARY_FLOAT_META_FUNC(special_ndtri)
CREATE_UNARY_FLOAT_META_FUNC(sqrt)
CREATE_UNARY_FLOAT_META_FUNC(tan)
CREATE_UNARY_FLOAT_META_FUNC(tanh)
@ -170,6 +171,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_entr_out, special_entr_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_i0e_out, special_i0e_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1e_out, special_i1e_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1_out, special_i1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_ndtri_out, special_ndtri_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
@ -759,6 +761,7 @@ DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-
DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_ndtri_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -45,6 +45,7 @@ DECLARE_DISPATCH(unary_fn, log_stub);
DECLARE_DISPATCH(unary_fn, log10_stub);
DECLARE_DISPATCH(unary_fn, log1p_stub);
DECLARE_DISPATCH(unary_fn, log2_stub);
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
DECLARE_DISPATCH(unary_fn, neg_stub);
DECLARE_DISPATCH(unary_fn, reciprocal_stub);

View File

@ -600,6 +600,13 @@ static void frexp_kernel(TensorIteratorBase& iter) {
});
}
static void ndtri_kernel(TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cpu", [&]() {
cpu_kernel(iter, [](scalar_t x) { return calc_ndtri(x); });
});
}
static void i0e_kernel(TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES_AND(
@ -765,6 +772,8 @@ REGISTER_DISPATCH(frexp_stub, &CPU_CAPABILITY::frexp_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);

View File

@ -114,6 +114,13 @@ void logit_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
});
}
void ndtri_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_ndtri(a); });
});
}
void erf_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
@ -188,6 +195,7 @@ REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -9445,6 +9445,19 @@
dispatch:
CPU, CUDA: special_entr_out
- func: special_ndtri(Tensor self) -> Tensor
structured_delegate: special_ndtri.out
python_module: special
variants: function
- func: special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
python_module: special
variants: function
dispatch:
CPU, CUDA: special_ndtri_out
- func: special_expm1(Tensor self) -> Tensor
python_module: special
variants: function

View File

@ -34,4 +34,5 @@ Functions
.. autofunction:: i1e
.. autofunction:: logit
.. autofunction:: ndtr
.. autofunction:: ndtri
.. autofunction:: xlog1py

View File

@ -1108,6 +1108,9 @@
- name: special_entr(Tensor self) -> Tensor
self: grad * (-(1 + self.log()))
- name: special_ndtri(Tensor self) -> Tensor
self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
# DO NOT define a backward for reshape!
# reshape is special in that it sometimes returns a view, and sometimes not.
# Defining a backward will make codegen spit out the forward call as

View File

@ -117,6 +117,14 @@ inline Tensor& erfinv_out(Tensor& result, const Tensor& self) {
return torch::special_erfinv_out(result, self);
}
inline Tensor ndtri(const Tensor& self) {
return torch::special_ndtri(self);
}
inline Tensor& ndtri_out(Tensor& result, const Tensor& self) {
return torch::special_ndtri_out(result, self);
}
/// Computes the logit of input, elementwise.
/// See https://pytorch.org/docs/master/special.html#torch.special.logit.
///

View File

@ -882,6 +882,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.i1: lambda input: -1,
torch.special.i1e: lambda input: -1,
torch.special.logit: lambda input: -1,
torch.special.ndtri: lambda input: -1,
torch.special.ndtr: lambda input: -1,
torch.special.xlog1py: lambda input, other, out=None: -1,
torch.t: lambda input: -1,

View File

@ -342,8 +342,10 @@ for each element of :attr:`input`.
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
@ -361,8 +363,10 @@ for each element of :attr:`input`.
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
@ -381,8 +385,10 @@ for each element of :attr:`input`.
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
@ -408,3 +414,27 @@ Example::
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
""".format(**common_args))
ndtri = _add_docstr(_special.special_ndtri,
r"""
ndtri(input, *, out=None) -> Tensor
Computes the argument, x, for which the area under the Gaussian probability density function
(integrated from minus infinity to x) is equal to :attr:`input`, elementwise.
.. math::
\text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1)
.. note::
Also known as quantile function for Normal Distribution.
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
""".format(**common_args))

View File

@ -7341,6 +7341,12 @@ op_db: List[OpInfo] = [
supports_inplace_autograd=False,
safe_casts_outputs=True,
sample_inputs_func=sample_inputs_entr),
UnaryUfuncInfo('special.ndtri',
ref=scipy.special.ndtri if TEST_SCIPY else _NOTHING,
domain=(0, 1),
aten_name='special_ndtri',
dtypes=all_types_and(torch.bool),
safe_casts_outputs=True),
UnaryUfuncInfo('erf',
ref=scipy.special.erf if TEST_SCIPY else _NOTHING,
aliases=('special.erf', ),