[special] Add i0e (#54409)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345

Changes:
* Add `i0e`
* Move some kernels from `UnaryOpsKernel.cu` to `UnarySpecialOpsKernel.cu` to decrease compilation time per file.

Time taken by i0e_vs_scipy tests: around 6.33.s

<details>

<summary>Test Run Log</summary>

```
(pytorch-cuda-dev) kshiteej@qgpu1:~/Pytorch/pytorch_module_special$ pytest test/test_unary_ufuncs.py -k _i0e_vs
======================================================================= test session starts ========================================================================
platform linux -- Python 3.8.6, pytest-6.1.2, py-1.9.0, pluggy-0.13.1
rootdir: /home/kshiteej/Pytorch/pytorch_module_special, configfile: pytest.ini
plugins: hypothesis-5.38.1
collected 8843 items / 8833 deselected / 10 selected

test/test_unary_ufuncs.py ...sss....                                                                                                                         [100%]

========================================================================= warnings summary =========================================================================
../../.conda/envs/pytorch-cuda-dev/lib/python3.8/site-packages/torch/backends/cudnn/__init__.py:73
test/test_unary_ufuncs.py::TestUnaryUfuncsCUDA::test_special_i0e_vs_scipy_cuda_bfloat16
  /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/site-packages/torch/backends/cudnn/__init__.py:73: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/warnings.html
===================================================================== short test summary info ======================================================================
SKIPPED [3] test/test_unary_ufuncs.py:1182: not implemented: Could not run 'aten::_copy_from' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_copy_from' is only available for these backends: [BackendSelect, Named, InplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
InplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:56 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradMLC: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_4.cpp:9348 [kernel]
Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:250 [backend fallback]
Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
==================================================== 7 passed, 3 skipped, 8833 deselected, 2 warnings in 6.33s =====================================================
```

</details>

TODO:
* [x] Check rendered docs (https://11743402-65600975-gh.circle-artifacts.com/0/docs/special.html)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54409

Reviewed By: jbschlosser

Differential Revision: D27760472

Pulled By: mruberry

fbshipit-source-id: bdfbcaa798b00c51dc9513c34626246c8fc10548
This commit is contained in:
kshitij12345
2021-04-15 06:04:44 -07:00
committed by Facebook GitHub Bot
parent 2f895f790a
commit 50057e560b
24 changed files with 468 additions and 273 deletions

View File

@ -508,6 +508,7 @@ filegroup(
"aten/src/ATen/native/cuda/TensorTransformations.cu.cc",
"aten/src/ATen/native/cuda/TriangularOps.cu.cc",
"aten/src/ATen/native/cuda/UnaryOpsKernel.cu.cc",
"aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu.cc",
"aten/src/ATen/native/cuda/Unique.cu.cc",
"aten/src/ATen/native/cuda/UpSampleBicubic2d.cu.cc",
"aten/src/ATen/native/cuda/UpSampleBilinear2d.cu.cc",

View File

@ -327,6 +327,7 @@ namespace c10 {
_(aten, special_expm1) \
_(aten, exp2) \
_(aten, special_exp2) \
_(aten, special_i0e) \
_(aten, has_torch_function) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \

View File

@ -392,6 +392,9 @@ public:
Vec256<T> i0() const {
return map(calc_i0);
}
Vec256<T> i0e() const {
return map(calc_i0e);
}
Vec256<T> igamma(const Vec256<T> &x) const {
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {

View File

@ -315,6 +315,22 @@ public:
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> i0e() const {
__m256 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto sz = size();
__at_align32__ float tmp1[sz / 2], tmp2[sz / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (decltype(sz) i = 0; i < sz / 2; i++) {
tmp1[i] = calc_i0e(tmp1[i]);
tmp2[i] = calc_i0e(tmp2[i]);
}
auto o1 = _mm256_loadu_ps(tmp1);
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> igamma(const Vec256<BFloat16> &x) const {
__m256 lo, hi;
__m256 xlo, xhi;

View File

@ -164,6 +164,9 @@ public:
Vec256<double> i0() const {
return map(calc_i0);
}
Vec256<double> i0e() const {
return map(calc_i0e);
}
Vec256<double> igamma(const Vec256<double> &x) const {
__at_align32__ double tmp[size()];
__at_align32__ double tmp_x[size()];

View File

@ -202,6 +202,9 @@ public:
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> i0e() const {
return map(calc_i0e);
}
Vec256<float> igamma(const Vec256<float> &x) const {
__at_align32__ float tmp[size()];
__at_align32__ float tmp_x[size()];

View File

@ -362,6 +362,9 @@ public:
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> i0e() const {
return map(calc_i0e);
}
Vec256<float> igamma(const Vec256<float> &x) const {
__at_align32__ float tmp[size()];
__at_align32__ float tmp_x[size()];

View File

@ -351,6 +351,10 @@ class Vec256<double> {
return map(calc_i0);
}
Vec256<double> i0e() const {
return map(calc_i0e);
}
DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)
DEFINE_MEMBER_OP(operator<, double, vec_cmplt)

View File

@ -637,6 +637,10 @@ class Vec256<float> {
return map(calc_i0);
}
Vec256<float> i0e() const {
return map(calc_i0e);
}
DEFINE_MEMBER_OP(operator==, float, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, float, vec_cmpne)
DEFINE_MEMBER_OP(operator<, float, vec_cmplt)

View File

@ -129,6 +129,7 @@ IMPLEMENT_VML_BUG(exp)
IMPLEMENT_VML_BUG(expm1)
IMPLEMENT_VML_BUG(floor)
IMPLEMENT_VML(i0)
IMPLEMENT_VML(i0e)
IMPLEMENT_VML(reciprocal)
IMPLEMENT_VML_BUG(log)
IMPLEMENT_VML_BUG(log10)

View File

@ -1161,7 +1161,7 @@ calc_gcd(T a, T b) {
*/
template <typename T>
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
chbevl(T x, T array[], size_t len) {
chbevl(const T x,const T array[], size_t len) {
T b0, b1, b2;
b0 = array[0];
@ -1186,87 +1186,98 @@ static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i0(T _x) {
T x = std::abs(_x);
inline const T* chebyshev_coefficients_A(){
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I0(x) } = 1.
*/
static T A[] = {
-4.41534164647933937950E-18,
3.33079451882223809783E-17,
-2.43127984654795469359E-16,
1.71539128555513303061E-15,
-1.16853328779934516808E-14,
7.67618549860493561688E-14,
-4.85644678311192946090E-13,
2.95505266312963983461E-12,
-1.72682629144155570723E-11,
9.67580903537323691224E-11,
-5.18979560163526290666E-10,
2.65982372468238665035E-9,
-1.30002500998624804212E-8,
6.04699502254191894932E-8,
-2.67079385394061173391E-7,
1.11738753912010371815E-6,
-4.41673835845875056359E-6,
1.64484480707288970893E-5,
-5.75419501008210370398E-5,
1.88502885095841655729E-4,
-5.76375574538582365885E-4,
1.63947561694133579842E-3,
-4.32430999505057594430E-3,
1.05464603945949983183E-2,
-2.37374148058994688156E-2,
4.93052842396707084878E-2,
-9.49010970480476444210E-2,
1.71620901522208775349E-1,
-3.04682672343198398683E-1,
6.76795274409476084995E-1
};
static const T coeff[] = {
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
return coeff;
};
template <typename T>
inline const T* chebyshev_coefficients_B(){
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
*/
static T B[] = {
-7.23318048787475395456E-18,
-4.83050448594418207126E-18,
4.46562142029675999901E-17,
3.46122286769746109310E-17,
-2.82762398051658348494E-16,
-3.42548561967721913462E-16,
1.77256013305652638360E-15,
3.81168066935262242075E-15,
-9.55484669882830764870E-15,
-4.15056934728722208663E-14,
1.54008621752140982691E-14,
3.85277838274214270114E-13,
7.18012445138366623367E-13,
-1.79417853150680611778E-12,
-1.32158118404477131188E-11,
-3.14991652796324136454E-11,
1.18891471078464383424E-11,
4.94060238822496958910E-10,
3.39623202570838634515E-9,
2.26666899049817806459E-8,
2.04891858946906374183E-7,
2.89137052083475648297E-6,
6.88975834691682398426E-5,
3.36911647825569408990E-3,
8.04490411014108831608E-1
};
static const T coeff[] = {
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1};
return coeff;
};
template <typename T>
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i0(T _x) {
T x = std::abs(_x);
if (x <= 8.0) {
const auto A = chebyshev_coefficients_A<T>();
T y = (x / 2.0) - 2.0;
return static_cast<T>(std::exp(x) * chbevl(y, A, 30));
}
const auto B = chebyshev_coefficients_B<T>();
return static_cast<T>(std::exp(x) * chbevl(static_cast<T>(32.0 / x - 2.0), B, 25) / std::sqrt(x));
}
// Upcast bfloat16 input to float for numerical accuracy purposes
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
/*
* This function is derived from the implementation of the i0e function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*
* Computes an approximation of the exponentially scaled zeroth order modified Bessel function of the first kind.
* The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
* One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i0e(T _x) {
T x = std::abs(_x);
if (x <= 8.0) {
auto A = chebyshev_coefficients_A<T>();
T y = (x / 2.0) - 2.0;
return static_cast<T>(chbevl(y, A, 30));
}
auto B = chebyshev_coefficients_B<T>();
return static_cast<T>(
chbevl(static_cast<T>(32.0 / x - 2.0), B, 25) / std::sqrt(x));
}
// Upcast bfloat16 input to float for numerical accuracy purposes
inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); }

View File

@ -40,6 +40,7 @@ CREATE_UNARY_META_FUNC(sinh)
CREATE_UNARY_META_FUNC(cosh)
CREATE_UNARY_META_FUNC(acosh)
CREATE_UNARY_META_FUNC(cos)
CREATE_UNARY_META_FUNC(special_i0e)
} // namespace meta
@ -354,6 +355,10 @@ Tensor& i0_out(const Tensor& self, Tensor& result) { return unary_op_impl_out(re
Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); }
Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); }
TORCH_IMPL_FUNC(special_i0e_out) (const Tensor& self, const Tensor& result) {
i0e_stub(device_type(), *this);
}
Tensor& log_out(const Tensor& self, Tensor& result) { return unary_op_impl_float_out(result, self, log_stub); }
Tensor log(const Tensor& self) { return unary_op_impl_float(self, log_stub); }
Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); }
@ -804,6 +809,7 @@ DEFINE_DISPATCH(floor_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-
DEFINE_DISPATCH(frac_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(frexp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(i0e_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
@ -829,5 +835,6 @@ DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-v
DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
} // namespace native
} // namespace at

View File

@ -43,6 +43,7 @@ DECLARE_DISPATCH(unary_fn, floor_stub);
DECLARE_DISPATCH(unary_fn, frac_stub);
DECLARE_DISPATCH(unary_fn, frexp_stub);
DECLARE_DISPATCH(unary_fn, i0_stub);
DECLARE_DISPATCH(structured_unary_fn, i0e_stub);
DECLARE_DISPATCH(unary_fn, log_stub);
DECLARE_DISPATCH(unary_fn, log10_stub);
DECLARE_DISPATCH(unary_fn, log1p_stub);

View File

@ -631,6 +631,17 @@ static void frexp_kernel(TensorIterator& iter) {
});
}
static void i0e_kernel(TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES_AND(
kBFloat16, iter.common_dtype(), "i0e_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t x) { return calc_i0e(x); },
[](Vec256<scalar_t> x) { return x.i0e(); });
});
}
// TODO: Disable cont. branch to test more risky code
#define IMPLEMENT_ITERATOR_LAMBDA(op) \
@ -736,6 +747,7 @@ REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel);
REGISTER_DISPATCH(entr_stub, &entr_kernel);
REGISTER_DISPATCH(frexp_stub, &frexp_kernel);
REGISTER_DISPATCH(i0e_stub, &i0e_kernel);
IMPLEMENT_COMPLEX_KERNEL(acos)

View File

@ -10,7 +10,7 @@ namespace native {
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
template <typename scalar_t>
static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
static inline C10_HOST_DEVICE scalar_t zeta(scalar_t _x, scalar_t _q) {
using accscalar_t = at::acc_type<scalar_t, true>;
static const accscalar_t MACHEP = 1.11022302462515654042E-16;
const accscalar_t A[] = {
@ -92,7 +92,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
@ -151,7 +151,7 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
}
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) {
static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
const accscalar_t PI = 3.14159265358979323846;
accscalar_t x = static_cast<accscalar_t>(in);
@ -174,7 +174,7 @@ static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) {
}
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_polygamma(int n, scalar_t x) {
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * zeta(static_cast<scalar_t>(n + 1), x);
}
@ -216,6 +216,58 @@ static inline C10_HOST_DEVICE scalar_t chbevl(scalar_t _x, const scalar_t array[
/*
* For licensing information and documentation, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
template <typename T>
C10_HOST_DEVICE inline const T* chebyshev_coefficients_A() {
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I0(x) } = 1.
*/
static const T coefficients[] = {
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
return coefficients;
}
template <typename T>
C10_HOST_DEVICE inline const T* chebyshev_coefficients_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
*/
static const T coefficients[] = {
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1};
return coefficients;
}
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
using accscalar_t = at::acc_type<scalar_t, true>;
@ -224,84 +276,33 @@ static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
// Needed for accurate results if input is bfloat16 or float16
accscalar_t x = ::abs(static_cast<accscalar_t>(_x));
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I0(x) } = 1.
*/
const accscalar_t A[] = {
-4.41534164647933937950E-18,
3.33079451882223809783E-17,
-2.43127984654795469359E-16,
1.71539128555513303061E-15,
-1.16853328779934516808E-14,
7.67618549860493561688E-14,
-4.85644678311192946090E-13,
2.95505266312963983461E-12,
-1.72682629144155570723E-11,
9.67580903537323691224E-11,
-5.18979560163526290666E-10,
2.65982372468238665035E-9,
-1.30002500998624804212E-8,
6.04699502254191894932E-8,
-2.67079385394061173391E-7,
1.11738753912010371815E-6,
-4.41673835845875056359E-6,
1.64484480707288970893E-5,
-5.75419501008210370398E-5,
1.88502885095841655729E-4,
-5.76375574538582365885E-4,
1.63947561694133579842E-3,
-4.32430999505057594430E-3,
1.05464603945949983183E-2,
-2.37374148058994688156E-2,
4.93052842396707084878E-2,
-9.49010970480476444210E-2,
1.71620901522208775349E-1,
-3.04682672343198398683E-1,
6.76795274409476084995E-1
};
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
*/
const accscalar_t B[] = {
-7.23318048787475395456E-18,
-4.83050448594418207126E-18,
4.46562142029675999901E-17,
3.46122286769746109310E-17,
-2.82762398051658348494E-16,
-3.42548561967721913462E-16,
1.77256013305652638360E-15,
3.81168066935262242075E-15,
-9.55484669882830764870E-15,
-4.15056934728722208663E-14,
1.54008621752140982691E-14,
3.85277838274214270114E-13,
7.18012445138366623367E-13,
-1.79417853150680611778E-12,
-1.32158118404477131188E-11,
-3.14991652796324136454E-11,
1.18891471078464383424E-11,
4.94060238822496958910E-10,
3.39623202570838634515E-9,
2.26666899049817806459E-8,
2.04891858946906374183E-7,
2.89137052083475648297E-6,
6.88975834691682398426E-5,
3.36911647825569408990E-3,
8.04490411014108831608E-1
};
if (x <= 8.0) {
const auto A = chebyshev_coefficients_A<accscalar_t>();
accscalar_t y = static_cast<accscalar_t>((x / 2.0) - 2.0);
return static_cast<scalar_t>(::exp(x) * chbevl(y, A, 30));
}
const auto B = chebyshev_coefficients_B<accscalar_t>();
return static_cast<scalar_t>(::exp(x) * chbevl(static_cast<accscalar_t>(32.0 / x - 2.0), B, 25) / ::sqrt(x));
}
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_i0e(scalar_t _x) {
using accscalar_t = at::acc_type<scalar_t, true>;
// Upcast input for numerical accuracy purposes
// Needed for accurate results if input is bfloat16 or float16
accscalar_t x = ::abs(static_cast<accscalar_t>(_x));
if (x <= 8.0) {
const auto A = chebyshev_coefficients_A<accscalar_t>();
accscalar_t y = static_cast<accscalar_t>((x / 2.0) - 2.0);
return static_cast<scalar_t>(chbevl(y, A, 30));
}
const auto B = chebyshev_coefficients_B<accscalar_t>();
return static_cast<scalar_t>(chbevl(static_cast<accscalar_t>(32.0 / x - 2.0), B, 25) / ::sqrt(x));
}
}
} // namespace native
} // namespace at

View File

@ -41,14 +41,6 @@ void exp_kernel_cuda(TensorIterator& iter) {
});
}
void exp2_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::exp2(a);
});
});
}
void expm1_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "expm1_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
@ -57,14 +49,6 @@ void expm1_kernel_cuda(TensorIterator& iter) {
});
}
void i0_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i0(a);
});
});
}
// We manually overload rsqrt because std::rsqrt does not work with complex types.
template<typename scalar_t>
__host__ __device__ static inline scalar_t rsqrt_wrapper(scalar_t v) {
@ -95,80 +79,6 @@ void sqrt_kernel_cuda(TensorIterator& iter) {
});
}
void sigmoid_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "sigmoid_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
scalar_t one = scalar_t(1);
return one / (one + std::exp(- a));
});
});
}
void sinc_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinc_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
if (a == scalar_t(0)) {
return scalar_t(1);
} else {
// NVCC says constexpr var is not accessible from device
scalar_t product = c10::detail::pi<scalar_t>() * a;
return std::sin(product) / product;
}
});
});
}
void logit_kernel_cuda(TensorIterator& iter, const Scalar& eps_scalar) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"logit_cuda",
[&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC eps = eps_scalar.to<T_ACC>();
if (eps < T_ACC(0)) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
const T_ACC x_acc = static_cast<T_ACC>(x);
return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
});
} else {
const T_ACC lo = eps;
const T_ACC hi = T_ACC(1) - eps;
gpu_kernel(
iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
const T_ACC x_acc = static_cast<T_ACC>(x);
T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
return c10::cuda::compat::log(z / (T_ACC(1) - z));
});
}
});
}
void erf_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erf(a);
});
});
}
void erfc_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfc_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfc(a);
});
});
}
void erfinv_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
});
}
void clamp_kernel_cuda(TensorIterator& iter, const Scalar& min_value, const Scalar& max_value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_cuda", [&]() {
auto lower = min_value.to<scalar_t>();
@ -238,40 +148,6 @@ void nan_to_num_kernel_cuda(
});
}
void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
using T_ACC = acc_type<scalar_t, true>;
const T_ACC inv_alpha = static_cast<T_ACC>(2.0 / (window_length - 1));
const T_ACC beta = static_cast<T_ACC>(beta_);
const T_ACC inv_i0_beta = 1.0 / calc_i0(beta);
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
T_ACC x = static_cast<T_ACC>(a) * inv_alpha - 1;
T_ACC y = std::max<T_ACC>(0, 1 - x * x);
return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
});
});
}
void entr_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
iter.common_dtype(),
"entr_cuda",
[&]() {
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
if (at::_isnan(x)) {
return x;
} else if (x > 0) {
return -x * std::log(x);
} else if (x == 0) {
return 0;
}
return static_cast<scalar_t>(-INFINITY);
});
});
}
void frexp_kernel_cuda(TensorIterator& iter) {
#ifdef __HIP_PLATFORM_HCC__
// Reference: https://rocmdocs.amd.com/en/latest/ROCm_API_References/HIP-MATH.html
@ -295,23 +171,13 @@ void frexp_kernel_cuda(TensorIterator& iter) {
REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);
REGISTER_DISPATCH(exp_stub, &exp_kernel_cuda);
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda);
REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
REGISTER_DISPATCH(sqrt_stub, &sqrt_kernel_cuda);
REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
REGISTER_DISPATCH(clamp_stub, &clamp_kernel_cuda);
REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel_cuda);
REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel_cuda);
REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
REGISTER_DISPATCH(entr_stub, &entr_kernel_cuda);
REGISTER_DISPATCH(frexp_stub, &frexp_kernel_cuda);
} // namespace native

View File

@ -0,0 +1,167 @@
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/NumericUtils.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/NumericUtils.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
void exp2_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::exp2(a);
});
});
}
void i0_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i0(a);
});
});
}
void i0e_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i0e(a);
});
});
}
void sigmoid_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "sigmoid_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
scalar_t one = scalar_t(1);
return one / (one + std::exp(- a));
});
});
}
void sinc_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinc_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
if (a == scalar_t(0)) {
return scalar_t(1);
} else {
// NVCC says constexpr var is not accessible from device
scalar_t product = c10::detail::pi<scalar_t>() * a;
return std::sin(product) / product;
}
});
});
}
void logit_kernel_cuda(TensorIterator& iter, const Scalar& eps_scalar) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"logit_cuda",
[&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC eps = eps_scalar.to<T_ACC>();
if (eps < T_ACC(0)) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
const T_ACC x_acc = static_cast<T_ACC>(x);
return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
});
} else {
const T_ACC lo = eps;
const T_ACC hi = T_ACC(1) - eps;
gpu_kernel(
iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
const T_ACC x_acc = static_cast<T_ACC>(x);
T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
return c10::cuda::compat::log(z / (T_ACC(1) - z));
});
}
});
}
void erf_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erf(a);
});
});
}
void erfc_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfc_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfc(a);
});
});
}
void erfinv_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
});
}
void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
using T_ACC = acc_type<scalar_t, true>;
const T_ACC inv_alpha = static_cast<T_ACC>(2.0 / (window_length - 1));
const T_ACC beta = static_cast<T_ACC>(beta_);
const T_ACC inv_i0_beta = 1.0 / calc_i0(beta);
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
T_ACC x = static_cast<T_ACC>(a) * inv_alpha - 1;
T_ACC y = std::max<T_ACC>(0, 1 - x * x);
return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
});
});
}
void entr_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
iter.common_dtype(),
"entr_cuda",
[&]() {
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
if (at::_isnan(x)) {
return x;
} else if (x > 0) {
return -x * std::log(x);
} else if (x == 0) {
return 0;
}
return static_cast<scalar_t>(-INFINITY);
});
});
}
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
REGISTER_DISPATCH(i0e_stub, &i0e_kernel_cuda);
REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
REGISTER_DISPATCH(entr_stub, &entr_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -8418,6 +8418,18 @@
- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
- func: special_i0e(Tensor self) -> Tensor
python_module: special
variants: function
structured_delegate: special_i0e.out
- func: special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: special_i0e_out
- func: special_logit(Tensor self, float? eps=None) -> Tensor
python_module: special
variants: function

View File

@ -26,4 +26,5 @@ Functions
.. autofunction:: expm1
.. autofunction:: exp2
.. autofunction:: gammaln
.. autofunction:: i0e
.. autofunction:: logit

View File

@ -1195,6 +1195,38 @@ class TestUnaryUfuncs(TestCase):
t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype)
self.assertTrue(torch.i0(t).isnan().all())
@dtypesIfCUDA(*torch.testing.get_all_fp_dtypes())
@dtypes(torch.bfloat16, torch.float32, torch.float64)
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
def test_special_i0e_vs_scipy(self, device, dtype):
def check_equal(t):
# Test by comparing to scipy
actual = torch.special.i0e(t)
if dtype is torch.bfloat16:
t = t.to(torch.float32)
expected = scipy.special.i0e(t.cpu().numpy())
# Casting down for dtype float16 is required since scipy upcasts to float32
if dtype is torch.bfloat16 or dtype is torch.float16:
expected = torch.from_numpy(expected).to(dtype)
self.assertEqual(actual, expected)
t = torch.tensor([], device=device, dtype=dtype)
check_equal(t)
range = (-1e7, 1e7)
if dtype == torch.half:
range = (-65000, 65000)
t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
check_equal(t)
# NaN, inf, -inf are tested in reference_numerics tests.
info = torch.finfo(dtype)
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
check_equal(t)
# TODO: allow large opinfo values to be opted-into via metadata
@dtypes(torch.long)
def test_abs_big_number(self, device, dtype):

View File

@ -149,4 +149,20 @@ inline Tensor& expm1_out(Tensor& result, const Tensor& self) {
return torch::special_expm1_out(result, self);
}
/// Computes the exponentially scaled zeroth order modified Bessel function of the first kind
/// See https://pytorch.org/docs/master/special.html#torch.special.i0e.
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// torch::special::i0e(t);
/// ```
inline Tensor i0e(const Tensor& self) {
return torch::special_i0e(self);
}
inline Tensor i0e_out(const Tensor& self) {
return torch::special_i0e(self);
}
}} // torch::special

View File

@ -848,6 +848,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.expm1: lambda input: -1,
torch.special.expit: lambda input: -1,
torch.special.gammaln: lambda input: -1,
torch.special.i0e: lambda input: -1,
torch.special.logit: lambda input: -1,
torch.t: lambda input: -1,
torch.take: lambda input, index: -1,

View File

@ -229,3 +229,22 @@ Example::
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([ 0., 1.])
""".format(**common_args))
i0e = _add_docstr(_special.special_i0e,
r"""
i0e(input, *, out=None) -> Tensor
Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below)
for each element of :attr:`input`.
.. math::
\text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2}
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
""".format(**common_args))

View File

@ -3302,6 +3302,16 @@ op_db: List[OpInfo] = [
dtypesIfCPU=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
supports_autograd=False),
UnaryUfuncInfo('special.i0e',
aten_name='special_i0e',
ref=scipy.special.i0e,
decorators=(precisionOverride({torch.bfloat16: 3e-1,
torch.float16: 3e-1}),),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_autograd=False,
safe_casts_outputs=True),
OpInfo('floor_divide',
dtypes=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_floor_divide,