[CPU] Fix memory access for sbgemm bf16 (#156585)

Fixes #156022.

1. The original dtype conversion overwrites the whole `n_*ldc_` instead of `n_*m_` with stride `ldc_`, causing the potential memory issue.
2. Fix the None value issue in attention backward UT, as the sbgemm bf16 could be used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156585
Approved by: https://github.com/mingfeima, https://github.com/aditew01, https://github.com/ezyang
This commit is contained in:
Valentine233
2025-07-08 02:36:24 +00:00
committed by PyTorch MergeBot
parent 12f9942b10
commit f56bfb3030
2 changed files with 34 additions and 26 deletions

View File

@ -358,18 +358,25 @@ void gemm(
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
float alpha_ = alpha, beta_ = beta;
int c_size = n_ * ldc_;
int c_size = n_ * m_;
// C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back.
std::vector<float> float_v(c, c + c_size);
std::vector<float> float_v(c_size, 0.0f);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
float_v[j * m_ + i] = c10::convert<float>(c[j * ldc_ + i]);
}
}
sbgemm_(&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
a, &lda_,
b, &ldb_,
&beta_,
float_v.data(), &ldc_);
for (auto cv: float_v) {
*(c++) = c10::convert<at::BFloat16>(cv);
float_v.data(), &m_);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
c[j * ldc_ + i] = c10::convert<at::BFloat16>(float_v[j * m_ + i]);
}
}
return;
}

View File

@ -2108,19 +2108,28 @@ class TestSDPACpuOnly(NNTestCase):
tol = Tolerances(5e-2, 5e-2)
if dtype is torch.float16:
tol = Tolerances(1e-2, 1e-2)
tol_grad = Tolerances(1e-5, 5e-6)
if dtype is torch.bfloat16:
tol_grad = Tolerances(5e-2, 5e-2)
if dtype is torch.float16:
tol_grad = Tolerances(1e-1, 1e-1)
for mask_shape in itertools.product(
[q_seq_len, 1], [kv_seq_len, 1]
) if mask_dim == 2 else itertools.product(
[batch_size, 1], [n_head, 1], [q_seq_len, 1], [kv_seq_len, 1]
):
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
q = make_tensor(q_shape)
k = make_tensor(kv_shape)
v = make_tensor(kv_shape)
q2, k2, v2 = q.clone(), k.clone(), v.clone()
def sdpa_helper():
torch.manual_seed(777)
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
q = make_tensor(q_shape).transpose(1, 2)
k = make_tensor(kv_shape).transpose(1, 2)
v = make_tensor(kv_shape).transpose(1, 2)
return q, k, v
q, k, v = sdpa_helper()
q2, k2, v2 = sdpa_helper()
if train:
q.requires_grad_(True)
k.requires_grad_(True)
@ -2129,12 +2138,6 @@ class TestSDPACpuOnly(NNTestCase):
k2.requires_grad_(True)
v2.requires_grad_(True)
if dtype in [torch.bfloat16, torch.float16]:
q2, k2, v2 = q2.float(), k2.float(), v2.float()
# (B, nh, T, hs)
q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
if set_attn_mask and not casual:
if bool_mask:
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
@ -2142,16 +2145,11 @@ class TestSDPACpuOnly(NNTestCase):
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
else:
attn_mask = None
q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
with sdpa_kernel(backends=[fused_kernel]):
actual = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=casual)
with sdpa_kernel(backends=[SDPBackend.MATH]):
if not bool_mask and dtype in [torch.bfloat16, torch.float16] and attn_mask is not None:
attn_mask = attn_mask.float()
math_ref = torch.nn.functional.scaled_dot_product_attention(
q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=casual)
@ -2170,9 +2168,12 @@ class TestSDPACpuOnly(NNTestCase):
grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
self.assertFalse(grad_q_actual is None)
self.assertFalse(grad_k_actual is None)
self.assertFalse(grad_v_actual is None)
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol_grad.atol, rtol=tol_grad.rtol)
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol_grad.atol, rtol=tol_grad.rtol)
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol_grad.atol, rtol=tol_grad.rtol)
def test_sdpa_with_inf(self, device):
# https://github.com/pytorch/pytorch/issues/127055.