mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[mps] Move polygamma to special_math.h. (#146253)
In preparation to implement it in inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146253 Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
07dbd539b4
commit
dca5cc0255
@ -136,11 +136,7 @@ kernel void polygamma(
|
||||
constant int64_t& order [[buffer(2)]],
|
||||
uint id [[thread_position_in_grid]]) {
|
||||
// already blocked if n <= 1
|
||||
float x = input[id];
|
||||
float n = order;
|
||||
float sgn = ((order % 2) ? 1 : -1);
|
||||
output[id] = static_cast<T1>(
|
||||
sgn * c10::metal::gamma(n + 1) * c10::metal::zeta(n + 1, x));
|
||||
output[id] = static_cast<T1>(c10::metal::polygamma(input[id], order));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GAMMA_KERNELS(DTYPE0, DTYPE1) \
|
||||
|
@ -382,5 +382,13 @@ float zeta(float x, float q) {
|
||||
return s;
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
float polygamma(const T0 input, const T1 order) {
|
||||
float x = input;
|
||||
float n = order;
|
||||
float sgn = ((order % 2) ? 1 : -1);
|
||||
return sgn * gamma(n + 1) * zeta(n + 1, x);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
Reference in New Issue
Block a user