mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CPU] Fix memory access for sbgemm bf16 (#156585)
Fixes #156022. 1. The original dtype conversion overwrites the whole `n_*ldc_` instead of `n_*m_` with stride `ldc_`, causing the potential memory issue. 2. Fix the None value issue in attention backward UT, as the sbgemm bf16 could be used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156585 Approved by: https://github.com/mingfeima, https://github.com/aditew01, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
12f9942b10
commit
f56bfb3030
@ -358,18 +358,25 @@ void gemm(
|
||||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
|
||||
char transa_ = to_blas(transa), transb_ = to_blas(transb);
|
||||
float alpha_ = alpha, beta_ = beta;
|
||||
int c_size = n_ * ldc_;
|
||||
int c_size = n_ * m_;
|
||||
// C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back.
|
||||
std::vector<float> float_v(c, c + c_size);
|
||||
std::vector<float> float_v(c_size, 0.0f);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
float_v[j * m_ + i] = c10::convert<float>(c[j * ldc_ + i]);
|
||||
}
|
||||
}
|
||||
sbgemm_(&transa_, &transb_,
|
||||
&m_, &n_, &k_,
|
||||
&alpha_,
|
||||
a, &lda_,
|
||||
b, &ldb_,
|
||||
&beta_,
|
||||
float_v.data(), &ldc_);
|
||||
for (auto cv: float_v) {
|
||||
*(c++) = c10::convert<at::BFloat16>(cv);
|
||||
float_v.data(), &m_);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
c[j * ldc_ + i] = c10::convert<at::BFloat16>(float_v[j * m_ + i]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -2108,19 +2108,28 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
tol = Tolerances(5e-2, 5e-2)
|
||||
if dtype is torch.float16:
|
||||
tol = Tolerances(1e-2, 1e-2)
|
||||
tol_grad = Tolerances(1e-5, 5e-6)
|
||||
if dtype is torch.bfloat16:
|
||||
tol_grad = Tolerances(5e-2, 5e-2)
|
||||
if dtype is torch.float16:
|
||||
tol_grad = Tolerances(1e-1, 1e-1)
|
||||
for mask_shape in itertools.product(
|
||||
[q_seq_len, 1], [kv_seq_len, 1]
|
||||
) if mask_dim == 2 else itertools.product(
|
||||
[batch_size, 1], [n_head, 1], [q_seq_len, 1], [kv_seq_len, 1]
|
||||
):
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
|
||||
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
|
||||
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
|
||||
q = make_tensor(q_shape)
|
||||
k = make_tensor(kv_shape)
|
||||
v = make_tensor(kv_shape)
|
||||
q2, k2, v2 = q.clone(), k.clone(), v.clone()
|
||||
def sdpa_helper():
|
||||
torch.manual_seed(777)
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
|
||||
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
|
||||
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
|
||||
q = make_tensor(q_shape).transpose(1, 2)
|
||||
k = make_tensor(kv_shape).transpose(1, 2)
|
||||
v = make_tensor(kv_shape).transpose(1, 2)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = sdpa_helper()
|
||||
q2, k2, v2 = sdpa_helper()
|
||||
if train:
|
||||
q.requires_grad_(True)
|
||||
k.requires_grad_(True)
|
||||
@ -2129,12 +2138,6 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
k2.requires_grad_(True)
|
||||
v2.requires_grad_(True)
|
||||
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
q2, k2, v2 = q2.float(), k2.float(), v2.float()
|
||||
# (B, nh, T, hs)
|
||||
q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
if set_attn_mask and not casual:
|
||||
if bool_mask:
|
||||
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
|
||||
@ -2142,16 +2145,11 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
|
||||
else:
|
||||
attn_mask = None
|
||||
q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
|
||||
with sdpa_kernel(backends=[fused_kernel]):
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=casual)
|
||||
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
if not bool_mask and dtype in [torch.bfloat16, torch.float16] and attn_mask is not None:
|
||||
attn_mask = attn_mask.float()
|
||||
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=casual)
|
||||
|
||||
@ -2170,9 +2168,12 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
|
||||
grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
|
||||
|
||||
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
self.assertFalse(grad_q_actual is None)
|
||||
self.assertFalse(grad_k_actual is None)
|
||||
self.assertFalse(grad_v_actual is None)
|
||||
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol_grad.atol, rtol=tol_grad.rtol)
|
||||
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol_grad.atol, rtol=tol_grad.rtol)
|
||||
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol_grad.atol, rtol=tol_grad.rtol)
|
||||
|
||||
def test_sdpa_with_inf(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/127055.
|
||||
|
Reference in New Issue
Block a user