Bessel functions (#78451)

Adds:

```Python
bessel_j0(input, *, out=None) -> Tensor
```

Bessel function of the first kind of order $0$, $J_{0}(\text{input})$.

```Python
bessel_j1(input, *, out=None) -> Tensor
```

Bessel function of the first kind of order $1$, $J_{1}(\text{input})$.

```Python
bessel_j0(input, *, out=None) -> Tensor
```

Bessel function of the second kind of order $0$, $Y_{0}(\text{input})$.

```Python
bessel_j1(input, *, out=None) -> Tensor
```

Bessel function of the second kind of order $1$, $Y_{1}(\text{input})$.

```Python
modified_bessel_i0(input, *, out=None) -> Tensor
```

Modified Bessel function of the first kind of order $0$, $I_{0}(\text{input})$.

```Python
modified_bessel_i1(input, *, out=None) -> Tensor
```

Modified Bessel function of the first kind of order $1$, $I_{1}(\text{input})$.

```Python
modified_bessel_k0(input, *, out=None) -> Tensor
```

Modified Bessel function of the second kind of order $0$, $K_{0}(\text{input})$.

```Python
modified_bessel_k1(input, *, out=None) -> Tensor
```

Modified Bessel function of the second kind of order $1$, $K_{1}(\text{input})$.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78451
Approved by: https://github.com/mruberry
This commit is contained in:
Allen Goodman
2022-06-02 14:06:20 +00:00
committed by PyTorch MergeBot
parent 78824a7d54
commit 4a5381ab40
20 changed files with 2587 additions and 2 deletions

View File

@ -2177,6 +2177,455 @@ static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
}
}
template<typename T>
static inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
static const T PP[] = {
+7.96936729297347051624e-04,
+8.28352392107440799803e-02,
+1.23953371646414299388e+00,
+5.44725003058768775090e+00,
+8.74716500199817011941e+00,
+5.30324038235394892183e+00,
+9.99999999999999997821e-01,
};
static const T PQ[] = {
+9.24408810558863637013e-04,
+8.56288474354474431428e-02,
+1.25352743901058953537e+00,
+5.47097740330417105182e+00,
+8.76190883237069594232e+00,
+5.30605288235394617618e+00,
+1.00000000000000000218e+00,
};
static const T QP[] = {
-1.13663838898469149931e-02,
-1.28252718670509318512e+00,
-1.95539544257735972385e+01,
-9.32060152123768231369e+01,
-1.77681167980488050595e+02,
-1.47077505154951170175e+02,
-5.14105326766599330220e+01,
-6.05014350600728481186e+00,
};
static const T QQ[] = {
+6.43178256118178023184e+01,
+8.56430025976980587198e+02,
+3.88240183605401609683e+03,
+7.24046774195652478189e+03,
+5.93072701187316984827e+03,
+2.06209331660327847417e+03,
+2.42005740240291393179e+02,
};
static const T RP[] = {
-4.79443220978201773821e+09,
+1.95617491946556577543e+12,
-2.49248344360967716204e+14,
+9.70862251047306323952e+15,
};
static const T RQ[] = {
+4.99563147152651017219e+02,
+1.73785401676374683123e+05,
+4.84409658339962045305e+07,
+1.11855537045356834862e+10,
+2.11277520115489217587e+12,
+3.10518229857422583814e+14,
+3.18121955943204943306e+16,
+1.71086294081043136091e+18,
};
if (x < T(0)) {
x = -x;
}
if (x <= T(5.0)) {
if (x < T(0.00001)) {
return T(1.0) - x * x / T(4.0);
}
T rp = 0.0;
for (uint8_t index = 0; index <= 3; index++) {
rp = rp * (x * x) + RP[index];
}
T rq = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
rq = rq * (x * x) + RQ[index];
}
return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(25.0) / (x * x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(25.0) / (x * x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(25.0) / (x * x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(25.0) / (x * x)) + QQ[index];
}
return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
} // bessel_j0_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
static const T PP[] = {
+7.62125616208173112003e-04,
+7.31397056940917570436e-02,
+1.12719608129684925192e+00,
+5.11207951146807644818e+00,
+8.42404590141772420927e+00,
+5.21451598682361504063e+00,
+1.00000000000000000254e+00,
};
static const T PQ[] = {
+5.71323128072548699714e-04,
+6.88455908754495404082e-02,
+1.10514232634061696926e+00,
+5.07386386128601488557e+00,
+8.39985554327604159757e+00,
+5.20982848682361821619e+00,
+9.99999999999999997461e-01,
};
static const T QP[] = {
+5.10862594750176621635e-02,
+4.98213872951233449420e+00,
+7.58238284132545283818e+01,
+3.66779609360150777800e+02,
+7.10856304998926107277e+02,
+5.97489612400613639965e+02,
+2.11688757100572135698e+02,
+2.52070205858023719784e+01,
};
static const T QQ[] = {
+7.42373277035675149943e+01,
+1.05644886038262816351e+03,
+4.98641058337653607651e+03,
+9.56231892404756170795e+03,
+7.99704160447350683650e+03,
+2.82619278517639096600e+03,
+3.36093607810698293419e+02,
};
static const T RP[] = {
-8.99971225705559398224e+08,
+4.52228297998194034323e+11,
-7.27494245221818276015e+13,
+3.68295732863852883286e+15,
};
static const T RQ[] = {
+6.20836478118054335476e+02,
+2.56987256757748830383e+05,
+8.35146791431949253037e+07,
+2.21511595479792499675e+10,
+4.74914122079991414898e+12,
+7.84369607876235854894e+14,
+8.95222336184627338078e+16,
+5.32278620332680085395e+18,
};
if (x < T(0.0)) {
return -bessel_j1_forward(-x);
}
if (x <= T(5.0)) {
T rp = 0.0;
for (uint8_t index = 0; index <= 3; index++) {
rp = rp * (x * x) + RP[index];
}
T rq = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
rq = rq * (x * x) + RQ[index];
}
return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
}
return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
} // bessel_j1_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
static const T PP[] = {
+7.96936729297347051624e-04,
+8.28352392107440799803e-02,
+1.23953371646414299388e+00,
+5.44725003058768775090e+00,
+8.74716500199817011941e+00,
+5.30324038235394892183e+00,
+9.99999999999999997821e-01,
};
static const T PQ[] = {
+9.24408810558863637013e-04,
+8.56288474354474431428e-02,
+1.25352743901058953537e+00,
+5.47097740330417105182e+00,
+8.76190883237069594232e+00,
+5.30605288235394617618e+00,
+1.00000000000000000218e+00,
};
static const T QP[] = {
-1.13663838898469149931e-02,
-1.28252718670509318512e+00,
-1.95539544257735972385e+01,
-9.32060152123768231369e+01,
-1.77681167980488050595e+02,
-1.47077505154951170175e+02,
-5.14105326766599330220e+01,
-6.05014350600728481186e+00,
};
static const T QQ[] = {
+6.43178256118178023184e+01,
+8.56430025976980587198e+02,
+3.88240183605401609683e+03,
+7.24046774195652478189e+03,
+5.93072701187316984827e+03,
+2.06209331660327847417e+03,
+2.42005740240291393179e+02,
};
static const T YP[] = {
+1.55924367855235737965e+04,
-1.46639295903971606143e+07,
+5.43526477051876500413e+09,
-9.82136065717911466409e+11,
+8.75906394395366999549e+13,
-3.46628303384729719441e+15,
+4.42733268572569800351e+16,
-1.84950800436986690637e+16,
};
static const T YQ[] = {
+1.04128353664259848412e+03,
+6.26107330137134956842e+05,
+2.68919633393814121987e+08,
+8.64002487103935000337e+10,
+2.02979612750105546709e+13,
+3.17157752842975028269e+15,
+2.50596256172653059228e+17,
};
if (x <= T(5.0)) {
if (x == T(0.0)) {
return -std::numeric_limits<T>::infinity();
}
if (x < T(0.0)) {
return std::numeric_limits<T>::quiet_NaN();
}
T yp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
yp = yp * (x * x) + YP[index];
}
T yq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
yq = yq * (x * x) + YQ[index];
}
return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x));
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(25.0) / (x * x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(25.0) / (x * x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(25.0) / (x * x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(25.0) / (x * x)) + QQ[index];
}
return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
} // bessel_y0_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
static const T PP[] = {
+7.62125616208173112003e-04,
+7.31397056940917570436e-02,
+1.12719608129684925192e+00,
+5.11207951146807644818e+00,
+8.42404590141772420927e+00,
+5.21451598682361504063e+00,
+1.00000000000000000254e+00,
};
static const T PQ[] = {
+5.71323128072548699714e-04,
+6.88455908754495404082e-02,
+1.10514232634061696926e+00,
+5.07386386128601488557e+00,
+8.39985554327604159757e+00,
+5.20982848682361821619e+00,
+9.99999999999999997461e-01,
};
static const T QP[] = {
+5.10862594750176621635e-02,
+4.98213872951233449420e+00,
+7.58238284132545283818e+01,
+3.66779609360150777800e+02,
+7.10856304998926107277e+02,
+5.97489612400613639965e+02,
+2.11688757100572135698e+02,
+2.52070205858023719784e+01,
};
static const T QQ[] = {
+7.42373277035675149943e+01,
+1.05644886038262816351e+03,
+4.98641058337653607651e+03,
+9.56231892404756170795e+03,
+7.99704160447350683650e+03,
+2.82619278517639096600e+03,
+3.36093607810698293419e+02,
};
static const T YP[] = {
+1.26320474790178026440e+09,
-6.47355876379160291031e+11,
+1.14509511541823727583e+14,
-8.12770255501325109621e+15,
+2.02439475713594898196e+17,
-7.78877196265950026825e+17,
};
static const T YQ[] = {
+5.94301592346128195359e+02,
+2.35564092943068577943e+05,
+7.34811944459721705660e+07,
+1.87601316108706159478e+10,
+3.88231277496238566008e+12,
+6.20557727146953693363e+14,
+6.87141087355300489866e+16,
+3.97270608116560655612e+18,
};
if (x <= T(5.0)) {
if (x == T(0.0)) {
return -std::numeric_limits<T>::infinity();
}
if (x <= T(0.0)) {
return std::numeric_limits<T>::quiet_NaN();
}
T yp = 0.0;
for (uint8_t index = 0; index <= 5; index++) {
yp = yp * (x * x) + YP[index];
}
T yq = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
yq = yq * (x * x) + YQ[index];
}
return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x));
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
}
return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
} // bessel_y1_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
if (n < 0) {
@ -2369,4 +2818,344 @@ static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
return laguerre_polynomial_l_forward(x, static_cast<int64_t>(n));
} // laguerre_polynomial_l_forward(T x, T n)
template<typename T>
static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
static const T A[] = {
-4.41534164647933937950e-18,
+3.33079451882223809783e-17,
-2.43127984654795469359e-16,
+1.71539128555513303061e-15,
-1.16853328779934516808e-14,
+7.67618549860493561688e-14,
-4.85644678311192946090e-13,
+2.95505266312963983461e-12,
-1.72682629144155570723e-11,
+9.67580903537323691224e-11,
-5.18979560163526290666e-10,
+2.65982372468238665035e-09,
-1.30002500998624804212e-08,
+6.04699502254191894932e-08,
-2.67079385394061173391e-07,
+1.11738753912010371815e-06,
-4.41673835845875056359e-06,
+1.64484480707288970893e-05,
-5.75419501008210370398e-05,
+1.88502885095841655729e-04,
-5.76375574538582365885e-04,
+1.63947561694133579842e-03,
-4.32430999505057594430e-03,
+1.05464603945949983183e-02,
-2.37374148058994688156e-02,
+4.93052842396707084878e-02,
-9.49010970480476444210e-02,
+1.71620901522208775349e-01,
-3.04682672343198398683e-01,
+6.76795274409476084995e-01,
};
static const T B[] = {
-7.23318048787475395456e-18,
-4.83050448594418207126e-18,
+4.46562142029675999901e-17,
+3.46122286769746109310e-17,
-2.82762398051658348494e-16,
-3.42548561967721913462e-16,
+1.77256013305652638360e-15,
+3.81168066935262242075e-15,
-9.55484669882830764870e-15,
-4.15056934728722208663e-14,
+1.54008621752140982691e-14,
+3.85277838274214270114e-13,
+7.18012445138366623367e-13,
-1.79417853150680611778e-12,
-1.32158118404477131188e-11,
-3.14991652796324136454e-11,
+1.18891471078464383424e-11,
+4.94060238822496958910e-10,
+3.39623202570838634515e-09,
+2.26666899049817806459e-08,
+2.04891858946906374183e-07,
+2.89137052083475648297e-06,
+6.88975834691682398426e-05,
+3.36911647825569408990e-03,
+8.04490411014108831608e-01,
};
T p;
T q = 0.0;
if (std::abs(x) <= T(8.0)) {
T a = A[0];
for (uint8_t index = 1; index < 30; index++) {
p = q;
q = a;
a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
}
return std::exp(std::abs(x)) * (T(0.5) * (a - p));
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
}
return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
} // modified_bessel_i0_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
static const T A[] = {
+2.77791411276104639959e-18,
-2.11142121435816608115e-17,
+1.55363195773620046921e-16,
-1.10559694773538630805e-15,
+7.60068429473540693410e-15,
-5.04218550472791168711e-14,
+3.22379336594557470981e-13,
-1.98397439776494371520e-12,
+1.17361862988909016308e-11,
-6.66348972350202774223e-11,
+3.62559028155211703701e-10,
-1.88724975172282928790e-09,
+9.38153738649577178388e-09,
-4.44505912879632808065e-08,
+2.00329475355213526229e-07,
-8.56872026469545474066e-07,
+3.47025130813767847674e-06,
-1.32731636560394358279e-05,
+4.78156510755005422638e-05,
-1.61760815825896745588e-04,
+5.12285956168575772895e-04,
-1.51357245063125314899e-03,
+4.15642294431288815669e-03,
-1.05640848946261981558e-02,
+2.47264490306265168283e-02,
-5.29459812080949914269e-02,
+1.02643658689847095384e-01,
-1.76416518357834055153e-01,
+2.52587186443633654823e-01,
};
static const T B[] = {
+7.51729631084210481353e-18,
+4.41434832307170791151e-18,
-4.65030536848935832153e-17,
-3.20952592199342395980e-17,
+2.96262899764595013876e-16,
+3.30820231092092828324e-16,
-1.88035477551078244854e-15,
-3.81440307243700780478e-15,
+1.04202769841288027642e-14,
+4.27244001671195135429e-14,
-2.10154184277266431302e-14,
-4.08355111109219731823e-13,
-7.19855177624590851209e-13,
+2.03562854414708950722e-12,
+1.41258074366137813316e-11,
+3.25260358301548823856e-11,
-1.89749581235054123450e-11,
-5.58974346219658380687e-10,
-3.83538038596423702205e-09,
-2.63146884688951950684e-08,
-2.51223623787020892529e-07,
-3.88256480887769039346e-06,
-1.10588938762623716291e-04,
-9.76109749136146840777e-03,
+7.78576235018280120474e-01,
};
T p;
T q = 0.0;
if (std::abs(x) <= T(8.0)) {
T a = A[0];
for (uint8_t index = 1; index < 29; index++) {
p = q;
q = a;
a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
}
if (x < T(0.0)) {
return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)));
}
return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x));
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
}
if (x < T(0.0)) {
return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)));
}
return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
} // modified_bessel_i1_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
static const T A[] = {
+1.37446543561352307156e-16,
+4.25981614279661018399e-14,
+1.03496952576338420167e-11,
+1.90451637722020886025e-09,
+2.53479107902614945675e-07,
+2.28621210311945178607e-05,
+1.26461541144692592338e-03,
+3.59799365153615016266e-02,
+3.44289899924628486886e-01,
-5.35327393233902768720e-01,
};
static const T B[] = {
+5.30043377268626276149e-18,
-1.64758043015242134646e-17,
+5.21039150503902756861e-17,
-1.67823109680541210385e-16,
+5.51205597852431940784e-16,
-1.84859337734377901440e-15,
+6.34007647740507060557e-15,
-2.22751332699166985548e-14,
+8.03289077536357521100e-14,
-2.98009692317273043925e-13,
+1.14034058820847496303e-12,
-4.51459788337394416547e-12,
+1.85594911495471785253e-11,
-7.95748924447710747776e-11,
+3.57739728140030116597e-10,
-1.69753450938905987466e-09,
+8.57403401741422608519e-09,
-4.66048989768794782956e-08,
+2.76681363944501510342e-07,
-1.83175552271911948767e-06,
+1.39498137188764993662e-05,
-1.28495495816278026384e-04,
+1.56988388573005337491e-03,
-3.14481013119645005427e-02,
+2.44030308206595545468e+00,
};
if (x == T(0.0)) {
return std::numeric_limits<T>::infinity();
}
if (x < T(0.0)) {
return std::numeric_limits<T>::quiet_NaN();
}
T p;
T q = 0.0;
if (x <= T(2.0)) {
T a = A[0];
for (uint8_t index = 1; index < 10; index++) {
p = q;
q = a;
a = (x * x - T(2.0)) * q - p + A[index];
}
return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x);
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
}
return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
} // modified_bessel_k0_forward(T x)
template<typename T>
static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
static const T A[] = {
-7.02386347938628759343e-18,
-2.42744985051936593393e-15,
-6.66690169419932900609e-13,
-1.41148839263352776110e-10,
-2.21338763073472585583e-08,
-2.43340614156596823496e-06,
-1.73028895751305206302e-04,
-6.97572385963986435018e-03,
-1.22611180822657148235e-01,
-3.53155960776544875667e-01,
+1.52530022733894777053e+00,
};
static const T B[] = {
-5.75674448366501715755e-18,
+1.79405087314755922667e-17,
-5.68946255844285935196e-17,
+1.83809354436663880070e-16,
-6.05704724837331885336e-16,
+2.03870316562433424052e-15,
-7.01983709041831346144e-15,
+2.47715442448130437068e-14,
-8.97670518232499435011e-14,
+3.34841966607842919884e-13,
-1.28917396095102890680e-12,
+5.13963967348173025100e-12,
-2.12996783842756842877e-11,
+9.21831518760500529508e-11,
-4.19035475934189648750e-10,
+2.01504975519703286596e-09,
-1.03457624656780970260e-08,
+5.74108412545004946722e-08,
-3.50196060308781257119e-07,
+2.40648494783721712015e-06,
-1.93619797416608296024e-05,
+1.95215518471351631108e-04,
-2.85781685962277938680e-03,
+1.03923736576817238437e-01,
+2.72062619048444266945e+00,
};
if (x == T(0.0)) {
return std::numeric_limits<T>::infinity();
}
if (x < T(0.0)) {
return std::numeric_limits<T>::quiet_NaN();
}
T p;
T q = 0.0;
if (x <= T(2.0)) {
T a = A[0];
for (uint8_t index = 1; index < 11; index++) {
p = q;
q = a;
a = (x * x - T(2.0)) * q - p + A[index];
}
return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
}
return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
} // modified_bessel_k1_forward(T x)
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -71,6 +71,14 @@ CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr)
CREATE_UNARY_FLOAT_META_FUNC(sqrt)
CREATE_UNARY_FLOAT_META_FUNC(tan)
CREATE_UNARY_FLOAT_META_FUNC(tanh)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k1)
TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
@ -190,6 +198,14 @@ CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(trunc_out, trunc_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k1_out, special_modified_bessel_k1_stub)
TORCH_IMPL_FUNC(round_decimals_out)
(const Tensor& self, int64_t decimals, const Tensor& result) {
@ -863,6 +879,14 @@ DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-v
DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
} // namespace native
} // namespace at

View File

@ -70,6 +70,14 @@ DECLARE_DISPATCH(unary_fn, tanh_stub);
DECLARE_DISPATCH(unary_fn, trigamma_stub);
DECLARE_DISPATCH(unary_fn, trunc_stub);
DECLARE_DISPATCH(unary_fn, lgamma_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
// NB: these are actually defined in Distribution
DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional<Generator>), bernoulli_tensor_stub);

View File

@ -566,6 +566,86 @@ void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
});
}
static void bessel_j0_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return bessel_j0_forward(x);
});
});
} // bessel_j0_kernel(TensorIteratorBase& iterator)
static void bessel_j1_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return bessel_j1_forward(x);
});
});
} // bessel_j1_kernel(TensorIteratorBase& iterator)
static void bessel_y0_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return bessel_y0_forward(x);
});
});
} // bessel_y0_kernel(TensorIteratorBase& iterator)
static void bessel_y1_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return bessel_y1_forward(x);
});
});
} // bessel_y1_kernel(TensorIteratorBase& iterator)
static void modified_bessel_i0_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return modified_bessel_i0_forward(x);
});
});
} // modified_bessel_i0_kernel(TensorIteratorBase& iterator)
static void modified_bessel_i1_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return modified_bessel_i1_forward(x);
});
});
} // modified_bessel_i1_kernel(TensorIteratorBase& iterator)
static void modified_bessel_k0_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return modified_bessel_k0_forward(x);
});
});
} // modified_bessel_k0_kernel(TensorIteratorBase& iterator)
static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) {
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cpu", [&]() {
cpu_kernel(iterator, [](scalar_t x) {
return modified_bessel_k1_forward(x);
});
});
} // modified_bessel_k1_kernel(TensorIteratorBase& iterator)
// TODO: Disable cont. branch to test more risky code
#define IMPLEMENT_ITERATOR_LAMBDA(op) \
@ -656,7 +736,14 @@ REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
REGISTER_DISPATCH(special_erfcx_stub, &CPU_CAPABILITY::erfcx_kernel);
REGISTER_DISPATCH(round_decimals_stub, &CPU_CAPABILITY::round_decimals_kernel);
REGISTER_DISPATCH(special_bessel_j0_stub, &CPU_CAPABILITY::bessel_j0_kernel);
REGISTER_DISPATCH(special_bessel_j1_stub, &CPU_CAPABILITY::bessel_j1_kernel);
REGISTER_DISPATCH(special_bessel_y0_stub, &CPU_CAPABILITY::bessel_y0_kernel);
REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel);
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel);
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel);
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel);
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(acos)

View File

@ -1265,6 +1265,463 @@ const auto erfcx_string = jiterator_stringify(
}
); // erfcx_string
const auto bessel_j0_string = jiterator_stringify(
template<typename T>
T bessel_j0_forward(T x) {
static const T PP[] = {
+7.96936729297347051624e-04,
+8.28352392107440799803e-02,
+1.23953371646414299388e+00,
+5.44725003058768775090e+00,
+8.74716500199817011941e+00,
+5.30324038235394892183e+00,
+9.99999999999999997821e-01,
};
static const T PQ[] = {
+9.24408810558863637013e-04,
+8.56288474354474431428e-02,
+1.25352743901058953537e+00,
+5.47097740330417105182e+00,
+8.76190883237069594232e+00,
+5.30605288235394617618e+00,
+1.00000000000000000218e+00,
};
static const T QP[] = {
-1.13663838898469149931e-02,
-1.28252718670509318512e+00,
-1.95539544257735972385e+01,
-9.32060152123768231369e+01,
-1.77681167980488050595e+02,
-1.47077505154951170175e+02,
-5.14105326766599330220e+01,
-6.05014350600728481186e+00,
};
static const T QQ[] = {
+6.43178256118178023184e+01,
+8.56430025976980587198e+02,
+3.88240183605401609683e+03,
+7.24046774195652478189e+03,
+5.93072701187316984827e+03,
+2.06209331660327847417e+03,
+2.42005740240291393179e+02,
};
static const T RP[] = {
-4.79443220978201773821e+09,
+1.95617491946556577543e+12,
-2.49248344360967716204e+14,
+9.70862251047306323952e+15,
};
static const T RQ[] = {
+4.99563147152651017219e+02,
+1.73785401676374683123e+05,
+4.84409658339962045305e+07,
+1.11855537045356834862e+10,
+2.11277520115489217587e+12,
+3.10518229857422583814e+14,
+3.18121955943204943306e+16,
+1.71086294081043136091e+18,
};
if (x < T(0)) {
x = -x;
}
if (x <= T(5.0)) {
if (x < T(0.00001)) {
return T(1.0) - x * x / T(4.0);
}
T rp = 0.0;
for (uint8_t index = 0; index <= 3; index++) {
rp = rp * (x * x) + RP[index];
}
T rq = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
rq = rq * (x * x) + RQ[index];
}
return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(25.0) / (x * x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(25.0) / (x * x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(25.0) / (x * x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(25.0) / (x * x)) + QQ[index];
}
return (pp / pq * cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x);
} // bessel_j0_forward(T x)
); // bessel_j0_string
const auto bessel_y0_string = bessel_j0_string + jiterator_stringify(
template<typename T>
T bessel_y0_forward(T x) {
static const T PP[] = {
+7.96936729297347051624e-04,
+8.28352392107440799803e-02,
+1.23953371646414299388e+00,
+5.44725003058768775090e+00,
+8.74716500199817011941e+00,
+5.30324038235394892183e+00,
+9.99999999999999997821e-01,
};
static const T PQ[] = {
+9.24408810558863637013e-04,
+8.56288474354474431428e-02,
+1.25352743901058953537e+00,
+5.47097740330417105182e+00,
+8.76190883237069594232e+00,
+5.30605288235394617618e+00,
+1.00000000000000000218e+00,
};
static const T QP[] = {
-1.13663838898469149931e-02,
-1.28252718670509318512e+00,
-1.95539544257735972385e+01,
-9.32060152123768231369e+01,
-1.77681167980488050595e+02,
-1.47077505154951170175e+02,
-5.14105326766599330220e+01,
-6.05014350600728481186e+00,
};
static const T QQ[] = {
+6.43178256118178023184e+01,
+8.56430025976980587198e+02,
+3.88240183605401609683e+03,
+7.24046774195652478189e+03,
+5.93072701187316984827e+03,
+2.06209331660327847417e+03,
+2.42005740240291393179e+02,
};
static const T YP[] = {
+1.55924367855235737965e+04,
-1.46639295903971606143e+07,
+5.43526477051876500413e+09,
-9.82136065717911466409e+11,
+8.75906394395366999549e+13,
-3.46628303384729719441e+15,
+4.42733268572569800351e+16,
-1.84950800436986690637e+16,
};
static const T YQ[] = {
+1.04128353664259848412e+03,
+6.26107330137134956842e+05,
+2.68919633393814121987e+08,
+8.64002487103935000337e+10,
+2.02979612750105546709e+13,
+3.17157752842975028269e+15,
+2.50596256172653059228e+17,
};
if (x <= T(5.0)) {
if (x == T(0.0)) {
return NEG_INFINITY;
}
if (x < T(0.0)) {
NAN;
}
T yp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
yp = yp * (x * x) + YP[index];
}
T yq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
yq = yq * (x * x) + YQ[index];
}
return yp / yq + (T(0.636619772367581343075535053490057448) * log(x) * bessel_j0_forward(x));
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(25.0) / (x * x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(25.0) / (x * x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(25.0) / (x * x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(25.0) / (x * x)) + QQ[index];
}
return (pp / pq * sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x);
} // bessel_y0_forward(T x)
); // bessel_y0_string
const auto bessel_j1_string = jiterator_stringify(
template<typename T>
T bessel_j1_forward(T x) {
static const T PP[] = {
+7.62125616208173112003e-04,
+7.31397056940917570436e-02,
+1.12719608129684925192e+00,
+5.11207951146807644818e+00,
+8.42404590141772420927e+00,
+5.21451598682361504063e+00,
+1.00000000000000000254e+00,
};
static const T PQ[] = {
+5.71323128072548699714e-04,
+6.88455908754495404082e-02,
+1.10514232634061696926e+00,
+5.07386386128601488557e+00,
+8.39985554327604159757e+00,
+5.20982848682361821619e+00,
+9.99999999999999997461e-01,
};
static const T QP[] = {
+5.10862594750176621635e-02,
+4.98213872951233449420e+00,
+7.58238284132545283818e+01,
+3.66779609360150777800e+02,
+7.10856304998926107277e+02,
+5.97489612400613639965e+02,
+2.11688757100572135698e+02,
+2.52070205858023719784e+01,
};
static const T QQ[] = {
+7.42373277035675149943e+01,
+1.05644886038262816351e+03,
+4.98641058337653607651e+03,
+9.56231892404756170795e+03,
+7.99704160447350683650e+03,
+2.82619278517639096600e+03,
+3.36093607810698293419e+02,
};
static const T RP[] = {
-8.99971225705559398224e+08,
+4.52228297998194034323e+11,
-7.27494245221818276015e+13,
+3.68295732863852883286e+15,
};
static const T RQ[] = {
+6.20836478118054335476e+02,
+2.56987256757748830383e+05,
+8.35146791431949253037e+07,
+2.21511595479792499675e+10,
+4.74914122079991414898e+12,
+7.84369607876235854894e+14,
+8.95222336184627338078e+16,
+5.32278620332680085395e+18,
};
if (x < T(0.0)) {
return -bessel_j1_forward(-x);
}
if (x <= T(5.0)) {
T rp = 0.0;
for (uint8_t index = 0; index <= 3; index++) {
rp = rp * (x * x) + RP[index];
}
T rq = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
rq = rq * (x * x) + RQ[index];
}
return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
}
return (pp / pq * cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x);
} // bessel_j1_forward(T x)
); // bessel_j1_string
const auto bessel_y1_string = bessel_j1_string + jiterator_stringify(
template<typename T>
T bessel_y1_forward(T x) {
static const T PP[] = {
+7.62125616208173112003e-04,
+7.31397056940917570436e-02,
+1.12719608129684925192e+00,
+5.11207951146807644818e+00,
+8.42404590141772420927e+00,
+5.21451598682361504063e+00,
+1.00000000000000000254e+00,
};
static const T PQ[] = {
+5.71323128072548699714e-04,
+6.88455908754495404082e-02,
+1.10514232634061696926e+00,
+5.07386386128601488557e+00,
+8.39985554327604159757e+00,
+5.20982848682361821619e+00,
+9.99999999999999997461e-01,
};
static const T QP[] = {
+5.10862594750176621635e-02,
+4.98213872951233449420e+00,
+7.58238284132545283818e+01,
+3.66779609360150777800e+02,
+7.10856304998926107277e+02,
+5.97489612400613639965e+02,
+2.11688757100572135698e+02,
+2.52070205858023719784e+01,
};
static const T QQ[] = {
+7.42373277035675149943e+01,
+1.05644886038262816351e+03,
+4.98641058337653607651e+03,
+9.56231892404756170795e+03,
+7.99704160447350683650e+03,
+2.82619278517639096600e+03,
+3.36093607810698293419e+02,
};
static const T YP[] = {
+1.26320474790178026440e+09,
-6.47355876379160291031e+11,
+1.14509511541823727583e+14,
-8.12770255501325109621e+15,
+2.02439475713594898196e+17,
-7.78877196265950026825e+17,
};
static const T YQ[] = {
+5.94301592346128195359e+02,
+2.35564092943068577943e+05,
+7.34811944459721705660e+07,
+1.87601316108706159478e+10,
+3.88231277496238566008e+12,
+6.20557727146953693363e+14,
+6.87141087355300489866e+16,
+3.97270608116560655612e+18,
};
if (x <= T(5.0)) {
if (x == T(0.0)) {
return NEG_INFINITY;
}
if (x <= T(0.0)) {
return NAN;
}
T yp = 0.0;
for (uint8_t index = 0; index <= 5; index++) {
yp = yp * (x * x) + YP[index];
}
T yq = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
yq = yq * (x * x) + YQ[index];
}
return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * log(x) - T(1.0) / x));
}
T pp = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
}
T pq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
}
T qp = 0.0;
for (uint8_t index = 0; index <= 7; index++) {
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
}
T qq = 0.0;
for (uint8_t index = 0; index <= 6; index++) {
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
}
return (pp / pq * sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x);
} // bessel_y1_forward(T x)
); // bessel_y1_string
const auto chebyshev_polynomial_t_string = jiterator_stringify(
template<typename T>
T chebyshev_polynomial_t_forward(T x, int64_t n) {
@ -1467,6 +1924,354 @@ const auto laguerre_polynomial_l_string = jiterator_stringify(
} // laguerre_polynomial_l_forward(T x, T n)
); // laguerre_polynomial_l_string
const auto modified_bessel_i0_string = jiterator_stringify(
template<typename T>
T modified_bessel_i0_forward(T x) {
static const T A[] = {
-4.41534164647933937950e-18,
+3.33079451882223809783e-17,
-2.43127984654795469359e-16,
+1.71539128555513303061e-15,
-1.16853328779934516808e-14,
+7.67618549860493561688e-14,
-4.85644678311192946090e-13,
+2.95505266312963983461e-12,
-1.72682629144155570723e-11,
+9.67580903537323691224e-11,
-5.18979560163526290666e-10,
+2.65982372468238665035e-09,
-1.30002500998624804212e-08,
+6.04699502254191894932e-08,
-2.67079385394061173391e-07,
+1.11738753912010371815e-06,
-4.41673835845875056359e-06,
+1.64484480707288970893e-05,
-5.75419501008210370398e-05,
+1.88502885095841655729e-04,
-5.76375574538582365885e-04,
+1.63947561694133579842e-03,
-4.32430999505057594430e-03,
+1.05464603945949983183e-02,
-2.37374148058994688156e-02,
+4.93052842396707084878e-02,
-9.49010970480476444210e-02,
+1.71620901522208775349e-01,
-3.04682672343198398683e-01,
+6.76795274409476084995e-01,
};
static const T B[] = {
-7.23318048787475395456e-18,
-4.83050448594418207126e-18,
+4.46562142029675999901e-17,
+3.46122286769746109310e-17,
-2.82762398051658348494e-16,
-3.42548561967721913462e-16,
+1.77256013305652638360e-15,
+3.81168066935262242075e-15,
-9.55484669882830764870e-15,
-4.15056934728722208663e-14,
+1.54008621752140982691e-14,
+3.85277838274214270114e-13,
+7.18012445138366623367e-13,
-1.79417853150680611778e-12,
-1.32158118404477131188e-11,
-3.14991652796324136454e-11,
+1.18891471078464383424e-11,
+4.94060238822496958910e-10,
+3.39623202570838634515e-09,
+2.26666899049817806459e-08,
+2.04891858946906374183e-07,
+2.89137052083475648297e-06,
+6.88975834691682398426e-05,
+3.36911647825569408990e-03,
+8.04490411014108831608e-01,
};
T p;
T q = 0.0;
if (abs(x) <= T(8.0)) {
T a = A[0];
for (uint8_t index = 1; index < 30; index++) {
p = q;
q = a;
a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
}
return exp(abs(x)) * (T(0.5) * (a - p));
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index];
}
return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x));
} // modified_bessel_i0_forward(T x)
); // modified_bessel_i0_string
const auto modified_bessel_i1_string = jiterator_stringify(
template<typename T>
T modified_bessel_i1_forward(T x) {
static const T A[] = {
+2.77791411276104639959e-18,
-2.11142121435816608115e-17,
+1.55363195773620046921e-16,
-1.10559694773538630805e-15,
+7.60068429473540693410e-15,
-5.04218550472791168711e-14,
+3.22379336594557470981e-13,
-1.98397439776494371520e-12,
+1.17361862988909016308e-11,
-6.66348972350202774223e-11,
+3.62559028155211703701e-10,
-1.88724975172282928790e-09,
+9.38153738649577178388e-09,
-4.44505912879632808065e-08,
+2.00329475355213526229e-07,
-8.56872026469545474066e-07,
+3.47025130813767847674e-06,
-1.32731636560394358279e-05,
+4.78156510755005422638e-05,
-1.61760815825896745588e-04,
+5.12285956168575772895e-04,
-1.51357245063125314899e-03,
+4.15642294431288815669e-03,
-1.05640848946261981558e-02,
+2.47264490306265168283e-02,
-5.29459812080949914269e-02,
+1.02643658689847095384e-01,
-1.76416518357834055153e-01,
+2.52587186443633654823e-01,
};
static const T B[] = {
+7.51729631084210481353e-18,
+4.41434832307170791151e-18,
-4.65030536848935832153e-17,
-3.20952592199342395980e-17,
+2.96262899764595013876e-16,
+3.30820231092092828324e-16,
-1.88035477551078244854e-15,
-3.81440307243700780478e-15,
+1.04202769841288027642e-14,
+4.27244001671195135429e-14,
-2.10154184277266431302e-14,
-4.08355111109219731823e-13,
-7.19855177624590851209e-13,
+2.03562854414708950722e-12,
+1.41258074366137813316e-11,
+3.25260358301548823856e-11,
-1.89749581235054123450e-11,
-5.58974346219658380687e-10,
-3.83538038596423702205e-09,
-2.63146884688951950684e-08,
-2.51223623787020892529e-07,
-3.88256480887769039346e-06,
-1.10588938762623716291e-04,
-9.76109749136146840777e-03,
+7.78576235018280120474e-01,
};
T p;
T q = 0.0;
if (abs(x) <= T(8.0)) {
T a = A[0];
for (uint8_t index = 1; index < 29; index++) {
p = q;
q = a;
a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
}
if (x < T(0.0)) {
return -(T(0.5) * (a - p) * abs(x) * exp(abs(x)));
}
return T(0.5) * (a - p) * abs(x) * exp(abs(x));
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index];
}
if (x < T(0.0)) {
return -(exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)));
}
return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x));
} // modified_bessel_i1_forward(T x)
); // modified_bessel_i1_string
const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify(
template<typename T>
T modified_bessel_k0_forward(T x) {
static const T A[] = {
+1.37446543561352307156e-16,
+4.25981614279661018399e-14,
+1.03496952576338420167e-11,
+1.90451637722020886025e-09,
+2.53479107902614945675e-07,
+2.28621210311945178607e-05,
+1.26461541144692592338e-03,
+3.59799365153615016266e-02,
+3.44289899924628486886e-01,
-5.35327393233902768720e-01,
};
static const T B[] = {
+5.30043377268626276149e-18,
-1.64758043015242134646e-17,
+5.21039150503902756861e-17,
-1.67823109680541210385e-16,
+5.51205597852431940784e-16,
-1.84859337734377901440e-15,
+6.34007647740507060557e-15,
-2.22751332699166985548e-14,
+8.03289077536357521100e-14,
-2.98009692317273043925e-13,
+1.14034058820847496303e-12,
-4.51459788337394416547e-12,
+1.85594911495471785253e-11,
-7.95748924447710747776e-11,
+3.57739728140030116597e-10,
-1.69753450938905987466e-09,
+8.57403401741422608519e-09,
-4.66048989768794782956e-08,
+2.76681363944501510342e-07,
-1.83175552271911948767e-06,
+1.39498137188764993662e-05,
-1.28495495816278026384e-04,
+1.56988388573005337491e-03,
-3.14481013119645005427e-02,
+2.44030308206595545468e+00,
};
if (x == T(0.0)) {
return INFINITY;
}
if (x < T(0.0)) {
return NAN;
}
T p;
T q = 0.0;
if (x <= T(2.0)) {
T a = A[0];
for (uint8_t index = 1; index < 10; index++) {
p = q;
q = a;
a = (x * x - T(2.0)) * q - p + A[index];
}
return T(0.5) * (a - p) - log(0.5 * x) * modified_bessel_i0_forward(x);
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
}
return exp(-x) * (T(0.5) * (b - p)) / sqrt(x);
} // modified_bessel_k0_forward(T x)
); // modified_bessel_k0_string
const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify(
template<typename T>
T modified_bessel_k1_forward(T x) {
static const T A[] = {
-7.02386347938628759343e-18,
-2.42744985051936593393e-15,
-6.66690169419932900609e-13,
-1.41148839263352776110e-10,
-2.21338763073472585583e-08,
-2.43340614156596823496e-06,
-1.73028895751305206302e-04,
-6.97572385963986435018e-03,
-1.22611180822657148235e-01,
-3.53155960776544875667e-01,
+1.52530022733894777053e+00,
};
static const T B[] = {
-5.75674448366501715755e-18,
+1.79405087314755922667e-17,
-5.68946255844285935196e-17,
+1.83809354436663880070e-16,
-6.05704724837331885336e-16,
+2.03870316562433424052e-15,
-7.01983709041831346144e-15,
+2.47715442448130437068e-14,
-8.97670518232499435011e-14,
+3.34841966607842919884e-13,
-1.28917396095102890680e-12,
+5.13963967348173025100e-12,
-2.12996783842756842877e-11,
+9.21831518760500529508e-11,
-4.19035475934189648750e-10,
+2.01504975519703286596e-09,
-1.03457624656780970260e-08,
+5.74108412545004946722e-08,
-3.50196060308781257119e-07,
+2.40648494783721712015e-06,
-1.93619797416608296024e-05,
+1.95215518471351631108e-04,
-2.85781685962277938680e-03,
+1.03923736576817238437e-01,
+2.72062619048444266945e+00,
};
if (x == T(0.0)) {
return INFINITY;
}
if (x < T(0.0)) {
return NAN;
}
T p;
T q = 0.0;
if (x <= T(2.0)) {
T a = A[0];
for (uint8_t index = 1; index < 11; index++) {
p = q;
q = a;
a = (x * x - T(2.0)) * q - p + A[index];
}
return log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
}
T b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
}
return exp(-x) * (T(0.5) * (b - p)) / sqrt(x);
} // modified_bessel_k1_forward(T x)
); // modified_bessel_k1_string
#else // !AT_USE_JITERATOR() -- kernels must be precompiled
template <typename scalar_t>

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char bessel_j0_name[] = "bessel_j0_forward";
void bessel_j0_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cuda", [&]() {
jitted_gpu_kernel<bessel_j0_name, scalar_t, scalar_t, 1>(iterator, bessel_j0_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return bessel_j0_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char bessel_j1_name[] = "bessel_j1_forward";
void bessel_j1_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cuda", [&]() {
jitted_gpu_kernel<bessel_j1_name, scalar_t, scalar_t, 1>(iterator, bessel_j1_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return bessel_j1_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char bessel_y0_name[] = "bessel_y0_forward";
void bessel_y0_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cuda", [&]() {
jitted_gpu_kernel<bessel_y0_name, scalar_t, scalar_t, 1>(iterator, bessel_y0_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return bessel_y0_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char bessel_y1_name[] = "bessel_y1_forward";
void bessel_y1_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cuda", [&]() {
jitted_gpu_kernel<bessel_y1_name, scalar_t, scalar_t, 1>(iterator, bessel_y1_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return bessel_y1_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char modified_bessel_i0_name[] = "modified_bessel_i0_forward";
void modified_bessel_i0_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cuda", [&]() {
jitted_gpu_kernel<modified_bessel_i0_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_i0_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return modified_bessel_i0_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char modified_bessel_i1_name[] = "modified_bessel_i1_forward";
void modified_bessel_i1_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cuda", [&]() {
jitted_gpu_kernel<modified_bessel_i1_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_i1_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return modified_bessel_i1_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char modified_bessel_k0_name[] = "modified_bessel_k0_forward";
void modified_bessel_k0_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cuda", [&]() {
jitted_gpu_kernel<modified_bessel_k0_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_k0_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return modified_bessel_k0_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &modified_bessel_k0_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -0,0 +1,43 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
namespace {
const char modified_bessel_k1_name[] = "modified_bessel_k1_forward";
void modified_bessel_k1_kernel_cuda(TensorIteratorBase& iterator) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cuda", [&]() {
jitted_gpu_kernel<modified_bessel_k1_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_k1_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cuda", [&]() {
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return modified_bessel_k1_forward(a);
});
});
#endif // AT_USE_JITERATOR()
}
}
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &modified_bessel_k1_kernel_cuda);
} // namespace native
} // namespace at

View File

@ -12404,6 +12404,58 @@
dispatch:
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention
- func: special_bessel_j0(Tensor self) -> Tensor
python_module: special
structured_delegate: special_bessel_j0.out
variants: function
- func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_bessel_j0_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_bessel_j1(Tensor self) -> Tensor
python_module: special
structured_delegate: special_bessel_j1.out
variants: function
- func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_bessel_j1_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_bessel_y0(Tensor self) -> Tensor
python_module: special
structured_delegate: special_bessel_y0.out
variants: function
- func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_bessel_y0_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_bessel_y1(Tensor self) -> Tensor
python_module: special
structured_delegate: special_bessel_y1.out
variants: function
- func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_bessel_y1_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
device_check: NoCheck
python_module: special
@ -12588,3 +12640,55 @@
device_check: NoCheck
python_module: special
variants: function
- func: special_modified_bessel_i0(Tensor self) -> Tensor
python_module: special
structured_delegate: special_modified_bessel_i0.out
variants: function
- func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_modified_bessel_i0_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_modified_bessel_i1(Tensor self) -> Tensor
python_module: special
structured_delegate: special_modified_bessel_i1.out
variants: function
- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_modified_bessel_i1_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_modified_bessel_k0(Tensor self) -> Tensor
python_module: special
structured_delegate: special_modified_bessel_k0.out
variants: function
- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_modified_bessel_k0_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function
- func: special_modified_bessel_k1(Tensor self) -> Tensor
python_module: special
structured_delegate: special_modified_bessel_k1.out
variants: function
- func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_modified_bessel_k1_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True
variants: function

View File

@ -2955,6 +2955,10 @@
"softmax"
],
"torch.special": [
"bessel_j0",
"bessel_j1",
"bessel_y0",
"bessel_y1",
"chebyshev_polynomial_t",
"chebyshev_polynomial_u",
"digamma",
@ -2981,6 +2985,10 @@
"log_softmax",
"logit",
"logsumexp",
"modified_bessel_i0",
"modified_bessel_i1",
"modified_bessel_k0",
"modified_bessel_k1",
"multigammaln",
"ndtr",
"ndtri",

View File

@ -2757,6 +2757,18 @@
self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result)
index: non_differentiable
- name: special_bessel_j0(Tensor self) -> Tensor
self: non_differentiable
- name: special_bessel_j1(Tensor self) -> Tensor
self: non_differentiable
- name: special_bessel_y0(Tensor self) -> Tensor
self: non_differentiable
- name: special_bessel_y1(Tensor self) -> Tensor
self: non_differentiable
- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
@ -2806,3 +2818,15 @@
- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_modified_bessel_i0(Tensor self) -> Tensor
self: non_differentiable
- name: special_modified_bessel_i1(Tensor self) -> Tensor
self: non_differentiable
- name: special_modified_bessel_k0(Tensor self) -> Tensor
self: non_differentiable
- name: special_modified_bessel_k1(Tensor self) -> Tensor
self: non_differentiable

View File

@ -565,6 +565,82 @@ inline Tensor softmax(const Tensor& self, int64_t dim, c10::optional<ScalarType>
return torch::special_softmax(self, dim, dtype);
}
/// Bessel function of the first kind of order 0.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_j0.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::bessel_j0(x);
/// ```
inline Tensor bessel_j0(const Tensor& self) {
return torch::special_bessel_j0(self);
}
inline Tensor& bessel_j0_out(Tensor& result, const Tensor& self) {
return torch::special_bessel_j0_out(result, self);
}
/// Bessel function of the first kind of order 1.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_j1.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::bessel_j1(x);
/// ```
inline Tensor bessel_j1(const Tensor& self) {
return torch::special_bessel_j1(self);
}
inline Tensor& bessel_j1_out(Tensor& result, const Tensor& self) {
return torch::special_bessel_j1_out(result, self);
}
/// Bessel function of the second kind of order 0.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_y0.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::bessel_y0(x);
/// ```
inline Tensor bessel_y0(const Tensor& self) {
return torch::special_bessel_y0(self);
}
inline Tensor& bessel_y0_out(Tensor& result, const Tensor& self) {
return torch::special_bessel_y0_out(result, self);
}
/// Bessel function of the second kind of order 1.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_y1.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::bessel_y1(x);
/// ```
inline Tensor bessel_y1(const Tensor& self) {
return torch::special_bessel_y1(self);
}
inline Tensor& bessel_y1_out(Tensor& result, const Tensor& self) {
return torch::special_bessel_y1_out(result, self);
}
/// Chebyshev polynomial of the first kind.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_t.
@ -745,4 +821,80 @@ inline Tensor& laguerre_polynomial_l_out(Tensor& output, const Tensor& x, const
return torch::special_laguerre_polynomial_l_out(output, x, n);
}
/// Modified Bessel function of the first kind of order 0.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i0.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::modified_bessel_i0(x);
/// ```
inline Tensor modified_bessel_i0(const Tensor& self) {
return torch::special_modified_bessel_i0(self);
}
inline Tensor& modified_bessel_i0_out(Tensor& result, const Tensor& self) {
return torch::special_modified_bessel_i0_out(result, self);
}
/// Modified Bessel function of the first kind of order 1.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i1.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::modified_bessel_i1(x);
/// ```
inline Tensor modified_bessel_i1(const Tensor& self) {
return torch::special_modified_bessel_i1(self);
}
inline Tensor& modified_bessel_i1_out(Tensor& result, const Tensor& self) {
return torch::special_modified_bessel_i1_out(result, self);
}
/// Modified Bessel function of the second kind of order 0.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k0.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::modified_bessel_k0(x);
/// ```
inline Tensor modified_bessel_k0(const Tensor& self) {
return torch::special_modified_bessel_k0(self);
}
inline Tensor& modified_bessel_k0_out(Tensor& result, const Tensor& self) {
return torch::special_modified_bessel_k0_out(result, self);
}
/// Modified Bessel function of the second kind of order 1.
///
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k1.
///
/// Example:
///
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
///
/// torch::special::modified_bessel_k1(x);
/// ```
inline Tensor modified_bessel_k1(const Tensor& self) {
return torch::special_modified_bessel_k1(self);
}
inline Tensor& modified_bessel_k1_out(Tensor& result, const Tensor& self) {
return torch::special_modified_bessel_k1_out(result, self);
}
}} // torch::special

View File

@ -986,6 +986,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1,
torch.swapaxes: lambda input, dim0, dim1: -1,
torch.swapdims: lambda input, axis0, axis1: -1,
torch.special.bessel_j0: lambda input: -1,
torch.special.bessel_j1: lambda input: -1,
torch.special.bessel_y0: lambda input: -1,
torch.special.bessel_y1: lambda input: -1,
torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
torch.special.digamma: lambda input: -1,
@ -1012,6 +1016,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.log_softmax: lambda input, dim, dtype=None: -1,
torch.special.logit: lambda input: -1,
torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
torch.special.modified_bessel_i0: lambda input: -1,
torch.special.modified_bessel_i1: lambda input: -1,
torch.special.modified_bessel_k0: lambda input: -1,
torch.special.modified_bessel_k1: lambda input: -1,
torch.special.multigammaln: lambda input, p: -1,
torch.special.ndtr: lambda input: -1,
torch.special.ndtri: lambda input: -1,

View File

@ -3,6 +3,10 @@ from torch._C import _add_docstr, _special # type: ignore[attr-defined]
from torch._torch_docs import common_args, multi_dim_common
__all__ = [
'bessel_j0',
'bessel_j1',
'bessel_y0',
'bessel_y1',
'chebyshev_polynomial_t',
'chebyshev_polynomial_u',
'digamma',
@ -24,11 +28,15 @@ __all__ = [
'i1',
'i1e',
'laguerre_polynomial_l',
'log1p',
'log_ndtr',
'log_softmax',
'log1p',
'logit',
'logsumexp',
'modified_bessel_i0',
'modified_bessel_i1',
'modified_bessel_k0',
'modified_bessel_k1',
'multigammaln',
'ndtr',
'ndtri',
@ -856,6 +864,62 @@ Example::
""".format(**common_args))
bessel_j0 = _add_docstr(_special.special_bessel_j0,
r"""
bessel_j0(input, *, out=None) -> Tensor
Bessel function of the first kind of order :math:`0`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
bessel_j1 = _add_docstr(_special.special_bessel_j1,
r"""
bessel_j1(input, *, out=None) -> Tensor
Bessel function of the first kind of order :math:`1`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
bessel_y0 = _add_docstr(_special.special_bessel_y0,
r"""
bessel_y0(input, *, out=None) -> Tensor
Bessel function of the second kind of order :math:`0`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
bessel_y1 = _add_docstr(_special.special_bessel_y1,
r"""
bessel_y1(input, *, out=None) -> Tensor
Bessel function of the second kind of order :math:`1`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
chebyshev_polynomial_t = _add_docstr(_special.special_chebyshev_polynomial_t,
r"""
chebyshev_polynomial_t(input, n, *, out=None) -> Tensor
@ -976,3 +1040,59 @@ Args:
Keyword args:
{out}
""".format(**common_args))
modified_bessel_i0 = _add_docstr(_special.special_modified_bessel_i0,
r"""
modified_bessel_i0(input, *, out=None) -> Tensor
Modified Bessel function of the first kind of order :math:`0`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
modified_bessel_i1 = _add_docstr(_special.special_modified_bessel_i1,
r"""
modified_bessel_i1(input, *, out=None) -> Tensor
Modified Bessel function of the first kind of order :math:`1`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
modified_bessel_k0 = _add_docstr(_special.special_modified_bessel_k0,
r"""
modified_bessel_k0(input, *, out=None) -> Tensor
Modified Bessel function of the second kind of order :math:`0`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))
modified_bessel_k1 = _add_docstr(_special.special_modified_bessel_k1,
r"""
modified_bessel_k1(input, *, out=None) -> Tensor
Modified Bessel function of the second kind of order :math:`1`.
""" + r"""
Args:
{input}
Keyword args:
{out}
""".format(**common_args))

View File

@ -19084,6 +19084,62 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_scatter_reduce,
),
UnaryUfuncInfo(
'special.bessel_j0',
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.j0 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.bessel_j1',
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.j1 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.bessel_y0',
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.y0 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.bessel_y1',
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.y1 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
BinaryUfuncInfo(
'special.chebyshev_polynomial_t',
dtypes=all_types_and(torch.bool),
@ -19139,6 +19195,62 @@ op_db: List[OpInfo] = [
supports_one_python_scalar=True,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.modified_bessel_i0',
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.i0 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.modified_bessel_i1',
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.i1 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.modified_bessel_k0',
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.k0 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
UnaryUfuncInfo(
'special.modified_bessel_k1',
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.k1 if TEST_SCIPY else _NOTHING,
supports_autograd=False,
),
]
# NOTE [Python References]