CP: make correct_attn_out robust to 4‑D views and fix Triton arg binding (#26509)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li
2025-10-11 13:50:57 -07:00
committed by GitHub
parent 5be7ca1b99
commit 0cd103e7cb

View File

@ -117,14 +117,52 @@ def correct_attn_out(
if ctx is None:
ctx = CPTritonContext()
lse = torch.empty_like(lses[0])
# --- Normalize to 3D views ---
if out.ndim == 4 and out.shape[1] == 1:
out = out.squeeze(1)
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
grid = (out.shape[0], out.shape[1], 1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank)
const_args = {
"HEAD_DIM": out.shape[-1],
"N_ROUNDED": lses.shape[0],
}
if lses.ndim == 4 and lses.shape[-1] == 1:
lses = lses.squeeze(-1)
if lses.ndim == 4 and lses.shape[1] == 1:
lses = lses.squeeze(1)
assert lses.ndim == 3, (
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
f"got {tuple(lses.shape)}"
)
B, H, D = out.shape
N = lses.shape[0]
# Strides after we normalized shapes to 3-D views. The kernel computes
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
# have the same B/H stride layout as a slice of `lses`.
o_sB, o_sH, o_sD = out.stride()
l_sN, l_sB, l_sH = lses.stride()
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
lse = torch.empty_strided(
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
)
# Kernel launch config
grid = (B, H, 1)
regular_args = (
out,
out,
lses,
lse,
o_sB,
o_sH,
o_sD,
l_sN,
l_sB,
l_sH,
cp_rank,
)
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
return out, lse