use accumulate type in BF16 gemm(include dot, mv) ref path (#96074)

Fix https://github.com/pytorch/pytorch/issues/95125 and https://github.com/pytorch/pytorch/issues/83863 for bf16 accumulation in gemm ref path

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96074
Approved by: https://github.com/lezcano, https://github.com/peterbell10
This commit is contained in:
haozhe.zhu
2023-03-22 14:05:48 +00:00
committed by PyTorch MergeBot
parent b45880c537
commit fe0afc5852
5 changed files with 201 additions and 45 deletions

View File

@ -1,13 +1,14 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <limits>
#include <algorithm>
#include <climits>
#include <ATen/Config.h>
#include <ATen/OpMathType.h>
#include <c10/core/ScalarType.h>
#include <c10/util/irange.h>
#include <c10/util/Exception.h>
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <climits>
#include <iostream>
#include <limits>
#if AT_BUILD_WITH_BLAS()
extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy);
extern "C" void dscal_(int *n, double *a, double *x, int *incx);
@ -180,9 +181,10 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t
return;
}
using opmath_t = at::opmath_type<scalar_t>;
if ((trans == 'T') || (trans == 't')) {
for (const auto i : c10::irange(n)) {
scalar_t sum = 0;
opmath_t sum = 0;
scalar_t *row_ = a + lda * i;
for (const auto j : c10::irange(m)) {
sum += x[j * incx] * row_[j];
@ -196,15 +198,37 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t
} else {
if (beta != scalar_t(1) && beta != scalar_t(0)) scal<scalar_t>(m, beta, y, incy);
bool is_low_precision = !std::is_same<opmath_t, scalar_t>::value;
std::vector<opmath_t> sum;
if (is_low_precision) {
sum.resize(m);
}
for (const auto j : c10::irange(n)) {
scalar_t *column_ = a + lda * j;
scalar_t z = alpha * x[j * incx];
opmath_t z = alpha * static_cast<opmath_t>(x[j * incx]);
for (const auto i : c10::irange(m)) {
//output values are ignored if beta is 0, and set to 0, nans and infs are not propagated
if (j==0 && beta==scalar_t(0)) {
y[i * incy] = scalar_t(0);
if (!is_low_precision) {
y[i * incy] = 0;
}
}
if (is_low_precision) {
sum[i] += z * column_[i];
} else {
y[i * incy] += z * column_[i];
}
}
}
if (is_low_precision) {
if (beta == scalar_t(0)) {
for (const auto i : c10::irange(m)) {
y[i * incy] = sum[i];
}
} else {
for (const auto i : c10::irange(m)) {
y[i * incy] += sum[i];
}
y[i * incy] += z * column_[i];
}
}
}
@ -263,11 +287,12 @@ scalar_t dot_naive(
Functor op) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t i;
scalar_t sum = 0;
using opmath_t = at::opmath_type<scalar_t>;
opmath_t sum = 0;
for (i = 0; i < n; i++) {
sum += op(x[i * incx], y[i * incy]);
sum += op(static_cast<opmath_t>(x[i * incx]), static_cast<opmath_t>(y[i * incy]));
}
return sum;
return static_cast<scalar_t>(sum);
}
} // namespace blas_impl

View File

@ -1,23 +1,23 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/native/mkldnn/Matmul.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorUtils.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/native/mkldnn/Matmul.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <c10/util/variant.h>
@ -1533,14 +1533,16 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
int64_t js = result.size(2);
int64_t ks = self.size(2);
scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha = alpha_.to<opmath_t>();
opmath_t beta = beta_.to<opmath_t>();
auto r0 = result.accessor<scalar_t, 3>();
auto s0 = self.accessor<scalar_t, 3>();
auto m0 = mat2.accessor<scalar_t, 3>();
int64_t grain_size = std::min(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
using opmath_t = at::opmath_type<scalar_t>;
parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
for (const auto b : c10::irange(b_begin, b_end)) {
auto r1 = r0[b];
@ -1550,17 +1552,19 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
auto r2 = r1[i];
auto s2 = s1[i];
for (const auto j : c10::irange(js)) {
scalar_t &r = r2[j];
opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]);
for (const auto k : c10::irange(ks)) {
acc_value += static_cast<opmath_t>(s2[k]) *
static_cast<opmath_t>(m1[k][j]);
}
if (is_bmm) {
r = 0;
for (const auto k : c10::irange(ks)) {
r += s2[k] * m1[k][j];
}
r2[j] = acc_value;
} else {
// For beta == 0, the r's value will be ignored, especially for nan value.
r = beta == scalar_t(0) ? scalar_t(0) : beta * r;
for (const auto k : c10::irange(ks)) {
r += alpha * s2[k] * m1[k][j];
if (beta == opmath_t{0}) {
r2[j] = alpha * acc_value;
} else {
r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
}
}
}

View File

@ -53,15 +53,20 @@ auto sum(int64_t N, Func f) {
return partial_sums[0];
}
template <typename scalar_t, typename opmath_t>
void gemm_notrans_(
int64_t m, int64_t n, int64_t k,
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
scalar_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
@ -83,6 +88,37 @@ void gemm_notrans_(
}
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c += alpha * (a @ b)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a[l * lda + i]) *
static_cast<opmath_t>(b[j * ldb + l]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transa_(
int64_t m, int64_t n, int64_t k,
@ -111,13 +147,19 @@ void gemm_transa_(
}
template <typename scalar_t, typename opmath_t>
void gemm_transb_(
int64_t m, int64_t n, int64_t k,
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
scalar_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
@ -139,6 +181,37 @@ void gemm_transb_(
}
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c += alpha * (a @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a[l * lda + i]) *
static_cast<opmath_t>(b[l * ldb + j]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transab_(
int64_t m, int64_t n, int64_t k,
@ -173,13 +246,19 @@ void gemm_core_(
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) {
if (transa == TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if(transa == TransposeType::Transpose && transb != TransposeType::Transpose) {
} else if (
transa == TransposeType::Transpose &&
transb != TransposeType::Transpose) {
gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if(transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) {
} else if (
transa == TransposeType::NoTranspose &&
transb == TransposeType::Transpose) {
gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else { // transa == TransposeType::Transpose && transb == TransposeType::Transpose
} else { // transa == TransposeType::Transpose && transb ==
// TransposeType::Transpose
gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}

View File

@ -7347,6 +7347,53 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
c = a.permute(0, 1, 3, 2).matmul(b)
self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720])
def test_bfloat16_accumulation_with_ref_path(self):
# fix https://github.com/pytorch/pytorch/issues/95125
# and https://github.com/pytorch/pytorch/issues/83863
# for bf16 accumulation in gemm ref path
def check_correctness(fn, *args):
expected = fn(*args).bfloat16()
with torch.backends.mkldnn.flags(enabled=False):
def test():
bf16_args = (arg.bfloat16() for arg in args)
tmp_result = fn(*bf16_args)
return tmp_result
c = test()
assert (torch.all(c == expected)), "Incorrect result with\n" \
f"expected: {expected}\n" \
f"got: {c}\n"
# test matmul
for transa in [True, False]:
for transb in [True, False]:
a = torch.ones(300, 300)
b = torch.ones(300, 300)
if transa:
a = a.transpose(0, 1).contiguous().transpose(0, 1)
if transb:
b = b.transpose(0, 1).contiguous().transpose(0, 1)
check_correctness(torch.matmul, a, b)
# test bmm
a = torch.ones(1, 1, 300)
b = torch.ones(1, 300, 1)
check_correctness(torch.bmm, a, b)
# test baddbmm
a = torch.ones(1, 1, 300)
b = torch.ones(1, 300, 1)
c = torch.ones(1, 1, 1)
check_correctness(torch.baddbmm, c, a, b)
# test mv/addmv
for trans in [True, False]:
c = torch.ones(300) * -300
a = torch.ones(300, 300)
if trans:
a = a.transpose(0, 1).contiguous().transpose(0, 1)
b = torch.ones(300)
check_correctness(torch.mv, a, b)
check_correctness(torch.addmv, c, a, b)
# test dot
a = torch.ones(300)
b = torch.ones(300)
check_correctness(torch.dot, a, b)
instantiate_device_type_tests(TestLinalg, globals())

View File

@ -2466,6 +2466,7 @@ class TestSparseCSR(TestCase):
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01})
def test_sparse_mm_reduce_sum(self, device, dtype):
def run_test(m, n, k, nnz, train):
sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)