mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[special] Add special.ndtri (#58650)
Summary: Reference: https://github.com/pytorch/pytorch/issues/50345 TODO * [x] Add docs https://13865352-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.ndtri * [x] Add comments on implementation * [x] Clean-up Pull Request resolved: https://github.com/pytorch/pytorch/pull/58650 Reviewed By: H-Huang Differential Revision: D29160170 Pulled By: mruberry fbshipit-source-id: 50e4ea663920e97b8437d03d5b52bcd9dedc1a8d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5824a866b7
commit
5ec4ad7f54
@ -533,6 +533,7 @@ _(aten, native_tensor) \
|
||||
_(aten, native_zero) \
|
||||
_(aten, special_ndtr) \
|
||||
_(aten, nextafter) \
|
||||
_(aten, special_ndtri) \
|
||||
_(aten, bitwise_and) \
|
||||
_(aten, bitwise_not) \
|
||||
_(aten, bitwise_or) \
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
|
||||
@ -118,15 +119,6 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
C10_DEVICE static inline scalar_t polevl(const scalar_t x, const scalar_t A[], size_t len) {
|
||||
scalar_t result = 0;
|
||||
for (size_t i = 0; i <= len; i++) {
|
||||
result = result * x + A[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
|
||||
* from TensorFlow's random_binomial_op.cc implementation. That code is under
|
||||
* copyright: 2019 The TensorFlow Authors.
|
||||
|
@ -220,16 +220,24 @@ static inline double zeta(double x, double q) {
|
||||
return s;
|
||||
}
|
||||
|
||||
static inline double polevl(double x, double *A, size_t len) {
|
||||
double result = 0;
|
||||
for (size_t i = 0; i <= len; i++) {
|
||||
result = result * x + A[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline float polevlf(float x, float *A, size_t len) {
|
||||
float result = 0;
|
||||
/*
|
||||
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*
|
||||
* Evaluates polynomial of degree N:
|
||||
*
|
||||
* 2 N
|
||||
* y = C + C x + C x +...+ C x
|
||||
* 0 1 2 N
|
||||
*
|
||||
* Coefficients are stored in reverse order:
|
||||
*
|
||||
* coef[0] = C , ..., coef[N] = C .
|
||||
* N 0
|
||||
*/
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) {
|
||||
T result = 0;
|
||||
for (size_t i = 0; i <= len; i++) {
|
||||
result = result * x + A[i];
|
||||
}
|
||||
@ -312,7 +320,7 @@ static inline double calc_digamma(double x) {
|
||||
}
|
||||
|
||||
// Compute asymptotic digamma
|
||||
static double A[] = {
|
||||
static const double A[] = {
|
||||
8.33333333333333333333E-2,
|
||||
-2.10927960927960927961E-2,
|
||||
7.57575757575757575758E-3,
|
||||
@ -371,7 +379,7 @@ static inline float calc_digamma(float x) {
|
||||
}
|
||||
|
||||
// Compute asymptotic digamma
|
||||
static float A[] = {
|
||||
static const float A[] = {
|
||||
8.33333333333333333333E-2f,
|
||||
-2.10927960927960927961E-2f,
|
||||
7.57575757575757575758E-3f,
|
||||
@ -384,7 +392,7 @@ static inline float calc_digamma(float x) {
|
||||
float y = 0;
|
||||
if (x < 1.0e17f) {
|
||||
float z = 1 / (x * x);
|
||||
y = z * polevlf(z, A, 6);
|
||||
y = z * polevl(z, A, 6);
|
||||
}
|
||||
return result + logf(x) - (0.5f / x) - y;
|
||||
}
|
||||
@ -1196,7 +1204,7 @@ chbevl(const T x, const T array[], size_t len) {
|
||||
* of all inputs to convert them into the domain of the approximation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
|
||||
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I0(x)
|
||||
* in the interval [0,8].
|
||||
*
|
||||
@ -1222,7 +1230,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
|
||||
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
*
|
||||
@ -1247,7 +1255,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I1(x)
|
||||
* in the interval [0,8].
|
||||
@ -1274,7 +1282,7 @@ chebyshev_coefficients_i1e_A() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I1(x)
|
||||
* in the interval [0,8].
|
||||
@ -1303,7 +1311,7 @@ chebyshev_coefficients_i1e_A() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
@ -1329,7 +1337,7 @@ chebyshev_coefficients_i1e_B() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
@ -1368,7 +1376,7 @@ calc_i0(T _x) {
|
||||
}
|
||||
|
||||
// Upcast bfloat16 input to float for numerical accuracy purposes
|
||||
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
|
||||
static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
|
||||
|
||||
/*
|
||||
* This function is derived from the implementation of the i0e function in the Cephes Math Library.
|
||||
@ -1400,7 +1408,7 @@ calc_i0e(T _x) {
|
||||
}
|
||||
|
||||
// Upcast bfloat16 input to float for numerical accuracy purposes
|
||||
inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); }
|
||||
static inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); }
|
||||
|
||||
/*
|
||||
* This function is derived from the implementation of the i1 function in the Cephes Math Library.
|
||||
@ -1461,3 +1469,135 @@ calc_i1e(T _x) {
|
||||
static_cast<T>(chbevl(static_cast<T>(32.0 / x - 2.0), B, len) / std::sqrt(x));
|
||||
return (_x < 0.0) ? -out : out;
|
||||
}
|
||||
|
||||
/*
|
||||
* This function is derived from the implementation of the i1e function in the Cephes Math Library.
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*
|
||||
* Computes the argument, x, for which the area under the Gaussian probability density function
|
||||
* (integrated from minus infinity to x) is equal to y.
|
||||
*/
|
||||
template <typename T>
|
||||
static inline C10_HOST_DEVICE T calc_ndtri(T y0) {
|
||||
|
||||
/* sqrt(2pi) */
|
||||
constexpr T s2pi = 2.50662827463100050242E0;
|
||||
constexpr T one = 1;
|
||||
constexpr T zero = 0;
|
||||
|
||||
/* approximation for 0 <= |y - 0.5| <= 3/8 */
|
||||
static const T P0[5] = {
|
||||
-5.99633501014107895267E1,
|
||||
9.80010754185999661536E1,
|
||||
-5.66762857469070293439E1,
|
||||
1.39312609387279679503E1,
|
||||
-1.23916583867381258016E0,
|
||||
};
|
||||
|
||||
static const T Q0[9] = {
|
||||
1.00000000000000000000E0,
|
||||
1.95448858338141759834E0,
|
||||
4.67627912898881538453E0,
|
||||
8.63602421390890590575E1,
|
||||
-2.25462687854119370527E2,
|
||||
2.00260212380060660359E2,
|
||||
-8.20372256168333339912E1,
|
||||
1.59056225126211695515E1,
|
||||
-1.18331621121330003142E0,
|
||||
};
|
||||
|
||||
/* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
|
||||
* i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
|
||||
*/
|
||||
static const T P1[9] = {
|
||||
4.05544892305962419923E0,
|
||||
3.15251094599893866154E1,
|
||||
5.71628192246421288162E1,
|
||||
4.40805073893200834700E1,
|
||||
1.46849561928858024014E1,
|
||||
2.18663306850790267539E0,
|
||||
-1.40256079171354495875E-1,
|
||||
-3.50424626827848203418E-2,
|
||||
-8.57456785154685413611E-4,
|
||||
};
|
||||
|
||||
static const T Q1[9] = {
|
||||
1.00000000000000000000E0,
|
||||
1.57799883256466749731E1,
|
||||
4.53907635128879210584E1,
|
||||
4.13172038254672030440E1,
|
||||
1.50425385692907503408E1,
|
||||
2.50464946208309415979E0,
|
||||
-1.42182922854787788574E-1,
|
||||
-3.80806407691578277194E-2,
|
||||
-9.33259480895457427372E-4,
|
||||
};
|
||||
|
||||
/* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
|
||||
* i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
|
||||
*/
|
||||
|
||||
static const T P2[9] = {
|
||||
3.23774891776946035970E0,
|
||||
6.91522889068984211695E0,
|
||||
3.93881025292474443415E0,
|
||||
1.33303460815807542389E0,
|
||||
2.01485389549179081538E-1,
|
||||
1.23716634817820021358E-2,
|
||||
3.01581553508235416007E-4,
|
||||
2.65806974686737550832E-6,
|
||||
6.23974539184983293730E-9,
|
||||
};
|
||||
|
||||
static const T Q2[9] = {
|
||||
1.00000000000000000000E0,
|
||||
6.02427039364742014255E0,
|
||||
3.67983563856160859403E0,
|
||||
1.37702099489081330271E0,
|
||||
2.16236993594496635890E-1,
|
||||
1.34204006088543189037E-2,
|
||||
3.28014464682127739104E-4,
|
||||
2.89247864745380683936E-6,
|
||||
6.79019408009981274425E-9,
|
||||
};
|
||||
|
||||
if (y0 == zero) {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
if (y0 == one) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
if (y0 < zero || y0 > one) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
bool code = true;
|
||||
T y = y0;
|
||||
if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */
|
||||
y = one - y;
|
||||
code = false;
|
||||
}
|
||||
|
||||
if (y > T{0.13533528323661269189}) {
|
||||
y = y - T{0.5};
|
||||
const T y2 = y * y;
|
||||
T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8));
|
||||
return (x * s2pi);
|
||||
}
|
||||
|
||||
T x = ::sqrt(T{-2.0} * ::log(y));
|
||||
const T x0 = x - ::log(x) / x;
|
||||
|
||||
const T z = one / x;
|
||||
T x1;
|
||||
if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */
|
||||
{
|
||||
x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8);
|
||||
} else {
|
||||
x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8);
|
||||
}
|
||||
x = x0 - x1;
|
||||
if (code) {
|
||||
x = -x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
@ -66,6 +66,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_entr)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_i0e)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_i1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_i1e)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_ndtri)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(sqrt)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(tan)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(tanh)
|
||||
@ -170,6 +171,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_entr_out, special_entr_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_i0e_out, special_i0e_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1e_out, special_i1e_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1_out, special_i1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_ndtri_out, special_ndtri_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
|
||||
@ -759,6 +761,7 @@ DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-
|
||||
DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_ndtri_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
@ -45,6 +45,7 @@ DECLARE_DISPATCH(unary_fn, log_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log10_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log1p_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log2_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
|
||||
DECLARE_DISPATCH(unary_fn, neg_stub);
|
||||
|
||||
DECLARE_DISPATCH(unary_fn, reciprocal_stub);
|
||||
|
@ -600,6 +600,13 @@ static void frexp_kernel(TensorIteratorBase& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
static void ndtri_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cpu", [&]() {
|
||||
cpu_kernel(iter, [](scalar_t x) { return calc_ndtri(x); });
|
||||
});
|
||||
}
|
||||
|
||||
static void i0e_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
@ -765,6 +772,8 @@ REGISTER_DISPATCH(frexp_stub, &CPU_CAPABILITY::frexp_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
|
||||
|
@ -114,6 +114,13 @@ void logit_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
|
||||
});
|
||||
}
|
||||
|
||||
void ndtri_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
|
||||
gpu_kernel(
|
||||
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_ndtri(a); });
|
||||
});
|
||||
}
|
||||
|
||||
void erf_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
@ -188,6 +195,7 @@ REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
|
||||
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
|
||||
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
|
||||
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
|
||||
REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -9445,6 +9445,19 @@
|
||||
dispatch:
|
||||
CPU, CUDA: special_entr_out
|
||||
|
||||
- func: special_ndtri(Tensor self) -> Tensor
|
||||
structured_delegate: special_ndtri.out
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: special_ndtri_out
|
||||
|
||||
- func: special_expm1(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
@ -34,4 +34,5 @@ Functions
|
||||
.. autofunction:: i1e
|
||||
.. autofunction:: logit
|
||||
.. autofunction:: ndtr
|
||||
.. autofunction:: ndtri
|
||||
.. autofunction:: xlog1py
|
||||
|
@ -1108,6 +1108,9 @@
|
||||
- name: special_entr(Tensor self) -> Tensor
|
||||
self: grad * (-(1 + self.log()))
|
||||
|
||||
- name: special_ndtri(Tensor self) -> Tensor
|
||||
self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
|
||||
|
||||
# DO NOT define a backward for reshape!
|
||||
# reshape is special in that it sometimes returns a view, and sometimes not.
|
||||
# Defining a backward will make codegen spit out the forward call as
|
||||
|
@ -117,6 +117,14 @@ inline Tensor& erfinv_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_erfinv_out(result, self);
|
||||
}
|
||||
|
||||
inline Tensor ndtri(const Tensor& self) {
|
||||
return torch::special_ndtri(self);
|
||||
}
|
||||
|
||||
inline Tensor& ndtri_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_ndtri_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the logit of input, elementwise.
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.logit.
|
||||
///
|
||||
|
@ -882,6 +882,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.i1: lambda input: -1,
|
||||
torch.special.i1e: lambda input: -1,
|
||||
torch.special.logit: lambda input: -1,
|
||||
torch.special.ndtri: lambda input: -1,
|
||||
torch.special.ndtr: lambda input: -1,
|
||||
torch.special.xlog1py: lambda input, other, out=None: -1,
|
||||
torch.t: lambda input: -1,
|
||||
|
@ -342,8 +342,10 @@ for each element of :attr:`input`.
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
|
||||
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
|
||||
@ -361,8 +363,10 @@ for each element of :attr:`input`.
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
|
||||
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
|
||||
@ -381,8 +385,10 @@ for each element of :attr:`input`.
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
|
||||
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
|
||||
@ -408,3 +414,27 @@ Example::
|
||||
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
||||
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
|
||||
""".format(**common_args))
|
||||
|
||||
ndtri = _add_docstr(_special.special_ndtri,
|
||||
r"""
|
||||
ndtri(input, *, out=None) -> Tensor
|
||||
Computes the argument, x, for which the area under the Gaussian probability density function
|
||||
(integrated from minus infinity to x) is equal to :attr:`input`, elementwise.
|
||||
|
||||
.. math::
|
||||
\text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1)
|
||||
|
||||
.. note::
|
||||
Also known as quantile function for Normal Distribution.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
|
||||
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
|
||||
""".format(**common_args))
|
||||
|
@ -7341,6 +7341,12 @@ op_db: List[OpInfo] = [
|
||||
supports_inplace_autograd=False,
|
||||
safe_casts_outputs=True,
|
||||
sample_inputs_func=sample_inputs_entr),
|
||||
UnaryUfuncInfo('special.ndtri',
|
||||
ref=scipy.special.ndtri if TEST_SCIPY else _NOTHING,
|
||||
domain=(0, 1),
|
||||
aten_name='special_ndtri',
|
||||
dtypes=all_types_and(torch.bool),
|
||||
safe_casts_outputs=True),
|
||||
UnaryUfuncInfo('erf',
|
||||
ref=scipy.special.erf if TEST_SCIPY else _NOTHING,
|
||||
aliases=('special.erf', ),
|
||||
|
Reference in New Issue
Block a user