[MPS/inductor] Fix the approximation of polygamma for n == 0. (#152214)

Fixes #152205

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152214
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano
2025-04-25 22:42:45 +00:00
committed by PyTorch MergeBot
parent cf101d66ee
commit e28864fc0f
3 changed files with 15 additions and 10 deletions

View File

@ -81,7 +81,8 @@ kernel void polygamma(
constant int64_t& order [[buffer(2)]],
uint id [[thread_position_in_grid]]) {
// already blocked if n <= 1
output[id] = static_cast<T1>(c10::metal::polygamma(order, input[id]));
output[id] = static_cast<T1>(
c10::metal::polygamma(order, static_cast<float>(input[id])));
}
#define INSTANTIATE_GAMMA_KERNELS(DTYPE0, DTYPE1) \

View File

@ -478,14 +478,6 @@ inline float zeta(float x, float q) {
return s;
}
template <typename T0>
inline float polygamma(const int64_t order, const T0 input) {
float x = input;
float n = order;
float sgn = ((order % 2) ? 1 : -1);
return sgn * gamma(n + 1) * zeta(n + 1, x);
}
inline float calc_digamma_positive_domain(float x) {
constexpr float DIGAMMA_COEF[7] = {
8.33333333333333333333E-2,
@ -546,6 +538,19 @@ inline float digamma(T0 x) {
}
}
template <typename T0>
inline float polygamma(const int64_t order, const T0 input) {
// Filter out n == 0.
if (order == 0) {
return digamma(input);
}
float x = input;
float n = order;
float sgn = ((order % 2) ? 1 : -1);
return sgn * gamma(n + 1) * zeta(n + 1, x);
}
template <typename T>
inline ::metal::enable_if_t<is_scalar_floating_point_v<T>, T> sinc(T a) {
if (a == static_cast<T>(0)) {

View File

@ -13223,7 +13223,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
ignore_empty_lines=True,
)
@xfail_if_mps
@expectedFailureCodegenDynamic
def test_special_polygamma(self):
fn = torch.special.polygamma