mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Bessel functions (#78451)
Adds: ```Python bessel_j0(input, *, out=None) -> Tensor ``` Bessel function of the first kind of order $0$, $J_{0}(\text{input})$. ```Python bessel_j1(input, *, out=None) -> Tensor ``` Bessel function of the first kind of order $1$, $J_{1}(\text{input})$. ```Python bessel_j0(input, *, out=None) -> Tensor ``` Bessel function of the second kind of order $0$, $Y_{0}(\text{input})$. ```Python bessel_j1(input, *, out=None) -> Tensor ``` Bessel function of the second kind of order $1$, $Y_{1}(\text{input})$. ```Python modified_bessel_i0(input, *, out=None) -> Tensor ``` Modified Bessel function of the first kind of order $0$, $I_{0}(\text{input})$. ```Python modified_bessel_i1(input, *, out=None) -> Tensor ``` Modified Bessel function of the first kind of order $1$, $I_{1}(\text{input})$. ```Python modified_bessel_k0(input, *, out=None) -> Tensor ``` Modified Bessel function of the second kind of order $0$, $K_{0}(\text{input})$. ```Python modified_bessel_k1(input, *, out=None) -> Tensor ``` Modified Bessel function of the second kind of order $1$, $K_{1}(\text{input})$. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78451 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
78824a7d54
commit
4a5381ab40
@ -2177,6 +2177,455 @@ static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.96936729297347051624e-04,
|
||||
+8.28352392107440799803e-02,
|
||||
+1.23953371646414299388e+00,
|
||||
+5.44725003058768775090e+00,
|
||||
+8.74716500199817011941e+00,
|
||||
+5.30324038235394892183e+00,
|
||||
+9.99999999999999997821e-01,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+9.24408810558863637013e-04,
|
||||
+8.56288474354474431428e-02,
|
||||
+1.25352743901058953537e+00,
|
||||
+5.47097740330417105182e+00,
|
||||
+8.76190883237069594232e+00,
|
||||
+5.30605288235394617618e+00,
|
||||
+1.00000000000000000218e+00,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
-1.13663838898469149931e-02,
|
||||
-1.28252718670509318512e+00,
|
||||
-1.95539544257735972385e+01,
|
||||
-9.32060152123768231369e+01,
|
||||
-1.77681167980488050595e+02,
|
||||
-1.47077505154951170175e+02,
|
||||
-5.14105326766599330220e+01,
|
||||
-6.05014350600728481186e+00,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+6.43178256118178023184e+01,
|
||||
+8.56430025976980587198e+02,
|
||||
+3.88240183605401609683e+03,
|
||||
+7.24046774195652478189e+03,
|
||||
+5.93072701187316984827e+03,
|
||||
+2.06209331660327847417e+03,
|
||||
+2.42005740240291393179e+02,
|
||||
};
|
||||
|
||||
static const T RP[] = {
|
||||
-4.79443220978201773821e+09,
|
||||
+1.95617491946556577543e+12,
|
||||
-2.49248344360967716204e+14,
|
||||
+9.70862251047306323952e+15,
|
||||
};
|
||||
|
||||
static const T RQ[] = {
|
||||
+4.99563147152651017219e+02,
|
||||
+1.73785401676374683123e+05,
|
||||
+4.84409658339962045305e+07,
|
||||
+1.11855537045356834862e+10,
|
||||
+2.11277520115489217587e+12,
|
||||
+3.10518229857422583814e+14,
|
||||
+3.18121955943204943306e+16,
|
||||
+1.71086294081043136091e+18,
|
||||
};
|
||||
|
||||
if (x < T(0)) {
|
||||
x = -x;
|
||||
}
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
if (x < T(0.00001)) {
|
||||
return T(1.0) - x * x / T(4.0);
|
||||
}
|
||||
|
||||
T rp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 3; index++) {
|
||||
rp = rp * (x * x) + RP[index];
|
||||
}
|
||||
|
||||
T rq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
rq = rq * (x * x) + RQ[index];
|
||||
}
|
||||
|
||||
return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(25.0) / (x * x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(25.0) / (x * x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(25.0) / (x * x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(25.0) / (x * x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
|
||||
} // bessel_j0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.62125616208173112003e-04,
|
||||
+7.31397056940917570436e-02,
|
||||
+1.12719608129684925192e+00,
|
||||
+5.11207951146807644818e+00,
|
||||
+8.42404590141772420927e+00,
|
||||
+5.21451598682361504063e+00,
|
||||
+1.00000000000000000254e+00,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+5.71323128072548699714e-04,
|
||||
+6.88455908754495404082e-02,
|
||||
+1.10514232634061696926e+00,
|
||||
+5.07386386128601488557e+00,
|
||||
+8.39985554327604159757e+00,
|
||||
+5.20982848682361821619e+00,
|
||||
+9.99999999999999997461e-01,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
+5.10862594750176621635e-02,
|
||||
+4.98213872951233449420e+00,
|
||||
+7.58238284132545283818e+01,
|
||||
+3.66779609360150777800e+02,
|
||||
+7.10856304998926107277e+02,
|
||||
+5.97489612400613639965e+02,
|
||||
+2.11688757100572135698e+02,
|
||||
+2.52070205858023719784e+01,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+7.42373277035675149943e+01,
|
||||
+1.05644886038262816351e+03,
|
||||
+4.98641058337653607651e+03,
|
||||
+9.56231892404756170795e+03,
|
||||
+7.99704160447350683650e+03,
|
||||
+2.82619278517639096600e+03,
|
||||
+3.36093607810698293419e+02,
|
||||
};
|
||||
|
||||
static const T RP[] = {
|
||||
-8.99971225705559398224e+08,
|
||||
+4.52228297998194034323e+11,
|
||||
-7.27494245221818276015e+13,
|
||||
+3.68295732863852883286e+15,
|
||||
};
|
||||
|
||||
static const T RQ[] = {
|
||||
+6.20836478118054335476e+02,
|
||||
+2.56987256757748830383e+05,
|
||||
+8.35146791431949253037e+07,
|
||||
+2.21511595479792499675e+10,
|
||||
+4.74914122079991414898e+12,
|
||||
+7.84369607876235854894e+14,
|
||||
+8.95222336184627338078e+16,
|
||||
+5.32278620332680085395e+18,
|
||||
};
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return -bessel_j1_forward(-x);
|
||||
}
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
T rp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 3; index++) {
|
||||
rp = rp * (x * x) + RP[index];
|
||||
}
|
||||
|
||||
T rq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
rq = rq * (x * x) + RQ[index];
|
||||
}
|
||||
|
||||
return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
|
||||
} // bessel_j1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.96936729297347051624e-04,
|
||||
+8.28352392107440799803e-02,
|
||||
+1.23953371646414299388e+00,
|
||||
+5.44725003058768775090e+00,
|
||||
+8.74716500199817011941e+00,
|
||||
+5.30324038235394892183e+00,
|
||||
+9.99999999999999997821e-01,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+9.24408810558863637013e-04,
|
||||
+8.56288474354474431428e-02,
|
||||
+1.25352743901058953537e+00,
|
||||
+5.47097740330417105182e+00,
|
||||
+8.76190883237069594232e+00,
|
||||
+5.30605288235394617618e+00,
|
||||
+1.00000000000000000218e+00,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
-1.13663838898469149931e-02,
|
||||
-1.28252718670509318512e+00,
|
||||
-1.95539544257735972385e+01,
|
||||
-9.32060152123768231369e+01,
|
||||
-1.77681167980488050595e+02,
|
||||
-1.47077505154951170175e+02,
|
||||
-5.14105326766599330220e+01,
|
||||
-6.05014350600728481186e+00,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+6.43178256118178023184e+01,
|
||||
+8.56430025976980587198e+02,
|
||||
+3.88240183605401609683e+03,
|
||||
+7.24046774195652478189e+03,
|
||||
+5.93072701187316984827e+03,
|
||||
+2.06209331660327847417e+03,
|
||||
+2.42005740240291393179e+02,
|
||||
};
|
||||
|
||||
static const T YP[] = {
|
||||
+1.55924367855235737965e+04,
|
||||
-1.46639295903971606143e+07,
|
||||
+5.43526477051876500413e+09,
|
||||
-9.82136065717911466409e+11,
|
||||
+8.75906394395366999549e+13,
|
||||
-3.46628303384729719441e+15,
|
||||
+4.42733268572569800351e+16,
|
||||
-1.84950800436986690637e+16,
|
||||
};
|
||||
|
||||
static const T YQ[] = {
|
||||
+1.04128353664259848412e+03,
|
||||
+6.26107330137134956842e+05,
|
||||
+2.68919633393814121987e+08,
|
||||
+8.64002487103935000337e+10,
|
||||
+2.02979612750105546709e+13,
|
||||
+3.17157752842975028269e+15,
|
||||
+2.50596256172653059228e+17,
|
||||
};
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
if (x == T(0.0)) {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
T yp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
yp = yp * (x * x) + YP[index];
|
||||
}
|
||||
|
||||
T yq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
yq = yq * (x * x) + YQ[index];
|
||||
}
|
||||
|
||||
return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x));
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(25.0) / (x * x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(25.0) / (x * x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(25.0) / (x * x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(25.0) / (x * x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
|
||||
} // bessel_y0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.62125616208173112003e-04,
|
||||
+7.31397056940917570436e-02,
|
||||
+1.12719608129684925192e+00,
|
||||
+5.11207951146807644818e+00,
|
||||
+8.42404590141772420927e+00,
|
||||
+5.21451598682361504063e+00,
|
||||
+1.00000000000000000254e+00,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+5.71323128072548699714e-04,
|
||||
+6.88455908754495404082e-02,
|
||||
+1.10514232634061696926e+00,
|
||||
+5.07386386128601488557e+00,
|
||||
+8.39985554327604159757e+00,
|
||||
+5.20982848682361821619e+00,
|
||||
+9.99999999999999997461e-01,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
+5.10862594750176621635e-02,
|
||||
+4.98213872951233449420e+00,
|
||||
+7.58238284132545283818e+01,
|
||||
+3.66779609360150777800e+02,
|
||||
+7.10856304998926107277e+02,
|
||||
+5.97489612400613639965e+02,
|
||||
+2.11688757100572135698e+02,
|
||||
+2.52070205858023719784e+01,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+7.42373277035675149943e+01,
|
||||
+1.05644886038262816351e+03,
|
||||
+4.98641058337653607651e+03,
|
||||
+9.56231892404756170795e+03,
|
||||
+7.99704160447350683650e+03,
|
||||
+2.82619278517639096600e+03,
|
||||
+3.36093607810698293419e+02,
|
||||
};
|
||||
|
||||
static const T YP[] = {
|
||||
+1.26320474790178026440e+09,
|
||||
-6.47355876379160291031e+11,
|
||||
+1.14509511541823727583e+14,
|
||||
-8.12770255501325109621e+15,
|
||||
+2.02439475713594898196e+17,
|
||||
-7.78877196265950026825e+17,
|
||||
};
|
||||
|
||||
static const T YQ[] = {
|
||||
+5.94301592346128195359e+02,
|
||||
+2.35564092943068577943e+05,
|
||||
+7.34811944459721705660e+07,
|
||||
+1.87601316108706159478e+10,
|
||||
+3.88231277496238566008e+12,
|
||||
+6.20557727146953693363e+14,
|
||||
+6.87141087355300489866e+16,
|
||||
+3.97270608116560655612e+18,
|
||||
};
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
if (x == T(0.0)) {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
if (x <= T(0.0)) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
T yp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 5; index++) {
|
||||
yp = yp * (x * x) + YP[index];
|
||||
}
|
||||
|
||||
T yq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
yq = yq * (x * x) + YQ[index];
|
||||
}
|
||||
|
||||
return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x));
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
|
||||
} // bessel_y1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
@ -2369,4 +2818,344 @@ static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
|
||||
return laguerre_polynomial_l_forward(x, static_cast<int64_t>(n));
|
||||
} // laguerre_polynomial_l_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
|
||||
static const T A[] = {
|
||||
-4.41534164647933937950e-18,
|
||||
+3.33079451882223809783e-17,
|
||||
-2.43127984654795469359e-16,
|
||||
+1.71539128555513303061e-15,
|
||||
-1.16853328779934516808e-14,
|
||||
+7.67618549860493561688e-14,
|
||||
-4.85644678311192946090e-13,
|
||||
+2.95505266312963983461e-12,
|
||||
-1.72682629144155570723e-11,
|
||||
+9.67580903537323691224e-11,
|
||||
-5.18979560163526290666e-10,
|
||||
+2.65982372468238665035e-09,
|
||||
-1.30002500998624804212e-08,
|
||||
+6.04699502254191894932e-08,
|
||||
-2.67079385394061173391e-07,
|
||||
+1.11738753912010371815e-06,
|
||||
-4.41673835845875056359e-06,
|
||||
+1.64484480707288970893e-05,
|
||||
-5.75419501008210370398e-05,
|
||||
+1.88502885095841655729e-04,
|
||||
-5.76375574538582365885e-04,
|
||||
+1.63947561694133579842e-03,
|
||||
-4.32430999505057594430e-03,
|
||||
+1.05464603945949983183e-02,
|
||||
-2.37374148058994688156e-02,
|
||||
+4.93052842396707084878e-02,
|
||||
-9.49010970480476444210e-02,
|
||||
+1.71620901522208775349e-01,
|
||||
-3.04682672343198398683e-01,
|
||||
+6.76795274409476084995e-01,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
-7.23318048787475395456e-18,
|
||||
-4.83050448594418207126e-18,
|
||||
+4.46562142029675999901e-17,
|
||||
+3.46122286769746109310e-17,
|
||||
-2.82762398051658348494e-16,
|
||||
-3.42548561967721913462e-16,
|
||||
+1.77256013305652638360e-15,
|
||||
+3.81168066935262242075e-15,
|
||||
-9.55484669882830764870e-15,
|
||||
-4.15056934728722208663e-14,
|
||||
+1.54008621752140982691e-14,
|
||||
+3.85277838274214270114e-13,
|
||||
+7.18012445138366623367e-13,
|
||||
-1.79417853150680611778e-12,
|
||||
-1.32158118404477131188e-11,
|
||||
-3.14991652796324136454e-11,
|
||||
+1.18891471078464383424e-11,
|
||||
+4.94060238822496958910e-10,
|
||||
+3.39623202570838634515e-09,
|
||||
+2.26666899049817806459e-08,
|
||||
+2.04891858946906374183e-07,
|
||||
+2.89137052083475648297e-06,
|
||||
+6.88975834691682398426e-05,
|
||||
+3.36911647825569408990e-03,
|
||||
+8.04490411014108831608e-01,
|
||||
};
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (std::abs(x) <= T(8.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 30; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
return std::exp(std::abs(x)) * (T(0.5) * (a - p));
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
|
||||
} // modified_bessel_i0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
|
||||
static const T A[] = {
|
||||
+2.77791411276104639959e-18,
|
||||
-2.11142121435816608115e-17,
|
||||
+1.55363195773620046921e-16,
|
||||
-1.10559694773538630805e-15,
|
||||
+7.60068429473540693410e-15,
|
||||
-5.04218550472791168711e-14,
|
||||
+3.22379336594557470981e-13,
|
||||
-1.98397439776494371520e-12,
|
||||
+1.17361862988909016308e-11,
|
||||
-6.66348972350202774223e-11,
|
||||
+3.62559028155211703701e-10,
|
||||
-1.88724975172282928790e-09,
|
||||
+9.38153738649577178388e-09,
|
||||
-4.44505912879632808065e-08,
|
||||
+2.00329475355213526229e-07,
|
||||
-8.56872026469545474066e-07,
|
||||
+3.47025130813767847674e-06,
|
||||
-1.32731636560394358279e-05,
|
||||
+4.78156510755005422638e-05,
|
||||
-1.61760815825896745588e-04,
|
||||
+5.12285956168575772895e-04,
|
||||
-1.51357245063125314899e-03,
|
||||
+4.15642294431288815669e-03,
|
||||
-1.05640848946261981558e-02,
|
||||
+2.47264490306265168283e-02,
|
||||
-5.29459812080949914269e-02,
|
||||
+1.02643658689847095384e-01,
|
||||
-1.76416518357834055153e-01,
|
||||
+2.52587186443633654823e-01,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
+7.51729631084210481353e-18,
|
||||
+4.41434832307170791151e-18,
|
||||
-4.65030536848935832153e-17,
|
||||
-3.20952592199342395980e-17,
|
||||
+2.96262899764595013876e-16,
|
||||
+3.30820231092092828324e-16,
|
||||
-1.88035477551078244854e-15,
|
||||
-3.81440307243700780478e-15,
|
||||
+1.04202769841288027642e-14,
|
||||
+4.27244001671195135429e-14,
|
||||
-2.10154184277266431302e-14,
|
||||
-4.08355111109219731823e-13,
|
||||
-7.19855177624590851209e-13,
|
||||
+2.03562854414708950722e-12,
|
||||
+1.41258074366137813316e-11,
|
||||
+3.25260358301548823856e-11,
|
||||
-1.89749581235054123450e-11,
|
||||
-5.58974346219658380687e-10,
|
||||
-3.83538038596423702205e-09,
|
||||
-2.63146884688951950684e-08,
|
||||
-2.51223623787020892529e-07,
|
||||
-3.88256480887769039346e-06,
|
||||
-1.10588938762623716291e-04,
|
||||
-9.76109749136146840777e-03,
|
||||
+7.78576235018280120474e-01,
|
||||
};
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (std::abs(x) <= T(8.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 29; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)));
|
||||
}
|
||||
|
||||
return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x));
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)));
|
||||
}
|
||||
|
||||
return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
|
||||
} // modified_bessel_i1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
|
||||
static const T A[] = {
|
||||
+1.37446543561352307156e-16,
|
||||
+4.25981614279661018399e-14,
|
||||
+1.03496952576338420167e-11,
|
||||
+1.90451637722020886025e-09,
|
||||
+2.53479107902614945675e-07,
|
||||
+2.28621210311945178607e-05,
|
||||
+1.26461541144692592338e-03,
|
||||
+3.59799365153615016266e-02,
|
||||
+3.44289899924628486886e-01,
|
||||
-5.35327393233902768720e-01,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
+5.30043377268626276149e-18,
|
||||
-1.64758043015242134646e-17,
|
||||
+5.21039150503902756861e-17,
|
||||
-1.67823109680541210385e-16,
|
||||
+5.51205597852431940784e-16,
|
||||
-1.84859337734377901440e-15,
|
||||
+6.34007647740507060557e-15,
|
||||
-2.22751332699166985548e-14,
|
||||
+8.03289077536357521100e-14,
|
||||
-2.98009692317273043925e-13,
|
||||
+1.14034058820847496303e-12,
|
||||
-4.51459788337394416547e-12,
|
||||
+1.85594911495471785253e-11,
|
||||
-7.95748924447710747776e-11,
|
||||
+3.57739728140030116597e-10,
|
||||
-1.69753450938905987466e-09,
|
||||
+8.57403401741422608519e-09,
|
||||
-4.66048989768794782956e-08,
|
||||
+2.76681363944501510342e-07,
|
||||
-1.83175552271911948767e-06,
|
||||
+1.39498137188764993662e-05,
|
||||
-1.28495495816278026384e-04,
|
||||
+1.56988388573005337491e-03,
|
||||
-3.14481013119645005427e-02,
|
||||
+2.44030308206595545468e+00,
|
||||
};
|
||||
|
||||
if (x == T(0.0)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (x <= T(2.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 10; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = (x * x - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x);
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
|
||||
} // modified_bessel_k0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
|
||||
static const T A[] = {
|
||||
-7.02386347938628759343e-18,
|
||||
-2.42744985051936593393e-15,
|
||||
-6.66690169419932900609e-13,
|
||||
-1.41148839263352776110e-10,
|
||||
-2.21338763073472585583e-08,
|
||||
-2.43340614156596823496e-06,
|
||||
-1.73028895751305206302e-04,
|
||||
-6.97572385963986435018e-03,
|
||||
-1.22611180822657148235e-01,
|
||||
-3.53155960776544875667e-01,
|
||||
+1.52530022733894777053e+00,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
-5.75674448366501715755e-18,
|
||||
+1.79405087314755922667e-17,
|
||||
-5.68946255844285935196e-17,
|
||||
+1.83809354436663880070e-16,
|
||||
-6.05704724837331885336e-16,
|
||||
+2.03870316562433424052e-15,
|
||||
-7.01983709041831346144e-15,
|
||||
+2.47715442448130437068e-14,
|
||||
-8.97670518232499435011e-14,
|
||||
+3.34841966607842919884e-13,
|
||||
-1.28917396095102890680e-12,
|
||||
+5.13963967348173025100e-12,
|
||||
-2.12996783842756842877e-11,
|
||||
+9.21831518760500529508e-11,
|
||||
-4.19035475934189648750e-10,
|
||||
+2.01504975519703286596e-09,
|
||||
-1.03457624656780970260e-08,
|
||||
+5.74108412545004946722e-08,
|
||||
-3.50196060308781257119e-07,
|
||||
+2.40648494783721712015e-06,
|
||||
-1.93619797416608296024e-05,
|
||||
+1.95215518471351631108e-04,
|
||||
-2.85781685962277938680e-03,
|
||||
+1.03923736576817238437e-01,
|
||||
+2.72062619048444266945e+00,
|
||||
};
|
||||
|
||||
if (x == T(0.0)) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (x <= T(2.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 11; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = (x * x - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
|
||||
} // modified_bessel_k1_forward(T x)
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
|
@ -71,6 +71,14 @@ CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(sqrt)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(tan)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(tanh)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k1)
|
||||
|
||||
TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
|
||||
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
|
||||
@ -190,6 +198,14 @@ CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(trunc_out, trunc_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k1_out, special_modified_bessel_k1_stub)
|
||||
|
||||
TORCH_IMPL_FUNC(round_decimals_out)
|
||||
(const Tensor& self, int64_t decimals, const Tensor& result) {
|
||||
@ -863,6 +879,14 @@ DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-v
|
||||
DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -70,6 +70,14 @@ DECLARE_DISPATCH(unary_fn, tanh_stub);
|
||||
DECLARE_DISPATCH(unary_fn, trigamma_stub);
|
||||
DECLARE_DISPATCH(unary_fn, trunc_stub);
|
||||
DECLARE_DISPATCH(unary_fn, lgamma_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
|
||||
|
||||
// NB: these are actually defined in Distribution
|
||||
DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional<Generator>), bernoulli_tensor_stub);
|
||||
|
@ -566,6 +566,86 @@ void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
|
||||
});
|
||||
}
|
||||
|
||||
static void bessel_j0_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return bessel_j0_forward(x);
|
||||
});
|
||||
});
|
||||
} // bessel_j0_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void bessel_j1_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return bessel_j1_forward(x);
|
||||
});
|
||||
});
|
||||
} // bessel_j1_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void bessel_y0_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return bessel_y0_forward(x);
|
||||
});
|
||||
});
|
||||
} // bessel_y0_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void bessel_y1_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return bessel_y1_forward(x);
|
||||
});
|
||||
});
|
||||
} // bessel_y1_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void modified_bessel_i0_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return modified_bessel_i0_forward(x);
|
||||
});
|
||||
});
|
||||
} // modified_bessel_i0_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void modified_bessel_i1_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return modified_bessel_i1_forward(x);
|
||||
});
|
||||
});
|
||||
} // modified_bessel_i1_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void modified_bessel_k0_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return modified_bessel_k0_forward(x);
|
||||
});
|
||||
});
|
||||
} // modified_bessel_k0_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) {
|
||||
TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cpu", [&]() {
|
||||
cpu_kernel(iterator, [](scalar_t x) {
|
||||
return modified_bessel_k1_forward(x);
|
||||
});
|
||||
});
|
||||
} // modified_bessel_k1_kernel(TensorIteratorBase& iterator)
|
||||
|
||||
// TODO: Disable cont. branch to test more risky code
|
||||
|
||||
#define IMPLEMENT_ITERATOR_LAMBDA(op) \
|
||||
@ -656,7 +736,14 @@ REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
|
||||
REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
|
||||
REGISTER_DISPATCH(special_erfcx_stub, &CPU_CAPABILITY::erfcx_kernel);
|
||||
REGISTER_DISPATCH(round_decimals_stub, &CPU_CAPABILITY::round_decimals_kernel);
|
||||
|
||||
REGISTER_DISPATCH(special_bessel_j0_stub, &CPU_CAPABILITY::bessel_j0_kernel);
|
||||
REGISTER_DISPATCH(special_bessel_j1_stub, &CPU_CAPABILITY::bessel_j1_kernel);
|
||||
REGISTER_DISPATCH(special_bessel_y0_stub, &CPU_CAPABILITY::bessel_y0_kernel);
|
||||
REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel);
|
||||
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel);
|
||||
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel);
|
||||
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel);
|
||||
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
IMPLEMENT_COMPLEX_KERNEL(acos)
|
||||
|
@ -1265,6 +1265,463 @@ const auto erfcx_string = jiterator_stringify(
|
||||
}
|
||||
); // erfcx_string
|
||||
|
||||
const auto bessel_j0_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T bessel_j0_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.96936729297347051624e-04,
|
||||
+8.28352392107440799803e-02,
|
||||
+1.23953371646414299388e+00,
|
||||
+5.44725003058768775090e+00,
|
||||
+8.74716500199817011941e+00,
|
||||
+5.30324038235394892183e+00,
|
||||
+9.99999999999999997821e-01,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+9.24408810558863637013e-04,
|
||||
+8.56288474354474431428e-02,
|
||||
+1.25352743901058953537e+00,
|
||||
+5.47097740330417105182e+00,
|
||||
+8.76190883237069594232e+00,
|
||||
+5.30605288235394617618e+00,
|
||||
+1.00000000000000000218e+00,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
-1.13663838898469149931e-02,
|
||||
-1.28252718670509318512e+00,
|
||||
-1.95539544257735972385e+01,
|
||||
-9.32060152123768231369e+01,
|
||||
-1.77681167980488050595e+02,
|
||||
-1.47077505154951170175e+02,
|
||||
-5.14105326766599330220e+01,
|
||||
-6.05014350600728481186e+00,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+6.43178256118178023184e+01,
|
||||
+8.56430025976980587198e+02,
|
||||
+3.88240183605401609683e+03,
|
||||
+7.24046774195652478189e+03,
|
||||
+5.93072701187316984827e+03,
|
||||
+2.06209331660327847417e+03,
|
||||
+2.42005740240291393179e+02,
|
||||
};
|
||||
|
||||
static const T RP[] = {
|
||||
-4.79443220978201773821e+09,
|
||||
+1.95617491946556577543e+12,
|
||||
-2.49248344360967716204e+14,
|
||||
+9.70862251047306323952e+15,
|
||||
};
|
||||
|
||||
static const T RQ[] = {
|
||||
+4.99563147152651017219e+02,
|
||||
+1.73785401676374683123e+05,
|
||||
+4.84409658339962045305e+07,
|
||||
+1.11855537045356834862e+10,
|
||||
+2.11277520115489217587e+12,
|
||||
+3.10518229857422583814e+14,
|
||||
+3.18121955943204943306e+16,
|
||||
+1.71086294081043136091e+18,
|
||||
};
|
||||
|
||||
if (x < T(0)) {
|
||||
x = -x;
|
||||
}
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
if (x < T(0.00001)) {
|
||||
return T(1.0) - x * x / T(4.0);
|
||||
}
|
||||
|
||||
T rp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 3; index++) {
|
||||
rp = rp * (x * x) + RP[index];
|
||||
}
|
||||
|
||||
T rq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
rq = rq * (x * x) + RQ[index];
|
||||
}
|
||||
|
||||
return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(25.0) / (x * x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(25.0) / (x * x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(25.0) / (x * x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(25.0) / (x * x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x);
|
||||
} // bessel_j0_forward(T x)
|
||||
); // bessel_j0_string
|
||||
|
||||
const auto bessel_y0_string = bessel_j0_string + jiterator_stringify(
|
||||
template<typename T>
|
||||
T bessel_y0_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.96936729297347051624e-04,
|
||||
+8.28352392107440799803e-02,
|
||||
+1.23953371646414299388e+00,
|
||||
+5.44725003058768775090e+00,
|
||||
+8.74716500199817011941e+00,
|
||||
+5.30324038235394892183e+00,
|
||||
+9.99999999999999997821e-01,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+9.24408810558863637013e-04,
|
||||
+8.56288474354474431428e-02,
|
||||
+1.25352743901058953537e+00,
|
||||
+5.47097740330417105182e+00,
|
||||
+8.76190883237069594232e+00,
|
||||
+5.30605288235394617618e+00,
|
||||
+1.00000000000000000218e+00,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
-1.13663838898469149931e-02,
|
||||
-1.28252718670509318512e+00,
|
||||
-1.95539544257735972385e+01,
|
||||
-9.32060152123768231369e+01,
|
||||
-1.77681167980488050595e+02,
|
||||
-1.47077505154951170175e+02,
|
||||
-5.14105326766599330220e+01,
|
||||
-6.05014350600728481186e+00,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+6.43178256118178023184e+01,
|
||||
+8.56430025976980587198e+02,
|
||||
+3.88240183605401609683e+03,
|
||||
+7.24046774195652478189e+03,
|
||||
+5.93072701187316984827e+03,
|
||||
+2.06209331660327847417e+03,
|
||||
+2.42005740240291393179e+02,
|
||||
};
|
||||
|
||||
static const T YP[] = {
|
||||
+1.55924367855235737965e+04,
|
||||
-1.46639295903971606143e+07,
|
||||
+5.43526477051876500413e+09,
|
||||
-9.82136065717911466409e+11,
|
||||
+8.75906394395366999549e+13,
|
||||
-3.46628303384729719441e+15,
|
||||
+4.42733268572569800351e+16,
|
||||
-1.84950800436986690637e+16,
|
||||
};
|
||||
|
||||
static const T YQ[] = {
|
||||
+1.04128353664259848412e+03,
|
||||
+6.26107330137134956842e+05,
|
||||
+2.68919633393814121987e+08,
|
||||
+8.64002487103935000337e+10,
|
||||
+2.02979612750105546709e+13,
|
||||
+3.17157752842975028269e+15,
|
||||
+2.50596256172653059228e+17,
|
||||
};
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
if (x == T(0.0)) {
|
||||
return NEG_INFINITY;
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
NAN;
|
||||
}
|
||||
|
||||
T yp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
yp = yp * (x * x) + YP[index];
|
||||
}
|
||||
|
||||
T yq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
yq = yq * (x * x) + YQ[index];
|
||||
}
|
||||
|
||||
return yp / yq + (T(0.636619772367581343075535053490057448) * log(x) * bessel_j0_forward(x));
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(25.0) / (x * x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(25.0) / (x * x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(25.0) / (x * x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(25.0) / (x * x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x);
|
||||
} // bessel_y0_forward(T x)
|
||||
); // bessel_y0_string
|
||||
|
||||
const auto bessel_j1_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T bessel_j1_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.62125616208173112003e-04,
|
||||
+7.31397056940917570436e-02,
|
||||
+1.12719608129684925192e+00,
|
||||
+5.11207951146807644818e+00,
|
||||
+8.42404590141772420927e+00,
|
||||
+5.21451598682361504063e+00,
|
||||
+1.00000000000000000254e+00,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+5.71323128072548699714e-04,
|
||||
+6.88455908754495404082e-02,
|
||||
+1.10514232634061696926e+00,
|
||||
+5.07386386128601488557e+00,
|
||||
+8.39985554327604159757e+00,
|
||||
+5.20982848682361821619e+00,
|
||||
+9.99999999999999997461e-01,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
+5.10862594750176621635e-02,
|
||||
+4.98213872951233449420e+00,
|
||||
+7.58238284132545283818e+01,
|
||||
+3.66779609360150777800e+02,
|
||||
+7.10856304998926107277e+02,
|
||||
+5.97489612400613639965e+02,
|
||||
+2.11688757100572135698e+02,
|
||||
+2.52070205858023719784e+01,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+7.42373277035675149943e+01,
|
||||
+1.05644886038262816351e+03,
|
||||
+4.98641058337653607651e+03,
|
||||
+9.56231892404756170795e+03,
|
||||
+7.99704160447350683650e+03,
|
||||
+2.82619278517639096600e+03,
|
||||
+3.36093607810698293419e+02,
|
||||
};
|
||||
|
||||
static const T RP[] = {
|
||||
-8.99971225705559398224e+08,
|
||||
+4.52228297998194034323e+11,
|
||||
-7.27494245221818276015e+13,
|
||||
+3.68295732863852883286e+15,
|
||||
};
|
||||
|
||||
static const T RQ[] = {
|
||||
+6.20836478118054335476e+02,
|
||||
+2.56987256757748830383e+05,
|
||||
+8.35146791431949253037e+07,
|
||||
+2.21511595479792499675e+10,
|
||||
+4.74914122079991414898e+12,
|
||||
+7.84369607876235854894e+14,
|
||||
+8.95222336184627338078e+16,
|
||||
+5.32278620332680085395e+18,
|
||||
};
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return -bessel_j1_forward(-x);
|
||||
}
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
T rp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 3; index++) {
|
||||
rp = rp * (x * x) + RP[index];
|
||||
}
|
||||
|
||||
T rq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
rq = rq * (x * x) + RQ[index];
|
||||
}
|
||||
|
||||
return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x);
|
||||
} // bessel_j1_forward(T x)
|
||||
); // bessel_j1_string
|
||||
|
||||
const auto bessel_y1_string = bessel_j1_string + jiterator_stringify(
|
||||
template<typename T>
|
||||
T bessel_y1_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.62125616208173112003e-04,
|
||||
+7.31397056940917570436e-02,
|
||||
+1.12719608129684925192e+00,
|
||||
+5.11207951146807644818e+00,
|
||||
+8.42404590141772420927e+00,
|
||||
+5.21451598682361504063e+00,
|
||||
+1.00000000000000000254e+00,
|
||||
};
|
||||
|
||||
static const T PQ[] = {
|
||||
+5.71323128072548699714e-04,
|
||||
+6.88455908754495404082e-02,
|
||||
+1.10514232634061696926e+00,
|
||||
+5.07386386128601488557e+00,
|
||||
+8.39985554327604159757e+00,
|
||||
+5.20982848682361821619e+00,
|
||||
+9.99999999999999997461e-01,
|
||||
};
|
||||
|
||||
static const T QP[] = {
|
||||
+5.10862594750176621635e-02,
|
||||
+4.98213872951233449420e+00,
|
||||
+7.58238284132545283818e+01,
|
||||
+3.66779609360150777800e+02,
|
||||
+7.10856304998926107277e+02,
|
||||
+5.97489612400613639965e+02,
|
||||
+2.11688757100572135698e+02,
|
||||
+2.52070205858023719784e+01,
|
||||
};
|
||||
|
||||
static const T QQ[] = {
|
||||
+7.42373277035675149943e+01,
|
||||
+1.05644886038262816351e+03,
|
||||
+4.98641058337653607651e+03,
|
||||
+9.56231892404756170795e+03,
|
||||
+7.99704160447350683650e+03,
|
||||
+2.82619278517639096600e+03,
|
||||
+3.36093607810698293419e+02,
|
||||
};
|
||||
|
||||
static const T YP[] = {
|
||||
+1.26320474790178026440e+09,
|
||||
-6.47355876379160291031e+11,
|
||||
+1.14509511541823727583e+14,
|
||||
-8.12770255501325109621e+15,
|
||||
+2.02439475713594898196e+17,
|
||||
-7.78877196265950026825e+17,
|
||||
};
|
||||
|
||||
static const T YQ[] = {
|
||||
+5.94301592346128195359e+02,
|
||||
+2.35564092943068577943e+05,
|
||||
+7.34811944459721705660e+07,
|
||||
+1.87601316108706159478e+10,
|
||||
+3.88231277496238566008e+12,
|
||||
+6.20557727146953693363e+14,
|
||||
+6.87141087355300489866e+16,
|
||||
+3.97270608116560655612e+18,
|
||||
};
|
||||
|
||||
if (x <= T(5.0)) {
|
||||
if (x == T(0.0)) {
|
||||
return NEG_INFINITY;
|
||||
}
|
||||
|
||||
if (x <= T(0.0)) {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
T yp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 5; index++) {
|
||||
yp = yp * (x * x) + YP[index];
|
||||
}
|
||||
|
||||
T yq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
yq = yq * (x * x) + YQ[index];
|
||||
}
|
||||
|
||||
return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * log(x) - T(1.0) / x));
|
||||
}
|
||||
|
||||
T pp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
|
||||
}
|
||||
|
||||
T pq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
|
||||
}
|
||||
|
||||
T qp = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 7; index++) {
|
||||
qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
|
||||
}
|
||||
|
||||
T qq = 0.0;
|
||||
|
||||
for (uint8_t index = 0; index <= 6; index++) {
|
||||
qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
|
||||
}
|
||||
|
||||
return (pp / pq * sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x);
|
||||
} // bessel_y1_forward(T x)
|
||||
); // bessel_y1_string
|
||||
|
||||
const auto chebyshev_polynomial_t_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
@ -1467,6 +1924,354 @@ const auto laguerre_polynomial_l_string = jiterator_stringify(
|
||||
} // laguerre_polynomial_l_forward(T x, T n)
|
||||
); // laguerre_polynomial_l_string
|
||||
|
||||
const auto modified_bessel_i0_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T modified_bessel_i0_forward(T x) {
|
||||
static const T A[] = {
|
||||
-4.41534164647933937950e-18,
|
||||
+3.33079451882223809783e-17,
|
||||
-2.43127984654795469359e-16,
|
||||
+1.71539128555513303061e-15,
|
||||
-1.16853328779934516808e-14,
|
||||
+7.67618549860493561688e-14,
|
||||
-4.85644678311192946090e-13,
|
||||
+2.95505266312963983461e-12,
|
||||
-1.72682629144155570723e-11,
|
||||
+9.67580903537323691224e-11,
|
||||
-5.18979560163526290666e-10,
|
||||
+2.65982372468238665035e-09,
|
||||
-1.30002500998624804212e-08,
|
||||
+6.04699502254191894932e-08,
|
||||
-2.67079385394061173391e-07,
|
||||
+1.11738753912010371815e-06,
|
||||
-4.41673835845875056359e-06,
|
||||
+1.64484480707288970893e-05,
|
||||
-5.75419501008210370398e-05,
|
||||
+1.88502885095841655729e-04,
|
||||
-5.76375574538582365885e-04,
|
||||
+1.63947561694133579842e-03,
|
||||
-4.32430999505057594430e-03,
|
||||
+1.05464603945949983183e-02,
|
||||
-2.37374148058994688156e-02,
|
||||
+4.93052842396707084878e-02,
|
||||
-9.49010970480476444210e-02,
|
||||
+1.71620901522208775349e-01,
|
||||
-3.04682672343198398683e-01,
|
||||
+6.76795274409476084995e-01,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
-7.23318048787475395456e-18,
|
||||
-4.83050448594418207126e-18,
|
||||
+4.46562142029675999901e-17,
|
||||
+3.46122286769746109310e-17,
|
||||
-2.82762398051658348494e-16,
|
||||
-3.42548561967721913462e-16,
|
||||
+1.77256013305652638360e-15,
|
||||
+3.81168066935262242075e-15,
|
||||
-9.55484669882830764870e-15,
|
||||
-4.15056934728722208663e-14,
|
||||
+1.54008621752140982691e-14,
|
||||
+3.85277838274214270114e-13,
|
||||
+7.18012445138366623367e-13,
|
||||
-1.79417853150680611778e-12,
|
||||
-1.32158118404477131188e-11,
|
||||
-3.14991652796324136454e-11,
|
||||
+1.18891471078464383424e-11,
|
||||
+4.94060238822496958910e-10,
|
||||
+3.39623202570838634515e-09,
|
||||
+2.26666899049817806459e-08,
|
||||
+2.04891858946906374183e-07,
|
||||
+2.89137052083475648297e-06,
|
||||
+6.88975834691682398426e-05,
|
||||
+3.36911647825569408990e-03,
|
||||
+8.04490411014108831608e-01,
|
||||
};
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (abs(x) <= T(8.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 30; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
return exp(abs(x)) * (T(0.5) * (a - p));
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x));
|
||||
} // modified_bessel_i0_forward(T x)
|
||||
); // modified_bessel_i0_string
|
||||
|
||||
const auto modified_bessel_i1_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
T modified_bessel_i1_forward(T x) {
|
||||
static const T A[] = {
|
||||
+2.77791411276104639959e-18,
|
||||
-2.11142121435816608115e-17,
|
||||
+1.55363195773620046921e-16,
|
||||
-1.10559694773538630805e-15,
|
||||
+7.60068429473540693410e-15,
|
||||
-5.04218550472791168711e-14,
|
||||
+3.22379336594557470981e-13,
|
||||
-1.98397439776494371520e-12,
|
||||
+1.17361862988909016308e-11,
|
||||
-6.66348972350202774223e-11,
|
||||
+3.62559028155211703701e-10,
|
||||
-1.88724975172282928790e-09,
|
||||
+9.38153738649577178388e-09,
|
||||
-4.44505912879632808065e-08,
|
||||
+2.00329475355213526229e-07,
|
||||
-8.56872026469545474066e-07,
|
||||
+3.47025130813767847674e-06,
|
||||
-1.32731636560394358279e-05,
|
||||
+4.78156510755005422638e-05,
|
||||
-1.61760815825896745588e-04,
|
||||
+5.12285956168575772895e-04,
|
||||
-1.51357245063125314899e-03,
|
||||
+4.15642294431288815669e-03,
|
||||
-1.05640848946261981558e-02,
|
||||
+2.47264490306265168283e-02,
|
||||
-5.29459812080949914269e-02,
|
||||
+1.02643658689847095384e-01,
|
||||
-1.76416518357834055153e-01,
|
||||
+2.52587186443633654823e-01,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
+7.51729631084210481353e-18,
|
||||
+4.41434832307170791151e-18,
|
||||
-4.65030536848935832153e-17,
|
||||
-3.20952592199342395980e-17,
|
||||
+2.96262899764595013876e-16,
|
||||
+3.30820231092092828324e-16,
|
||||
-1.88035477551078244854e-15,
|
||||
-3.81440307243700780478e-15,
|
||||
+1.04202769841288027642e-14,
|
||||
+4.27244001671195135429e-14,
|
||||
-2.10154184277266431302e-14,
|
||||
-4.08355111109219731823e-13,
|
||||
-7.19855177624590851209e-13,
|
||||
+2.03562854414708950722e-12,
|
||||
+1.41258074366137813316e-11,
|
||||
+3.25260358301548823856e-11,
|
||||
-1.89749581235054123450e-11,
|
||||
-5.58974346219658380687e-10,
|
||||
-3.83538038596423702205e-09,
|
||||
-2.63146884688951950684e-08,
|
||||
-2.51223623787020892529e-07,
|
||||
-3.88256480887769039346e-06,
|
||||
-1.10588938762623716291e-04,
|
||||
-9.76109749136146840777e-03,
|
||||
+7.78576235018280120474e-01,
|
||||
};
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (abs(x) <= T(8.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 29; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return -(T(0.5) * (a - p) * abs(x) * exp(abs(x)));
|
||||
}
|
||||
|
||||
return T(0.5) * (a - p) * abs(x) * exp(abs(x));
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return -(exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)));
|
||||
}
|
||||
|
||||
return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x));
|
||||
} // modified_bessel_i1_forward(T x)
|
||||
); // modified_bessel_i1_string
|
||||
|
||||
const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify(
|
||||
template<typename T>
|
||||
T modified_bessel_k0_forward(T x) {
|
||||
static const T A[] = {
|
||||
+1.37446543561352307156e-16,
|
||||
+4.25981614279661018399e-14,
|
||||
+1.03496952576338420167e-11,
|
||||
+1.90451637722020886025e-09,
|
||||
+2.53479107902614945675e-07,
|
||||
+2.28621210311945178607e-05,
|
||||
+1.26461541144692592338e-03,
|
||||
+3.59799365153615016266e-02,
|
||||
+3.44289899924628486886e-01,
|
||||
-5.35327393233902768720e-01,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
+5.30043377268626276149e-18,
|
||||
-1.64758043015242134646e-17,
|
||||
+5.21039150503902756861e-17,
|
||||
-1.67823109680541210385e-16,
|
||||
+5.51205597852431940784e-16,
|
||||
-1.84859337734377901440e-15,
|
||||
+6.34007647740507060557e-15,
|
||||
-2.22751332699166985548e-14,
|
||||
+8.03289077536357521100e-14,
|
||||
-2.98009692317273043925e-13,
|
||||
+1.14034058820847496303e-12,
|
||||
-4.51459788337394416547e-12,
|
||||
+1.85594911495471785253e-11,
|
||||
-7.95748924447710747776e-11,
|
||||
+3.57739728140030116597e-10,
|
||||
-1.69753450938905987466e-09,
|
||||
+8.57403401741422608519e-09,
|
||||
-4.66048989768794782956e-08,
|
||||
+2.76681363944501510342e-07,
|
||||
-1.83175552271911948767e-06,
|
||||
+1.39498137188764993662e-05,
|
||||
-1.28495495816278026384e-04,
|
||||
+1.56988388573005337491e-03,
|
||||
-3.14481013119645005427e-02,
|
||||
+2.44030308206595545468e+00,
|
||||
};
|
||||
|
||||
if (x == T(0.0)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (x <= T(2.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 10; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = (x * x - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
return T(0.5) * (a - p) - log(0.5 * x) * modified_bessel_i0_forward(x);
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
return exp(-x) * (T(0.5) * (b - p)) / sqrt(x);
|
||||
} // modified_bessel_k0_forward(T x)
|
||||
); // modified_bessel_k0_string
|
||||
|
||||
const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify(
|
||||
template<typename T>
|
||||
T modified_bessel_k1_forward(T x) {
|
||||
static const T A[] = {
|
||||
-7.02386347938628759343e-18,
|
||||
-2.42744985051936593393e-15,
|
||||
-6.66690169419932900609e-13,
|
||||
-1.41148839263352776110e-10,
|
||||
-2.21338763073472585583e-08,
|
||||
-2.43340614156596823496e-06,
|
||||
-1.73028895751305206302e-04,
|
||||
-6.97572385963986435018e-03,
|
||||
-1.22611180822657148235e-01,
|
||||
-3.53155960776544875667e-01,
|
||||
+1.52530022733894777053e+00,
|
||||
};
|
||||
|
||||
static const T B[] = {
|
||||
-5.75674448366501715755e-18,
|
||||
+1.79405087314755922667e-17,
|
||||
-5.68946255844285935196e-17,
|
||||
+1.83809354436663880070e-16,
|
||||
-6.05704724837331885336e-16,
|
||||
+2.03870316562433424052e-15,
|
||||
-7.01983709041831346144e-15,
|
||||
+2.47715442448130437068e-14,
|
||||
-8.97670518232499435011e-14,
|
||||
+3.34841966607842919884e-13,
|
||||
-1.28917396095102890680e-12,
|
||||
+5.13963967348173025100e-12,
|
||||
-2.12996783842756842877e-11,
|
||||
+9.21831518760500529508e-11,
|
||||
-4.19035475934189648750e-10,
|
||||
+2.01504975519703286596e-09,
|
||||
-1.03457624656780970260e-08,
|
||||
+5.74108412545004946722e-08,
|
||||
-3.50196060308781257119e-07,
|
||||
+2.40648494783721712015e-06,
|
||||
-1.93619797416608296024e-05,
|
||||
+1.95215518471351631108e-04,
|
||||
-2.85781685962277938680e-03,
|
||||
+1.03923736576817238437e-01,
|
||||
+2.72062619048444266945e+00,
|
||||
};
|
||||
|
||||
if (x == T(0.0)) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
if (x < T(0.0)) {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
T p;
|
||||
T q = 0.0;
|
||||
|
||||
if (x <= T(2.0)) {
|
||||
T a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 11; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = (x * x - T(2.0)) * q - p + A[index];
|
||||
}
|
||||
|
||||
return log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
|
||||
}
|
||||
|
||||
T b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (T(8.0) / x - T(2.0)) * q - p + B[index];
|
||||
}
|
||||
|
||||
return exp(-x) * (T(0.5) * (b - p)) / sqrt(x);
|
||||
} // modified_bessel_k1_forward(T x)
|
||||
); // modified_bessel_k1_string
|
||||
|
||||
#else // !AT_USE_JITERATOR() -- kernels must be precompiled
|
||||
|
||||
template <typename scalar_t>
|
||||
|
43
aten/src/ATen/native/cuda/bessel_j0.cu
Normal file
43
aten/src/ATen/native/cuda/bessel_j0.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char bessel_j0_name[] = "bessel_j0_forward";
|
||||
|
||||
void bessel_j0_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cuda", [&]() {
|
||||
jitted_gpu_kernel<bessel_j0_name, scalar_t, scalar_t, 1>(iterator, bessel_j0_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return bessel_j0_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/bessel_j1.cu
Normal file
43
aten/src/ATen/native/cuda/bessel_j1.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char bessel_j1_name[] = "bessel_j1_forward";
|
||||
|
||||
void bessel_j1_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cuda", [&]() {
|
||||
jitted_gpu_kernel<bessel_j1_name, scalar_t, scalar_t, 1>(iterator, bessel_j1_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return bessel_j1_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/bessel_y0.cu
Normal file
43
aten/src/ATen/native/cuda/bessel_y0.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char bessel_y0_name[] = "bessel_y0_forward";
|
||||
|
||||
void bessel_y0_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cuda", [&]() {
|
||||
jitted_gpu_kernel<bessel_y0_name, scalar_t, scalar_t, 1>(iterator, bessel_y0_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return bessel_y0_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/bessel_y1.cu
Normal file
43
aten/src/ATen/native/cuda/bessel_y1.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char bessel_y1_name[] = "bessel_y1_forward";
|
||||
|
||||
void bessel_y1_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cuda", [&]() {
|
||||
jitted_gpu_kernel<bessel_y1_name, scalar_t, scalar_t, 1>(iterator, bessel_y1_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return bessel_y1_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/modified_bessel_i0.cu
Normal file
43
aten/src/ATen/native/cuda/modified_bessel_i0.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char modified_bessel_i0_name[] = "modified_bessel_i0_forward";
|
||||
|
||||
void modified_bessel_i0_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cuda", [&]() {
|
||||
jitted_gpu_kernel<modified_bessel_i0_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_i0_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return modified_bessel_i0_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/modified_bessel_i1.cu
Normal file
43
aten/src/ATen/native/cuda/modified_bessel_i1.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char modified_bessel_i1_name[] = "modified_bessel_i1_forward";
|
||||
|
||||
void modified_bessel_i1_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cuda", [&]() {
|
||||
jitted_gpu_kernel<modified_bessel_i1_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_i1_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return modified_bessel_i1_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/modified_bessel_k0.cu
Normal file
43
aten/src/ATen/native/cuda/modified_bessel_k0.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char modified_bessel_k0_name[] = "modified_bessel_k0_forward";
|
||||
|
||||
void modified_bessel_k0_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cuda", [&]() {
|
||||
jitted_gpu_kernel<modified_bessel_k0_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_k0_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return modified_bessel_k0_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &modified_bessel_k0_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
43
aten/src/ATen/native/cuda/modified_bessel_k1.cu
Normal file
43
aten/src/ATen/native/cuda/modified_bessel_k1.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
const char modified_bessel_k1_name[] = "modified_bessel_k1_forward";
|
||||
|
||||
void modified_bessel_k1_kernel_cuda(TensorIteratorBase& iterator) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cuda", [&]() {
|
||||
jitted_gpu_kernel<modified_bessel_k1_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_k1_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cuda", [&]() {
|
||||
gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return modified_bessel_k1_forward(a);
|
||||
});
|
||||
});
|
||||
#endif // AT_USE_JITERATOR()
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &modified_bessel_k1_kernel_cuda);
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -12404,6 +12404,58 @@
|
||||
dispatch:
|
||||
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention
|
||||
|
||||
- func: special_bessel_j0(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_bessel_j0.out
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_bessel_j0_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_j1(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_bessel_j1.out
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_bessel_j1_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_y0(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_bessel_y0.out
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_bessel_y0_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_y1(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_bessel_y1.out
|
||||
variants: function
|
||||
|
||||
- func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_bessel_y1_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
|
||||
device_check: NoCheck
|
||||
python_module: special
|
||||
@ -12588,3 +12640,55 @@
|
||||
device_check: NoCheck
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_i0(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_modified_bessel_i0.out
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_modified_bessel_i0_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_i1(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_modified_bessel_i1.out
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_modified_bessel_i1_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_k0(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_modified_bessel_k0.out
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_modified_bessel_k0_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_k1(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
structured_delegate: special_modified_bessel_k1.out
|
||||
variants: function
|
||||
|
||||
- func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_modified_bessel_k1_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
variants: function
|
||||
|
@ -2955,6 +2955,10 @@
|
||||
"softmax"
|
||||
],
|
||||
"torch.special": [
|
||||
"bessel_j0",
|
||||
"bessel_j1",
|
||||
"bessel_y0",
|
||||
"bessel_y1",
|
||||
"chebyshev_polynomial_t",
|
||||
"chebyshev_polynomial_u",
|
||||
"digamma",
|
||||
@ -2981,6 +2985,10 @@
|
||||
"log_softmax",
|
||||
"logit",
|
||||
"logsumexp",
|
||||
"modified_bessel_i0",
|
||||
"modified_bessel_i1",
|
||||
"modified_bessel_k0",
|
||||
"modified_bessel_k1",
|
||||
"multigammaln",
|
||||
"ndtr",
|
||||
"ndtri",
|
||||
|
@ -2757,6 +2757,18 @@
|
||||
self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result)
|
||||
index: non_differentiable
|
||||
|
||||
- name: special_bessel_j0(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_bessel_j1(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_bessel_y0(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_bessel_y1(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
|
||||
x: non_differentiable
|
||||
n: non_differentiable
|
||||
@ -2806,3 +2818,15 @@
|
||||
|
||||
- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor
|
||||
x: non_differentiable
|
||||
|
||||
- name: special_modified_bessel_i0(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_modified_bessel_i1(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_modified_bessel_k0(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: special_modified_bessel_k1(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
@ -565,6 +565,82 @@ inline Tensor softmax(const Tensor& self, int64_t dim, c10::optional<ScalarType>
|
||||
return torch::special_softmax(self, dim, dtype);
|
||||
}
|
||||
|
||||
/// Bessel function of the first kind of order 0.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_j0.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::bessel_j0(x);
|
||||
/// ```
|
||||
inline Tensor bessel_j0(const Tensor& self) {
|
||||
return torch::special_bessel_j0(self);
|
||||
}
|
||||
|
||||
inline Tensor& bessel_j0_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_bessel_j0_out(result, self);
|
||||
}
|
||||
|
||||
/// Bessel function of the first kind of order 1.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_j1.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::bessel_j1(x);
|
||||
/// ```
|
||||
inline Tensor bessel_j1(const Tensor& self) {
|
||||
return torch::special_bessel_j1(self);
|
||||
}
|
||||
|
||||
inline Tensor& bessel_j1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_bessel_j1_out(result, self);
|
||||
}
|
||||
|
||||
/// Bessel function of the second kind of order 0.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_y0.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::bessel_y0(x);
|
||||
/// ```
|
||||
inline Tensor bessel_y0(const Tensor& self) {
|
||||
return torch::special_bessel_y0(self);
|
||||
}
|
||||
|
||||
inline Tensor& bessel_y0_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_bessel_y0_out(result, self);
|
||||
}
|
||||
|
||||
/// Bessel function of the second kind of order 1.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.bessel_y1.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::bessel_y1(x);
|
||||
/// ```
|
||||
inline Tensor bessel_y1(const Tensor& self) {
|
||||
return torch::special_bessel_y1(self);
|
||||
}
|
||||
|
||||
inline Tensor& bessel_y1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_bessel_y1_out(result, self);
|
||||
}
|
||||
|
||||
/// Chebyshev polynomial of the first kind.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_t.
|
||||
@ -745,4 +821,80 @@ inline Tensor& laguerre_polynomial_l_out(Tensor& output, const Tensor& x, const
|
||||
return torch::special_laguerre_polynomial_l_out(output, x, n);
|
||||
}
|
||||
|
||||
/// Modified Bessel function of the first kind of order 0.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i0.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::modified_bessel_i0(x);
|
||||
/// ```
|
||||
inline Tensor modified_bessel_i0(const Tensor& self) {
|
||||
return torch::special_modified_bessel_i0(self);
|
||||
}
|
||||
|
||||
inline Tensor& modified_bessel_i0_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_modified_bessel_i0_out(result, self);
|
||||
}
|
||||
|
||||
/// Modified Bessel function of the first kind of order 1.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i1.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::modified_bessel_i1(x);
|
||||
/// ```
|
||||
inline Tensor modified_bessel_i1(const Tensor& self) {
|
||||
return torch::special_modified_bessel_i1(self);
|
||||
}
|
||||
|
||||
inline Tensor& modified_bessel_i1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_modified_bessel_i1_out(result, self);
|
||||
}
|
||||
|
||||
/// Modified Bessel function of the second kind of order 0.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k0.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::modified_bessel_k0(x);
|
||||
/// ```
|
||||
inline Tensor modified_bessel_k0(const Tensor& self) {
|
||||
return torch::special_modified_bessel_k0(self);
|
||||
}
|
||||
|
||||
inline Tensor& modified_bessel_k0_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_modified_bessel_k0_out(result, self);
|
||||
}
|
||||
|
||||
/// Modified Bessel function of the second kind of order 1.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k1.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
///
|
||||
/// torch::special::modified_bessel_k1(x);
|
||||
/// ```
|
||||
inline Tensor modified_bessel_k1(const Tensor& self) {
|
||||
return torch::special_modified_bessel_k1(self);
|
||||
}
|
||||
|
||||
inline Tensor& modified_bessel_k1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_modified_bessel_k1_out(result, self);
|
||||
}
|
||||
|
||||
}} // torch::special
|
||||
|
@ -986,6 +986,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1,
|
||||
torch.swapaxes: lambda input, dim0, dim1: -1,
|
||||
torch.swapdims: lambda input, axis0, axis1: -1,
|
||||
torch.special.bessel_j0: lambda input: -1,
|
||||
torch.special.bessel_j1: lambda input: -1,
|
||||
torch.special.bessel_y0: lambda input: -1,
|
||||
torch.special.bessel_y1: lambda input: -1,
|
||||
torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
|
||||
torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
|
||||
torch.special.digamma: lambda input: -1,
|
||||
@ -1012,6 +1016,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.log_softmax: lambda input, dim, dtype=None: -1,
|
||||
torch.special.logit: lambda input: -1,
|
||||
torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
|
||||
torch.special.modified_bessel_i0: lambda input: -1,
|
||||
torch.special.modified_bessel_i1: lambda input: -1,
|
||||
torch.special.modified_bessel_k0: lambda input: -1,
|
||||
torch.special.modified_bessel_k1: lambda input: -1,
|
||||
torch.special.multigammaln: lambda input, p: -1,
|
||||
torch.special.ndtr: lambda input: -1,
|
||||
torch.special.ndtri: lambda input: -1,
|
||||
|
@ -3,6 +3,10 @@ from torch._C import _add_docstr, _special # type: ignore[attr-defined]
|
||||
from torch._torch_docs import common_args, multi_dim_common
|
||||
|
||||
__all__ = [
|
||||
'bessel_j0',
|
||||
'bessel_j1',
|
||||
'bessel_y0',
|
||||
'bessel_y1',
|
||||
'chebyshev_polynomial_t',
|
||||
'chebyshev_polynomial_u',
|
||||
'digamma',
|
||||
@ -24,11 +28,15 @@ __all__ = [
|
||||
'i1',
|
||||
'i1e',
|
||||
'laguerre_polynomial_l',
|
||||
'log1p',
|
||||
'log_ndtr',
|
||||
'log_softmax',
|
||||
'log1p',
|
||||
'logit',
|
||||
'logsumexp',
|
||||
'modified_bessel_i0',
|
||||
'modified_bessel_i1',
|
||||
'modified_bessel_k0',
|
||||
'modified_bessel_k1',
|
||||
'multigammaln',
|
||||
'ndtr',
|
||||
'ndtri',
|
||||
@ -856,6 +864,62 @@ Example::
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
bessel_j0 = _add_docstr(_special.special_bessel_j0,
|
||||
r"""
|
||||
bessel_j0(input, *, out=None) -> Tensor
|
||||
|
||||
Bessel function of the first kind of order :math:`0`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
bessel_j1 = _add_docstr(_special.special_bessel_j1,
|
||||
r"""
|
||||
bessel_j1(input, *, out=None) -> Tensor
|
||||
|
||||
Bessel function of the first kind of order :math:`1`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
bessel_y0 = _add_docstr(_special.special_bessel_y0,
|
||||
r"""
|
||||
bessel_y0(input, *, out=None) -> Tensor
|
||||
|
||||
Bessel function of the second kind of order :math:`0`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
bessel_y1 = _add_docstr(_special.special_bessel_y1,
|
||||
r"""
|
||||
bessel_y1(input, *, out=None) -> Tensor
|
||||
|
||||
Bessel function of the second kind of order :math:`1`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
chebyshev_polynomial_t = _add_docstr(_special.special_chebyshev_polynomial_t,
|
||||
r"""
|
||||
chebyshev_polynomial_t(input, n, *, out=None) -> Tensor
|
||||
@ -976,3 +1040,59 @@ Args:
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
modified_bessel_i0 = _add_docstr(_special.special_modified_bessel_i0,
|
||||
r"""
|
||||
modified_bessel_i0(input, *, out=None) -> Tensor
|
||||
|
||||
Modified Bessel function of the first kind of order :math:`0`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
modified_bessel_i1 = _add_docstr(_special.special_modified_bessel_i1,
|
||||
r"""
|
||||
modified_bessel_i1(input, *, out=None) -> Tensor
|
||||
|
||||
Modified Bessel function of the first kind of order :math:`1`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
modified_bessel_k0 = _add_docstr(_special.special_modified_bessel_k0,
|
||||
r"""
|
||||
modified_bessel_k0(input, *, out=None) -> Tensor
|
||||
|
||||
Modified Bessel function of the second kind of order :math:`0`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
||||
modified_bessel_k1 = _add_docstr(_special.special_modified_bessel_k1,
|
||||
r"""
|
||||
modified_bessel_k1(input, *, out=None) -> Tensor
|
||||
|
||||
Modified Bessel function of the second kind of order :math:`1`.
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(**common_args))
|
||||
|
@ -19084,6 +19084,62 @@ op_db: List[OpInfo] = [
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_scatter_reduce,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.bessel_j0',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-04,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.j0 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.bessel_j1',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-04,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.j1 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.bessel_y0',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-04,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.y0 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.bessel_y1',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-04,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.y1 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
BinaryUfuncInfo(
|
||||
'special.chebyshev_polynomial_t',
|
||||
dtypes=all_types_and(torch.bool),
|
||||
@ -19139,6 +19195,62 @@ op_db: List[OpInfo] = [
|
||||
supports_one_python_scalar=True,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.modified_bessel_i0',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-03,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.i0 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.modified_bessel_i1',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-03,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.i1 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.modified_bessel_k0',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-03,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.k0 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'special.modified_bessel_k1',
|
||||
decorators=(
|
||||
precisionOverride(
|
||||
{
|
||||
torch.float32: 1e-03,
|
||||
torch.float64: 1e-05,
|
||||
},
|
||||
),
|
||||
),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
ref=scipy.special.k1 if TEST_SCIPY else _NOTHING,
|
||||
supports_autograd=False,
|
||||
),
|
||||
]
|
||||
|
||||
# NOTE [Python References]
|
||||
|
Reference in New Issue
Block a user