Files
pytorch/test/inductor/test_flex_decoding.py
PyTorch MergeBot 5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e6026925f6b6e8a36e3a8b7c0295cd7541911.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00

2100 lines
67 KiB
Python

# Owner(s): ["module: inductor"]
# flake8: noqa: B950
import functools
import sys
import unittest
from collections import namedtuple
from collections.abc import Callable
from typing import Optional, Union
from unittest import expectedFailure
from unittest.mock import patch
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention.experimental._paged_attention import PagedAttention
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
_identity,
BlockMask,
create_block_mask,
flex_attention,
noop_mask,
)
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, with_tf32_off
from torch.testing._internal.common_device_type import (
flex_attention_supported_platform as supported_platform,
instantiate_device_type_tests,
skipXPUIf,
)
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton_tma_device
if IS_WINDOWS and IS_CI:
# TODO(xuhancn) : Need track if it is a requirement on windows.
sys.stderr.write("This UT is validated on windows, a lot of crash. Skip it.\n")
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("skip on Windows")
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:
torch.set_float32_matmul_precision("high")
index = torch.ops.aten.index
Tensor = torch.Tensor
TEST_ON_CUDA = (
torch.cuda.is_available()
and torch.utils._triton.has_triton()
and torch.cuda.get_device_capability() >= (8, 0)
)
TEST_ON_XPU = torch.xpu.is_available() and torch.utils._triton.has_triton()
if HAS_GPU:
if TEST_ON_CUDA:
test_device = ("cuda",)
test_dtypes = (
[torch.float32, torch.bfloat16, torch.float16]
if PLATFORM_SUPPORTS_BF16
else [torch.float16, torch.float32]
)
test_dtypes_fast = [torch.float16]
SKIP_UT_ON_CPU = False
elif TEST_ON_XPU:
torch._C._set_onednn_allow_tf32(True)
test_device = ("xpu",)
test_dtypes = [torch.float32, torch.bfloat16, torch.float16]
test_dtypes_fast = [torch.float16]
SKIP_UT_ON_CPU = False
else:
test_device = ("cpu",)
torch_config_string = torch.__config__.show()
SKIP_UT_ON_CPU = True
LONG_COMPILATION_ON_CPU = False
if "CLANG" in torch_config_string.upper():
# if the compiler is clang, skip UT for CPU due to long compilation time found in CI
# TODO: check reason of long compile time
LONG_COMPILATION_ON_CPU = True
test_dtypes = (
[torch.float32, torch.bfloat16]
if torch.backends.mkldnn.is_available()
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
else [torch.float32]
)
test_dtypes_fast = [torch.float32]
def skip_on_xpu(test_func):
"""Decorator to skip tests that are not supported on Intel GPU."""
decorated_func = skipXPUIf(True, "Not supported on Intel GPU")(test_func)
return decorated_func
def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=None):
return functools.partial(
flex_attention,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
kernel_options=kernel_options,
)
def create_block_mask_test(score_mod, query, key):
block_mask = create_block_mask(
score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device
)
return block_mask
test_page_sizes = [64, 128, 256]
# --------- Useful score mod functions for testing ---------
def _causal(
score: Tensor,
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
return torch.where(token_q >= token_kv, score, float("-inf"))
def _generate_windowed(offset):
def _windowed(score, b, h, q, kv):
return torch.where(q + offset >= kv, score, float("-inf"))
return _windowed
def _get_windowed_sdpa_mask(Mq, Mkv, offset):
return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device=test_device[0]))[
offset : offset + Mq
]
def _rel_bias(
score: Tensor,
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
return score + (token_q - token_kv)
def _rel_causal(
score: Tensor,
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf"))
def _generate_alibi_bias(num_heads: int):
def _alibi_bias(
score: Tensor,
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
scale = torch.exp2(-((head + 1) * 8.0 / num_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
def _inverse_causal(score, b, h, m, n):
return torch.where(m <= n, score, float("-inf"))
def _times_two(score, b, h, m, n):
"""Joint graph needed for correctness"""
return score * 2
def _squared(score, b, h, m, n):
"""Joint graph needed for correctness"""
return score * score
def _head_offset(dtype: torch.dtype):
"""Captured Buffer"""
head_offset = torch.rand(Hq, device=test_device[0], dtype=dtype)
def score_mod(score, b, h, m, n):
return score * head_offset[h]
return score_mod
def _trig(score, b, h, m, n):
"""Joint graph needed for correctness"""
return torch.sin(torch.cos(score)) + torch.tan(b)
def _trig2(score, b, h, m, n):
"""Branching joint graph"""
cos_score = torch.cos(score)
sin_score = torch.sin(score)
z = cos_score * sin_score + torch.tan(b)
return z
test_score_mods = [
_identity,
_times_two,
_squared,
_causal,
_inverse_causal,
_rel_bias,
_rel_causal,
_generate_alibi_bias(8),
_generate_windowed(1000),
]
captured_buffers_map = {
"_head_offset": _head_offset,
}
B = 4
S = 2048
D = 64
test_Hq_Hkv = [
(16, 1),
(8, 2),
(16, 16),
]
test_Bq_Bkv = [
(3, 1),
(5, 1),
(8, 1),
(16, 1),
]
test_block_size = [
64,
128,
(1, 64),
(128, 64),
]
(Hq, Hkv) = (16, 8)
def input_strides_1(B, H, S, D):
return ((H * S * D, S * D, D, 1), 997) # offset
def input_strides_2(B, H, S, D):
return ((H * D, D, B * H * D, 1), 499) # transposed dimensions
def input_strides_3(B, H, S, D):
return ((S * (D + 1), B * S * (D + 1), (D + 1), 1), 293) # additional buffer
def input_strides_4(B, H, S, D):
return ((1, D, (B + 1) * (H + 1) * D, 1), 97) # shared dimension
test_input_strides = [
input_strides_1,
input_strides_2,
input_strides_3,
input_strides_4,
]
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.detach().clone().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.detach().clone().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.detach().clone().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
(B,) = target_seq_len.shape
for b in range(B):
paged_attention.reserve(
torch.tensor(b),
target_seq_len[b],
)
class TestFlexDecoding(InductorTestCase):
def setUp(self):
super().setUp()
self.test_inference_only = False
if test_device[0] == "cpu":
if LONG_COMPILATION_ON_CPU:
self.skipTest(
"skip UT for CPU due to long compilation time found in CI"
)
self.test_inference_only = True
def _check_equal(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
fudge_factor: float,
tensor_name: Optional[str] = None,
):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
self.assertTrue(False, "Output/Grad with NaN")
if ref_error < (1e-4) * golden_out.abs().mean():
print(
"very small ref error of ",
(ref_error.to(torch.float64) * (1e5) / golden_out.abs().mean()),
)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(
golden_out.to(dtype=compiled_out.dtype),
compiled_out,
atol=tolerance.atol,
rtol=tolerance.rtol,
)
elif compiled_error > ref_error * fudge_factor:
name = tensor_name if tensor_name is not None else ""
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def _check_out(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
):
dtype = ref_out.dtype
with torch.no_grad():
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
# Checkout output
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
def run_test(
self,
score_mod: Optional[Callable] = None,
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = Hq,
Q_S: int = 1,
Q_D: int = D,
KV_B: int = B,
KV_H: int = Hkv,
KV_S: int = S,
V_D: int = D,
block_mask: Optional[BlockMask] = None,
device="cuda",
kernel_options=None,
):
assert score_mod is not None or block_mask is not None, (
"Must provide score_mod or block_mask"
)
assert Q_H % KV_H == 0
if device == "cpu" and dtype is torch.float16:
dtype = torch.float32
q = torch.randn(
(Q_B, Q_H, Q_S, Q_D),
dtype=dtype,
device=device,
requires_grad=not self.test_inference_only,
)
k = torch.randn(
(KV_B, KV_H, KV_S, Q_D),
dtype=dtype,
device=device,
requires_grad=not self.test_inference_only,
)
v = torch.randn(
(KV_B, KV_H, KV_S, V_D),
dtype=dtype,
device=device,
requires_grad=not self.test_inference_only,
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
sdpa_partial = create_attention(
score_mod,
block_mask,
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
compiled_sdpa = torch.compile(sdpa_partial)
if not self.test_inference_only:
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
self._check_out(
gold_lse,
ref_lse,
compiled_lse,
)
else:
golden_out = sdpa_partial(q_gold, k_gold, v_gold, return_lse=False)
ref_out = sdpa_partial(q_ref, k_ref, v_ref, return_lse=False)
compiled_out = compiled_sdpa(q, k, v, return_lse=False)
self._check_out(
golden_out,
ref_out,
compiled_out,
)
def run_test_with_call(
self,
sdpa_call: Callable,
golden_call: Optional[Callable] = None,
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = Hq,
Q_S: int = 1,
Q_D: int = D,
KV_B: int = B,
KV_H: int = Hkv,
KV_S: int = S,
V_D: int = D,
device="cuda",
):
if not golden_call:
golden_call = sdpa_call
if device == "cpu" and dtype is torch.float16:
dtype = torch.float32
q = torch.randn(
(Q_B, KV_H, Q_S, Q_D),
dtype=dtype,
device=device,
requires_grad=False,
)
k = torch.randn(
(KV_B, KV_H, KV_S, Q_D),
dtype=dtype,
device=device,
requires_grad=False,
)
v = torch.randn(
(KV_B, KV_H, KV_S, V_D),
dtype=dtype,
device=device,
requires_grad=False,
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
compiled_sdpa = torch.compile(sdpa_call)
golden_out = golden_call(q_gold, k_gold, v_gold)
ref_out = golden_call(q_ref, k_ref, v_ref)
compiled_out = compiled_sdpa(q, k, v)
self._check_out(
golden_out,
ref_out,
compiled_out,
)
def preprocess_paged_attention(
self,
score_mod: Optional[Callable],
q: Tensor,
k: Tensor,
v: Tensor,
block_mask,
dtype: torch.dtype = torch.float16,
page_size: int = 128,
device="cuda",
):
assert block_mask is not None, "Must provide block_mask"
if device == "cpu" and dtype is torch.float16:
dtype = torch.float32
Q_B, Q_H, Q_S, _ = q.shape
KV_B, KV_H, KV_S, QK_D = k.shape
_, _, _, V_D = v.shape
# test with different batch size
max_batch_size = max(Q_B, KV_B) + 3
n_pages = (KV_S + page_size - 1) // page_size * max_batch_size
# allocate cache
MAX_CACHED_SEQ_LEN = n_pages * page_size
k_cache = torch.zeros(
1,
KV_H,
MAX_CACHED_SEQ_LEN,
QK_D,
device=device,
dtype=dtype,
)
v_cache = torch.zeros(
1,
KV_H,
MAX_CACHED_SEQ_LEN,
V_D,
device=device,
dtype=dtype,
)
# "randomly" initialize the page table
paged_attention = PagedAttention(
n_pages, page_size, max_batch_size, device=device
)
batch_reserve(
paged_attention,
torch.tensor([KV_S // 4, KV_S // 2, KV_S // 4, KV_S // 3], device=device),
)
batch_reserve(
paged_attention,
torch.tensor([KV_S // 4, KV_S // 2, KV_S // 2, KV_S // 2], device=device),
)
batch_reserve(
paged_attention,
torch.tensor([KV_S // 2, KV_S, KV_S // 2, KV_S], device=device),
)
batch_reserve(
paged_attention, torch.tensor([KV_S, KV_S, KV_S, KV_S], device=device)
)
# update cache with k and v
input_pos = (
torch.arange(KV_S, device=device, dtype=torch.int32)
.unsqueeze(0)
.expand(KV_B, KV_S)
)
batch_idx = torch.arange(KV_B, device=device, dtype=torch.int32)
paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
# convert block mask and score mod
kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64)
converted_block_mask = paged_attention.convert_logical_block_mask(
block_mask, kv_len=kv_len_tensor
)
converted_score_mod = paged_attention.get_score_mod(
score_mod, kv_len=kv_len_tensor
)
return k_cache, v_cache, converted_block_mask, converted_score_mod
def run_paged_attention(
self,
score_mod: Optional[Callable],
q: Tensor,
k: Tensor,
v: Tensor,
dtype: torch.dtype = torch.float16,
block_mask: Optional[BlockMask] = None,
device="cuda",
):
Q_B, Q_H, KV_H = q.shape[0], q.shape[1], k.shape[1]
if device == "cpu" and dtype is torch.float16:
dtype = torch.float32
if block_mask is None:
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, S, device=device)
(
k_cache,
v_cache,
converted_block_mask,
converted_score_mod,
) = self.preprocess_paged_attention(
score_mod, q, k, v, block_mask, dtype, block_mask.BLOCK_SIZE[1], device
)
compiled_sdpa = torch.compile(flex_attention)
# compute
if not self.test_inference_only:
compiled_out, compiled_lse = compiled_sdpa(
q,
k_cache,
v_cache,
return_lse=True,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H),
)
else:
compiled_lse = None
compiled_out = compiled_sdpa(
q,
k_cache,
v_cache,
return_lse=False,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H),
)
return compiled_out, compiled_lse
def run_test_with_paged_attention(
self,
score_mod: Optional[Callable],
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = Hq,
Q_S: int = 1,
QK_D: int = D,
KV_B: int = B,
KV_H: int = Hkv,
KV_S: int = S,
V_D: int = D,
block_mask: Optional[BlockMask] = None,
device="cuda",
):
assert Q_H % KV_H == 0
if device == "cpu" and dtype is torch.float16:
dtype = torch.float32
q = torch.randn(
(Q_B, Q_H, Q_S, QK_D),
dtype=dtype,
device=device,
requires_grad=False,
)
k = torch.randn(
(KV_B, KV_H, KV_S, QK_D),
dtype=dtype,
device=device,
requires_grad=False,
)
v = torch.randn(
(KV_B, KV_H, KV_S, V_D),
dtype=dtype,
device=device,
requires_grad=False,
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
if block_mask is None:
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device)
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
compiled_out, compiled_lse = self.run_paged_attention(
score_mod, q, k, v, dtype, block_mask, device
)
self._check_out(
golden_out,
ref_out,
compiled_out,
)
if not self.test_inference_only:
self._check_out(
gold_lse,
ref_lse,
compiled_lse,
)
def run_test_with_call_paged_attention(
self,
score_mod: Optional[Callable],
mask_mod: Optional[Callable],
sdpa_mask: Tensor,
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = Hq,
Q_S: int = 1,
Q_D: int = D,
KV_B: int = B,
KV_H: int = Hkv,
KV_S: int = S,
V_D: int = D,
device="cuda",
):
if device == "cpu" and dtype is torch.float16:
dtype = torch.float32
q = torch.randn(
(Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D),
dtype=dtype,
device=device,
requires_grad=False,
)
k = torch.randn(
(KV_B, KV_H, KV_S, Q_D),
dtype=dtype,
device=device,
requires_grad=False,
)
v = torch.randn(
(KV_B, KV_H, KV_S, V_D),
dtype=dtype,
device=device,
requires_grad=False,
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
golden_call = functools.partial(
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
golden_out = golden_call(q_gold, k_gold, v_gold)
ref_out = golden_call(q_ref, k_ref, v_ref)
if mask_mod is not None:
block_mask = create_block_mask(mask_mod, Q_B, 1, Q_S, KV_S, device=device)
else:
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
compiled_out, _ = self.run_paged_attention(
score_mod, q, k, v, dtype, block_mask, device
)
self._check_out(
golden_out,
ref_out,
compiled_out,
)
@supported_platform
@expectedFailure # tl.dot does not support embedding size less than 16
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_bw_decoding_fails(self, device, dtype):
make_kv = functools.partial(
torch.randn,
(2, 2, 128, 4),
dtype=dtype,
device=device,
requires_grad=True,
)
make_q = functools.partial(
torch.randn,
(2, 2, 8, 4),
dtype=dtype,
device=device,
requires_grad=True,
)
q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
block_mask = _create_empty_block_mask(q, k)
@torch.compile
def sdpa_hop(q, k, v, score_mod, block_mask):
return flex_attention(q, k, v, score_mod)
output = sdpa_hop(q, k, v, _identity, block_mask)
output.backward(backward_grad)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
@common_utils.parametrize("head_dims", test_Hq_Hkv)
@with_tf32_off
def test_builtin_score_mods(
self, device, dtype: torch.dtype, score_mod: Callable, head_dims
):
Hq, Hkv = head_dims
assert Hq % Hkv == 0
self.run_test(score_mod, dtype, Q_H=Hq, KV_H=Hkv, device=device)
self.run_test_with_paged_attention(
score_mod, dtype, Q_H=Hq, KV_H=Hkv, device=device
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mod", test_score_mods)
@common_utils.parametrize("head_dims", test_Hq_Hkv)
@common_utils.parametrize("page_size", test_page_sizes)
def test_paged_attention_page_size(
self,
device,
dtype: torch.dtype,
score_mod: Callable,
head_dims: tuple[int, int],
page_size: int,
):
Hq, Hkv = head_dims
assert Hq % Hkv == 0
def generate_causal_offset(offset: torch.Tensor):
def causal_offset_mask(b, h, q_idx, kv_idx):
return (offset + q_idx) >= kv_idx
return causal_offset_mask
mod = generate_causal_offset(
torch.tensor(192, device=device, dtype=torch.int32)
)
block_mask = create_block_mask(
mod, B, 1, 1, S, BLOCK_SIZE=page_size, device=device
)
self.run_test_with_paged_attention(
score_mod,
dtype,
Q_B=B,
Q_H=Hq,
KV_B=B,
KV_H=Hkv,
KV_S=S,
block_mask=block_mask,
device=device,
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
@common_utils.parametrize("BLOCK_SIZE", test_block_size)
def test_builtin_score_mods_different_block_size(
self,
device,
dtype: torch.dtype,
score_mod: Callable,
BLOCK_SIZE: Union[int, tuple[int, int]],
):
block_mask = create_block_mask(
noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE, device=device
)
self.run_test(score_mod, dtype, block_mask=block_mask, device=device)
@unittest.skipIf(not has_triton_tma_device(), "Skip when TMA is not available")
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_tma_decoding(self, device, dtype: torch.dtype):
n_heads, head_dim, seq_len = 4, 16, 128
score_mod = _generate_alibi_bias(n_heads)
kernel_options = {"USE_TMA": True}
self.run_test(
score_mod=score_mod,
dtype=dtype,
Q_B=1,
Q_H=n_heads,
Q_S=1,
Q_D=head_dim,
KV_B=1,
KV_H=n_heads,
KV_S=seq_len,
V_D=head_dim,
device=device,
kernel_options=kernel_options,
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("k_s", test_input_strides)
@common_utils.parametrize("v_s", test_input_strides)
@common_utils.parametrize("head_dims", test_Hq_Hkv)
def test_strided_inputs(self, device, dtype: torch.dtype, k_s, v_s, head_dims):
Hq, Hkv = head_dims
assert Hq % Hkv == 0
q1 = torch.randn((B * Hq * D), dtype=dtype, device=device)
k1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device=device)
v1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device=device)
k_shape = (B, Hkv, S, D)
v_shape = (B, Hkv, S, D)
q = q1.view(1, Hq, B, D).transpose(0, 2)
k_strides, k_offset = k_s(B, Hkv, S, D)
k_max = [x * (y - 1) for x, y in zip(k_strides, k_shape)]
assert sum(k_max) + k_offset < B * Hkv * S * D * 4
assert k_strides[-1] == 1
k = torch.as_strided(k1, k_shape, k_strides, k_offset)
v_strides, v_offset = v_s(B, Hkv, S, D)
v_max = [x * (y - 1) for x, y in zip(v_strides, v_shape)]
assert sum(v_max) + v_offset < B * Hkv * S * D * 4
assert v_strides[-1] == 1
v = torch.as_strided(v1, v_shape, v_strides, v_offset)
score_mod = _generate_alibi_bias(8)
sdpa_partial = create_attention(
score_mod=score_mod,
block_mask=None,
enable_gqa=(not Hq == Hkv),
)
compiled_sdpa = torch.compile(sdpa_partial)
ref_out = sdpa_partial(q, k, v)
compiled_out = compiled_sdpa(q, k, v)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(
ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
)
paged_compiled_out, _ = self.run_paged_attention(
score_mod, q, k, v, dtype, device=device
)
torch.testing.assert_close(
ref_out, paged_compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("head_dims", test_Hq_Hkv)
@common_utils.parametrize("batch_dims", test_Bq_Bkv)
@common_utils.parametrize("score_mod", test_score_mods)
def test_kv_batch_broadcast(
self,
device,
dtype: torch.dtype,
head_dims: tuple[int, int],
batch_dims: tuple[int, int],
score_mod: Callable,
):
Hq, Hkv = head_dims
assert Hq % Hkv == 0
Bq, Bkv = batch_dims
assert Bq > 1 and Bkv == 1
block_mask = create_block_mask(noop_mask, Bq, 1, 1, S, device=device)
self.run_test(
score_mod, dtype, Bq, Hq, 1, D, Bkv, Hkv, S, D, block_mask, device=device
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, device, dtype: torch.dtype):
def score_mod(score, b, h, q, kv):
return torch.where(kv % 2 == 0, score, float("-inf"))
self.run_test(score_mod, dtype, device=device)
self.run_test_with_paged_attention(score_mod, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_function_composition(self, device, dtype: torch.dtype):
def score_mod_1(score, b, h, m, n):
return score + (m - n)
def score_mod_2(score, b, h, m, n):
return torch.where(m <= n, score, float("-inf"))
def composed_score_mod(score, b, h, m, n):
return score_mod_2(score_mod_1(score, b, h, m, n), b, h, m, n)
self.run_test(composed_score_mod, dtype, device=device)
self.run_test_with_paged_attention(composed_score_mod, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers(self, device, dtype: torch.dtype):
head_offset = torch.rand(Hq, device=device, dtype=dtype)
def score_mod(score, b, h, m, n):
return score + head_offset[h]
self.run_test(score_mod, dtype, device=device)
self.run_test_with_paged_attention(score_mod, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers_all_dims(self, device, dtype: torch.dtype):
head_scale = torch.randn(Hq, device=device)
batch_scale = torch.randn(B, device=device)
kv_scale = torch.randn(S, device=device)
q_scale = torch.randn(1, device=device)
def all_bias(score, batch, head, token_q, token_kv):
score = score + kv_scale[token_kv]
score = score + q_scale[token_q]
score = score + head_scale[head]
score = score + batch_scale[batch]
return score
self.run_test(all_bias, dtype, device=device)
self.run_test_with_paged_attention(all_bias, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_seq_masking(self, device, dtype):
seq_idx = torch.zeros(S, device=device, dtype=torch.bool)
seq_idx[S // 2 :] = 1
def seq_mask_mod(score, b, h, q, kv):
return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
self.run_test(seq_mask_mod, dtype, device=device)
self.run_test_with_paged_attention(seq_mask_mod, dtype, device=device)
@supported_platform
def test_non_divisible_offset_mask(self, device):
KV_S = S - 3
offset_tensor = torch.tensor(S // 2 - 3, device=device, dtype=torch.int32)
def mask_mod(b, h, q, kv):
return kv >= q + offset_tensor
block_mask = create_block_mask(mask_mod, B, 1, 1, KV_S, device=device)
self.run_test(KV_S=KV_S, block_mask=block_mask, device=device)
@supported_platform
def test_non_divisible_offset_mask_with_captured_buffer(self, device):
KV_S = S - 3
offset_kv = torch.randn(KV_S, device=device, dtype=torch.bfloat16)
offset_tensor = torch.tensor(S // 2 - 3, device=device, dtype=torch.int32)
def score_mod(score, b, h, q, kv):
return score + offset_kv[kv]
def mask_mod(b, h, q, kv):
return kv >= q + offset_tensor
block_mask = create_block_mask(mask_mod, B, 1, 1, KV_S, device=device)
self.run_test(
KV_S=KV_S, block_mask=block_mask, score_mod=score_mod, device=device
)
@supported_platform
def test_non_divisible_multi_token_offset_mask(self, device):
KV_S = S - 3
Q_S = 3
offset_tensor = torch.tensor(S // 2 - 1, device=device, dtype=torch.int32)
def mask_mod(b, h, q, kv):
return kv >= q + offset_tensor
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device)
self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask, device=device)
@supported_platform
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self, device):
KV_S = S - 3
Q_S = 3
offset_kv = torch.randn(KV_S, device=device, dtype=torch.bfloat16)
offset_q = torch.randn(Q_S, device=device, dtype=torch.bfloat16)
offset_tensor = torch.tensor(S // 2 - 3, device=device, dtype=torch.int32)
def score_mod(score, b, h, q, kv):
return score + offset_kv[kv] + offset_q[q]
def mask_mod(b, h, q, kv):
return kv >= q + offset_tensor
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device)
self.run_test(
Q_S=Q_S,
KV_S=KV_S,
block_mask=block_mask,
score_mod=score_mod,
device=device,
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_only(self, device, dtype):
bias = torch.randn(1, S, device=device, dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + bias[q, kv]
self.run_test(bias_mod, dtype, device=device)
self.run_test_with_paged_attention(bias_mod, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_batch(self, device, dtype):
bias = torch.randn(B, 1, S, device=device, dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + bias[b, q, kv]
self.run_test(bias_mod, dtype, device=device)
self.run_test_with_paged_attention(bias_mod, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_head_seq_batch(self, device, dtype):
bias = torch.randn(
B,
Hq,
1,
S,
device=device,
dtype=dtype,
)
def bias_mod(score, b, h, q, kv):
return score + bias[b, h, q, kv]
self.run_test(bias_mod, dtype, device=device)
self.run_test_with_paged_attention(bias_mod, dtype, device=device)
@supported_platform
@common_utils.parametrize("score_mod", test_score_mods)
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
@with_tf32_off
def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims):
qk_d, v_d = head_dims
self.run_test(
score_mod, dtype, B, Hq, 1, qk_d, B, Hkv, S, V_D=v_d, device=device
)
self.run_test_with_paged_attention(
score_mod, dtype, B, Hq, 1, qk_d, B, Hkv, S, V_D=v_d, device=device
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mod", test_score_mods)
@common_utils.parametrize("head_dims", test_Hq_Hkv)
def test_head_dependent_mask_mod(
self, device, dtype: torch.dtype, score_mod, head_dims
):
Hq, Hkv = head_dims
assert Hq % Hkv == 0
def head_attention_mod(kv_head_num):
head_type = torch.tensor(
[False if i % kv_head_num == 0 else True for i in range(kv_head_num)],
dtype=torch.bool,
device=device,
)
def mask_mod(b, h, q_idx, kv_idx):
bi_mask = head_type[h]
causal_mask = q_idx >= kv_idx
return bi_mask & causal_mask
return mask_mod
mask_mod = head_attention_mod(Hq)
mask = create_block_mask(mask_mod, 1, Hq, 1, S, device=device)
self.run_test(
score_mod, dtype, Q_H=Hq, KV_H=Hkv, block_mask=mask, device=device
)
self.run_test_with_paged_attention(
score_mod, dtype, Q_H=Hq, KV_H=Hkv, device=device
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_subgraph_respect_decompostion(self, device, dtype):
from torch._decomp import core_aten_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
def score_mod_func(score, b, h, q, kv):
return score - q // (1 + kv)
make_kv = functools.partial(
torch.randn,
(2, 2, 128, 4),
dtype=dtype,
device=device,
requires_grad=True,
)
make_q = functools.partial(
torch.randn,
(2, 2, 8, 4),
dtype=dtype,
device=device,
requires_grad=True,
)
query, key, value = make_q(), make_kv(), make_kv()
# floor_div is not decomposed in decompostion_table is empty
attention = functools.partial(flex_attention, score_mod=score_mod_func)
gm = make_fx(attention, decomposition_table={})(query, key, value)
self.assertExpectedInline(
gm.sdpa_score0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None
sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None
return sub""",
)
# floor_div is decomposed for core_aten_decompositions
gm = make_fx(attention, decomposition_table=core_aten_decompositions())(
query, key, value
)
self.assertExpectedInline(
gm.sdpa_score0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None
sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None
return sub""",
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_silu_on_score(self, device, dtype):
def silu_score(score, b, h, q, kv):
return torch.nn.functional.silu(score)
self.run_test(silu_score, dtype, device=device)
self.run_test_with_paged_attention(silu_score, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_padded_dense_causal(self, device, dtype):
seq_len = torch.arange(B, device=device, dtype=torch.int32) + 1
def create_padded_dense_wrapper(orig_score_mod):
def njt_score_mod(qk, b, h, q, kv):
return torch.where(
qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
)
return njt_score_mod
causal_njt = create_padded_dense_wrapper(_causal)
self.run_test(causal_njt, dtype, device=device)
self.run_test_with_paged_attention(causal_njt, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_captured_scale(self, device, dtype):
scale = torch.ones((), device=device, dtype=torch.int32)
def score_mod_scale(qk, b, h, q, kv):
return qk + scale
self.run_test(score_mod_scale, dtype, device=device)
self.run_test_with_paged_attention(score_mod_scale, dtype, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_recompile_changed_score_mod(self, device, dtype):
scale = torch.ones((), device=device, dtype=torch.int32)
ADD = True
def score_mod_scale(qk, b, h, q, kv):
if ADD:
return qk + scale
else:
return qk * scale
self.run_test(score_mod_scale, dtype, device=device)
self.run_test_with_paged_attention(score_mod_scale, dtype, device=device)
ADD = False
self.run_test(score_mod_scale, dtype, device=device)
self.run_test_with_paged_attention(score_mod_scale, dtype, device=device)
@supported_platform
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.serialTest()
def test_non_pow_2_headdim(self, device, dtype, head_dim):
self.run_test(
_rel_bias, dtype, B, Hq, S, head_dim, B, Hkv, S, head_dim, device=device
)
@supported_platform
@expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_captured_reduction(self, device, dtype):
scale = torch.randn((B, 8), device=device)
def score_mod_scale(qk, b, h, q, kv):
return qk + scale[b].sum(dim=-1)
self.run_test(score_mod_scale, dtype, device=device)
@supported_platform
def test_multiple_score_mod_calls(self, device):
query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device=device)
keys = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(2)
]
values = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(2)
]
def scoremod_1(qk, b, h, q, kv):
return qk + (q - kv)
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))
def f(q, k1, k2, v1, v2):
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1)
return flex_attention(q2, k2, v2, score_mod=scoremod_2)
out = f(query, *keys, *values)
out2 = torch.compile(f)(query, *keys, *values)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
@supported_platform
def test_multiple_score_mod_calls2(self, device):
query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device=device)
keys = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(3)
]
values = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(3)
]
def scoremod_1(qk, b, h, q, kv):
return qk + (q - kv)
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))
attention1 = functools.partial(flex_attention, score_mod=scoremod_1)
def f(q, k1, k2, k3, v1, v2, v3):
q2 = attention1(q, k1, v1)
q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2)
return flex_attention(q3, k3, v3, score_mod=scoremod_1)
out = f(query, *keys, *values)
out2 = torch.compile(f)(query, *keys, *values)
self.assertTrue((out - out2).abs().mean() < 1e-2)
@supported_platform
def test_multiple_score_mod_calls_paged_attention(self, device):
query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device=device)
keys = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(2)
]
values = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(2)
]
def scoremod_1(qk, b, h, q, kv):
return qk + (q - kv)
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024, device=device)
def f(q, k1, k2, v1, v2):
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1, block_mask=block_mask)
return flex_attention(
q2, k2, v2, score_mod=scoremod_2, block_mask=block_mask
)
eager_out = f(query, *keys, *values)
(
k_cache1,
v_cache1,
converted_block_mask1,
converted_score_mod1,
) = self.preprocess_paged_attention(
scoremod_1,
query,
keys[0],
values[0],
block_mask,
torch.float32,
device=device,
)
(
k_cache2,
v_cache2,
converted_block_mask2,
converted_score_mod2,
) = self.preprocess_paged_attention(
scoremod_2,
query,
keys[1],
values[1],
block_mask,
torch.float32,
device=device,
)
def paged_f(q, k1, k2, v1, v2):
q2 = flex_attention(
q,
k1,
v1,
score_mod=converted_score_mod1,
block_mask=converted_block_mask1,
)
return flex_attention(
q2,
k2,
v2,
score_mod=converted_score_mod2,
block_mask=converted_block_mask2,
)
compiled_out = torch.compile(paged_f)(
query, k_cache1, k_cache2, v_cache1, v_cache2
)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(
eager_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
)
@supported_platform
def test_multiple_score_mod_calls_paged_attention2(self, device):
query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device=device)
keys = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(3)
]
values = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device)
for _ in range(3)
]
def scoremod_1(qk, b, h, q, kv):
return qk + (q - kv)
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024, device=device)
attention1 = functools.partial(
flex_attention, score_mod=scoremod_1, block_mask=block_mask
)
def f(q, k1, k2, k3, v1, v2, v3):
q2 = attention1(q, k1, v1)
q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2, block_mask=block_mask)
return flex_attention(
q3, k3, v3, score_mod=scoremod_1, block_mask=block_mask
)
eager_out = f(query, *keys, *values)
(
k_cache1,
v_cache1,
converted_block_mask1,
converted_score_mod1,
) = self.preprocess_paged_attention(
scoremod_1,
query,
keys[0],
values[0],
block_mask,
torch.float32,
device=device,
)
(
k_cache2,
v_cache2,
converted_block_mask2,
converted_score_mod2,
) = self.preprocess_paged_attention(
scoremod_2,
query,
keys[1],
values[1],
block_mask,
torch.float32,
device=device,
)
(
k_cache3,
v_cache3,
converted_block_mask3,
converted_score_mod3,
) = self.preprocess_paged_attention(
scoremod_1,
query,
keys[2],
values[2],
block_mask,
torch.float32,
device=device,
)
paged_attention1 = functools.partial(
flex_attention,
score_mod=converted_score_mod1,
block_mask=converted_block_mask1,
)
def paged_f(q, k1, k2, k3, v1, v2, v3):
q2 = paged_attention1(q, k1, v1)
q3 = flex_attention(
q2,
k2,
v2,
score_mod=converted_score_mod2,
block_mask=converted_block_mask2,
)
return flex_attention(
q3,
k3,
v3,
score_mod=converted_score_mod3,
block_mask=converted_block_mask3,
)
compiled_out = torch.compile(paged_f)(
query, k_cache1, k_cache2, k_cache3, v_cache1, v_cache2, v_cache3
)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(
eager_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_njt_causal(self, device, dtype):
offsets = torch.tensor(
[0, 1024, 1024 + 512, S], device=device, dtype=torch.int32
)
seq_idx = torch.zeros(S, device=device, dtype=torch.int32)
for idx in range(len(offsets) - 1):
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
def njt_score_mod(qk, b, h, q, kv):
q_nested = q - offsets[seq_idx[q]]
kv_nested = kv - offsets[seq_idx[kv]]
return orig_score_mod(qk, b, h, q_nested, kv_nested)
return njt_score_mod
causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)
self.run_test(causal_njt, dtype, device=device)
self.run_test_with_paged_attention(causal_njt, dtype, device=device)
@supported_platform
def test_mixed_dtypes_fails(self, device):
query = torch.randn((1, 1, 8, 64), dtype=torch.float32, device=device)
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device=device)
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device=device)
with self.assertRaisesRegex(
ValueError, "Expected query, key, and value to have the same dtype"
):
flex_attention(query, key, value, _identity)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune(self, device):
def score_mod(score, b, h, m, n):
return score * 2
self.run_test(score_mod, device=device)
self.run_test_with_paged_attention(score_mod, device=device)
self.run_test_with_paged_attention(
score_mod=score_mod,
dtype=torch.bfloat16,
Q_B=4,
Q_H=1,
Q_S=1,
QK_D=16,
KV_B=4,
KV_H=1,
KV_S=64,
V_D=16,
device=device,
)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self, device):
head_scale = torch.randn(Hq, device=device)
batch_scale = torch.randn(B, device=device)
tok_scale = torch.randn(S, device=device)
q_scale = torch.randn(1, device=device)
def bias_mod(score, batch, head, token_q, token_kv):
score = score + tok_scale[token_kv]
score = score + q_scale[token_q]
score = score + batch_scale[batch]
score = score + head_scale[head]
return score
self.run_test(bias_mod, device=device)
self.run_test_with_paged_attention(bias_mod, device=device)
@supported_platform
def test_fully_masked_out_rows_0_check_gqa(self, device):
# Ensure fully masked out rows won't cause NaNs.
query = torch.randn(
(B, Hq, S, D),
dtype=torch.float32,
device=device,
requires_grad=not self.test_inference_only,
)
key = torch.randn(
(B, Hkv, S, D),
dtype=torch.float32,
device=device,
requires_grad=not self.test_inference_only,
)
value = torch.randn(
(B, Hkv, S, D),
dtype=torch.float32,
device=device,
requires_grad=not self.test_inference_only,
)
M = S // 2
def mask_mod(b, h, q, kv):
return q < M
block_mask = create_block_mask(mask_mod, 1, 1, S, S, device=device)
flex = torch.compile(flex_attention, dynamic=False)
if not self.test_inference_only:
out, lse = flex(
query,
key,
value,
block_mask=block_mask,
enable_gqa=True,
return_lse=True,
)
self.assertTrue((lse[:, :, M:] == -float("inf")).all())
loss = out.sum() + lse.sum()
loss.backward()
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
else:
out = flex(
query,
key,
value,
block_mask=block_mask,
enable_gqa=True,
return_lse=False,
)
self.assertEqual(out[:, :, M:, :].sum(), 0)
@supported_platform
def test_windowed_no_mask_vs_sdpa(self, device):
score_mod = _generate_windowed(1000)
attention = functools.partial(flex_attention, score_mod=score_mod)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
sdpa_attention = functools.partial(
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
self.run_test_with_call(
attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8, device=device
)
@supported_platform
def test_windowed_full_mask_vs_sdpa(self, device):
def mask_mod(b, h, q, kv):
return q + 1000 >= kv
score_mod = _generate_windowed(1000)
block_mask = create_block_mask(mask_mod, 1, 1, 8, S, device=device)
attention = functools.partial(
flex_attention, block_mask=block_mask, score_mod=score_mod
)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
sdpa_attention = functools.partial(
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
self.run_test_with_call(
attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8, device=device
)
@supported_platform
def test_windowed_partial_block_vs_sdpa(self, device):
def mask_mod(b, h, q, kv):
return q + 1000 >= kv
block_mask = create_block_mask(mask_mod, 1, 1, 8, S, device=device)
attention = functools.partial(flex_attention, block_mask=block_mask)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
sdpa_attention = functools.partial(
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
self.run_test_with_call(
attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8, device=device
)
@supported_platform
def test_windowed_no_mask_vs_sdpa_paged_attention(self, device):
score_mod = _generate_windowed(1000)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
self.run_test_with_call_paged_attention(
score_mod, None, sdpa_mask, Q_H=16, KV_H=16, Q_S=8, device=device
)
@supported_platform
def test_windowed_full_mask_vs_sdpa_paged_attention(self, device):
def mask_mod(b, h, q, kv):
return q + 1000 >= kv
score_mod = _generate_windowed(1000)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
self.run_test_with_call_paged_attention(
score_mod, mask_mod, sdpa_mask, Q_H=16, KV_H=16, Q_S=8, device=device
)
@supported_platform
def test_windowed_partial_block_vs_sdpa_paged_attention(self, device):
def mask_mod(b, h, q, kv):
return q + 1000 >= kv
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
self.run_test_with_call_paged_attention(
None, mask_mod, sdpa_mask, Q_H=16, KV_H=16, Q_S=8, device=device
)
@supported_platform
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", [_identity, _causal])
def test_logsumexp_correctness(self, device, dtype, score_mod):
make_kv = functools.partial(
torch.randn,
(B, Hkv, S, D),
dtype=dtype,
device=device,
requires_grad=True,
)
make_q = functools.partial(
torch.randn,
(B, Hkv, Hq // Hkv, D),
dtype=dtype,
device=device,
requires_grad=True,
)
q, k, v = make_q(), make_kv(), make_kv()
@torch.compile
def sdpa_hop(q, k, v, score_mod):
return flex_attention(q, k, v, score_mod, return_lse=True)
@torch.compile(backend="aot_eager")
def eager_sdpa_hop(q, k, v, score_mod):
return flex_attention(q, k, v, score_mod, return_lse=True)
ref_out, ref_lse = eager_sdpa_hop(
q.to(torch.float64),
k.to(torch.float64),
v.to(torch.float64),
score_mod,
)
compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod)
self.assertTrue(ref_lse.dtype == torch.float64)
self.assertTrue(compiled_lse.dtype == torch.float32)
tolerance = Tolerances(atol=2e-2, rtol=2e-2)
torch.testing.assert_close(
ref_out.to(dtype=torch.float32),
compiled_out.to(dtype=torch.float32),
atol=tolerance.atol,
rtol=tolerance.rtol,
)
torch.testing.assert_close(
ref_lse.to(dtype=torch.float32),
compiled_lse.to(dtype=torch.float32),
atol=tolerance.atol,
rtol=tolerance.rtol,
)
@supported_platform
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
def test_not_pw_of_two(self, device):
query = torch.randn(1, 12, 1, 16, device=device)
key = torch.randn(1, 2, 128, 16, device=device)
value = torch.randn(1, 2, 128, 16, device=device)
flex_compiled = torch.compile(flex_attention)
flex_compiled(query, key, value, enable_gqa=True)
@supported_platform
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
def test_logsumexp_only_return(self, device):
make_q = functools.partial(
torch.randn,
(B, Hkv, Hq // Hkv, D),
dtype=torch.float32,
device=device,
requires_grad=True,
)
make_kv = functools.partial(
torch.randn,
(B, Hkv, S, D),
dtype=torch.float32,
device=device,
requires_grad=True,
)
q, k, v = make_q(), make_kv(), make_kv()
@torch.compile
def func(q, k, v, score_mod):
_, lse = flex_attention(q, k, v, score_mod, return_lse=True)
lse_2 = lse * 2
return lse_2
_, code = run_and_get_code(func, q, k, v, _identity)
# Ensure that we're still generating the flexattention kernel
FileCheck().check_count(".run(primals_1, primals_2, primals_3", 1, True).run(
code[0]
)
@supported_platform
@skip_on_xpu # TODO: SYCL acc issue
def test_non_sparse_mulitple_block_size(self, device):
def generate_causal_offset(offset: torch.Tensor):
def causal_offset_mask(b, h, q_idx, kv_idx):
return (offset + q_idx) >= kv_idx
return causal_offset_mask
def noop(score, b, h, q_idx, kv_idx): # noqa: F841
return score
mod = generate_causal_offset(
torch.tensor(192, device=device, dtype=torch.int32)
)
block_mask = create_block_mask(mod, 1, 1, 1, 65, device=device)
self.run_test(
score_mod=None,
dtype=torch.float32,
block_mask=block_mask,
Q_B=1,
Q_H=1,
Q_S=1,
Q_D=16,
KV_B=1,
KV_H=1,
KV_S=65,
V_D=16,
device=device,
)
self.run_test_with_paged_attention(
score_mod=None,
dtype=torch.float32,
block_mask=block_mask,
Q_B=1,
Q_H=1,
Q_S=1,
QK_D=16,
KV_B=1,
KV_H=1,
KV_S=65,
V_D=16,
device=device,
)
@supported_platform
def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self, device):
torch._dynamo.reset()
H = Hq
q = torch.randn(B, H, 1, D, device=device)
for i in range(5):
k = torch.randn(B, H, S + i, D, device=device)
v = torch.randn(B, H, S + i, D, device=device)
compiled_flex_attention = torch.compile(flex_attention)
ref = flex_attention(q, k, v)
res = compiled_flex_attention(q, k, v)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(
ref, res, atol=tolerance.atol, rtol=tolerance.rtol
)
# Ensure no more re-compilation after the second automatic dynamic shape version.
if i == 0:
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
else:
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_larger_block_mask_bug(self, device, dtype):
def mask_mod(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
mask_2 = create_block_mask(
mask_mod=mask_mod,
B=2,
H=None,
Q_LEN=2,
KV_LEN=2,
device=device,
)
# Compile flex attention
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
# Create input tensors
shape = (2, 1, 2, 16)
q = torch.normal(0.0, 3.0, shape, device=device, dtype=dtype)
k = torch.normal(0.0, 3.0, shape, device=device, dtype=dtype)
v = torch.normal(0.0, 3.0, shape, device=device, dtype=dtype)
eager = flex_attention(q, k, v, block_mask=mask_2)
out = flex_attention_compiled(q, k, v, block_mask=mask_2)
torch.testing.assert_close(eager, out, atol=5e-3, rtol=5e-3)
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mod", test_score_mods)
@supported_platform
def test_decode_at_different_input_position(
self, device, dtype: torch.dtype, score_mod: Callable
):
n_pages, page_size, max_batch_size, max_seq_len = 32, 64, 4, 512
n_heads, head_dim = 4, 16
def causal_mask(b, h, q, kv):
return q >= kv
block_mask = create_block_mask(
causal_mask,
max_batch_size,
1,
max_seq_len,
max_seq_len,
device=device,
BLOCK_SIZE=page_size,
)
# init 4 requests with different prefill length
prefill_length = [5, 98, 47, 194]
queries, keys, values = [], [], []
for seq_len in prefill_length:
q = torch.randn(
1,
n_heads,
1,
head_dim,
device=device,
dtype=dtype,
requires_grad=False,
)
k = torch.randn(
1,
n_heads,
seq_len,
head_dim,
device=device,
dtype=dtype,
requires_grad=False,
)
v = torch.randn(
1,
n_heads,
seq_len,
head_dim,
device=device,
dtype=dtype,
requires_grad=False,
)
queries.append(q)
keys.append(k)
values.append(v)
# get ground truth output
ref_outs, golden_outs = [], []
for q, k, v in zip(queries, keys, values):
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
slice_block_mask = block_mask._adjust(1, k_ref.shape[2])
slice_block_mask.seq_lengths = (1, k_ref.shape[2])
ref_out = flex_attention(
q_ref, k_ref, v_ref, score_mod, slice_block_mask, enable_gqa=False
)
golden_out = flex_attention(
q_gold, k_gold, v_gold, score_mod, slice_block_mask, enable_gqa=False
)
ref_outs.append(ref_out)
golden_outs.append(golden_out)
ref_outs = torch.cat(ref_outs)
golden_outs = torch.cat(golden_outs)
# init paged attention
paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device)
batch_reserve(paged_cache, torch.tensor([100, 200, 50, 300], device=device))
batch_reserve(paged_cache, torch.tensor([100, 512, 300, 300], device=device))
batch_reserve(paged_cache, torch.tensor([512, 512, 300, 300], device=device))
batch_reserve(paged_cache, torch.tensor([512, 512, 512, 300], device=device))
batch_reserve(paged_cache, torch.tensor([512, 512, 512, 512], device=device))
# allocate paged kv cache
MAX_CACHED_SEQ_LEN = n_pages * page_size
k_cache = torch.zeros(
1,
n_heads,
MAX_CACHED_SEQ_LEN,
head_dim,
device=device,
dtype=dtype,
)
v_cache = torch.zeros(
1,
n_heads,
MAX_CACHED_SEQ_LEN,
head_dim,
device=device,
dtype=dtype,
)
# prefill paged kv cache
for i, seq_len in enumerate(prefill_length):
batch_idx = torch.tensor([i], device=device, dtype=torch.int32)
input_pos = torch.arange(seq_len, device=device, dtype=torch.int32).view(
1, seq_len
)
paged_cache.assign(
batch_idx, input_pos, keys[i], values[i], k_cache, v_cache
)
# get paged out and check correctness
batch_idx = torch.arange(max_batch_size, device=device, dtype=torch.int32)
input_pos = torch.tensor(prefill_length, device=device, dtype=torch.int32).view(
max_batch_size, 1
)
kv_len_tensor = torch.full(
(max_batch_size,), max_seq_len, device=device, dtype=torch.int64
)
new_block_mask = paged_cache.convert_logical_block_mask(
block_mask, kv_len=kv_len_tensor
)
new_block_mask.seq_lengths = (1, new_block_mask.seq_lengths[1])
compiled_sdpa = torch.compile(
create_attention(
paged_cache.get_score_mod(score_mod, kv_len=kv_len_tensor),
new_block_mask,
enable_gqa=False,
)
)
paged_out = compiled_sdpa(
torch.cat(queries, 0), k_cache, v_cache, block_mask=new_block_mask
)
with torch.no_grad():
dtype = paged_out.dtype
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
# Checkout output
self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
instantiate_device_type_tests(
TestFlexDecoding, globals(), only_for=test_device, allow_xpu=True
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()