mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Freeze layouts in FlexAttention (#163434)
Fixes #163300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163434 Approved by: https://github.com/drisspg ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419
This commit is contained in:
committed by
PyTorch MergeBot
parent
518c320676
commit
ed84e808f0
@ -21,6 +21,7 @@ from torch._inductor import metrics
|
||||
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.nn.attention import SDPBackend
|
||||
from torch.nn.attention.experimental._paged_attention import PagedAttention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_create_empty_block_mask,
|
||||
@ -4360,6 +4361,89 @@ class GraphModule(torch.nn.Module):
|
||||
attn_output = mod(q, k, v, mask)
|
||||
self.assertEqual(attn_output.device, torch.device("cuda:1"))
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_custom_score_mod_layout_freeze(self, device):
|
||||
torch.manual_seed(0)
|
||||
|
||||
class FlexAttentionCPB(nn.Module):
|
||||
def __init__(self, N: int, R: int, H: int = 4, hidden: int = 32):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(2, hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden, H, bias=False),
|
||||
)
|
||||
self.gamma = nn.Parameter(torch.zeros(H))
|
||||
self.H = H
|
||||
self._init_tables(N, R)
|
||||
self.register_buffer(
|
||||
"r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False
|
||||
)
|
||||
|
||||
def _init_tables(self, N: int, R: int) -> None:
|
||||
P = N - R
|
||||
S = int(P**0.5)
|
||||
assert S * S == P
|
||||
rng = torch.arange(-(S - 1), S, dtype=torch.float32)
|
||||
dY, dX = torch.meshgrid(rng, rng, indexing="ij")
|
||||
rel = torch.stack(
|
||||
[dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1
|
||||
).reshape(-1, 2)
|
||||
rel_table = torch.sign(rel) * torch.log1p(rel.abs())
|
||||
self.register_buffer("rel_table", rel_table, persistent=False)
|
||||
|
||||
yy, xx = torch.arange(S), torch.arange(S)
|
||||
Y, X = torch.meshgrid(yy, xx, indexing="ij")
|
||||
flat = torch.stack([Y, X], 0).flatten(1)
|
||||
d = flat[:, :, None] - flat[:, None, :]
|
||||
d = d.permute(1, 2, 0).contiguous()
|
||||
d[:, :, 0] += S - 1
|
||||
d[:, :, 1] += S - 1
|
||||
d[:, :, 0] *= 2 * S - 1
|
||||
l_idx = d.sum(-1).to(torch.long)
|
||||
idx = torch.full((N, N), 0, dtype=torch.long)
|
||||
idx[R:, R:] = l_idx
|
||||
self.register_buffer("idx_table", idx, persistent=False)
|
||||
|
||||
def _score_mod(self, mu: torch.Tensor):
|
||||
bt = self.mlp(self.rel_table)
|
||||
idx = self.idx_table
|
||||
mu_q, mu_k = mu.unbind(2)
|
||||
gam_sig = torch.sigmoid(self.gamma)
|
||||
|
||||
def score_mod(score, b, h, q, kv):
|
||||
has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff)
|
||||
l2 = idx[q, kv]
|
||||
bias = bt[l2, h]
|
||||
w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv])
|
||||
return score + has_bias.to(score.dtype) * w_gate * bias
|
||||
|
||||
return score_mod
|
||||
|
||||
def forward(self, q, k, v, mu):
|
||||
return flex_attention(q, k, v, score_mod=self._score_mod(mu))
|
||||
|
||||
dtype = torch.bfloat16 if PLATFORM_SUPPORTS_BF16 else torch.float16
|
||||
device_obj = torch.device(device)
|
||||
module = FlexAttentionCPB(N=18, R=2).to(device_obj)
|
||||
compiled_module = torch.compile(module, backend="inductor", dynamic=False)
|
||||
|
||||
q = torch.randn(2, 4, 18, 32, device=device_obj, dtype=dtype)
|
||||
k = torch.randn_like(q)
|
||||
v = torch.randn_like(q)
|
||||
mu = torch.randn(2, 4, 2, 18, device=device_obj)
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||
eager_out = module(q, k, v, mu)
|
||||
compiled_out = compiled_module(q, k, v, mu)
|
||||
|
||||
self.assertEqual(compiled_out.shape, eager_out.shape)
|
||||
torch.testing.assert_close(
|
||||
compiled_out.float(), eager_out.float(), atol=2e-2, rtol=2e-2
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
@common_utils.parametrize(
|
||||
|
@ -11,7 +11,7 @@ import sympy
|
||||
import torch
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_map, tree_map_only
|
||||
|
||||
from ...ir import (
|
||||
ComputedBuffer,
|
||||
@ -173,6 +173,22 @@ def maybe_realize(args: list[Optional[IRNode]]):
|
||||
)
|
||||
|
||||
|
||||
def freeze_irnodes(tree: Any) -> Any:
|
||||
"""Freeze layouts for every IRNode contained in a pytree."""
|
||||
|
||||
if tree is None:
|
||||
return None
|
||||
|
||||
def _freeze(node: IRNode) -> IRNode:
|
||||
try:
|
||||
node.freeze_layout()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return node
|
||||
|
||||
return tree_map_only(IRNode, _freeze, tree)
|
||||
|
||||
|
||||
def create_placeholder(
|
||||
name: str,
|
||||
dtype: torch.dtype,
|
||||
|
@ -26,6 +26,7 @@ from .common import (
|
||||
create_indices_fake,
|
||||
create_num_blocks_fake_generator,
|
||||
create_placeholder,
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
infer_dense_strides,
|
||||
load_template,
|
||||
@ -149,6 +150,7 @@ def flex_attention(
|
||||
subgraph_buffer = build_subgraph_buffer(
|
||||
placeholder_inps + list(score_mod_other_buffers), subgraph
|
||||
)
|
||||
freeze_irnodes(subgraph_buffer)
|
||||
|
||||
mask_graph_placeholder_inps = [
|
||||
create_placeholder(name, dtype, query.get_device())
|
||||
@ -162,6 +164,7 @@ def flex_attention(
|
||||
mask_graph_buffer = build_subgraph_buffer(
|
||||
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
||||
)
|
||||
freeze_irnodes(mask_graph_buffer)
|
||||
|
||||
kernel_options = dict(kernel_options)
|
||||
# Mark symbols in custom kernel options as static shapes and add guards.
|
||||
@ -218,6 +221,9 @@ def flex_attention(
|
||||
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
|
||||
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
|
||||
|
||||
freeze_irnodes(score_mod_other_buffers)
|
||||
freeze_irnodes(mask_mod_other_buffers)
|
||||
|
||||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||
@ -625,6 +631,7 @@ def flex_attention_backward(*args, **kwargs):
|
||||
fw_subgraph_buffer = build_subgraph_buffer(
|
||||
fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
|
||||
)
|
||||
freeze_irnodes(fw_subgraph_buffer)
|
||||
|
||||
joint_placeholder_inps = fwd_placeholder_inps + [
|
||||
create_placeholder("grad_score_mod", dtype, device)
|
||||
@ -640,6 +647,7 @@ def flex_attention_backward(*args, **kwargs):
|
||||
joint_placeholder_inps + list(score_mod_other_buffers),
|
||||
joint_graph,
|
||||
)
|
||||
freeze_irnodes(all_joint_outputs)
|
||||
|
||||
joint_outputs = process_joint_outputs(
|
||||
all_joint_outputs, len(joint_placeholder_inps)
|
||||
@ -657,8 +665,7 @@ def flex_attention_backward(*args, **kwargs):
|
||||
mask_graph_buffer = build_subgraph_buffer(
|
||||
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
||||
)
|
||||
|
||||
mask_graph_buffer = mask_graph_buffer
|
||||
freeze_irnodes(mask_graph_buffer)
|
||||
|
||||
# Construct layout with stride order matching K
|
||||
key_size = [Bq, Hkv, seq_len_kv, qk_head_dim]
|
||||
|
@ -20,6 +20,7 @@ from ...select_algorithm import (
|
||||
from .common import (
|
||||
create_indices_fake,
|
||||
create_num_blocks_fake_generator,
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_template,
|
||||
maybe_realize,
|
||||
@ -208,6 +209,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
||||
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
|
||||
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
|
||||
|
||||
freeze_irnodes(score_mod_other_buffers)
|
||||
freeze_irnodes(mask_mod_other_buffers)
|
||||
|
||||
choices: list[Any] = []
|
||||
dtype = key.get_dtype()
|
||||
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
|
||||
|
Reference in New Issue
Block a user