mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Pytorch] Add option to CPU Blas GEMM to avoid output downcast (#154012)
Summary: Dot product for a single output element consists of 3 steps (both input vectors have elements of type scalar_t): 1. elementwise vector multiply (scalar_t x scalar_t -> opmath_t) 2. vector reduction to a scalar value (opmath_t -> opmath_t) 3. optional downcast if opmath_t != out_t The current blas kernel performs steps 1 and 2 correctly, but for step 3, it will always downcast to scalar_t even when opmath_t == output_t (and then do an upcast back to output_t), which results in precision loss. This diff fixes the precision loss in the BlasKernel Test Plan: Attention CI passes Differential Revision: D75023858 topic: not user facing Pull Request resolved: https://github.com/pytorch/pytorch/pull/154012 Approved by: https://github.com/Valentine233, https://github.com/aditew01, https://github.com/CaoE, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ca082d9a1
commit
cfbd99fdfd
@ -135,6 +135,7 @@ CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) {
|
||||
} // namespace (anonymous)
|
||||
|
||||
DEFINE_DISPATCH(gemm_stub);
|
||||
DEFINE_DISPATCH(gemm_no_downcast_stub);
|
||||
|
||||
void gemm(
|
||||
TransposeType transa, TransposeType transb,
|
||||
@ -452,18 +453,18 @@ void gemm(
|
||||
// for the fallback path, first compute gemm with beta = 0,
|
||||
// and then add c in full precision.
|
||||
int64_t c_size = n * m;
|
||||
std::vector<at::BFloat16> bfloat_c(c_size, 0.f);
|
||||
gemm_stub(
|
||||
std::vector<float> float_c(c_size, 0.f);
|
||||
gemm_no_downcast_stub(
|
||||
at::kCPU, at::kBFloat16,
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m);
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
auto offset = j * ldc + i;
|
||||
// beta == 0 won't propagate NaN from C
|
||||
if (beta == 0.f) {
|
||||
c[offset] = c10::convert<float>(bfloat_c[j * m + i]);
|
||||
c[offset] = float_c[j * m + i];
|
||||
} else {
|
||||
c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]);
|
||||
c[offset] = beta * c[offset] + float_c[j * m + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,18 @@ using gemm_fn = void(*)(
|
||||
|
||||
DECLARE_DISPATCH(gemm_fn, gemm_stub)
|
||||
|
||||
using gemm_no_downcast_fn = void(*)(
|
||||
at::ScalarType type,
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
const Scalar& alpha,
|
||||
const void *a, int64_t lda,
|
||||
const void *b, int64_t ldb,
|
||||
const Scalar& beta,
|
||||
void *c, int64_t ldc);
|
||||
|
||||
DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
|
||||
|
||||
template <typename scalar_t>
|
||||
void gemm(
|
||||
TransposeType transa, TransposeType transb,
|
||||
|
@ -99,7 +99,7 @@ auto sum(int64_t N, Func f) {
|
||||
return partial_sums[0];
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_notrans_(
|
||||
int64_t m,
|
||||
@ -111,7 +111,7 @@ gemm_notrans_(
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
out_t* c,
|
||||
int64_t ldc) {
|
||||
// c *= beta
|
||||
scale_(m, n, beta, c, ldc);
|
||||
@ -135,7 +135,7 @@ gemm_notrans_(
|
||||
}
|
||||
|
||||
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_notrans_(
|
||||
int64_t m,
|
||||
@ -147,7 +147,7 @@ gemm_notrans_(
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
out_t* c,
|
||||
int64_t ldc) {
|
||||
// c += alpha * (a @ b)
|
||||
for (const auto i : c10::irange(m)) {
|
||||
@ -165,7 +165,7 @@ gemm_notrans_(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
void gemm_transa_(
|
||||
TransposeType transa,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -173,7 +173,7 @@ void gemm_transa_(
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
out_t *c, int64_t ldc) {
|
||||
// c = alpha * (a.T @ b) + beta * c
|
||||
const scalar_t *a_ = a;
|
||||
for (const auto i : c10::irange(m)) {
|
||||
@ -225,6 +225,7 @@ void gemm_transb_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// in this case, scalar_t == opmath_t == out_t so out_t template param is not needed
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_transb_(
|
||||
@ -247,7 +248,7 @@ gemm_transb_(
|
||||
}
|
||||
|
||||
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_transb_(
|
||||
TransposeType transb,
|
||||
@ -260,7 +261,7 @@ gemm_transb_(
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
out_t* c,
|
||||
int64_t ldc) {
|
||||
// We need to calculate full-precision dot products for correctness;
|
||||
// users notice error accumulation with reduced-width types (e.g.,
|
||||
@ -304,7 +305,7 @@ gemm_transb_(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
void gemm_transab_(
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -312,7 +313,7 @@ void gemm_transab_(
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
out_t *c, int64_t ldc) {
|
||||
// c = beta * c + alpha * (a.T @ b.T)
|
||||
for (const auto i : c10::irange(m)) {
|
||||
for (const auto j : c10::irange(n)) {
|
||||
@ -436,7 +437,7 @@ void gemm_transa_(
|
||||
}
|
||||
#endif // !defined(C10_MOBILE)
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
void gemm_core_(
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -444,7 +445,7 @@ void gemm_core_(
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
out_t *c, int64_t ldc) {
|
||||
if (transa == TransposeType::NoTranspose &&
|
||||
transb == TransposeType::NoTranspose) {
|
||||
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
@ -493,6 +494,27 @@ void cpublas_gemm_impl(
|
||||
});
|
||||
}
|
||||
|
||||
void cpublas_gemm_no_downcast_impl(
|
||||
at::ScalarType type,
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
const Scalar& alpha,
|
||||
const void *a, int64_t lda,
|
||||
const void *b, int64_t ldb,
|
||||
const Scalar& beta,
|
||||
void *c, int64_t ldc) {
|
||||
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_no_downcast_impl", [&]{
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
gemm_core_(
|
||||
transa, transb, m, n, k,
|
||||
alpha.to<opmath_t>(),
|
||||
static_cast<const scalar_t *>(a), lda,
|
||||
static_cast<const scalar_t *>(b), ldb,
|
||||
beta.to<opmath_t>(),
|
||||
static_cast<opmath_t *>(c), ldc);
|
||||
});
|
||||
}
|
||||
|
||||
void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
|
||||
if (type == at::kBool) {
|
||||
auto a = _a.to<bool>();
|
||||
@ -530,6 +552,7 @@ void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t i
|
||||
|
||||
|
||||
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl)
|
||||
REGISTER_DISPATCH(cpublas::gemm_no_downcast_stub, &cpublas::cpublas_gemm_no_downcast_impl)
|
||||
REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl)
|
||||
REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl)
|
||||
|
||||
|
Reference in New Issue
Block a user