[Pytorch] Add option to CPU Blas GEMM to avoid output downcast (#154012)

Summary:
Dot product for a single output element consists of 3 steps (both input vectors have elements of type scalar_t):
1. elementwise vector multiply (scalar_t x scalar_t -> opmath_t)
2. vector reduction to a scalar value (opmath_t -> opmath_t)
3. optional downcast if opmath_t != out_t

The current blas kernel performs steps 1 and 2 correctly, but for step 3, it will always downcast to scalar_t even when opmath_t == output_t (and then do an upcast back to output_t), which results in precision loss. This diff fixes the precision loss in the BlasKernel

Test Plan: Attention CI passes

Differential Revision: D75023858

topic: not user facing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154012
Approved by: https://github.com/Valentine233, https://github.com/aditew01, https://github.com/CaoE, https://github.com/drisspg
This commit is contained in:
Cyrus Daruwala
2025-05-27 17:43:21 +00:00
committed by PyTorch MergeBot
parent 1ca082d9a1
commit cfbd99fdfd
3 changed files with 53 additions and 17 deletions

View File

@ -135,6 +135,7 @@ CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) {
} // namespace (anonymous)
DEFINE_DISPATCH(gemm_stub);
DEFINE_DISPATCH(gemm_no_downcast_stub);
void gemm(
TransposeType transa, TransposeType transb,
@ -452,18 +453,18 @@ void gemm(
// for the fallback path, first compute gemm with beta = 0,
// and then add c in full precision.
int64_t c_size = n * m;
std::vector<at::BFloat16> bfloat_c(c_size, 0.f);
gemm_stub(
std::vector<float> float_c(c_size, 0.f);
gemm_no_downcast_stub(
at::kCPU, at::kBFloat16,
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m);
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
auto offset = j * ldc + i;
// beta == 0 won't propagate NaN from C
if (beta == 0.f) {
c[offset] = c10::convert<float>(bfloat_c[j * m + i]);
c[offset] = float_c[j * m + i];
} else {
c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]);
c[offset] = beta * c[offset] + float_c[j * m + i];
}
}
}

View File

@ -29,6 +29,18 @@ using gemm_fn = void(*)(
DECLARE_DISPATCH(gemm_fn, gemm_stub)
using gemm_no_downcast_fn = void(*)(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc);
DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
template <typename scalar_t>
void gemm(
TransposeType transa, TransposeType transb,

View File

@ -99,7 +99,7 @@ auto sum(int64_t N, Func f) {
return partial_sums[0];
}
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
gemm_notrans_(
int64_t m,
@ -111,7 +111,7 @@ gemm_notrans_(
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
out_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
@ -135,7 +135,7 @@ gemm_notrans_(
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
gemm_notrans_(
int64_t m,
@ -147,7 +147,7 @@ gemm_notrans_(
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
out_t* c,
int64_t ldc) {
// c += alpha * (a @ b)
for (const auto i : c10::irange(m)) {
@ -165,7 +165,7 @@ gemm_notrans_(
}
}
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
@ -173,7 +173,7 @@ void gemm_transa_(
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
out_t *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
const scalar_t *a_ = a;
for (const auto i : c10::irange(m)) {
@ -225,6 +225,7 @@ void gemm_transb_impl(
}
}
// in this case, scalar_t == opmath_t == out_t so out_t template param is not needed
template <typename scalar_t, typename opmath_t>
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
gemm_transb_(
@ -247,7 +248,7 @@ gemm_transb_(
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
gemm_transb_(
TransposeType transb,
@ -260,7 +261,7 @@ gemm_transb_(
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
out_t* c,
int64_t ldc) {
// We need to calculate full-precision dot products for correctness;
// users notice error accumulation with reduced-width types (e.g.,
@ -304,7 +305,7 @@ gemm_transb_(
}
}
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
void gemm_transab_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
@ -312,7 +313,7 @@ void gemm_transab_(
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
out_t *c, int64_t ldc) {
// c = beta * c + alpha * (a.T @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
@ -436,7 +437,7 @@ void gemm_transa_(
}
#endif // !defined(C10_MOBILE)
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
void gemm_core_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
@ -444,7 +445,7 @@ void gemm_core_(
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
out_t *c, int64_t ldc) {
if (transa == TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
@ -493,6 +494,27 @@ void cpublas_gemm_impl(
});
}
void cpublas_gemm_no_downcast_impl(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc) {
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_no_downcast_impl", [&]{
using opmath_t = at::opmath_type<scalar_t>;
gemm_core_(
transa, transb, m, n, k,
alpha.to<opmath_t>(),
static_cast<const scalar_t *>(a), lda,
static_cast<const scalar_t *>(b), ldb,
beta.to<opmath_t>(),
static_cast<opmath_t *>(c), ldc);
});
}
void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
if (type == at::kBool) {
auto a = _a.to<bool>();
@ -530,6 +552,7 @@ void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t i
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl)
REGISTER_DISPATCH(cpublas::gemm_no_downcast_stub, &cpublas::cpublas_gemm_no_downcast_impl)
REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl)
REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl)