[inductor] Freeze layouts in FlexAttention (#163434)

Fixes #163300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163434
Approved by: https://github.com/drisspg
ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419
This commit is contained in:
Jason Ansel
2025-09-22 21:41:15 -07:00
committed by PyTorch MergeBot
parent 518c320676
commit ed84e808f0
4 changed files with 114 additions and 3 deletions

View File

@ -21,6 +21,7 @@ from torch._inductor import metrics
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention import SDPBackend
from torch.nn.attention.experimental._paged_attention import PagedAttention
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
@ -4360,6 +4361,89 @@ class GraphModule(torch.nn.Module):
attn_output = mod(q, k, v, mask)
self.assertEqual(attn_output.device, torch.device("cuda:1"))
@supported_platform
@skip_on_cpu
def test_custom_score_mod_layout_freeze(self, device):
torch.manual_seed(0)
class FlexAttentionCPB(nn.Module):
def __init__(self, N: int, R: int, H: int = 4, hidden: int = 32):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(2, hidden),
nn.GELU(),
nn.Linear(hidden, H, bias=False),
)
self.gamma = nn.Parameter(torch.zeros(H))
self.H = H
self._init_tables(N, R)
self.register_buffer(
"r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False
)
def _init_tables(self, N: int, R: int) -> None:
P = N - R
S = int(P**0.5)
assert S * S == P
rng = torch.arange(-(S - 1), S, dtype=torch.float32)
dY, dX = torch.meshgrid(rng, rng, indexing="ij")
rel = torch.stack(
[dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1
).reshape(-1, 2)
rel_table = torch.sign(rel) * torch.log1p(rel.abs())
self.register_buffer("rel_table", rel_table, persistent=False)
yy, xx = torch.arange(S), torch.arange(S)
Y, X = torch.meshgrid(yy, xx, indexing="ij")
flat = torch.stack([Y, X], 0).flatten(1)
d = flat[:, :, None] - flat[:, None, :]
d = d.permute(1, 2, 0).contiguous()
d[:, :, 0] += S - 1
d[:, :, 1] += S - 1
d[:, :, 0] *= 2 * S - 1
l_idx = d.sum(-1).to(torch.long)
idx = torch.full((N, N), 0, dtype=torch.long)
idx[R:, R:] = l_idx
self.register_buffer("idx_table", idx, persistent=False)
def _score_mod(self, mu: torch.Tensor):
bt = self.mlp(self.rel_table)
idx = self.idx_table
mu_q, mu_k = mu.unbind(2)
gam_sig = torch.sigmoid(self.gamma)
def score_mod(score, b, h, q, kv):
has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff)
l2 = idx[q, kv]
bias = bt[l2, h]
w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv])
return score + has_bias.to(score.dtype) * w_gate * bias
return score_mod
def forward(self, q, k, v, mu):
return flex_attention(q, k, v, score_mod=self._score_mod(mu))
dtype = torch.bfloat16 if PLATFORM_SUPPORTS_BF16 else torch.float16
device_obj = torch.device(device)
module = FlexAttentionCPB(N=18, R=2).to(device_obj)
compiled_module = torch.compile(module, backend="inductor", dynamic=False)
q = torch.randn(2, 4, 18, 32, device=device_obj, dtype=dtype)
k = torch.randn_like(q)
v = torch.randn_like(q)
mu = torch.randn(2, 4, 2, 18, device=device_obj)
with torch.no_grad():
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
eager_out = module(q, k, v, mu)
compiled_out = compiled_module(q, k, v, mu)
self.assertEqual(compiled_out.shape, eager_out.shape)
torch.testing.assert_close(
compiled_out.float(), eager_out.float(), atol=2e-2, rtol=2e-2
)
@supported_platform
@skip_on_cpu
@common_utils.parametrize(

View File

@ -11,7 +11,7 @@ import sympy
import torch
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_map, tree_map_only
from ...ir import (
ComputedBuffer,
@ -173,6 +173,22 @@ def maybe_realize(args: list[Optional[IRNode]]):
)
def freeze_irnodes(tree: Any) -> Any:
"""Freeze layouts for every IRNode contained in a pytree."""
if tree is None:
return None
def _freeze(node: IRNode) -> IRNode:
try:
node.freeze_layout()
except NotImplementedError:
pass
return node
return tree_map_only(IRNode, _freeze, tree)
def create_placeholder(
name: str,
dtype: torch.dtype,

View File

@ -26,6 +26,7 @@ from .common import (
create_indices_fake,
create_num_blocks_fake_generator,
create_placeholder,
freeze_irnodes,
get_fwd_subgraph_outputs,
infer_dense_strides,
load_template,
@ -149,6 +150,7 @@ def flex_attention(
subgraph_buffer = build_subgraph_buffer(
placeholder_inps + list(score_mod_other_buffers), subgraph
)
freeze_irnodes(subgraph_buffer)
mask_graph_placeholder_inps = [
create_placeholder(name, dtype, query.get_device())
@ -162,6 +164,7 @@ def flex_attention(
mask_graph_buffer = build_subgraph_buffer(
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
)
freeze_irnodes(mask_graph_buffer)
kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards.
@ -218,6 +221,9 @@ def flex_attention(
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
freeze_irnodes(score_mod_other_buffers)
freeze_irnodes(mask_mod_other_buffers)
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
@ -625,6 +631,7 @@ def flex_attention_backward(*args, **kwargs):
fw_subgraph_buffer = build_subgraph_buffer(
fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
)
freeze_irnodes(fw_subgraph_buffer)
joint_placeholder_inps = fwd_placeholder_inps + [
create_placeholder("grad_score_mod", dtype, device)
@ -640,6 +647,7 @@ def flex_attention_backward(*args, **kwargs):
joint_placeholder_inps + list(score_mod_other_buffers),
joint_graph,
)
freeze_irnodes(all_joint_outputs)
joint_outputs = process_joint_outputs(
all_joint_outputs, len(joint_placeholder_inps)
@ -657,8 +665,7 @@ def flex_attention_backward(*args, **kwargs):
mask_graph_buffer = build_subgraph_buffer(
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
)
mask_graph_buffer = mask_graph_buffer
freeze_irnodes(mask_graph_buffer)
# Construct layout with stride order matching K
key_size = [Bq, Hkv, seq_len_kv, qk_head_dim]

View File

@ -20,6 +20,7 @@ from ...select_algorithm import (
from .common import (
create_indices_fake,
create_num_blocks_fake_generator,
freeze_irnodes,
get_fwd_subgraph_outputs,
load_template,
maybe_realize,
@ -208,6 +209,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
freeze_irnodes(score_mod_other_buffers)
freeze_irnodes(mask_mod_other_buffers)
choices: list[Any] = []
dtype = key.get_dtype()
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])