[MPS] Implement support for modified_bessel_i1 in eager. (#149368)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149368
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Davide Italiano
2025-03-18 03:29:10 +00:00
committed by PyTorch MergeBot
parent bb42e4d137
commit c43e35d6f7
5 changed files with 77 additions and 2 deletions

View File

@ -6,6 +6,7 @@ using namespace metal;
DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j0_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i0_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i1_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
@ -49,6 +50,7 @@ struct bessel_y1_forward_functor {
REGISTER_UNARY_OP(bessel_j0_forward, DTI, DTO); \
REGISTER_UNARY_OP(bessel_j1_forward, DTI, DTO); \
REGISTER_UNARY_OP(modified_bessel_i0_forward, DTI, DTO); \
REGISTER_UNARY_OP(modified_bessel_i1_forward, DTI, DTO); \
REGISTER_UNARY_OP(bessel_y0_forward, DTI, DTO); \
REGISTER_UNARY_OP(bessel_y1_forward, DTI, DTO); \
REGISTER_UNARY_OP(i0, DTI, DTO); \

View File

@ -48,6 +48,10 @@ static void modified_bessel_i0_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "modified_bessel_i0_forward");
}
static void modified_bessel_i1_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "modified_bessel_i1_forward");
}
static void bessel_y0_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "bessel_y0_forward");
}
@ -63,6 +67,7 @@ REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_mps)
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_mps)
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_mps)
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_mps)
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)

View File

@ -15454,7 +15454,7 @@
- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_modified_bessel_i1_out
CPU, CUDA, MPS: special_modified_bessel_i1_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True

View File

@ -1185,5 +1185,74 @@ inline float modified_bessel_i0_forward(T x) {
::metal::precise::sqrt(::metal::fabs(x));
} // modified_bessel_i0_forward(T x)
template <typename T>
inline float modified_bessel_i1_forward(T x) {
constexpr float A[] = {
+2.77791411276104639959e-18, -2.11142121435816608115e-17,
+1.55363195773620046921e-16, -1.10559694773538630805e-15,
+7.60068429473540693410e-15, -5.04218550472791168711e-14,
+3.22379336594557470981e-13, -1.98397439776494371520e-12,
+1.17361862988909016308e-11, -6.66348972350202774223e-11,
+3.62559028155211703701e-10, -1.88724975172282928790e-09,
+9.38153738649577178388e-09, -4.44505912879632808065e-08,
+2.00329475355213526229e-07, -8.56872026469545474066e-07,
+3.47025130813767847674e-06, -1.32731636560394358279e-05,
+4.78156510755005422638e-05, -1.61760815825896745588e-04,
+5.12285956168575772895e-04, -1.51357245063125314899e-03,
+4.15642294431288815669e-03, -1.05640848946261981558e-02,
+2.47264490306265168283e-02, -5.29459812080949914269e-02,
+1.02643658689847095384e-01, -1.76416518357834055153e-01,
+2.52587186443633654823e-01,
};
constexpr float B[] = {
+7.51729631084210481353e-18, +4.41434832307170791151e-18,
-4.65030536848935832153e-17, -3.20952592199342395980e-17,
+2.96262899764595013876e-16, +3.30820231092092828324e-16,
-1.88035477551078244854e-15, -3.81440307243700780478e-15,
+1.04202769841288027642e-14, +4.27244001671195135429e-14,
-2.10154184277266431302e-14, -4.08355111109219731823e-13,
-7.19855177624590851209e-13, +2.03562854414708950722e-12,
+1.41258074366137813316e-11, +3.25260358301548823856e-11,
-1.89749581235054123450e-11, -5.58974346219658380687e-10,
-3.83538038596423702205e-09, -2.63146884688951950684e-08,
-2.51223623787020892529e-07, -3.88256480887769039346e-06,
-1.10588938762623716291e-04, -9.76109749136146840777e-03,
+7.78576235018280120474e-01,
};
float p;
float q = 0.0;
if (::metal::fabs(x) <= T(8.0)) {
float a = A[0];
for (uint8_t index = 1; index < 29; index++) {
p = q;
q = a;
a = (.5 * ::metal::fabs(x) - 2.0) * q - p + A[index];
}
return .5 * (a - p) * x * ::metal::precise::exp(::metal::fabs(x));
}
float b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (32.0 / ::metal::fabs(x) - 2.0) * q - p + B[index];
}
if (x < 0.0) {
return -(
::metal::precise::exp(::metal::fabs(x)) * (0.5 * (b - p)) /
::metal::precise::sqrt(::metal::fabs(x)));
}
return ::metal::precise::exp(::metal::fabs(x)) * (0.5 * (b - p)) /
::metal::precise::sqrt(::metal::fabs(x));
} // modified_bessel_i1_forward(T x)
} // namespace metal
} // namespace c10

View File

@ -652,7 +652,6 @@ def mps_ops_modifier(ops):
'special.hermite_polynomial_he': None,
'special.laguerre_polynomial_l': None,
'special.log_ndtr': None,
'special.modified_bessel_i1': None,
'special.modified_bessel_k0': None,
'special.modified_bessel_k1': None,
'special.ndtri': None,