Split batch-num-heads grid dim between y and z (#157745)

for #157018

doesn't totally fix the problem but should help alot

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157745
Approved by: https://github.com/Chillee
This commit is contained in:
drisspg
2025-07-07 16:09:00 -07:00
committed by PyTorch MergeBot
parent 39a8f66d59
commit 987314aa96
2 changed files with 42 additions and 12 deletions

View File

@ -4205,6 +4205,36 @@ class GraphModule(torch.nn.Module):
# vanilla compiled vs TMA compiled
torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1)
@supported_platform
@skip_on_cpu
def test_large_batch_heads_grid_dimension(self, device):
B, H, S, D = 22720, 3, 64, 32
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
device=device,
dtype=torch.float16,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True)
out_compiled = flex_compile(query, key, value)
self.assertEqual(out_compiled.shape, (B, H, S, D))
grad_output = torch.randn_like(out_compiled)
out_compiled.backward(grad_output)
self.assertIsNotNone(query.grad)
self.assertIsNotNone(key.grad)
self.assertIsNotNone(value.grad)
self.assertEqual(query.grad.shape, query.shape)
self.assertEqual(key.grad.shape, key.shape)
self.assertEqual(value.grad.shape, value.shape)
class TestBlockMask(InductorTestCase):
def setUp(self):

View File

@ -98,11 +98,11 @@ def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
@SymbolicGridFn
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv):
"""How is this kernel parallelized?
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
We create a grid of (ceil_div(n_queries, query_block_size), batch_size, num_heads)
Each block is responsible for iterating over blocks of keys and values calculating
the final attention output.
"""
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size, q_heads)
def create_placeholder(
@ -390,8 +390,8 @@ compute_flex_attention = r"""
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0)
off_zq = tl.program_id(1) // HQ
off_hq = tl.program_id(1) % HQ
off_zq = tl.program_id(1)
off_hq = tl.program_id(2)
# Setting up the TMA descriptors for Q, K, V
@ -573,8 +573,8 @@ compute_flex_attention = r"""
l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None]
idx_zq = tl.program_id(1) // HQ
idx_hq = tl.program_id(1) % HQ
idx_zq = tl.program_id(1)
idx_hq = tl.program_id(2)
idx_m = offs_m[:, None]
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
@ -583,7 +583,7 @@ compute_flex_attention = r"""
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
if OUTPUT_LOGSUMEXP:
off_hz = tl.program_id(1)
off_hz = off_zq * HQ + off_hq
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE:
@ -1572,6 +1572,7 @@ def flex_attention_backward_grid(
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
):
"""How is this kernel parallelized?
We create a grid of (ceil_div(n_queries, query_block_size) * heads_ratio + ceil_div(n_kv, kv_block_size), batch_size, kv_heads)
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
@ -1581,8 +1582,8 @@ def flex_attention_backward_grid(
return (
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
1,
batch_size * kv_heads,
batch_size,
kv_heads,
)
@ -1647,9 +1648,8 @@ flex_attention_backward_template = TritonTemplate(
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2)
off_zq = off_hz // HKV # q batch idx
off_hkv = off_hz % HKV # kv head idx
off_zq = tl.program_id(1) # q batch idx
off_hkv = tl.program_id(2) # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}