mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[MPSInductor] Add gamma
op (#145341)
By moving `gamma` and `log_gamma` implementation from `Gamma.metal` to `c10/metal/special_math.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145341 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #145309
This commit is contained in:
committed by
PyTorch MergeBot
parent
b81209557b
commit
70ccbade83
@ -10,125 +10,10 @@
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*/
|
||||
|
||||
#include <c10/metal/special_math.h>
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
float LogGamma(const T);
|
||||
|
||||
template <typename T>
|
||||
float Gamma(const T x) {
|
||||
if (x < 0.001) {
|
||||
constexpr float EULER_MASCHERONI = 0.577215664901532860606512090;
|
||||
// For small x, 1/Gamma(x) has power series x + gamma x^2 - ...
|
||||
// So in this range, 1/Gamma(x) = x + gamma x^2 with error on the order of
|
||||
// x^3. The relative error over this interval is less than 6e-7.
|
||||
|
||||
return 1.0 / (x * (1.0 + EULER_MASCHERONI * x));
|
||||
}
|
||||
if (x >= 12.0) {
|
||||
return exp(LogGamma(x));
|
||||
}
|
||||
// The algorithm directly approximates gamma over (1,2) and uses
|
||||
// reduction identities to reduce other arguments to this interval.
|
||||
// numerator coefficients for gamma approximation over the interval (1,2)
|
||||
const float GAMMA_NUMERATOR_COEF[8] = {
|
||||
-1.71618513886549492533811E+0,
|
||||
2.47656508055759199108314E+1,
|
||||
-3.79804256470945635097577E+2,
|
||||
6.29331155312818442661052E+2,
|
||||
8.66966202790413211295064E+2,
|
||||
-3.14512729688483675254357E+4,
|
||||
-3.61444134186911729807069E+4,
|
||||
6.64561438202405440627855E+4};
|
||||
|
||||
// denominator coefficients for gamma approximation over the interval (1,2)
|
||||
const float GAMMA_DENOMINATOR_COEF[8] = {
|
||||
-3.08402300119738975254353E+1,
|
||||
3.15350626979604161529144E+2,
|
||||
-1.01515636749021914166146E+3,
|
||||
-3.10777167157231109440444E+3,
|
||||
2.25381184209801510330112E+4,
|
||||
4.75584627752788110767815E+3,
|
||||
-1.34659959864969306392456E+5,
|
||||
-1.15132259675553483497211E+5};
|
||||
|
||||
// Add or subtract integers as necessary to bring y into (1,2)
|
||||
float y = 1.0 + fract(x);
|
||||
|
||||
float num = 0.0;
|
||||
float den = 1.0;
|
||||
|
||||
float z = y - 1;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
num = (num + GAMMA_NUMERATOR_COEF[i]) * z;
|
||||
den = den * z + GAMMA_DENOMINATOR_COEF[i];
|
||||
}
|
||||
float result = num / den + 1.0;
|
||||
|
||||
// Apply correction if argument was not initially in (1,2)
|
||||
if (x < 1.0) {
|
||||
// identity gamma(z) = gamma(z+1)/z
|
||||
result /= (y - 1.0);
|
||||
} else {
|
||||
// identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
|
||||
auto n = static_cast<int>(floor(x));
|
||||
for (int i = 1; i < n; i++) {
|
||||
result *= y++;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float LogGamma(const T x) {
|
||||
constexpr float LOG_PI = 1.14472988584940017414342735135305;
|
||||
constexpr float HALF_LOG_TWO_PI = 0.91893853320467274178032973640562;
|
||||
constexpr float LGAMMA_EXPANSION_COEF[8] = {
|
||||
1.0 / 12.0,
|
||||
-1.0 / 360.0,
|
||||
1.0 / 1260.0,
|
||||
-1.0 / 1680.0,
|
||||
1.0 / 1188.0,
|
||||
-691.0 / 360360.0,
|
||||
1.0 / 156.0,
|
||||
-3617.0 / 122400.0};
|
||||
|
||||
float logGamma;
|
||||
|
||||
const auto abs_x = metal::abs(static_cast<float>(x));
|
||||
if (abs_x == 0) {
|
||||
return INFINITY;
|
||||
}
|
||||
if (abs_x < 12.0) {
|
||||
logGamma = log(fabs(Gamma(abs_x)));
|
||||
} else {
|
||||
// Abramowitz and Stegun 6.1.41
|
||||
// Asymptotic series should be good to at least 11 or 12 figures
|
||||
// For error analysis, see Whittiker and Watson
|
||||
// A Course in Modern Analysis (1927), page 252
|
||||
|
||||
float z = 1.0 / (abs_x * abs_x);
|
||||
float sum = LGAMMA_EXPANSION_COEF[7];
|
||||
|
||||
for (int i = 6; i >= 0; i--) {
|
||||
sum *= z;
|
||||
sum += LGAMMA_EXPANSION_COEF[i];
|
||||
}
|
||||
float series = sum / abs_x;
|
||||
|
||||
logGamma = (abs_x - 0.5) * log(abs_x) - abs_x + HALF_LOG_TWO_PI + series;
|
||||
}
|
||||
|
||||
if (x >= 0) {
|
||||
return logGamma;
|
||||
}
|
||||
|
||||
return LOG_PI - logGamma -
|
||||
log(fabs(abs_x * sinpi(abs_x))); // Reflection Formula
|
||||
}
|
||||
|
||||
float calc_digamma_positive_domain(float x) {
|
||||
const float DIGAMMA_COEF[7] = {
|
||||
8.33333333333333333333E-2,
|
||||
@ -270,7 +155,8 @@ kernel void lgamma(
|
||||
constant T0* input [[buffer(0)]],
|
||||
device T1* output [[buffer(1)]],
|
||||
uint id [[thread_position_in_grid]]) {
|
||||
output[id] = static_cast<T1>(LogGamma(static_cast<float>(input[id])));
|
||||
output[id] =
|
||||
static_cast<T1>(c10::metal::log_gamma(static_cast<float>(input[id])));
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
@ -322,7 +208,8 @@ kernel void polygamma(
|
||||
float x = input[id];
|
||||
float n = order;
|
||||
float sgn = ((order % 2) ? 1 : -1);
|
||||
output[id] = static_cast<T1>(sgn * Gamma(n + 1) * calc_zeta(n + 1, x));
|
||||
output[id] =
|
||||
static_cast<T1>(sgn * c10::metal::gamma(n + 1) * calc_zeta(n + 1, x));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GAMMA_KERNELS(DTYPE0, DTYPE1) \
|
||||
|
@ -161,5 +161,123 @@ T i1(T _x) {
|
||||
return static_cast<T>(_x < T(0.) ? -out : out);
|
||||
}
|
||||
|
||||
// gamma, lgamma
|
||||
template <typename T>
|
||||
float log_gamma(const T);
|
||||
|
||||
template <typename T>
|
||||
float gamma(const T x) {
|
||||
if (x < 0.001) {
|
||||
constexpr float EULER_MASCHERONI = 0.577215664901532860606512090;
|
||||
// For small x, 1/gamma(x) has power series x + gamma x^2 - ...
|
||||
// So in this range, 1/gamma(x) = x + gamma x^2 with error on the order of
|
||||
// x^3. The relative error over this interval is less than 6e-7.
|
||||
|
||||
return 1.0 / (x * (1.0 + EULER_MASCHERONI * x));
|
||||
}
|
||||
if (x >= 12.0) {
|
||||
return ::metal::exp(log_gamma(x));
|
||||
}
|
||||
// The algorithm directly approximates gamma over (1,2) and uses
|
||||
// reduction identities to reduce other arguments to this interval.
|
||||
// numerator coefficients for gamma approximation over the interval (1,2)
|
||||
constexpr float GAMMA_NUMERATOR_COEF[8] = {
|
||||
-1.71618513886549492533811E+0,
|
||||
2.47656508055759199108314E+1,
|
||||
-3.79804256470945635097577E+2,
|
||||
6.29331155312818442661052E+2,
|
||||
8.66966202790413211295064E+2,
|
||||
-3.14512729688483675254357E+4,
|
||||
-3.61444134186911729807069E+4,
|
||||
6.64561438202405440627855E+4};
|
||||
|
||||
// denominator coefficients for gamma approximation over the interval (1,2)
|
||||
constexpr float GAMMA_DENOMINATOR_COEF[8] = {
|
||||
-3.08402300119738975254353E+1,
|
||||
3.15350626979604161529144E+2,
|
||||
-1.01515636749021914166146E+3,
|
||||
-3.10777167157231109440444E+3,
|
||||
2.25381184209801510330112E+4,
|
||||
4.75584627752788110767815E+3,
|
||||
-1.34659959864969306392456E+5,
|
||||
-1.15132259675553483497211E+5};
|
||||
|
||||
// Add or subtract integers as necessary to bring y into (1,2)
|
||||
float y = 1.0 + ::metal::fract(x);
|
||||
|
||||
float num = 0.0;
|
||||
float den = 1.0;
|
||||
|
||||
float z = y - 1;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
num = (num + GAMMA_NUMERATOR_COEF[i]) * z;
|
||||
den = den * z + GAMMA_DENOMINATOR_COEF[i];
|
||||
}
|
||||
float result = num / den + 1.0;
|
||||
|
||||
// Apply correction if argument was not initially in (1,2)
|
||||
if (x < 1.0) {
|
||||
// identity gamma(z) = gamma(z+1)/z
|
||||
result /= (y - 1.0);
|
||||
} else {
|
||||
// identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
|
||||
auto n = static_cast<int>(::metal::floor(x));
|
||||
for (int i = 1; i < n; i++) {
|
||||
result *= y++;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float log_gamma(const T x) {
|
||||
constexpr float LOG_PI = 1.14472988584940017414342735135305;
|
||||
constexpr float HALF_LOG_TWO_PI = 0.91893853320467274178032973640562;
|
||||
constexpr float LGAMMA_EXPANSION_COEF[8] = {
|
||||
1.0 / 12.0,
|
||||
-1.0 / 360.0,
|
||||
1.0 / 1260.0,
|
||||
-1.0 / 1680.0,
|
||||
1.0 / 1188.0,
|
||||
-691.0 / 360360.0,
|
||||
1.0 / 156.0,
|
||||
-3617.0 / 122400.0};
|
||||
|
||||
float rc;
|
||||
|
||||
const auto abs_x = ::metal::abs(static_cast<float>(x));
|
||||
if (abs_x == 0) {
|
||||
return INFINITY;
|
||||
}
|
||||
if (abs_x < 12.0) {
|
||||
rc = ::metal::log(::metal::abs(gamma(abs_x)));
|
||||
} else {
|
||||
// Abramowitz and Stegun 6.1.41
|
||||
// Asymptotic series should be good to at least 11 or 12 figures
|
||||
// For error analysis, see Whittiker and Watson
|
||||
// A Course in Modern Analysis (1927), page 252
|
||||
|
||||
float z = 1.0 / (abs_x * abs_x);
|
||||
float sum = LGAMMA_EXPANSION_COEF[7];
|
||||
|
||||
for (int i = 6; i >= 0; i--) {
|
||||
sum *= z;
|
||||
sum += LGAMMA_EXPANSION_COEF[i];
|
||||
}
|
||||
float series = sum / abs_x;
|
||||
|
||||
rc = (abs_x - 0.5) * ::metal::log(abs_x) - abs_x + HALF_LOG_TWO_PI + series;
|
||||
}
|
||||
|
||||
if (x >= 0) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Reflection formula
|
||||
return LOG_PI - rc -
|
||||
::metal::log(::metal::abs(abs_x * ::metal::sinpi(abs_x)));
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
@ -146,7 +146,11 @@ class MPSBasicTests(TestCase):
|
||||
|
||||
|
||||
# Copy tests
|
||||
for test_name in ["test_builtins_round", "test_builtins_round_float_ndigits_neg"]:
|
||||
for test_name in [
|
||||
"test_builtins_round",
|
||||
"test_builtins_round_float_ndigits_neg",
|
||||
"test_lgamma",
|
||||
]:
|
||||
setattr(MPSBasicTests, test_name, getattr(CommonTemplate, test_name))
|
||||
|
||||
instantiate_parametrized_tests(MPSBasicTests)
|
||||
|
@ -222,6 +222,10 @@ class MetalOverrides(OpOverrides):
|
||||
def erf(x: CSEVariable) -> str:
|
||||
return f"c10::metal::erf({x})"
|
||||
|
||||
@staticmethod
|
||||
def lgamma(x: CSEVariable) -> str:
|
||||
return f"c10::metal::log_gamma({x})"
|
||||
|
||||
@staticmethod
|
||||
def tan(x: CSEVariable) -> str:
|
||||
return f"metal::tan({x})"
|
||||
|
Reference in New Issue
Block a user