mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use accumulate type in BF16 gemm(include dot, mv) ref path (#96074)
Fix https://github.com/pytorch/pytorch/issues/95125 and https://github.com/pytorch/pytorch/issues/83863 for bf16 accumulation in gemm ref path Pull Request resolved: https://github.com/pytorch/pytorch/pull/96074 Approved by: https://github.com/lezcano, https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
b45880c537
commit
fe0afc5852
@ -1,13 +1,14 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy);
|
||||
extern "C" void dscal_(int *n, double *a, double *x, int *incx);
|
||||
@ -180,9 +181,10 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t
|
||||
return;
|
||||
}
|
||||
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
if ((trans == 'T') || (trans == 't')) {
|
||||
for (const auto i : c10::irange(n)) {
|
||||
scalar_t sum = 0;
|
||||
opmath_t sum = 0;
|
||||
scalar_t *row_ = a + lda * i;
|
||||
for (const auto j : c10::irange(m)) {
|
||||
sum += x[j * incx] * row_[j];
|
||||
@ -196,15 +198,37 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t
|
||||
} else {
|
||||
if (beta != scalar_t(1) && beta != scalar_t(0)) scal<scalar_t>(m, beta, y, incy);
|
||||
|
||||
bool is_low_precision = !std::is_same<opmath_t, scalar_t>::value;
|
||||
std::vector<opmath_t> sum;
|
||||
if (is_low_precision) {
|
||||
sum.resize(m);
|
||||
}
|
||||
for (const auto j : c10::irange(n)) {
|
||||
scalar_t *column_ = a + lda * j;
|
||||
scalar_t z = alpha * x[j * incx];
|
||||
opmath_t z = alpha * static_cast<opmath_t>(x[j * incx]);
|
||||
for (const auto i : c10::irange(m)) {
|
||||
//output values are ignored if beta is 0, and set to 0, nans and infs are not propagated
|
||||
if (j==0 && beta==scalar_t(0)) {
|
||||
y[i * incy] = scalar_t(0);
|
||||
if (!is_low_precision) {
|
||||
y[i * incy] = 0;
|
||||
}
|
||||
}
|
||||
if (is_low_precision) {
|
||||
sum[i] += z * column_[i];
|
||||
} else {
|
||||
y[i * incy] += z * column_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_low_precision) {
|
||||
if (beta == scalar_t(0)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
y[i * incy] = sum[i];
|
||||
}
|
||||
} else {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
y[i * incy] += sum[i];
|
||||
}
|
||||
y[i * incy] += z * column_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -263,11 +287,12 @@ scalar_t dot_naive(
|
||||
Functor op) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t i;
|
||||
scalar_t sum = 0;
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t sum = 0;
|
||||
for (i = 0; i < n; i++) {
|
||||
sum += op(x[i * incx], y[i * incy]);
|
||||
sum += op(static_cast<opmath_t>(x[i * incx]), static_cast<opmath_t>(y[i * incy]));
|
||||
}
|
||||
return sum;
|
||||
return static_cast<scalar_t>(sum);
|
||||
}
|
||||
|
||||
} // namespace blas_impl
|
||||
|
||||
@ -1,23 +1,23 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/native/mkldnn/Matmul.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorIndexing.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/LinearAlgebra.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/ReduceOps.h>
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorIndexing.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
#include <ATen/native/mkldnn/Matmul.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/variant.h>
|
||||
@ -1533,14 +1533,16 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
|
||||
int64_t js = result.size(2);
|
||||
int64_t ks = self.size(2);
|
||||
|
||||
scalar_t alpha = alpha_.to<scalar_t>();
|
||||
scalar_t beta = beta_.to<scalar_t>();
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha = alpha_.to<opmath_t>();
|
||||
opmath_t beta = beta_.to<opmath_t>();
|
||||
|
||||
auto r0 = result.accessor<scalar_t, 3>();
|
||||
auto s0 = self.accessor<scalar_t, 3>();
|
||||
auto m0 = mat2.accessor<scalar_t, 3>();
|
||||
|
||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
|
||||
for (const auto b : c10::irange(b_begin, b_end)) {
|
||||
auto r1 = r0[b];
|
||||
@ -1550,17 +1552,19 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
|
||||
auto r2 = r1[i];
|
||||
auto s2 = s1[i];
|
||||
for (const auto j : c10::irange(js)) {
|
||||
scalar_t &r = r2[j];
|
||||
opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]);
|
||||
for (const auto k : c10::irange(ks)) {
|
||||
acc_value += static_cast<opmath_t>(s2[k]) *
|
||||
static_cast<opmath_t>(m1[k][j]);
|
||||
}
|
||||
if (is_bmm) {
|
||||
r = 0;
|
||||
for (const auto k : c10::irange(ks)) {
|
||||
r += s2[k] * m1[k][j];
|
||||
}
|
||||
r2[j] = acc_value;
|
||||
} else {
|
||||
// For beta == 0, the r's value will be ignored, especially for nan value.
|
||||
r = beta == scalar_t(0) ? scalar_t(0) : beta * r;
|
||||
for (const auto k : c10::irange(ks)) {
|
||||
r += alpha * s2[k] * m1[k][j];
|
||||
if (beta == opmath_t{0}) {
|
||||
r2[j] = alpha * acc_value;
|
||||
} else {
|
||||
r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,15 +53,20 @@ auto sum(int64_t N, Func f) {
|
||||
return partial_sums[0];
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
void gemm_notrans_(
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
|
||||
gemm_notrans_(
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
opmath_t alpha,
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
const scalar_t* a,
|
||||
int64_t lda,
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
scalar_t* c,
|
||||
int64_t ldc) {
|
||||
// c *= beta
|
||||
scale_(m, n, beta, c, ldc);
|
||||
|
||||
@ -83,6 +88,37 @@ void gemm_notrans_(
|
||||
}
|
||||
}
|
||||
|
||||
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
|
||||
gemm_notrans_(
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
opmath_t alpha,
|
||||
const scalar_t* a,
|
||||
int64_t lda,
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
int64_t ldc) {
|
||||
// c += alpha * (a @ b)
|
||||
for (const auto i : c10::irange(m)) {
|
||||
for (const auto j : c10::irange(n)) {
|
||||
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
|
||||
return static_cast<opmath_t>(a[l * lda + i]) *
|
||||
static_cast<opmath_t>(b[j * ldb + l]);
|
||||
});
|
||||
if (beta == opmath_t(0)) {
|
||||
c[j * ldc + i] = alpha * dot;
|
||||
} else {
|
||||
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
void gemm_transa_(
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -111,13 +147,19 @@ void gemm_transa_(
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
void gemm_transb_(
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
|
||||
gemm_transb_(
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
opmath_t alpha,
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
const scalar_t* a,
|
||||
int64_t lda,
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
scalar_t* c,
|
||||
int64_t ldc) {
|
||||
// c *= beta
|
||||
scale_(m, n, beta, c, ldc);
|
||||
|
||||
@ -139,6 +181,37 @@ void gemm_transb_(
|
||||
}
|
||||
}
|
||||
|
||||
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
|
||||
gemm_transb_(
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
opmath_t alpha,
|
||||
const scalar_t* a,
|
||||
int64_t lda,
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
int64_t ldc) {
|
||||
// c += alpha * (a @ b.T)
|
||||
for (const auto i : c10::irange(m)) {
|
||||
for (const auto j : c10::irange(n)) {
|
||||
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
|
||||
return static_cast<opmath_t>(a[l * lda + i]) *
|
||||
static_cast<opmath_t>(b[l * ldb + j]);
|
||||
});
|
||||
if (beta == opmath_t(0)) {
|
||||
c[j * ldc + i] = alpha * dot;
|
||||
} else {
|
||||
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
void gemm_transab_(
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -173,13 +246,19 @@ void gemm_core_(
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) {
|
||||
if (transa == TransposeType::NoTranspose &&
|
||||
transb == TransposeType::NoTranspose) {
|
||||
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
} else if(transa == TransposeType::Transpose && transb != TransposeType::Transpose) {
|
||||
} else if (
|
||||
transa == TransposeType::Transpose &&
|
||||
transb != TransposeType::Transpose) {
|
||||
gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
} else if(transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) {
|
||||
} else if (
|
||||
transa == TransposeType::NoTranspose &&
|
||||
transb == TransposeType::Transpose) {
|
||||
gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
} else { // transa == TransposeType::Transpose && transb == TransposeType::Transpose
|
||||
} else { // transa == TransposeType::Transpose && transb ==
|
||||
// TransposeType::Transpose
|
||||
gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
}
|
||||
}
|
||||
|
||||
@ -7347,6 +7347,53 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
c = a.permute(0, 1, 3, 2).matmul(b)
|
||||
self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720])
|
||||
|
||||
def test_bfloat16_accumulation_with_ref_path(self):
|
||||
# fix https://github.com/pytorch/pytorch/issues/95125
|
||||
# and https://github.com/pytorch/pytorch/issues/83863
|
||||
# for bf16 accumulation in gemm ref path
|
||||
def check_correctness(fn, *args):
|
||||
expected = fn(*args).bfloat16()
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
def test():
|
||||
bf16_args = (arg.bfloat16() for arg in args)
|
||||
tmp_result = fn(*bf16_args)
|
||||
return tmp_result
|
||||
c = test()
|
||||
assert (torch.all(c == expected)), "Incorrect result with\n" \
|
||||
f"expected: {expected}\n" \
|
||||
f"got: {c}\n"
|
||||
# test matmul
|
||||
for transa in [True, False]:
|
||||
for transb in [True, False]:
|
||||
a = torch.ones(300, 300)
|
||||
b = torch.ones(300, 300)
|
||||
if transa:
|
||||
a = a.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
if transb:
|
||||
b = b.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
check_correctness(torch.matmul, a, b)
|
||||
# test bmm
|
||||
a = torch.ones(1, 1, 300)
|
||||
b = torch.ones(1, 300, 1)
|
||||
check_correctness(torch.bmm, a, b)
|
||||
# test baddbmm
|
||||
a = torch.ones(1, 1, 300)
|
||||
b = torch.ones(1, 300, 1)
|
||||
c = torch.ones(1, 1, 1)
|
||||
check_correctness(torch.baddbmm, c, a, b)
|
||||
# test mv/addmv
|
||||
for trans in [True, False]:
|
||||
c = torch.ones(300) * -300
|
||||
a = torch.ones(300, 300)
|
||||
if trans:
|
||||
a = a.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
b = torch.ones(300)
|
||||
check_correctness(torch.mv, a, b)
|
||||
check_correctness(torch.addmv, c, a, b)
|
||||
# test dot
|
||||
a = torch.ones(300)
|
||||
b = torch.ones(300)
|
||||
check_correctness(torch.dot, a, b)
|
||||
|
||||
instantiate_device_type_tests(TestLinalg, globals())
|
||||
|
||||
|
||||
@ -2466,6 +2466,7 @@ class TestSparseCSR(TestCase):
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
def test_sparse_mm_reduce_sum(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, train):
|
||||
sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user