[mps] Move polygamma to special_math.h. (#146253)

In preparation to implement it in inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146253
Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
Davide Italiano
2025-02-01 21:45:23 +00:00
committed by PyTorch MergeBot
parent 07dbd539b4
commit dca5cc0255
2 changed files with 9 additions and 5 deletions

View File

@ -136,11 +136,7 @@ kernel void polygamma(
constant int64_t& order [[buffer(2)]],
uint id [[thread_position_in_grid]]) {
// already blocked if n <= 1
float x = input[id];
float n = order;
float sgn = ((order % 2) ? 1 : -1);
output[id] = static_cast<T1>(
sgn * c10::metal::gamma(n + 1) * c10::metal::zeta(n + 1, x));
output[id] = static_cast<T1>(c10::metal::polygamma(input[id], order));
}
#define INSTANTIATE_GAMMA_KERNELS(DTYPE0, DTYPE1) \

View File

@ -382,5 +382,13 @@ float zeta(float x, float q) {
return s;
}
template <typename T0, typename T1>
float polygamma(const T0 input, const T1 order) {
float x = input;
float n = order;
float sgn = ((order % 2) ? 1 : -1);
return sgn * gamma(n + 1) * zeta(n + 1, x);
}
} // namespace metal
} // namespace c10