mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[inductor] Freeze layouts in FlexAttention (#163434)
Fixes #163300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163434 Approved by: https://github.com/drisspg ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419
This commit is contained in:
committed by
PyTorch MergeBot
parent
518c320676
commit
ed84e808f0
@ -21,6 +21,7 @@ from torch._inductor import metrics
|
||||
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.nn.attention import SDPBackend
|
||||
from torch.nn.attention.experimental._paged_attention import PagedAttention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_create_empty_block_mask,
|
||||
@ -4360,6 +4361,89 @@ class GraphModule(torch.nn.Module):
|
||||
attn_output = mod(q, k, v, mask)
|
||||
self.assertEqual(attn_output.device, torch.device("cuda:1"))
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_custom_score_mod_layout_freeze(self, device):
|
||||
torch.manual_seed(0)
|
||||
|
||||
class FlexAttentionCPB(nn.Module):
|
||||
def __init__(self, N: int, R: int, H: int = 4, hidden: int = 32):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(2, hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden, H, bias=False),
|
||||
)
|
||||
self.gamma = nn.Parameter(torch.zeros(H))
|
||||
self.H = H
|
||||
self._init_tables(N, R)
|
||||
self.register_buffer(
|
||||
"r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False
|
||||
)
|
||||
|
||||
def _init_tables(self, N: int, R: int) -> None:
|
||||
P = N - R
|
||||
S = int(P**0.5)
|
||||
assert S * S == P
|
||||
rng = torch.arange(-(S - 1), S, dtype=torch.float32)
|
||||
dY, dX = torch.meshgrid(rng, rng, indexing="ij")
|
||||
rel = torch.stack(
|
||||
[dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1
|
||||
).reshape(-1, 2)
|
||||
rel_table = torch.sign(rel) * torch.log1p(rel.abs())
|
||||
self.register_buffer("rel_table", rel_table, persistent=False)
|
||||
|
||||
yy, xx = torch.arange(S), torch.arange(S)
|
||||
Y, X = torch.meshgrid(yy, xx, indexing="ij")
|
||||
flat = torch.stack([Y, X], 0).flatten(1)
|
||||
d = flat[:, :, None] - flat[:, None, :]
|
||||
d = d.permute(1, 2, 0).contiguous()
|
||||
d[:, :, 0] += S - 1
|
||||
d[:, :, 1] += S - 1
|
||||
d[:, :, 0] *= 2 * S - 1
|
||||
l_idx = d.sum(-1).to(torch.long)
|
||||
idx = torch.full((N, N), 0, dtype=torch.long)
|
||||
idx[R:, R:] = l_idx
|
||||
self.register_buffer("idx_table", idx, persistent=False)
|
||||
|
||||
def _score_mod(self, mu: torch.Tensor):
|
||||
bt = self.mlp(self.rel_table)
|
||||
idx = self.idx_table
|
||||
mu_q, mu_k = mu.unbind(2)
|
||||
gam_sig = torch.sigmoid(self.gamma)
|
||||
|
||||
def score_mod(score, b, h, q, kv):
|
||||
has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff)
|
||||
l2 = idx[q, kv]
|
||||
bias = bt[l2, h]
|
||||
w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv])
|
||||
return score + has_bias.to(score.dtype) * w_gate * bias
|
||||
|
||||
return score_mod
|
||||
|
||||
def forward(self, q, k, v, mu):
|
||||
return flex_attention(q, k, v, score_mod=self._score_mod(mu))
|
||||
|
||||
dtype = torch.bfloat16 if PLATFORM_SUPPORTS_BF16 else torch.float16
|
||||
device_obj = torch.device(device)
|
||||
module = FlexAttentionCPB(N=18, R=2).to(device_obj)
|
||||
compiled_module = torch.compile(module, backend="inductor", dynamic=False)
|
||||
|
||||
q = torch.randn(2, 4, 18, 32, device=device_obj, dtype=dtype)
|
||||
k = torch.randn_like(q)
|
||||
v = torch.randn_like(q)
|
||||
mu = torch.randn(2, 4, 2, 18, device=device_obj)
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||
eager_out = module(q, k, v, mu)
|
||||
compiled_out = compiled_module(q, k, v, mu)
|
||||
|
||||
self.assertEqual(compiled_out.shape, eager_out.shape)
|
||||
torch.testing.assert_close(
|
||||
compiled_out.float(), eager_out.float(), atol=2e-2, rtol=2e-2
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
@common_utils.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user