mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
We previously only supported the same v_head dim and + qk_head dim. When allowed for different head-dims I accidently kept the same query strides for the output. This PR fixes this bug as well it ensures that we always produce output in the same stride order as the input query. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135882 Approved by: https://github.com/yanboliang, https://github.com/Chillee
2493 lines
89 KiB
Python
2493 lines
89 KiB
Python
# Owner(s): ["module: inductor"]
|
|
# flake8: noqa: B950
|
|
|
|
import functools
|
|
import string
|
|
import unittest
|
|
from collections import namedtuple
|
|
from contextlib import contextmanager, nullcontext
|
|
from typing import Callable, Optional, Tuple
|
|
from unittest import expectedFailure, skip, skipUnless
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
|
|
from torch._inductor import metrics
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
from torch._inductor.utils import run_and_get_code
|
|
from torch.nn.attention.flex_attention import (
|
|
_create_empty_block_mask,
|
|
_DEFAULT_SPARSE_BLOCK_SIZE,
|
|
_identity,
|
|
_score_mod_signature,
|
|
and_masks,
|
|
BlockMask,
|
|
create_block_mask,
|
|
flex_attention,
|
|
noop_mask,
|
|
or_masks,
|
|
)
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal import common_utils
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU
|
|
from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
# Skip tests if Triton is not available
|
|
supported_platform = skipUnless(
|
|
torch.cuda.is_available()
|
|
and torch.version.hip is None
|
|
and has_triton()
|
|
and torch.cuda.get_device_capability() >= (8, 0),
|
|
"Requires CUDA and Triton",
|
|
)
|
|
|
|
# Use this decorator only when hitting Triton bugs on H100
|
|
running_on_a100_only = skipUnless(
|
|
torch.cuda.is_available()
|
|
and torch.version.hip is None
|
|
and has_triton()
|
|
and torch.cuda.get_device_capability() == (8, 0),
|
|
"Requires A100 and Triton",
|
|
)
|
|
|
|
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
index = torch.ops.aten.index
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
@contextmanager
|
|
def temp_float32_matmul_precision(precision: str):
|
|
"""
|
|
Temporarily set the float32 matmul precision and restore it after the context is exited.
|
|
|
|
Args:
|
|
precision (str): The precision to set ('highest', 'high', or 'medium').
|
|
"""
|
|
original_precision = torch.get_float32_matmul_precision()
|
|
try:
|
|
torch.set_float32_matmul_precision(precision)
|
|
yield
|
|
finally:
|
|
torch.set_float32_matmul_precision(original_precision)
|
|
|
|
|
|
def rmse(ref, res):
|
|
"""
|
|
Calculate root mean squared error
|
|
"""
|
|
return torch.sqrt(torch.mean(torch.square(ref - res)))
|
|
|
|
|
|
def create_attention(score_mod, block_mask, enable_gqa=False):
|
|
return functools.partial(
|
|
flex_attention,
|
|
score_mod=score_mod,
|
|
block_mask=block_mask,
|
|
enable_gqa=enable_gqa,
|
|
)
|
|
|
|
|
|
def create_block_mask_test(score_mod, query, key):
|
|
block_mask = create_block_mask(
|
|
score_mod,
|
|
1,
|
|
1,
|
|
query.shape[-2],
|
|
key.shape[-2],
|
|
query.device,
|
|
)
|
|
return block_mask
|
|
|
|
|
|
test_dtypes = (
|
|
[torch.float16, torch.bfloat16, torch.float32]
|
|
if PLATFORM_SUPPORTS_BF16
|
|
else [torch.float16, torch.float32]
|
|
)
|
|
|
|
test_dtypes_fast = [torch.float16]
|
|
|
|
|
|
# --------- Useful score mod functions for testing ---------
|
|
def _causal(
|
|
score: Tensor,
|
|
batch: Tensor,
|
|
head: Tensor,
|
|
token_q: Tensor,
|
|
token_kv: Tensor,
|
|
) -> Tensor:
|
|
return torch.where(token_q >= token_kv, score, float("-inf"))
|
|
|
|
|
|
def _rel_bias(
|
|
score: Tensor,
|
|
batch: Tensor,
|
|
head: Tensor,
|
|
token_q: Tensor,
|
|
token_kv: Tensor,
|
|
) -> Tensor:
|
|
return score + (token_q - token_kv)
|
|
|
|
|
|
def _rel_causal(
|
|
score: Tensor,
|
|
batch: Tensor,
|
|
head: Tensor,
|
|
token_q: Tensor,
|
|
token_kv: Tensor,
|
|
) -> Tensor:
|
|
return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf"))
|
|
|
|
|
|
def _generate_alibi_bias(num_heads: int):
|
|
def _alibi_bias(
|
|
score: Tensor,
|
|
batch: Tensor,
|
|
head: Tensor,
|
|
token_q: Tensor,
|
|
token_kv: Tensor,
|
|
) -> Tensor:
|
|
scale = torch.exp2(-((head + 1) * 8.0 / num_heads))
|
|
return score + (token_kv - token_q) * scale
|
|
|
|
return _alibi_bias
|
|
|
|
|
|
def _inverse_causal(score, b, h, m, n):
|
|
return torch.where(m <= n, score, float("-inf"))
|
|
|
|
|
|
def _times_two(score, b, h, m, n):
|
|
"""Joint graph needed for correctness"""
|
|
return score * 2
|
|
|
|
|
|
def _squared(score, b, h, m, n):
|
|
"""Joint graph needed for correctness"""
|
|
return score * score
|
|
|
|
|
|
def _head_offset(dtype: torch.dtype):
|
|
"""Captured Buffer"""
|
|
head_offset = torch.rand(H, device="cuda", dtype=dtype)
|
|
|
|
def score_mod(score, b, h, m, n):
|
|
return score * head_offset[h]
|
|
|
|
return score_mod
|
|
|
|
|
|
def _trig(score, b, h, m, n):
|
|
"""Joint graph needed for correctness"""
|
|
return torch.sin(torch.cos(score)) + torch.tan(b)
|
|
|
|
|
|
def _trig2(score, b, h, m, n):
|
|
"""Branching joint graph"""
|
|
cos_score = torch.cos(score)
|
|
sin_score = torch.sin(score)
|
|
z = cos_score * sin_score + torch.tan(b)
|
|
return z
|
|
|
|
|
|
test_score_mods = [
|
|
_identity,
|
|
_times_two,
|
|
_squared,
|
|
_causal,
|
|
_inverse_causal,
|
|
_rel_bias,
|
|
_rel_causal,
|
|
_generate_alibi_bias(8),
|
|
]
|
|
|
|
captured_buffers_map = {
|
|
"_head_offset": _head_offset,
|
|
}
|
|
|
|
B = 4
|
|
H = 8
|
|
S = 2048
|
|
D = 64
|
|
|
|
test_Hq_Hkv = [
|
|
(4, 2),
|
|
(4, 1),
|
|
]
|
|
|
|
test_Bq_Bkv = [
|
|
(3, 1),
|
|
(4, 1),
|
|
(5, 1),
|
|
]
|
|
|
|
|
|
def query_key_value_clones(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
dtype: torch.dtype = None,
|
|
):
|
|
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
|
|
if dtype is None:
|
|
dtype = query.dtype
|
|
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
|
|
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
|
|
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
|
|
return query_ref, key_ref, value_ref
|
|
|
|
|
|
class TestFlexAttention(InductorTestCase):
|
|
def _check_equal(
|
|
self,
|
|
golden_out: torch.Tensor,
|
|
ref_out: torch.Tensor,
|
|
compiled_out: torch.Tensor,
|
|
fudge_factor: float,
|
|
tensor_name: Optional[str] = None,
|
|
):
|
|
compiled_error = (golden_out - compiled_out).abs().mean()
|
|
ref_error = (golden_out - ref_out).abs().mean()
|
|
if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
|
|
self.assertTrue(False, "Output/Grad with NaN")
|
|
if compiled_error > ref_error * fudge_factor:
|
|
name = tensor_name if tensor_name is not None else ""
|
|
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
|
self.assertTrue(False, msg)
|
|
|
|
def _check_out_and_grad(
|
|
self,
|
|
golden_out: torch.Tensor,
|
|
ref_out: torch.Tensor,
|
|
compiled_out: torch.Tensor,
|
|
q_gold: torch.Tensor,
|
|
q_ref: torch.Tensor,
|
|
q: torch.Tensor,
|
|
k_gold: torch.Tensor,
|
|
k_ref: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v_gold: torch.Tensor,
|
|
v_ref: torch.Tensor,
|
|
v: torch.Tensor,
|
|
):
|
|
dtype = ref_out.dtype
|
|
with torch.no_grad():
|
|
# Note, it seems like we really are less accurate than the float32
|
|
# computation, likely due to the online softmax
|
|
if dtype == torch.float32:
|
|
fudge_factor = 10.0
|
|
else:
|
|
fudge_factor = 1.1
|
|
|
|
# Checkout output
|
|
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
|
|
|
|
# Check gradients
|
|
q_fudge_factor = 1.0 * fudge_factor
|
|
self._check_equal(
|
|
q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
|
|
)
|
|
k_fudge_factor = 1.0 * fudge_factor
|
|
self._check_equal(
|
|
k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
|
|
)
|
|
v_fudge_factor = 1.0 * fudge_factor
|
|
self._check_equal(
|
|
v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
|
|
)
|
|
|
|
def run_test(
|
|
self,
|
|
score_mod: _score_mod_signature,
|
|
dtype: torch.dtype = torch.float16,
|
|
Q_B: int = B,
|
|
Q_H: int = H,
|
|
Q_S: int = S,
|
|
Q_D: int = D,
|
|
KV_B: int = B,
|
|
KV_H: int = H,
|
|
KV_S: int = S,
|
|
V_D: int = D,
|
|
block_mask: Optional[BlockMask] = None,
|
|
):
|
|
if TEST_WITH_ROCM and Q_H != KV_H:
|
|
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
|
|
q = torch.randn(
|
|
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
k = torch.randn(
|
|
(KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
v = torch.randn(
|
|
(KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
|
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
|
sdpa_partial = create_attention(
|
|
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
|
)
|
|
compiled_sdpa = torch.compile(sdpa_partial)
|
|
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
|
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
|
|
compiled_out = compiled_sdpa(q, k, v)
|
|
|
|
backward_grad = torch.randn((Q_B, Q_H, Q_S, V_D), dtype=dtype, device="cuda")
|
|
|
|
golden_out.backward(backward_grad.to(torch.float64))
|
|
ref_out.backward(backward_grad)
|
|
compiled_out.backward(backward_grad)
|
|
|
|
self._check_out_and_grad(
|
|
golden_out,
|
|
ref_out,
|
|
compiled_out,
|
|
q_gold,
|
|
q_ref,
|
|
q,
|
|
k_gold,
|
|
k_ref,
|
|
k,
|
|
v_gold,
|
|
v_ref,
|
|
v,
|
|
)
|
|
|
|
def run_test_with_call(
|
|
self,
|
|
sdpa_call: Callable,
|
|
dtype: torch.dtype = torch.float16,
|
|
Q_B: int = B,
|
|
Q_H: int = H,
|
|
Q_S: int = S,
|
|
Q_D: int = D,
|
|
KV_B: int = B,
|
|
KV_H: int = H,
|
|
KV_S: int = S,
|
|
V_D: int = D,
|
|
):
|
|
q = torch.randn(
|
|
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
k = torch.randn(
|
|
(KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
v = torch.randn(
|
|
(KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
|
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
|
compiled_sdpa = torch.compile(sdpa_call)
|
|
golden_out = sdpa_call(q_gold, k_gold, v_gold)
|
|
ref_out = sdpa_call(q_ref, k_ref, v_ref)
|
|
compiled_out = compiled_sdpa(q, k, v)
|
|
|
|
backward_grad = torch.randn((Q_B, Q_H, Q_S, V_D), dtype=dtype, device="cuda")
|
|
|
|
golden_out.backward(backward_grad.to(torch.float64))
|
|
ref_out.backward(backward_grad)
|
|
compiled_out.backward(backward_grad)
|
|
|
|
self._check_out_and_grad(
|
|
golden_out,
|
|
ref_out,
|
|
compiled_out,
|
|
q_gold,
|
|
q_ref,
|
|
q,
|
|
k_gold,
|
|
k_ref,
|
|
k,
|
|
v_gold,
|
|
v_ref,
|
|
v,
|
|
)
|
|
|
|
def run_dynamic_test(
|
|
self,
|
|
score_mod: Callable,
|
|
dtype: torch.dtype = torch.float16,
|
|
B: int = B,
|
|
H: int = H,
|
|
S: int = S,
|
|
D: int = D,
|
|
):
|
|
# If the seqlen becomes smaller than the seqlen of the previous batch,
|
|
# we can still reuse the block_mask created from a larger seqlen.
|
|
MAX_S = S
|
|
block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S)
|
|
sdpa_partial = create_attention(score_mod, block_mask=block_mask)
|
|
# The first eager batch, shape (B, H, S, D)
|
|
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
|
|
q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
|
|
ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
|
|
golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)
|
|
|
|
backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
|
|
golden_out1.backward(backward_grad1.to(torch.float64))
|
|
ref_out1.backward(backward_grad1)
|
|
|
|
# The second eager batch, shape (B * 2, H, S / 2, D)
|
|
B = int(B * 2)
|
|
S = int(S / 2)
|
|
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
|
|
q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
|
|
ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
|
|
golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)
|
|
|
|
backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
|
|
golden_out2.backward(backward_grad2.to(torch.float64))
|
|
ref_out2.backward(backward_grad2)
|
|
|
|
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
|
|
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
|
|
torch._dynamo.reset()
|
|
# Compiling with dynamic shape in the first batch.
|
|
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
|
|
compiled_out1 = compiled_sdpa(q1, k1, v1)
|
|
compiled_out1.backward(backward_grad1)
|
|
|
|
self._check_out_and_grad(
|
|
golden_out1,
|
|
ref_out1,
|
|
compiled_out1,
|
|
q1_gold,
|
|
q1_ref,
|
|
q1,
|
|
k1_gold,
|
|
k1_ref,
|
|
k1,
|
|
v1_gold,
|
|
v1_ref,
|
|
v1,
|
|
)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
|
|
# No re-compilation, use the compiled dynamic shape version.
|
|
compiled_out2 = compiled_sdpa(q2, k2, v2)
|
|
compiled_out2.backward(backward_grad2)
|
|
self._check_out_and_grad(
|
|
golden_out2,
|
|
ref_out2,
|
|
compiled_out2,
|
|
q2_gold,
|
|
q2_ref,
|
|
q2,
|
|
k2_gold,
|
|
k2_ref,
|
|
k2,
|
|
v2_gold,
|
|
v2_ref,
|
|
v2,
|
|
)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
|
|
# The third iteration, shape (B * 2, H, S * 2, D)
|
|
# Since seqlen is larger than the seqlen in block_mask, throw errors.
|
|
S = int(S * 4)
|
|
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.BackendCompilerFailed, "Q seqlen must be smaller than"
|
|
):
|
|
compiled_sdpa(q3, k3, v3)
|
|
|
|
def run_automatic_dynamic_test(
|
|
self,
|
|
score_mod: Callable,
|
|
dtype: torch.dtype = torch.float16,
|
|
B: int = B,
|
|
H: int = H,
|
|
S: int = S,
|
|
D: int = D,
|
|
):
|
|
MAX_S = S
|
|
block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S)
|
|
sdpa_partial = create_attention(score_mod, block_mask=block_mask)
|
|
# The first eager batch, shape (B, H, S, D)
|
|
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
golden_out1 = sdpa_partial(
|
|
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
|
|
)
|
|
ref_out1 = sdpa_partial(q1, k1, v1)
|
|
|
|
# The second eager batch, shape (B * 2, H, S / 2, D)
|
|
B = int(B * 2)
|
|
S = int(S / 2)
|
|
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
golden_out2 = sdpa_partial(
|
|
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
|
|
)
|
|
ref_out2 = sdpa_partial(q2, k2, v2)
|
|
|
|
# The third eager batch, shape (B * 4, H, S / 4, D)
|
|
B = int(B * 2)
|
|
S = int(S / 2)
|
|
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
|
golden_out3 = sdpa_partial(
|
|
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
|
|
)
|
|
ref_out3 = sdpa_partial(q3, k3, v3)
|
|
|
|
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
|
|
# We check dynamo counters["frames"]["ok"] to ensure:
|
|
# 1, the first batch is compiled with static shape
|
|
# 2, the second batch is compiled with dynamic shape
|
|
# 3, no re-compilation in the third batch
|
|
torch._dynamo.reset()
|
|
|
|
# Note, it seems like we really are less accurate than the float32
|
|
# computation, likely due to the online softmax
|
|
if dtype == torch.float32:
|
|
fudge_factor = 10.0
|
|
else:
|
|
fudge_factor = 1.1
|
|
|
|
# The first batch.
|
|
compiled_sdpa = torch.compile(sdpa_partial)
|
|
compiled_out1 = compiled_sdpa(q1, k1, v1)
|
|
self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
|
|
# The second batch (automatic dynamic).
|
|
compiled_out2 = compiled_sdpa(q2, k2, v2)
|
|
self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
|
|
|
# The third batch (no re-compilation).
|
|
compiled_out3 = compiled_sdpa(q3, k3, v3)
|
|
self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
|
|
self.run_test(score_mod, dtype)
|
|
|
|
@running_on_a100_only
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_builtin_score_mods_seqlen_lt_default_sparse_block_size(
|
|
self, dtype: torch.dtype, score_mod: Callable
|
|
):
|
|
# _DEFAULT_SPARSE_BLOCK_SIZE is 128
|
|
attention = functools.partial(
|
|
flex_attention,
|
|
score_mod=score_mod,
|
|
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
|
|
)
|
|
self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D)
|
|
|
|
@running_on_a100_only
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size(
|
|
self, dtype: torch.dtype, score_mod: Callable
|
|
):
|
|
def causal_mask(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(causal_mask, 1, 1, 64, 64, BLOCK_SIZE=256)
|
|
attention = functools.partial(
|
|
flex_attention,
|
|
score_mod=score_mod,
|
|
block_mask=block_mask,
|
|
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
|
|
)
|
|
self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
|
|
self.run_dynamic_test(score_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_builtin_score_mods_automatic_dynamic(
|
|
self, dtype: torch.dtype, score_mod: Callable
|
|
):
|
|
self.run_automatic_dynamic_test(score_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_builtin_score_mods_different_seqlen(
|
|
self, dtype: torch.dtype, score_mod: Callable
|
|
):
|
|
self.run_test(
|
|
score_mod,
|
|
dtype,
|
|
B,
|
|
H,
|
|
S // 2, # Seqlen of Q is different from seqlen of K/V
|
|
D,
|
|
B,
|
|
H,
|
|
S,
|
|
D,
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("batch_dims", test_Bq_Bkv)
|
|
@common_utils.parametrize("head_dims", test_Hq_Hkv)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_kv_batch_broadcast(
|
|
self,
|
|
dtype: torch.dtype,
|
|
batch_dims: Tuple[int, int],
|
|
head_dims: Tuple[int, int],
|
|
score_mod: Callable,
|
|
):
|
|
Hq, Hkv = head_dims
|
|
assert Hq % Hkv == 0
|
|
|
|
Bq, Bkv = batch_dims
|
|
assert Bq > 1 and Bkv == 1
|
|
|
|
self.run_test(
|
|
score_mod,
|
|
dtype,
|
|
Bq,
|
|
Hq,
|
|
S,
|
|
D,
|
|
Bkv,
|
|
Hkv,
|
|
S,
|
|
D,
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("batch_dims", test_Bq_Bkv)
|
|
@common_utils.parametrize("head_dims", test_Hq_Hkv)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_kv_batch_broadcast_causal_mask(
|
|
self,
|
|
dtype: torch.dtype,
|
|
batch_dims: Tuple[int, int],
|
|
head_dims: Tuple[int, int],
|
|
score_mod: Callable,
|
|
):
|
|
Hq, Hkv = head_dims
|
|
assert Hq % Hkv == 0
|
|
|
|
Bq, Bkv = batch_dims
|
|
assert Bq > 1 and Bkv == 1
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
|
attention = functools.partial(
|
|
flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv)
|
|
)
|
|
|
|
self.run_test_with_call(
|
|
attention,
|
|
torch.float16,
|
|
Bq,
|
|
Hq,
|
|
S,
|
|
D,
|
|
Bkv,
|
|
Hkv,
|
|
S,
|
|
D,
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
def test_GQA(self, dtype: torch.dtype, score_mod: Callable):
|
|
self.run_test(
|
|
score_mod,
|
|
dtype,
|
|
B,
|
|
H * 4, # Hq = 4*Hkv.
|
|
S // 8,
|
|
D,
|
|
B,
|
|
H,
|
|
S,
|
|
D,
|
|
)
|
|
|
|
test_strides = [
|
|
((H * S * D, S * D, D, 1), 997), # offset
|
|
((H * D, D, B * H * D, 1), 499), # transposed dimensions
|
|
((H * S * D, D, H * D, 1), 0), # heads/sequence transposed
|
|
(
|
|
(S * (D + 1), B * S * (D + 1), (D + 1), 1),
|
|
293,
|
|
), # additional buffer on one dim
|
|
(
|
|
(1, D, (B + 1) * (H + 1) * D, 1),
|
|
97,
|
|
), # additional buffer on multiple dim + shared dimension
|
|
]
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
@common_utils.parametrize(
|
|
"q_s", test_strides[:2]
|
|
) # TODO: fix layout for query braodcasting
|
|
@common_utils.parametrize(
|
|
"k_s,v_s",
|
|
[
|
|
(test_strides[0], test_strides[0]),
|
|
(test_strides[0], test_strides[1]),
|
|
(test_strides[2], test_strides[3]),
|
|
(test_strides[3], test_strides[1]),
|
|
# (test_strides[2], test_strides[4]), # TODO: Doesn't work for
|
|
# broadcasting reasons i think
|
|
],
|
|
)
|
|
@common_utils.parametrize("do_s", test_strides[:3])
|
|
def test_strided_inputs(self, dtype: torch.dtype, q_s, k_s, v_s, do_s):
|
|
q1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
|
|
k1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
|
|
v1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
|
|
do1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
|
|
|
|
q_shape = (B, H, S // 2, D)
|
|
k_shape = (B, H, S, D)
|
|
v_shape = (B, H, S, D)
|
|
do_shape = (B, H, S // 2, D)
|
|
|
|
def coerce_to_strides(val, shape, strides):
|
|
strides, offset = strides
|
|
val_max = [x * (y - 1) for x, y in zip(strides, shape)]
|
|
assert sum(val_max) + offset < B * H * S * D * 2
|
|
assert strides[-1] == 1
|
|
return torch.as_strided(val, shape, strides, offset).requires_grad_(True)
|
|
|
|
q = coerce_to_strides(q1, q_shape, q_s)
|
|
k = coerce_to_strides(k1, k_shape, k_s)
|
|
v = coerce_to_strides(v1, v_shape, v_s)
|
|
do = coerce_to_strides(do1, do_shape, do_s)
|
|
|
|
block_mask = _create_empty_block_mask(q, k)
|
|
sdpa_partial = create_attention(
|
|
score_mod=_generate_alibi_bias(8), block_mask=block_mask
|
|
)
|
|
compiled_sdpa = torch.compile(sdpa_partial)
|
|
ref_out = sdpa_partial(q, k, v)
|
|
compiled_out = compiled_sdpa(q, k, v)
|
|
|
|
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
|
|
torch.testing.assert_close(
|
|
ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
ref_out.backward(do)
|
|
ref_grads = [q.grad, k.grad, v.grad]
|
|
q.grad = None
|
|
k.grad = None
|
|
v.grad = None
|
|
|
|
compiled_out.backward(do)
|
|
compiled_grads = [q.grad, k.grad, v.grad]
|
|
q.grad = None
|
|
k.grad = None
|
|
v.grad = None
|
|
torch.testing.assert_close(
|
|
compiled_grads[0], ref_grads[0], atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
torch.testing.assert_close(
|
|
compiled_grads[1], ref_grads[1], atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
torch.testing.assert_close(
|
|
compiled_grads[2], ref_grads[2], atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
|
|
@supported_platform
|
|
def test_doc_mask_sparse(self):
|
|
document_id = torch.zeros(S, dtype=torch.int, device="cuda")
|
|
for i in range(0, S, 256):
|
|
document_id[i : i + 256] = i // 256
|
|
|
|
def document_masking_causal(score, b, h, q_idx, kv_idx):
|
|
causal_mask = q_idx >= kv_idx
|
|
document_mask = document_id[q_idx] == document_id[kv_idx]
|
|
return torch.where(causal_mask & document_mask, score, -float("inf"))
|
|
|
|
self.run_test(document_masking_causal, torch.float16)
|
|
|
|
@supported_platform
|
|
def test_index_multiple(self):
|
|
bias = torch.randn(B, S, device="cuda")
|
|
|
|
def index_multiple(score, b, h, q_idx, kv_idx):
|
|
return score + bias[b][q_idx]
|
|
|
|
self.run_test(index_multiple, torch.float16)
|
|
|
|
@supported_platform
|
|
def test_index_weird1(self):
|
|
bias = torch.randn(4, B, H, S, device="cuda")
|
|
|
|
def index_weird1(score, b, h, q_idx, kv_idx):
|
|
return score + bias[0][b, h][q_idx]
|
|
|
|
self.run_test(index_weird1, torch.float16)
|
|
|
|
@supported_platform
|
|
def test_index_weird2(self):
|
|
bias = torch.randn(B, H, 4, S, device="cuda")
|
|
which_bias = torch.tensor(0, device="cuda")
|
|
|
|
def index_weird2(score, b, h, q_idx, kv_idx):
|
|
return score + bias[b][h][which_bias, q_idx]
|
|
|
|
self.run_test(index_weird2, torch.float16)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
def test_skip_odd_keys(self, dtype: torch.dtype):
|
|
def score_mod(score, b, h, q, kv):
|
|
return torch.where(kv % 2 == 0, score, float("-inf"))
|
|
|
|
self.run_test(score_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
def test_function_composition(self, dtype: torch.dtype):
|
|
def score_mod_1(score, b, h, m, n):
|
|
return score + (m - n)
|
|
|
|
def score_mod_2(score, b, h, m, n):
|
|
return torch.where(m <= n, score, float("-inf"))
|
|
|
|
def composed_score_mod(score, b, h, m, n):
|
|
return score_mod_2(score_mod_1(score, b, h, m, n), b, h, m, n)
|
|
|
|
self.run_test(composed_score_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
def test_captured_buffers(self, dtype: torch.dtype):
|
|
head_offset = torch.rand(H, device="cuda", dtype=dtype)
|
|
|
|
def score_mod(score, b, h, m, n):
|
|
return score + head_offset[h]
|
|
|
|
self.run_test(score_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
def test_captured_buffers_all_dims(self, dtype: torch.dtype):
|
|
head_scale = torch.randn(H, device="cuda")
|
|
batch_scale = torch.randn(B, device="cuda")
|
|
tok_scale = torch.randn(S, device="cuda")
|
|
|
|
def all_bias(score, batch, head, token_q, token_kv):
|
|
score = score + tok_scale[token_q]
|
|
score = score + batch_scale[batch]
|
|
score = score + head_scale[head]
|
|
return score
|
|
|
|
self.run_test(all_bias, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_seq_masking(self, dtype):
|
|
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
|
|
seq_idx[S // 2 :] = 1
|
|
|
|
def seq_mask_mod(score, b, h, q, kv):
|
|
return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
|
|
|
|
self.run_test(seq_mask_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_load_from_bias_seq_only(self, dtype):
|
|
bias = torch.randn(S, S, device="cuda", dtype=dtype)
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
return score + bias[q, kv]
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_load_from_bias_seq_batch(self, dtype):
|
|
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
return score + bias[b, q, kv]
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_load_from_bias_head_seq_batch(self, dtype):
|
|
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
return score + bias[b, h, q, kv]
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_load_rel_bias(self, dtype):
|
|
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
return score + rel_bias[(q - kv) + S]
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_dependent_causal_bidirectional(self, dtype):
|
|
num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
causal_attention = q >= kv
|
|
cur_num_bidirectional = num_bidirectional[b]
|
|
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
|
|
kv <= cur_num_bidirectional
|
|
)
|
|
return torch.where(
|
|
bidirectional_attention_on_video | causal_attention,
|
|
score,
|
|
-float("inf"),
|
|
)
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_natten_2d(self, dtype):
|
|
H = 32
|
|
W = S // H
|
|
WINDOW = 3
|
|
assert W * H == S
|
|
|
|
def get_x_y(idx):
|
|
# This should be a floor divide, but we don't support that properly
|
|
return idx / W, idx % W
|
|
|
|
def natten_mask(score, b, h, q, kv):
|
|
q_x, q_y = get_x_y(q)
|
|
kv_x, kv_y = get_x_y(kv)
|
|
return torch.where(
|
|
((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW),
|
|
score,
|
|
float("-inf"),
|
|
)
|
|
|
|
self.run_test(natten_mask, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_subgraph_respect_decompostion(self, dtype):
|
|
from torch._decomp import core_aten_decompositions
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
def score_mod_func(score, b, h, q, kv):
|
|
return score - q // (1 + kv)
|
|
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 4),
|
|
device="cuda",
|
|
dtype=torch.float64,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
# floor_div is not decomposed in decompostion_table is empty
|
|
attention = functools.partial(flex_attention, score_mod=score_mod_func)
|
|
gm = make_fx(attention, decomposition_table={})(query, key, value)
|
|
self.assertExpectedInline(
|
|
gm.sdpa_score0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
|
|
floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None
|
|
sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None
|
|
return sub""",
|
|
)
|
|
|
|
# floor_div is decomposed for core_aten_decompositions
|
|
gm = make_fx(attention, decomposition_table=core_aten_decompositions())(
|
|
query, key, value
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.sdpa_score0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
|
|
div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None
|
|
sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None
|
|
return sub""",
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_silu_on_score(self, dtype):
|
|
def silu_score(score, b, h, q, kv):
|
|
return torch.nn.functional.silu(score)
|
|
|
|
self.run_test(silu_score, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_padded_dense_causal(self, dtype):
|
|
seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1
|
|
|
|
def create_padded_dense_wrapper(orig_score_mod):
|
|
def njt_score_mod(qk, b, h, q, kv):
|
|
return torch.where(
|
|
qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
|
|
)
|
|
|
|
return njt_score_mod
|
|
|
|
causal_njt = create_padded_dense_wrapper(_causal)
|
|
|
|
self.run_test(causal_njt, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_captured_scale(self, dtype):
|
|
scale = torch.ones((), device="cuda", dtype=torch.int32)
|
|
|
|
def score_mod_scale(qk, b, h, q, kv):
|
|
return qk + scale
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_recompile_changed_score_mod(self, dtype):
|
|
scale = torch.ones((), device="cuda", dtype=torch.int32)
|
|
ADD = True
|
|
|
|
def score_mod_scale(qk, b, h, q, kv):
|
|
if ADD:
|
|
return qk + scale
|
|
else:
|
|
return qk * scale
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
ADD = False
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
@supported_platform
|
|
@expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
def test_captured_reduction(self, dtype):
|
|
scale = torch.randn((B, 8), device="cuda")
|
|
|
|
def score_mod_scale(qk, b, h, q, kv):
|
|
return qk + scale[b].sum(dim=-1)
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
@supported_platform
|
|
def test_multiple_score_mod_calls(self):
|
|
query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
keys = [
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
for _ in range(2)
|
|
]
|
|
values = [
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
for _ in range(2)
|
|
]
|
|
|
|
def scoremod_1(qk, b, h, q, kv):
|
|
return qk + (q - kv)
|
|
|
|
def scoremod_2(qk, b, h, q, kv):
|
|
return torch.where(q >= kv, qk, -float("inf"))
|
|
|
|
def f(q, k1, k2, v1, v2):
|
|
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1)
|
|
return flex_attention(q2, k2, v2, score_mod=scoremod_2)
|
|
|
|
out = f(query, *keys, *values)
|
|
out2 = torch.compile(f)(query, *keys, *values)
|
|
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
|
|
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
|
|
|
|
@supported_platform
|
|
def test_multiple_score_mod_calls2(self):
|
|
query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
keys = [
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
for _ in range(3)
|
|
]
|
|
values = [
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
for _ in range(3)
|
|
]
|
|
|
|
def scoremod_1(qk, b, h, q, kv):
|
|
return qk + (q - kv)
|
|
|
|
def scoremod_2(qk, b, h, q, kv):
|
|
return torch.where(q >= kv, qk, -float("inf"))
|
|
|
|
attention1 = functools.partial(flex_attention, score_mod=scoremod_1)
|
|
|
|
def f(q, k1, k2, k3, v1, v2, v3):
|
|
q2 = attention1(q, k1, v1)
|
|
q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2)
|
|
return flex_attention(q3, k3, v3, score_mod=scoremod_1)
|
|
|
|
out = f(query, *keys, *values)
|
|
out2 = torch.compile(f)(query, *keys, *values)
|
|
self.assertTrue((out - out2).abs().mean() < 1e-2)
|
|
|
|
@supported_platform
|
|
def test_inputs_are_realized(self):
|
|
def f(q, k, v):
|
|
x = torch.randn(1024, device="cuda")
|
|
x = x * 2
|
|
|
|
def func(qk, b, h, q, kv):
|
|
return qk + x[q]
|
|
|
|
return flex_attention(q.sin(), k, v, score_mod=func).cos()
|
|
|
|
q, k, v = (
|
|
torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
|
|
for _ in range(3)
|
|
)
|
|
ref = f(q, k, v)
|
|
out = torch.compile(f)(q, k, v)
|
|
self.assertTrue((ref - out).abs().mean() < 1e-2)
|
|
gradOut = torch.randn_like(q)
|
|
|
|
ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut)
|
|
out_grads = torch.autograd.grad(out, (q, k, v), gradOut)
|
|
for ref, out in zip(ref_grads, out_grads):
|
|
self.assertTrue((ref - out).abs().mean() < 1e-2)
|
|
|
|
@supported_platform
|
|
def test_make_block_mask(self):
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
block_mask_a = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=True)
|
|
block_mask_b = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=False)
|
|
self.assertEqual(block_mask_a.kv_num_blocks, block_mask_b.kv_num_blocks)
|
|
self.assertEqual(block_mask_a.kv_indices, block_mask_b.kv_indices)
|
|
self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks)
|
|
|
|
@supported_platform
|
|
def test_mask_mod_combiners(self):
|
|
def causal_mask(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
def neg_causal_mask(b, h, q, kv):
|
|
return q < kv
|
|
|
|
def sliding_window(b, h, q, kv):
|
|
return (q - kv) <= 512
|
|
|
|
block_mask = create_block_mask(
|
|
and_masks(causal_mask, sliding_window), 1, 1, S, S
|
|
)
|
|
self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""")
|
|
attention = functools.partial(flex_attention, block_mask=block_mask)
|
|
self.run_test_with_call(attention)
|
|
|
|
block_mask = create_block_mask(
|
|
and_masks(causal_mask, neg_causal_mask), 1, 1, S, S
|
|
)
|
|
self.assertEqual(block_mask.kv_num_blocks.sum(), 0)
|
|
|
|
block_mask1 = create_block_mask(
|
|
or_masks(causal_mask, neg_causal_mask), 1, 1, S, S
|
|
)
|
|
block_mask2 = create_block_mask(noop_mask, 1, 1, S, S)
|
|
self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity())
|
|
|
|
@supported_platform
|
|
def test_epilogue_fused(self):
|
|
@torch.compile
|
|
def f(q, k, v):
|
|
out = flex_attention(q, k, v)
|
|
return out.cos()
|
|
|
|
q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3))
|
|
metrics.reset()
|
|
_, code = run_and_get_code(f, q, k, v)
|
|
fc = FileCheck()
|
|
fc.check("triton_tem_fused") # template call
|
|
fc.check_not("poi_fused_cos") # No cos pointwise operation
|
|
fc.run(code[0])
|
|
accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
|
|
num_accesses = 4 # q, k, v reads, one output.
|
|
# TODO: Get rid of this fudge factor
|
|
# We need this fudge factor for now as we write the extraneous logsumexp
|
|
num_accesses += 1
|
|
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
def test_njt_causal(self, dtype):
|
|
offsets = torch.tensor(
|
|
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
|
|
)
|
|
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
|
|
for idx in range(len(offsets) - 1):
|
|
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
|
|
|
|
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
|
|
def njt_score_mod(qk, b, h, q, kv):
|
|
q_nested = q - offsets[seq_idx[q]]
|
|
kv_nested = kv - offsets[seq_idx[kv]]
|
|
return orig_score_mod(qk, b, h, q_nested, kv_nested)
|
|
|
|
return njt_score_mod
|
|
|
|
causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)
|
|
|
|
self.run_test(causal_njt, dtype)
|
|
|
|
@supported_platform
|
|
def test_mixed_dtypes_fails(self):
|
|
query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
|
|
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
|
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Expected query, key, and value to have the same dtype"
|
|
):
|
|
flex_attention(query, key, value, _identity)
|
|
|
|
@supported_platform
|
|
@patch.object(torch._inductor.config, "max_autotune", True)
|
|
def test_max_autotune(self):
|
|
def score_mod(score, b, h, m, n):
|
|
return score * 2
|
|
|
|
self.run_test(score_mod)
|
|
|
|
@supported_platform
|
|
@skip("TODO: Figure out why this is erroring")
|
|
@patch.object(torch._inductor.config, "max_autotune", True)
|
|
def test_max_autotune_with_captured(self):
|
|
head_scale = torch.randn(H, device="cuda")
|
|
batch_scale = torch.randn(B, device="cuda")
|
|
tok_scale = torch.randn(S, device="cuda")
|
|
|
|
def bias_mod(score, batch, head, token_q, token_kv):
|
|
score = score + tok_scale[token_q]
|
|
score = score + batch_scale[batch]
|
|
score = score + head_scale[head]
|
|
return score
|
|
|
|
self.run_test(bias_mod)
|
|
|
|
# TODO this config segfaults with Triton without:
|
|
# https://github.com/triton-lang/triton/pull/4540
|
|
@supported_platform
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
@common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
|
|
def test_non_equal_head_dims(self, dtype, score_mod, head_dims):
|
|
qk_d, v_d = head_dims
|
|
context = nullcontext() if qk_d > v_d else self.assertRaises(ValueError)
|
|
with context:
|
|
self.run_test(score_mod, dtype, B, H, S, qk_d, B, H, S, V_D=v_d)
|
|
|
|
@supported_platform
|
|
def test_autograd_function_in_score_mod(self):
|
|
class ApplyMask(torch.autograd.Function):
|
|
generate_vmap_rule = True
|
|
|
|
@staticmethod
|
|
def forward(a, mask):
|
|
return torch.where(mask, a, -float("inf"))
|
|
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
_, mask = inputs
|
|
ctx.mark_non_differentiable(mask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, i):
|
|
return i, None
|
|
|
|
def score_mod(score, b, h, q, kv):
|
|
return ApplyMask.apply(score, q <= kv)
|
|
|
|
func = torch.compile(flex_attention, fullgraph=True)
|
|
|
|
q, k, v = (
|
|
torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
|
|
for _ in range(3)
|
|
)
|
|
|
|
# Just checking that it runs
|
|
func(q, k, v)
|
|
|
|
# expectedFailure
|
|
# This doesn't work due to vmap + autograd.Function + torch.compile not composing
|
|
# self.run_test(score_mod)
|
|
|
|
@supported_platform
|
|
def test_causal_block(self):
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
|
attention = functools.partial(flex_attention, block_mask=block_mask)
|
|
|
|
self.run_test_with_call(attention)
|
|
|
|
@skipIfRocm
|
|
@supported_platform
|
|
def test_GQA_causal_mask(self):
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S // 8, S // 8)
|
|
attention = functools.partial(
|
|
flex_attention, block_mask=block_mask, enable_gqa=True
|
|
)
|
|
|
|
self.run_test_with_call(
|
|
attention,
|
|
torch.float16,
|
|
B,
|
|
H * 4, # Hq = 4*Hkv.
|
|
S // 8,
|
|
D,
|
|
B,
|
|
H,
|
|
S // 8,
|
|
D,
|
|
)
|
|
|
|
@supported_platform
|
|
def test_custom_block_mask_generator(self):
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
auto_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
|
BLOCK_SIZE = 128
|
|
|
|
def causal_constructor(S):
|
|
num_blocks = torch.arange(S // BLOCK_SIZE, device="cuda") + 1
|
|
indices = torch.arange(S // BLOCK_SIZE, device="cuda").expand(
|
|
S // BLOCK_SIZE, S // BLOCK_SIZE
|
|
)
|
|
num_blocks = num_blocks[None, None, :]
|
|
indices = indices[None, None, :]
|
|
return BlockMask.from_kv_blocks(
|
|
num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=mask_mod
|
|
)
|
|
|
|
manual_mask = causal_constructor(S)
|
|
self.assertEqual(auto_mask.to_dense(), manual_mask.to_dense())
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
@common_utils.parametrize("score_mod", [_identity, _causal])
|
|
def test_logsumexp_correctness(self, dtype, score_mod):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(B, H, S, D),
|
|
dtype=dtype,
|
|
device="cuda",
|
|
requires_grad=True,
|
|
)
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
@torch.compile
|
|
def sdpa_hop(q, k, v, score_mod):
|
|
return flex_attention(q, k, v, score_mod, return_lse=True)
|
|
|
|
@torch.compile(backend="aot_eager")
|
|
def eager_sdpa_hop(q, k, v, score_mod):
|
|
return flex_attention(q, k, v, score_mod, return_lse=True)
|
|
|
|
ref_out, ref_lse = eager_sdpa_hop(
|
|
q.to(torch.float64),
|
|
k.to(torch.float64),
|
|
v.to(torch.float64),
|
|
score_mod,
|
|
)
|
|
compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod)
|
|
|
|
self.assertTrue(ref_lse.dtype == torch.float64)
|
|
self.assertTrue(compiled_lse.dtype == torch.float32)
|
|
|
|
tolerance = Tolerances(atol=2e-2, rtol=2e-2)
|
|
torch.testing.assert_close(
|
|
ref_out.to(dtype=torch.float32),
|
|
compiled_out.to(dtype=torch.float32),
|
|
atol=tolerance.atol,
|
|
rtol=tolerance.rtol,
|
|
)
|
|
torch.testing.assert_close(
|
|
ref_lse.to(dtype=torch.float32),
|
|
compiled_lse.to(dtype=torch.float32),
|
|
atol=tolerance.atol,
|
|
rtol=tolerance.rtol,
|
|
)
|
|
|
|
@supported_platform
|
|
def test_logsumexp_only_return(self):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(B, H, S, D),
|
|
dtype=torch.float32,
|
|
device="cuda",
|
|
requires_grad=True,
|
|
)
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
@torch.compile
|
|
def func(q, k, v, score_mod):
|
|
_, lse = flex_attention(q, k, v, score_mod, return_lse=True)
|
|
lse_2 = lse * 2
|
|
return lse_2
|
|
|
|
_, code = run_and_get_code(func, q, k, v, _identity)
|
|
# Ensure that we're still generating the flexattention kernel
|
|
FileCheck().check_count(".run(primals_1, primals_2, primals_3", 1, True).run(
|
|
code[0]
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize(
|
|
"score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2]
|
|
)
|
|
def test_aot_eager_gradcheck(self, score_mod):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 4),
|
|
device="cuda",
|
|
dtype=torch.float64,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
func = torch.compile(flex_attention, backend="aot_eager", fullgraph=True)
|
|
|
|
self.assertTrue(
|
|
torch.autograd.gradcheck(
|
|
func, (query, key, value, score_mod), raise_exception=True
|
|
)
|
|
)
|
|
|
|
@supported_platform
|
|
def test_eager_backward_strides(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.qkv_proj = torch.nn.Linear(256, 256 * 3)
|
|
self.n_head = 256 // 64
|
|
self.d_attn = 256
|
|
|
|
def forward(self, x):
|
|
n_batch, n_ctx, _ = x.shape
|
|
q, k, v = self.qkv_proj(x).split(
|
|
[self.d_attn, self.d_attn, self.d_attn], dim=2
|
|
)
|
|
q = q.reshape(n_batch, n_ctx, self.n_head, -1)
|
|
k = k.reshape(n_batch, n_ctx, self.n_head, -1)
|
|
v = v.reshape(n_batch, n_ctx, self.n_head, -1)
|
|
q = q.transpose(1, 2)
|
|
k = k.transpose(1, 2)
|
|
v = v.transpose(1, 2)
|
|
x = torch.nn.attention.flex_attention.flex_attention(q, k, v)
|
|
return x
|
|
|
|
model = Repro().cuda()
|
|
x = torch.randn((1, 512, 256), device="cuda", requires_grad=True)
|
|
out = torch.compile(model, backend="aot_eager")(x)
|
|
out.backward(torch.ones_like(out))
|
|
|
|
@supported_platform
|
|
def test_differentiable_logsumexp_gradcheck(self):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 4),
|
|
device="cuda",
|
|
dtype=torch.float64,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
def flex_attention_lse_only(q, k, v):
|
|
return flex_attention(q, k, v, return_lse=True)[1]
|
|
|
|
func = torch.compile(
|
|
flex_attention_lse_only, backend="aot_eager", fullgraph=True
|
|
)
|
|
|
|
self.assertTrue(
|
|
torch.autograd.gradcheck(func, (query, key, value), raise_exception=True)
|
|
)
|
|
|
|
@supported_platform
|
|
def test_differentiable_logsumexp_compiled(self):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 64),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=True,
|
|
)
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
lse_mask = torch.randn(2, 2, 128, device="cuda")
|
|
|
|
out, lse = flex_attention(q, k, v, return_lse=True)
|
|
(out.mean() + (lse * lse_mask).sum()).backward()
|
|
q_grad, k_grad, v_grad = q.grad, k.grad, v.grad
|
|
q.grad = None
|
|
k.grad = None
|
|
v.grad = None
|
|
|
|
out2, lse2 = torch.compile(flex_attention)(q, k, v, return_lse=True)
|
|
(out2.mean() + (lse2 * lse_mask).sum()).backward()
|
|
q_grad2, k_grad2, v_grad2 = q.grad, k.grad, v.grad
|
|
tolerance = Tolerances(atol=1e-1, rtol=1e-1)
|
|
|
|
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
|
|
torch.testing.assert_close(lse, lse2, atol=tolerance.atol, rtol=tolerance.rtol)
|
|
torch.testing.assert_close(
|
|
q_grad, q_grad2, atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
torch.testing.assert_close(
|
|
k_grad, k_grad2, atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
torch.testing.assert_close(
|
|
v_grad, v_grad2, atol=tolerance.atol, rtol=tolerance.rtol
|
|
)
|
|
|
|
@supported_platform
|
|
def test_float32_matmul_precision(self):
|
|
make_tensor = functools.partial(
|
|
torch.zeros,
|
|
(2, 2, 128, 32),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
query.fill_(0.2)
|
|
key.fill_(0.3)
|
|
value.fill_(0.4)
|
|
|
|
query.requires_grad = True
|
|
key.requires_grad = True
|
|
value.requires_grad = True
|
|
|
|
def score_mod(score, b, h, q, kv):
|
|
return score * 2
|
|
|
|
with temp_float32_matmul_precision("highest"):
|
|
out_eager = flex_attention(query, key, value, score_mod)
|
|
flex_compiled = torch.compile(flex_attention, fullgraph=True)
|
|
out_compiled = flex_compiled(query, key, value, score_mod)
|
|
|
|
grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value))
|
|
grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value))
|
|
|
|
torch.testing.assert_close(grads_eager, grads_compile)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("score_mod_name", ["_head_offset"])
|
|
@common_utils.parametrize("mode", ["eager", "aot_eager"])
|
|
def test_captured_score_mod_aot_eager_gradcheck(
|
|
self, score_mod_name: str, mode: str
|
|
):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 4),
|
|
device="cuda",
|
|
dtype=torch.float64,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
func = torch.compile(flex_attention, backend=mode, fullgraph=True)
|
|
score_mod = captured_buffers_map[score_mod_name](torch.float64)
|
|
|
|
self.assertTrue(
|
|
torch.autograd.gradcheck(
|
|
func, (query, key, value, score_mod), raise_exception=True
|
|
)
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("mode", ["eager", "aot_eager"])
|
|
def test_document_masking_edge_case(self, mode):
|
|
document_masks = torch.full((2, 128), 0, dtype=torch.int32, device="cuda")
|
|
document_masks[:, 64:] = 1
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
same_doc = document_masks[b, q] == document_masks[b, kv]
|
|
return same_doc
|
|
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 1, 128, 4),
|
|
device="cuda",
|
|
dtype=torch.float64,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
func = torch.compile(flex_attention, backend=mode, fullgraph=True)
|
|
|
|
block_mask = create_block_mask(mask_mod, 2, 1, 128, 128)
|
|
out = func(query, key, value, block_mask=block_mask)
|
|
out.sum().backward()
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("mode", ["eager", "inductor"])
|
|
@common_utils.parametrize(
|
|
"permute_order",
|
|
[
|
|
(0, 1, 2, 3), # Default order
|
|
(1, 0, 2, 3), # Reverse order
|
|
(0, 2, 1, 3), # Mixed order
|
|
(2, 0, 1, 3), # Another mixed order
|
|
],
|
|
)
|
|
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
|
|
def test_flex_attention_stride_ordering(self, mode, permute_order, shape):
|
|
from torch._inductor.ir import get_stride_order
|
|
|
|
# Setup
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
shape,
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=True,
|
|
)
|
|
|
|
# Create and permute tensors
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
query = query.permute(permute_order)
|
|
key = key.permute(permute_order)
|
|
value = value.permute(permute_order)
|
|
|
|
if mode == "inductor":
|
|
func = torch.compile(flex_attention, backend=mode, fullgraph=True)
|
|
else:
|
|
func = flex_attention
|
|
|
|
out = func(query, key, value)
|
|
|
|
out_stride_order = get_stride_order(out.stride())
|
|
query_stride_order = get_stride_order(query.stride())
|
|
|
|
self.assertEqual(
|
|
out_stride_order,
|
|
query_stride_order,
|
|
f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}",
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("compile", [True, False])
|
|
def test_fully_masked_out_rows_0_check(self, compile: bool):
|
|
# Ensure fully masked out rows won't cause NaNs.
|
|
query = torch.randn(
|
|
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
|
)
|
|
key = torch.randn(
|
|
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
|
)
|
|
value = torch.randn(
|
|
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
|
)
|
|
|
|
M = S // 2
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
return q < M
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
|
|
|
flex = (
|
|
torch.compile(flex_attention, dynamic=False) if compile else flex_attention
|
|
)
|
|
out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True)
|
|
self.assertEqual(out[:, :, M:, :].sum(), 0)
|
|
self.assertTrue((lse[:, :, M:] == -float("inf")).all())
|
|
|
|
loss = out.sum() + lse.sum()
|
|
loss.backward()
|
|
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("compile", [True, False])
|
|
def test_fully_masked_out_rows(self, compile: bool):
|
|
M = S // 2
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
return q < M
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
|
|
|
def noop_mod(score, b, h, q_idx, kv_idx):
|
|
return score
|
|
|
|
self.run_test(noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask)
|
|
|
|
@supported_platform
|
|
def test_kernel_options_argument_is_respected(self):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 64),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=True,
|
|
)
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
# Ensure we respect user's input kernel options.
|
|
_, code = run_and_get_code(
|
|
torch.compile(flex_attention), q, k, v, kernel_options={"BLOCK_M": 16}
|
|
)
|
|
FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0])
|
|
|
|
@supported_platform
|
|
def test_comparison_vs_sdpa(self):
|
|
def causal(score, b, h, q_idx, kv_idx):
|
|
return torch.where(q_idx >= kv_idx, score, -float("inf"))
|
|
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
no_sparse_flex = functools.partial(flex_attention, score_mod=causal)
|
|
score_mod_sparse_flex = functools.partial(
|
|
flex_attention,
|
|
score_mod=causal,
|
|
block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048),
|
|
)
|
|
mask_mod_sparse_flex = functools.partial(
|
|
flex_attention, block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048)
|
|
)
|
|
for attention_call in [
|
|
no_sparse_flex,
|
|
score_mod_sparse_flex,
|
|
mask_mod_sparse_flex,
|
|
]:
|
|
inputs = [
|
|
torch.randn(
|
|
2,
|
|
2,
|
|
2048,
|
|
64,
|
|
device="cuda",
|
|
dtype=torch.float16,
|
|
requires_grad=True,
|
|
)
|
|
for _ in range(3)
|
|
]
|
|
gradOut = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float16)
|
|
out_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
*inputs, is_causal=True
|
|
)
|
|
out_ref.backward(gradOut)
|
|
|
|
inputs_flex = [i.detach().clone().requires_grad_(True) for i in inputs]
|
|
out_flex = torch.compile(attention_call)(*inputs_flex)
|
|
out_flex.backward(gradOut)
|
|
inputs_golden = [
|
|
i.detach().clone().to(dtype=torch.float64).requires_grad_(True)
|
|
for i in inputs
|
|
]
|
|
out_golden = torch.nn.functional.scaled_dot_product_attention(
|
|
*inputs_golden, is_causal=True
|
|
)
|
|
out_golden.backward(gradOut.to(dtype=torch.float64))
|
|
|
|
for ref, flex, golden in [
|
|
(out_ref, out_flex, out_golden),
|
|
(inputs[0].grad, inputs_flex[0].grad, inputs_golden[0].grad),
|
|
(inputs[1].grad, inputs_flex[1].grad, inputs_golden[1].grad),
|
|
(inputs[2].grad, inputs_flex[2].grad, inputs_golden[2].grad),
|
|
]:
|
|
ref_error = rmse(ref, golden)
|
|
flex_error = rmse(flex, golden)
|
|
# Note: This has been carefully tested that FlexAttention is within
|
|
# 20% of the average error of SDPA! Do not bump this tolerance
|
|
# unless you are absolutely sure you are not worsening the accuracy
|
|
# of FlexAttention!
|
|
self.assertTrue(
|
|
ref_error * 1.2 > flex_error,
|
|
f"Ref error: {ref_error}, Flex Error: {flex_error}",
|
|
)
|
|
|
|
@supported_platform
|
|
def test_causal_block_non_divisible(self):
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S - 1, S - 1)
|
|
attention = functools.partial(flex_attention, block_mask=block_mask)
|
|
|
|
self.run_test_with_call(attention, Q_S=S - 1, KV_S=S - 1)
|
|
|
|
@supported_platform
|
|
def test_force_write_lse(self):
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 16),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
out_eager, lse_eager = flex_attention(query, key, value, return_lse=True)
|
|
|
|
flex_compile = torch.compile(flex_attention, fullgraph=True)
|
|
out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True)
|
|
|
|
torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("backend", ["flex_attention", "flex_decode", "eager"])
|
|
def test_lse_masked_output(self, backend):
|
|
if backend == "flex_decode":
|
|
kernel_options = {"FORCE_USE_FLEX_ATTENTION": False}
|
|
flex_call = torch.compile(flex_attention, fullgraph=True)
|
|
elif backend == "flex_attention":
|
|
kernel_options = {"FORCE_USE_FLEX_ATTENTION": True}
|
|
flex_call = torch.compile(flex_attention, fullgraph=True)
|
|
else:
|
|
kernel_options = {}
|
|
flex_call = flex_attention
|
|
|
|
N_CTX = 96
|
|
SLIDING_WINDOW = 64
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, N_CTX, 64),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=True,
|
|
)
|
|
|
|
def sliding_window_causal(b, h, q_idx, kv_idx):
|
|
causal_mask = q_idx >= kv_idx
|
|
window_mask = q_idx - kv_idx <= SLIDING_WINDOW
|
|
return causal_mask & window_mask
|
|
|
|
def global_causal(b, h, q_idx, kv_idx):
|
|
causal_mask = q_idx >= kv_idx
|
|
window_mask = q_idx - kv_idx > SLIDING_WINDOW
|
|
return causal_mask & window_mask
|
|
|
|
sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask(
|
|
sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
|
|
)
|
|
global_causal = torch.nn.attention.flex_attention.create_block_mask(
|
|
global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
|
|
)
|
|
|
|
local_attn = functools.partial(
|
|
flex_call,
|
|
block_mask=sliding_window_causal,
|
|
return_lse=True,
|
|
kernel_options=kernel_options,
|
|
)
|
|
global_attn = functools.partial(
|
|
flex_call,
|
|
block_mask=global_causal,
|
|
return_lse=True,
|
|
kernel_options=kernel_options,
|
|
)
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
gradOut = make_tensor(requires_grad=False)
|
|
|
|
x_local, lse_local = local_attn(q, k, v)
|
|
x_global, lse_global = global_attn(q, k, v)
|
|
|
|
max_lse = torch.maximum(lse_local, lse_global)
|
|
lse_global = lse_global - max_lse
|
|
lse_local = lse_local - max_lse
|
|
lse_global = torch.exp(lse_global)
|
|
lse_local = torch.exp(lse_local)
|
|
x = ((x_local * lse_local[..., None]) + (x_global * lse_global[..., None])) / (
|
|
lse_global[..., None] + lse_local[..., None]
|
|
)
|
|
x.backward(gradOut)
|
|
flex_q_grad, flex_k_grad, flex_v_grad = q.grad, k.grad, v.grad
|
|
q.grad = None
|
|
k.grad = None
|
|
v.grad = None
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
out.backward(gradOut)
|
|
|
|
torch.testing.assert_close(x, out, atol=3e-3, rtol=2e-3)
|
|
torch.testing.assert_close(flex_q_grad, q.grad, atol=3e-3, rtol=2e-3)
|
|
torch.testing.assert_close(flex_k_grad, k.grad, atol=3e-3, rtol=2e-3)
|
|
torch.testing.assert_close(flex_v_grad, v.grad, atol=3e-3, rtol=2e-3)
|
|
|
|
@supported_platform
|
|
def test_small_q_kv_len(self):
|
|
make_tensor = functools.partial(
|
|
torch.ones,
|
|
(1, 1, 1, 16),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
kernel_options = {"FORCE_USE_FLEX_ATTENTION": True}
|
|
out_eager, lse_eager = flex_attention(
|
|
query, key, value, return_lse=True, kernel_options=kernel_options
|
|
)
|
|
|
|
flex_compile = torch.compile(flex_attention, fullgraph=True)
|
|
out_compiled, lse_compiled = flex_compile(
|
|
query, key, value, return_lse=True, kernel_options=kernel_options
|
|
)
|
|
|
|
assert torch.equal(out_eager, out_compiled)
|
|
assert torch.equal(lse_eager, lse_compiled)
|
|
|
|
grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value))
|
|
grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value))
|
|
|
|
torch.testing.assert_close(grads_eager, grads_compile)
|
|
|
|
@supported_platform
|
|
def test_causal_block_non_divisible_with_captured_buffer(self):
|
|
Q_S = S - 3
|
|
KV_S = S - 3
|
|
offset_q = torch.randn(Q_S, device="cuda", dtype=torch.bfloat16)
|
|
offset_kv = torch.randn(KV_S, device="cuda", dtype=torch.bfloat16)
|
|
|
|
def score_mod(score, b, h, q, kv):
|
|
return score + offset_q[q] + offset_kv[kv]
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, Q_S, KV_S)
|
|
# block_mask = None
|
|
attention = functools.partial(flex_attention, block_mask=block_mask)
|
|
|
|
self.run_test_with_call(attention, Q_S=Q_S, KV_S=KV_S)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
|
def test_qkv_and_block_mask_on_the_same_device(self):
|
|
make_tensor = functools.partial(
|
|
torch.ones,
|
|
(2, 2, 256, 32),
|
|
device="cuda:0",
|
|
dtype=torch.float32,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, 256, 256, device="cuda:1")
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Expect q/k/v and block_mask to be on the same device"
|
|
):
|
|
torch.compile(flex_attention)(query, key, value, block_mask=block_mask)
|
|
|
|
@supported_platform
|
|
def test_fw_bw_graph_correctness(self):
|
|
cnt = CompileCounterWithBackend("aot_eager")
|
|
make_tensor = functools.partial(
|
|
torch.randn,
|
|
(2, 2, 128, 4),
|
|
device="cuda",
|
|
dtype=torch.float64,
|
|
requires_grad=True,
|
|
)
|
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
block_mask = create_block_mask(causal_mask, 1, 1, 128, 128)
|
|
|
|
func = torch.compile(flex_attention, backend=cnt, fullgraph=True)
|
|
out = func(query, key, value, _squared, block_mask=block_mask)
|
|
out.sum().backward()
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(cnt.graphs), 1)
|
|
graph = cnt.graphs[0]
|
|
norm_graph = normalize_gm(graph.print_readable(print_output=False))
|
|
|
|
self.assertExpectedInline(
|
|
norm_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"):
|
|
l_query_ = L_query_
|
|
l_key_ = L_key_
|
|
l_value_ = L_value_
|
|
l_block_mask_kv_num_blocks = L_block_mask_kv_num_blocks
|
|
l_block_mask_kv_indices = L_block_mask_kv_indices
|
|
l_block_mask_full_kv_num_blocks = L_block_mask_full_kv_num_blocks
|
|
l_block_mask_full_kv_indices = L_block_mask_full_kv_indices
|
|
l_block_mask_q_num_blocks = L_block_mask_q_num_blocks
|
|
l_block_mask_q_indices = L_block_mask_q_indices
|
|
l_block_mask_full_q_num_blocks = L_block_mask_full_q_num_blocks
|
|
l_block_mask_full_q_indices = L_block_mask_full_q_indices
|
|
|
|
child_1: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_1 = None
|
|
child_2: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_2 = None
|
|
child_3: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_3 = None
|
|
child_4: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_4 = None
|
|
child: "f64[]" = l_query_.new_empty([], requires_grad = True); child = None
|
|
score_mod_0 = self.score_mod_0
|
|
child_5: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_5 = None
|
|
child_6: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_6 = None
|
|
child_7: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_7 = None
|
|
child_8: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_8 = None
|
|
mask_fn_0 = self.mask_fn_0
|
|
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
|
|
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
|
|
return (out,)
|
|
|
|
class score_mod_0(torch.nn.Module):
|
|
def forward(self, child: "f64[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]", child_4: "i32[]"):
|
|
mul: "f64[]" = child * child; child = None
|
|
return mul
|
|
|
|
class mask_fn_0(torch.nn.Module):
|
|
def forward(self, child_5: "i32[]", child_6: "i32[]", child_7: "i32[]", child_8: "i32[]"):
|
|
ge: "b8[]" = child_7 >= child_8; child_7 = child_8 = None
|
|
return ge
|
|
""", # noqa: B950
|
|
)
|
|
# Save the AOT graphs
|
|
aot_graphs = []
|
|
from torch._inductor import compile_fx
|
|
|
|
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
|
|
aot_graphs.append(graph)
|
|
return graph
|
|
|
|
backend = functools.partial(
|
|
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
|
|
)
|
|
func = torch.compile(func, backend=backend, fullgraph=True)
|
|
out = func(query, key, value, _squared)
|
|
out.sum().backward()
|
|
|
|
joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False))
|
|
|
|
self.assertExpectedInline(
|
|
joint_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
|
|
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
|
|
fw_graph = self.fw_graph
|
|
joint_graph = self.joint_graph
|
|
mask_graph = self.mask_graph
|
|
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph, joint_graph, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph), 0.5, {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph = joint_graph = full = full_default = convert_element_type = convert_element_type_1 = mask_graph = None
|
|
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
|
|
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
|
|
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
|
|
return (getitem_4, getitem_5, getitem_6)
|
|
|
|
class fw_graph(torch.nn.Module):
|
|
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
|
|
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return mul
|
|
|
|
class joint_graph(torch.nn.Module):
|
|
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
|
|
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None
|
|
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
|
|
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
|
|
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
|
|
return [add, None, None, None, None]
|
|
|
|
class mask_graph(torch.nn.Module):
|
|
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
|
|
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
|
|
return full
|
|
""", # noqa: B950
|
|
)
|
|
|
|
|
|
class TestBlockMask(InductorTestCase):
|
|
@supported_platform
|
|
def test_block_mask_attributes(self):
|
|
offset = torch.zeros(8, device="cuda")
|
|
|
|
def causal_mask(b, h, q, kv):
|
|
return (q + (offset[b] * 128)) >= kv
|
|
|
|
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048)
|
|
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
|
|
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
|
|
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
|
|
self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048)
|
|
self.assertEqual(block_mask.sparsity(), 46.875)
|
|
self.assertEqual(block_mask[0].sparsity(), 46.875)
|
|
self.assertEqual(block_mask[1, 0].sparsity(), 46.875)
|
|
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
|
|
|
|
offset = torch.arange(8, device="cuda")
|
|
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048)
|
|
self.assertEqual(block_mask.sparsity(), 29.1015625)
|
|
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
|
|
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
|
|
|
|
@supported_platform
|
|
def test_getitem(self):
|
|
offset = torch.zeros(8, device="cuda")
|
|
|
|
def causal_mask(b, h, q, kv):
|
|
return (q + (offset[b] * 128)) >= kv
|
|
|
|
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512)
|
|
assert block_mask.kv_num_blocks.shape == (4, 2, 4)
|
|
assert block_mask.kv_indices.shape == (4, 2, 4, 4)
|
|
|
|
# Index on batch dimension
|
|
new_block_mask = block_mask[0]
|
|
assert new_block_mask.kv_num_blocks.shape == (2, 4)
|
|
assert new_block_mask.kv_indices.shape == (2, 4, 4)
|
|
|
|
# Index on batch and head dimension
|
|
new_block_mask = block_mask[0, 1]
|
|
assert new_block_mask.kv_num_blocks.shape == (4,)
|
|
assert new_block_mask.kv_indices.shape == (4, 4)
|
|
|
|
# slicing on batch and head dimension
|
|
new_block_mask = block_mask[0:2, 1:2]
|
|
assert new_block_mask.kv_num_blocks.shape == (2, 1, 4)
|
|
assert new_block_mask.kv_indices.shape == (2, 1, 4, 4)
|
|
|
|
# slicing on batch, head, and query dimension
|
|
new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)]
|
|
assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
|
|
assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)
|
|
|
|
# slicing on batch, head, and query dimension
|
|
q_index = torch.tensor([0], dtype=torch.int32)
|
|
new_block_mask = block_mask[:, :, q_index]
|
|
|
|
self.assertEqual(new_block_mask.kv_num_blocks.ndim, 3)
|
|
self.assertEqual(new_block_mask.kv_indices.ndim, 4)
|
|
torch.testing.assert_close(
|
|
new_block_mask.kv_num_blocks,
|
|
block_mask.kv_num_blocks[:, :, q_index],
|
|
)
|
|
torch.testing.assert_close(
|
|
new_block_mask.kv_indices, block_mask.kv_indices[:, :, q_index, :]
|
|
)
|
|
|
|
if block_mask.full_kv_num_blocks is not None:
|
|
assert new_block_mask.full_kv_num_blocks is not None
|
|
assert new_block_mask.full_kv_indices is not None
|
|
torch.testing.assert_close(
|
|
new_block_mask.full_kv_num_blocks,
|
|
block_mask.full_kv_num_blocks[:, :, q_index],
|
|
)
|
|
torch.testing.assert_close(
|
|
new_block_mask.full_kv_indices,
|
|
block_mask.full_kv_indices[:, :, q_index, :],
|
|
)
|
|
|
|
@supported_platform
|
|
def test_block_mask_device_change(self):
|
|
offset = torch.zeros(8, device="cuda")
|
|
|
|
def causal_mask(b, h, q, kv):
|
|
return (q + (offset[b] * 128)) >= kv
|
|
|
|
block_mask = create_block_mask(causal_mask, 1, 1, 512, 512)
|
|
assert block_mask.kv_indices.is_cuda
|
|
assert block_mask.kv_num_blocks.is_cuda
|
|
assert block_mask.q_indices.is_cuda
|
|
assert block_mask.q_num_blocks.is_cuda
|
|
|
|
block_mask = block_mask.to("cpu")
|
|
assert block_mask.kv_indices.is_cpu
|
|
assert block_mask.kv_num_blocks.is_cpu
|
|
assert block_mask.q_indices.is_cpu
|
|
assert block_mask.q_num_blocks.is_cpu
|
|
|
|
block_mask = block_mask.to("cuda")
|
|
assert block_mask.kv_indices.is_cuda
|
|
assert block_mask.kv_num_blocks.is_cuda
|
|
assert block_mask.q_indices.is_cuda
|
|
assert block_mask.q_num_blocks.is_cuda
|
|
|
|
@supported_platform
|
|
def test_compiling_create_block_mask(self):
|
|
def mask_mod(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, 512, 512, _compile=True)
|
|
self.assertIsInstance(block_mask, BlockMask)
|
|
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4)))
|
|
self.assertEqual(block_mask.kv_indices.shape, torch.Size((1, 1, 4, 4)))
|
|
|
|
@supported_platform
|
|
def test_block_mask_viz(self):
|
|
def causal_mask(b, h, q, kv):
|
|
return q >= kv
|
|
|
|
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048)
|
|
|
|
def replace_non_printable(s):
|
|
def replace(c):
|
|
if c not in string.printable:
|
|
return "@"
|
|
elif c == " ":
|
|
return "s"
|
|
return c
|
|
|
|
return "".join(replace(c) for c in s)
|
|
|
|
self.assertExpectedInline(
|
|
replace_non_printable(str(block_mask)),
|
|
"""\
|
|
BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|
(0,s0)
|
|
@@ssssssssssssssssssssssssssssss
|
|
@@@@ssssssssssssssssssssssssssss
|
|
@@@@@@ssssssssssssssssssssssssss
|
|
@@@@@@@@ssssssssssssssssssssssss
|
|
@@@@@@@@@@ssssssssssssssssssssss
|
|
@@@@@@@@@@@@ssssssssssssssssssss
|
|
@@@@@@@@@@@@@@ssssssssssssssssss
|
|
@@@@@@@@@@@@@@@@ssssssssssssssss
|
|
@@@@@@@@@@@@@@@@@@ssssssssssssss
|
|
@@@@@@@@@@@@@@@@@@@@ssssssssssss
|
|
@@@@@@@@@@@@@@@@@@@@@@ssssssssss
|
|
@@@@@@@@@@@@@@@@@@@@@@@@ssssssss
|
|
@@@@@@@@@@@@@@@@@@@@@@@@@@ssssss
|
|
@@@@@@@@@@@@@@@@@@@@@@@@@@@@ssss
|
|
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ss
|
|
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
|
|
)""",
|
|
)
|
|
|
|
offset = torch.arange(8, device="cuda")
|
|
|
|
def causal_offset_mask(b, h, q, kv):
|
|
return (q + offset[b] * 128) >= kv
|
|
|
|
block_mask = create_block_mask(causal_offset_mask, 8, 1, 2048, 2048)
|
|
str_block_mask = str(block_mask)
|
|
self.assertTrue("sparsity=29.10" in str_block_mask)
|
|
|
|
def generate_test_inputs(self, full_seq_len: bool, device):
|
|
if full_seq_len:
|
|
kv_num_blocks = torch.tensor([1], dtype=torch.int32, device=device).view(
|
|
1, 1, 1
|
|
)
|
|
kv_indices = torch.tensor([1, -1], dtype=torch.int32, device=device).view(
|
|
1, 1, 1, 2
|
|
)
|
|
full_kv_num_blocks = torch.tensor(
|
|
[1], dtype=torch.int32, device=device
|
|
).view(1, 1, 1)
|
|
full_kv_indices = torch.tensor(
|
|
[0, -1], dtype=torch.int32, device=device
|
|
).view(1, 1, 1, 2)
|
|
else:
|
|
kv_num_blocks = torch.tensor([2], dtype=torch.int32, device=device).view(
|
|
1, 1, 1
|
|
)
|
|
kv_indices = torch.tensor([0, 1], dtype=torch.int32, device=device).view(
|
|
1, 1, 1, 2
|
|
)
|
|
full_kv_indices = None
|
|
full_kv_num_blocks = None
|
|
return kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("full_indices", [False, True])
|
|
def test_from_kv_blocks(self, full_indices: bool):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
(
|
|
kv_num_blocks,
|
|
kv_indices,
|
|
full_kv_num_blocks,
|
|
full_kv_indices,
|
|
) = self.generate_test_inputs(full_indices, device=device)
|
|
|
|
block_mask = BlockMask.from_kv_blocks(
|
|
kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices
|
|
)
|
|
|
|
self.assertIsInstance(block_mask, BlockMask)
|
|
torch.testing.assert_close(block_mask.kv_num_blocks, kv_num_blocks)
|
|
torch.testing.assert_close(block_mask.kv_indices, kv_indices)
|
|
|
|
if full_indices:
|
|
torch.testing.assert_close(
|
|
block_mask.full_kv_num_blocks, full_kv_num_blocks
|
|
)
|
|
torch.testing.assert_close(block_mask.full_kv_indices, full_kv_indices)
|
|
torch.testing.assert_close(
|
|
block_mask.q_num_blocks,
|
|
torch.tensor([0, 1], dtype=torch.int32, device=device).view(1, 1, 2),
|
|
)
|
|
torch.testing.assert_close(
|
|
block_mask.q_indices,
|
|
torch.tensor([0, 0], dtype=torch.int32, device=device).view(1, 1, 2, 1),
|
|
)
|
|
torch.testing.assert_close(
|
|
block_mask.full_q_num_blocks,
|
|
torch.tensor([1, 0], dtype=torch.int32, device=device).view(1, 1, 2),
|
|
)
|
|
torch.testing.assert_close(
|
|
block_mask.full_q_indices,
|
|
torch.tensor([0, 0], dtype=torch.int32, device=device).view(1, 1, 2, 1),
|
|
)
|
|
|
|
else:
|
|
torch.testing.assert_close(
|
|
block_mask.q_num_blocks,
|
|
torch.tensor([1, 1], dtype=torch.int32, device=device).view(1, 1, 2),
|
|
)
|
|
torch.testing.assert_close(
|
|
block_mask.q_indices,
|
|
torch.tensor([0, 0], dtype=torch.int32, device=device).view(1, 1, 2, 1),
|
|
)
|
|
self.assertIsNone(block_mask.full_kv_num_blocks)
|
|
self.assertIsNone(block_mask.full_kv_indices)
|
|
self.assertIsNone(block_mask.full_q_num_blocks)
|
|
self.assertIsNone(block_mask.full_q_indices)
|
|
|
|
@supported_platform
|
|
def test_block_size(self):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
kv_num_blocks, kv_indices, _, _ = self.generate_test_inputs(False, device)
|
|
block_mask = BlockMask.from_kv_blocks(kv_num_blocks, kv_indices)
|
|
self.assertEqual(
|
|
block_mask.BLOCK_SIZE,
|
|
(_DEFAULT_SPARSE_BLOCK_SIZE, _DEFAULT_SPARSE_BLOCK_SIZE),
|
|
)
|
|
|
|
custom_block_size = (64, 64)
|
|
block_mask_custom = BlockMask.from_kv_blocks(
|
|
kv_num_blocks, kv_indices, BLOCK_SIZE=custom_block_size
|
|
)
|
|
self.assertEqual(block_mask_custom.BLOCK_SIZE, custom_block_size)
|
|
|
|
@supported_platform
|
|
def test_init_mismatched_full_kv(self):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
kv_num_blocks, kv_indices, full_kv_num_blocks, _ = self.generate_test_inputs(
|
|
True, device
|
|
)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
BlockMask(
|
|
kv_num_blocks=kv_num_blocks,
|
|
kv_indices=kv_indices,
|
|
full_kv_num_blocks=full_kv_num_blocks,
|
|
full_kv_indices=None, # Mismatched, should raise error
|
|
q_num_blocks=kv_num_blocks,
|
|
q_indices=kv_indices,
|
|
full_q_num_blocks=None,
|
|
full_q_indices=None,
|
|
BLOCK_SIZE=(64, 64),
|
|
mask_mod=noop_mask,
|
|
)
|
|
|
|
@supported_platform
|
|
def test_init_mismatched_full_q(self):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
kv_num_blocks, kv_indices, _, _ = self.generate_test_inputs(False, device)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
BlockMask(
|
|
kv_num_blocks=kv_num_blocks,
|
|
kv_indices=kv_indices,
|
|
full_kv_num_blocks=None,
|
|
full_kv_indices=None,
|
|
q_num_blocks=kv_num_blocks,
|
|
q_indices=kv_indices,
|
|
full_q_num_blocks=kv_num_blocks,
|
|
full_q_indices=None, # Mismatched, should raise error
|
|
BLOCK_SIZE=(64, 64),
|
|
mask_mod=noop_mask,
|
|
)
|
|
|
|
@supported_platform
|
|
@common_utils.parametrize("compile", [False, True])
|
|
def test_no_q_info(self, compile: bool):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048)
|
|
# manually set q_num_blocks and q_indices to None
|
|
block_mask.q_num_blocks = None
|
|
block_mask.q_indices = None
|
|
block_mask.full_q_num_blocks = None
|
|
block_mask.full_q_indices = None
|
|
|
|
mask_mod_sparse_flex = functools.partial(flex_attention, block_mask=block_mask)
|
|
if compile:
|
|
mask_mod_sparse_flex = torch.compile(
|
|
mask_mod_sparse_flex, backend="inductor"
|
|
)
|
|
inputs = [
|
|
torch.randn(
|
|
2,
|
|
2,
|
|
2048,
|
|
64,
|
|
device="cuda",
|
|
dtype=torch.float16,
|
|
requires_grad=True,
|
|
)
|
|
for _ in range(3)
|
|
]
|
|
|
|
causal_mask_out = mask_mod_sparse_flex(*inputs)
|
|
sdpa_mask_out = torch.nn.functional.scaled_dot_product_attention(
|
|
*inputs, is_causal=True
|
|
)
|
|
|
|
torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0)
|
|
|
|
|
|
common_utils.instantiate_parametrized_tests(TestFlexAttention)
|
|
common_utils.instantiate_parametrized_tests(TestBlockMask)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|