[inductor] Freeze layouts in FlexAttention (#163434)

Fixes #163300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163434
Approved by: https://github.com/drisspg
ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419
This commit is contained in:
Jason Ansel
2025-09-22 21:41:15 -07:00
committed by PyTorch MergeBot
parent 518c320676
commit ed84e808f0
4 changed files with 114 additions and 3 deletions

View File

@ -21,6 +21,7 @@ from torch._inductor import metrics
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention import SDPBackend
from torch.nn.attention.experimental._paged_attention import PagedAttention
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
@ -4360,6 +4361,89 @@ class GraphModule(torch.nn.Module):
attn_output = mod(q, k, v, mask)
self.assertEqual(attn_output.device, torch.device("cuda:1"))
@supported_platform
@skip_on_cpu
def test_custom_score_mod_layout_freeze(self, device):
torch.manual_seed(0)
class FlexAttentionCPB(nn.Module):
def __init__(self, N: int, R: int, H: int = 4, hidden: int = 32):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(2, hidden),
nn.GELU(),
nn.Linear(hidden, H, bias=False),
)
self.gamma = nn.Parameter(torch.zeros(H))
self.H = H
self._init_tables(N, R)
self.register_buffer(
"r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False
)
def _init_tables(self, N: int, R: int) -> None:
P = N - R
S = int(P**0.5)
assert S * S == P
rng = torch.arange(-(S - 1), S, dtype=torch.float32)
dY, dX = torch.meshgrid(rng, rng, indexing="ij")
rel = torch.stack(
[dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1
).reshape(-1, 2)
rel_table = torch.sign(rel) * torch.log1p(rel.abs())
self.register_buffer("rel_table", rel_table, persistent=False)
yy, xx = torch.arange(S), torch.arange(S)
Y, X = torch.meshgrid(yy, xx, indexing="ij")
flat = torch.stack([Y, X], 0).flatten(1)
d = flat[:, :, None] - flat[:, None, :]
d = d.permute(1, 2, 0).contiguous()
d[:, :, 0] += S - 1
d[:, :, 1] += S - 1
d[:, :, 0] *= 2 * S - 1
l_idx = d.sum(-1).to(torch.long)
idx = torch.full((N, N), 0, dtype=torch.long)
idx[R:, R:] = l_idx
self.register_buffer("idx_table", idx, persistent=False)
def _score_mod(self, mu: torch.Tensor):
bt = self.mlp(self.rel_table)
idx = self.idx_table
mu_q, mu_k = mu.unbind(2)
gam_sig = torch.sigmoid(self.gamma)
def score_mod(score, b, h, q, kv):
has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff)
l2 = idx[q, kv]
bias = bt[l2, h]
w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv])
return score + has_bias.to(score.dtype) * w_gate * bias
return score_mod
def forward(self, q, k, v, mu):
return flex_attention(q, k, v, score_mod=self._score_mod(mu))
dtype = torch.bfloat16 if PLATFORM_SUPPORTS_BF16 else torch.float16
device_obj = torch.device(device)
module = FlexAttentionCPB(N=18, R=2).to(device_obj)
compiled_module = torch.compile(module, backend="inductor", dynamic=False)
q = torch.randn(2, 4, 18, 32, device=device_obj, dtype=dtype)
k = torch.randn_like(q)
v = torch.randn_like(q)
mu = torch.randn(2, 4, 2, 18, device=device_obj)
with torch.no_grad():
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
eager_out = module(q, k, v, mu)
compiled_out = compiled_module(q, k, v, mu)
self.assertEqual(compiled_out.shape, eager_out.shape)
torch.testing.assert_close(
compiled_out.float(), eager_out.float(), atol=2e-2, rtol=2e-2
)
@supported_platform
@skip_on_cpu
@common_utils.parametrize(