mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Implement support for modified_bessel_i1
in eager. (#149368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149368 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb42e4d137
commit
c43e35d6f7
@ -6,6 +6,7 @@ using namespace metal;
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j0_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i0_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i1_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
|
||||
@ -49,6 +50,7 @@ struct bessel_y1_forward_functor {
|
||||
REGISTER_UNARY_OP(bessel_j0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(bessel_j1_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(modified_bessel_i0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(modified_bessel_i1_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(bessel_y0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(bessel_y1_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(i0, DTI, DTO); \
|
||||
|
@ -48,6 +48,10 @@ static void modified_bessel_i0_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "modified_bessel_i0_forward");
|
||||
}
|
||||
|
||||
static void modified_bessel_i1_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "modified_bessel_i1_forward");
|
||||
}
|
||||
|
||||
static void bessel_y0_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "bessel_y0_forward");
|
||||
}
|
||||
@ -63,6 +67,7 @@ REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)
|
||||
|
@ -15454,7 +15454,7 @@
|
||||
|
||||
- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_modified_bessel_i1_out
|
||||
CPU, CUDA, MPS: special_modified_bessel_i1_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
|
@ -1185,5 +1185,74 @@ inline float modified_bessel_i0_forward(T x) {
|
||||
::metal::precise::sqrt(::metal::fabs(x));
|
||||
} // modified_bessel_i0_forward(T x)
|
||||
|
||||
template <typename T>
|
||||
inline float modified_bessel_i1_forward(T x) {
|
||||
constexpr float A[] = {
|
||||
+2.77791411276104639959e-18, -2.11142121435816608115e-17,
|
||||
+1.55363195773620046921e-16, -1.10559694773538630805e-15,
|
||||
+7.60068429473540693410e-15, -5.04218550472791168711e-14,
|
||||
+3.22379336594557470981e-13, -1.98397439776494371520e-12,
|
||||
+1.17361862988909016308e-11, -6.66348972350202774223e-11,
|
||||
+3.62559028155211703701e-10, -1.88724975172282928790e-09,
|
||||
+9.38153738649577178388e-09, -4.44505912879632808065e-08,
|
||||
+2.00329475355213526229e-07, -8.56872026469545474066e-07,
|
||||
+3.47025130813767847674e-06, -1.32731636560394358279e-05,
|
||||
+4.78156510755005422638e-05, -1.61760815825896745588e-04,
|
||||
+5.12285956168575772895e-04, -1.51357245063125314899e-03,
|
||||
+4.15642294431288815669e-03, -1.05640848946261981558e-02,
|
||||
+2.47264490306265168283e-02, -5.29459812080949914269e-02,
|
||||
+1.02643658689847095384e-01, -1.76416518357834055153e-01,
|
||||
+2.52587186443633654823e-01,
|
||||
};
|
||||
|
||||
constexpr float B[] = {
|
||||
+7.51729631084210481353e-18, +4.41434832307170791151e-18,
|
||||
-4.65030536848935832153e-17, -3.20952592199342395980e-17,
|
||||
+2.96262899764595013876e-16, +3.30820231092092828324e-16,
|
||||
-1.88035477551078244854e-15, -3.81440307243700780478e-15,
|
||||
+1.04202769841288027642e-14, +4.27244001671195135429e-14,
|
||||
-2.10154184277266431302e-14, -4.08355111109219731823e-13,
|
||||
-7.19855177624590851209e-13, +2.03562854414708950722e-12,
|
||||
+1.41258074366137813316e-11, +3.25260358301548823856e-11,
|
||||
-1.89749581235054123450e-11, -5.58974346219658380687e-10,
|
||||
-3.83538038596423702205e-09, -2.63146884688951950684e-08,
|
||||
-2.51223623787020892529e-07, -3.88256480887769039346e-06,
|
||||
-1.10588938762623716291e-04, -9.76109749136146840777e-03,
|
||||
+7.78576235018280120474e-01,
|
||||
};
|
||||
|
||||
float p;
|
||||
float q = 0.0;
|
||||
|
||||
if (::metal::fabs(x) <= T(8.0)) {
|
||||
float a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 29; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = (.5 * ::metal::fabs(x) - 2.0) * q - p + A[index];
|
||||
}
|
||||
|
||||
return .5 * (a - p) * x * ::metal::precise::exp(::metal::fabs(x));
|
||||
}
|
||||
|
||||
float b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (32.0 / ::metal::fabs(x) - 2.0) * q - p + B[index];
|
||||
}
|
||||
|
||||
if (x < 0.0) {
|
||||
return -(
|
||||
::metal::precise::exp(::metal::fabs(x)) * (0.5 * (b - p)) /
|
||||
::metal::precise::sqrt(::metal::fabs(x)));
|
||||
}
|
||||
|
||||
return ::metal::precise::exp(::metal::fabs(x)) * (0.5 * (b - p)) /
|
||||
::metal::precise::sqrt(::metal::fabs(x));
|
||||
} // modified_bessel_i1_forward(T x)
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
@ -652,7 +652,6 @@ def mps_ops_modifier(ops):
|
||||
'special.hermite_polynomial_he': None,
|
||||
'special.laguerre_polynomial_l': None,
|
||||
'special.log_ndtr': None,
|
||||
'special.modified_bessel_i1': None,
|
||||
'special.modified_bessel_k0': None,
|
||||
'special.modified_bessel_k1': None,
|
||||
'special.ndtri': None,
|
||||
|
Reference in New Issue
Block a user