mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS/inductor] Fix the approximation of polygamma for n == 0. (#152214)
Fixes #152205 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152214 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf101d66ee
commit
e28864fc0f
@ -81,7 +81,8 @@ kernel void polygamma(
|
||||
constant int64_t& order [[buffer(2)]],
|
||||
uint id [[thread_position_in_grid]]) {
|
||||
// already blocked if n <= 1
|
||||
output[id] = static_cast<T1>(c10::metal::polygamma(order, input[id]));
|
||||
output[id] = static_cast<T1>(
|
||||
c10::metal::polygamma(order, static_cast<float>(input[id])));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GAMMA_KERNELS(DTYPE0, DTYPE1) \
|
||||
|
@ -478,14 +478,6 @@ inline float zeta(float x, float q) {
|
||||
return s;
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
inline float polygamma(const int64_t order, const T0 input) {
|
||||
float x = input;
|
||||
float n = order;
|
||||
float sgn = ((order % 2) ? 1 : -1);
|
||||
return sgn * gamma(n + 1) * zeta(n + 1, x);
|
||||
}
|
||||
|
||||
inline float calc_digamma_positive_domain(float x) {
|
||||
constexpr float DIGAMMA_COEF[7] = {
|
||||
8.33333333333333333333E-2,
|
||||
@ -546,6 +538,19 @@ inline float digamma(T0 x) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
inline float polygamma(const int64_t order, const T0 input) {
|
||||
// Filter out n == 0.
|
||||
if (order == 0) {
|
||||
return digamma(input);
|
||||
}
|
||||
|
||||
float x = input;
|
||||
float n = order;
|
||||
float sgn = ((order % 2) ? 1 : -1);
|
||||
return sgn * gamma(n + 1) * zeta(n + 1, x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<is_scalar_floating_point_v<T>, T> sinc(T a) {
|
||||
if (a == static_cast<T>(0)) {
|
||||
|
@ -13223,7 +13223,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
@expectedFailureCodegenDynamic
|
||||
def test_special_polygamma(self):
|
||||
fn = torch.special.polygamma
|
||||
|
Reference in New Issue
Block a user