# Owner(s): ["module: inductor"] # flake8: noqa: B950 import functools import random import string import unittest import warnings from collections import namedtuple from collections.abc import Callable from contextlib import contextmanager from dataclasses import dataclass from itertools import product from typing import Optional, TypeVar, Union from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch import torch import torch.nn as nn from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm from torch._inductor import config, metrics from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import run_and_get_code from torch.nn.attention import SDPBackend from torch.nn.attention.experimental._paged_attention import PagedAttention from torch.nn.attention.flex_attention import ( _create_empty_block_mask, _DEFAULT_SPARSE_BLOCK_SIZE, _identity, _mask_mod_signature, _score_mod_signature, _WARNINGS_SHOWN, and_masks, AuxOutput, AuxRequest, BlockMask, create_block_mask, flex_attention, flex_attention_hop, noop_mask, or_masks, ) from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, dtypesIfXPU, flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, largeTensorTest, skipCPUIf, skipCUDAIf, skipXPUIf, ) from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils._triton import has_triton, has_triton_tma_device # Use this decorator only when hitting Triton bugs on H100 running_on_a100_only = skipUnless( ( (torch.cuda.is_available() and has_triton()) and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip) ) or (torch.xpu.is_available() and has_triton()), "Requires Triton + A100 or Triton + ROCm or Triton + Intel GPU", ) Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) torch.set_float32_matmul_precision("high") index = torch.ops.aten.index Tensor = torch.Tensor T = TypeVar("T") M = TypeVar("M", bound=Callable) def large_tensor_test_class( size: str, device: Optional[Union[torch.device, str]] = None ) -> Callable[[type[T]], type[T]]: def decorator(cls: type[T]) -> type[T]: for name, method in list(cls.__dict__.items()): if callable(method) and name.startswith("test_"): setattr(cls, name, largeTensorTest(size, device)(method)) return cls return decorator @contextmanager def temp_float32_matmul_precision(precision: str): """ Temporarily set the float32 matmul precision and restore it after the context is exited. Args: precision (str): The precision to set ('highest', 'high', or 'medium'). """ def set_float32_matmul_precision_xpu(precision: str): if precision == "highest": torch._C._set_onednn_allow_tf32(False) if precision == "high": torch._C._set_onednn_allow_tf32(True) original_precision = torch.get_float32_matmul_precision() try: torch.set_float32_matmul_precision(precision) if TEST_ON_XPU: set_float32_matmul_precision_xpu(precision) yield finally: torch.set_float32_matmul_precision(original_precision) if TEST_ON_XPU: set_float32_matmul_precision_xpu(original_precision) def skip_on_cpu(test_func): """Decorator to skip tests that are not supported on CPU.""" decorated_func = skipCPUIf(True, "Not supported on CPU")(test_func) return decorated_func def skip_on_cuda(test_func): """Decorator to skip tests that are not supported on CUDA.""" decorated_func = skipCUDAIf(True, "Not supported on CUDA")(test_func) return decorated_func def skip_on_rocm(test_func): """Decorator to skip tests that are not supported on CUDA.""" IS_ROCM = torch.cuda.is_available() and torch.version.hip decorated_func = skipCUDAIf(IS_ROCM, "Not supported on ROCM")(test_func) return decorated_func def skip_on_xpu(test_func): """Decorator to skip tests that are not supported on Intel GPU.""" decorated_func = skipXPUIf(True, "Not supported on Intel GPU")(test_func) return decorated_func def rmse(ref, res): """ Calculate root mean squared error """ ref = ref.to(torch.float64) res = res.to(torch.float64) return torch.sqrt(torch.mean(torch.square(ref - res))) def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=None): return functools.partial( flex_attention, score_mod=score_mod, block_mask=block_mask, enable_gqa=enable_gqa, kernel_options=kernel_options, ) def create_block_mask_test(score_mod, query, key): block_mask = create_block_mask( score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device, ) return block_mask @dataclass class DeviceConfig: dtypes: list[torch.dtype] dtypes_fast: list[torch.dtype] TEST_ON_CUDA = ( torch.cuda.is_available() and torch.utils._triton.has_triton() and torch.cuda.get_device_capability() >= (8, 0) ) TEST_ON_XPU = torch.xpu.is_available() and torch.utils._triton.has_triton() device_configs = {} if HAS_GPU: if TEST_ON_CUDA: test_device = ( "cuda", "cpu", ) elif TEST_ON_XPU: torch._C._set_onednn_allow_tf32(True) test_device = ("xpu",) else: test_device = ("cpu",) class SubstringSet: def __init__(self, items): self.items = set(items) def __contains__(self, item): if "cuda" in item: item = "cuda" if "xpu" in item: item = "xpu" return item in self.items DEVICE_SUPPORTS_BACKWARDS = SubstringSet( [ "cuda", ] ) device_configs["cuda"] = DeviceConfig( dtypes=( [torch.float32, torch.bfloat16, torch.float16] if PLATFORM_SUPPORTS_BF16 else [torch.float16, torch.float32] ), dtypes_fast=[torch.float16], ) device_configs["xpu"] = DeviceConfig( dtypes=([torch.float32, torch.bfloat16, torch.float16]), dtypes_fast=[torch.float16], ) device_configs["cpu"] = DeviceConfig( dtypes=( [torch.float32, torch.bfloat16, torch.float16] if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported() else [torch.float32] ), dtypes_fast=[torch.float32], ) torch_config_string = torch.__config__.show() LONG_COMPILATION_ON_CPU = False if "CLANG" in torch_config_string.upper(): # if the compiler is clang, skip UT for CPU due to long compilation time found in CI # TODO: check reason of long compile time LONG_COMPILATION_ON_CPU = True # --------- Useful score mod functions for testing --------- def _causal( score: Tensor, batch: Tensor, head: Tensor, token_q: Tensor, token_kv: Tensor, ) -> Tensor: return torch.where(token_q >= token_kv, score, float("-inf")) def _rel_bias( score: Tensor, batch: Tensor, head: Tensor, token_q: Tensor, token_kv: Tensor, ) -> Tensor: return score + (token_q - token_kv) def _rel_causal( score: Tensor, batch: Tensor, head: Tensor, token_q: Tensor, token_kv: Tensor, ) -> Tensor: return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf")) def _generate_alibi_bias(num_heads: int): def _alibi_bias( score: Tensor, batch: Tensor, head: Tensor, token_q: Tensor, token_kv: Tensor, ) -> Tensor: scale = torch.exp2(-((head + 1) * 8.0 / num_heads)) return score + (token_kv - token_q) * scale return _alibi_bias def _inverse_causal(score, b, h, m, n): return torch.where(m <= n, score, float("-inf")) def _times_two(score, b, h, m, n): """Joint graph needed for correctness""" return score * 2 def _squared(score, b, h, m, n): """Joint graph needed for correctness""" return score * score def _head_offset(dtype: torch.dtype, device: str): """Captured Buffer""" head_offset = torch.rand(H, device=device, dtype=dtype) def score_mod(score, b, h, m, n): return score * head_offset[h] return score_mod def _trig(score, b, h, m, n): """Joint graph needed for correctness""" return torch.sin(torch.cos(score)) + torch.tan(b) def _trig2(score, b, h, m, n): """Branching joint graph""" cos_score = torch.cos(score) sin_score = torch.sin(score) z = cos_score * sin_score + torch.tan(b) return z # --------- Useful mask mod functions for testing --------- def _causal_mask( batch: Tensor, head: Tensor, token_q: Tensor, token_kv: Tensor, ) -> Tensor: return token_q >= token_kv def _inverse_causal_mask( batch: Tensor, head: Tensor, token_q: Tensor, token_kv: Tensor, ) -> Tensor: return token_q <= token_kv test_score_mods = [ _identity, _times_two, _squared, _causal, _inverse_causal, _rel_bias, _rel_causal, _generate_alibi_bias(8), ] test_score_mask_mod_map = { _identity: noop_mask, _times_two: noop_mask, _squared: noop_mask, _causal: _causal_mask, _inverse_causal: _inverse_causal_mask, _rel_bias: noop_mask, _rel_causal: _causal_mask, _generate_alibi_bias(8): noop_mask, } captured_buffers_map = { "_head_offset": _head_offset, } B = 2 H = 4 S = 256 D = 64 test_Hq_Hkv = [ (4, 2), (4, 1), ] test_Bq_Bkv = [ (3, 1), (4, 1), (5, 1), ] test_block_size = [ 128, 256, (128, 256), (256, 128), ] test_strides = [ ((H * S * D, S * D, D, 1), 997), # offset ((H * D, D, B * H * D, 1), 499), # transposed dimensions ((H * S * D, D, H * D, 1), 0), # heads/sequence transposed ( (S * (D + 1), B * S * (D + 1), (D + 1), 1), 293, ), # additional buffer on one dim ( (1, D, (B + 1) * (H + 1) * D, 1), 97, ), # additional buffer on multiple dim + shared dimension ] def query_key_value_clones( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: Optional[torch.dtype] = None, ): """Clones the query, key, and value tensors and moves them to the specified dtype.""" if dtype is None: dtype = query.dtype query_ref = query.detach().clone().to(dtype).requires_grad_(query.requires_grad) key_ref = key.detach().clone().to(dtype).requires_grad_(key.requires_grad) value_ref = value.detach().clone().to(dtype).requires_grad_(value.requires_grad) return query_ref, key_ref, value_ref def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): (B,) = target_seq_len.shape for b in range(B): paged_attention.reserve( torch.tensor(b), target_seq_len[b], ) @large_tensor_test_class("2GB", device=test_device[0]) class TestFlexAttention(InductorTestCase): def setUp(self): super().setUp() skipCPUIf( LONG_COMPILATION_ON_CPU, "skip UT for CPU due to long compilation time found in CI", ) def _check_equal( self, golden_out: torch.Tensor, ref_out: torch.Tensor, compiled_out: torch.Tensor, fudge_factor: float, tensor_name: Optional[str] = None, fudge_atol: float = 0, ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any(): self.fail("Output/Grad with NaN") name = tensor_name if tensor_name is not None else "" msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." torch.testing.assert_close( compiled_error, ref_error, rtol=fudge_factor, atol=1e-7, msg=msg ) def _check_out( self, golden_out: torch.Tensor, ref_out: torch.Tensor, compiled_out: torch.Tensor, is_paged_attention: bool = False, ): dtype = ref_out.dtype with torch.no_grad(): # Note, it seems like we really are less accurate than the float32 # computation, likely due to the online softmax if dtype == torch.float32: fudge_factor = 10.0 if is_paged_attention: # paged attention is less accurate since it may reorder # the blocks from block mask fudge_factor = 20.0 else: fudge_factor = 1.1 # Checkout output self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") def _check_out_and_grad( self, golden_out: torch.Tensor, ref_out: torch.Tensor, compiled_out: torch.Tensor, q_gold: torch.Tensor, q_ref: torch.Tensor, q: torch.Tensor, k_gold: torch.Tensor, k_ref: torch.Tensor, k: torch.Tensor, v_gold: torch.Tensor, v_ref: torch.Tensor, v: torch.Tensor, ): dtype = ref_out.dtype with torch.no_grad(): # Note, it seems like we really are less accurate than the float32 # computation, likely due to the online softmax if dtype == torch.float32: fudge_factor = 10.0 else: fudge_factor = 1.1 # Checkout output self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") # Check gradients q_fudge_factor = 1.0 * fudge_factor self._check_equal( q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" ) k_fudge_factor = 1.0 * fudge_factor self._check_equal( k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" ) v_fudge_factor = 1.0 * fudge_factor self._check_equal( v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" ) def run_test( self, score_mod: _score_mod_signature, dtype: torch.dtype, device: str, Q_B: int = B, Q_H: int = H, Q_S: int = S, Q_D: int = D, KV_B: Optional[int] = None, KV_H: Optional[int] = None, KV_S: Optional[int] = None, V_D: Optional[int] = None, block_mask: Optional[BlockMask] = None, ): requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if KV_B is None: KV_B = Q_B if KV_H is None: KV_H = Q_H if KV_S is None: KV_S = Q_S if V_D is None: V_D = Q_D if device == "cpu" and dtype is torch.float16: dtype = torch.float32 requires_grad = device in DEVICE_SUPPORTS_BACKWARDS q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device=device, requires_grad=requires_grad, ) k = torch.randn( (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device=device, requires_grad=requires_grad, ) v = torch.randn( (KV_B, KV_H, KV_S, V_D), dtype=dtype, device=device, requires_grad=requires_grad, ) if block_mask is None: block_mask = create_block_mask( noop_mask, Q_B, Q_H, Q_S, KV_S, device=device ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) compiled_sdpa = torch.compile(sdpa_partial) golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) assert isinstance(golden_out, torch.Tensor) assert isinstance(ref_out, torch.Tensor) assert isinstance(compiled_out, torch.Tensor) if not requires_grad: self._check_out( golden_out, ref_out, compiled_out, is_paged_attention=False, ) else: backward_grad = torch.randn( (Q_B, Q_H, Q_S, V_D), dtype=dtype, device=device ) golden_out.backward(backward_grad.to(torch.float64)) ref_out.backward(backward_grad) compiled_out.backward(backward_grad) self._check_out_and_grad( golden_out, ref_out, compiled_out, q_gold, q_ref, q, k_gold, k_ref, k, v_gold, v_ref, v, ) def preprocess_paged_attention( self, score_mod: Optional[Callable], q: Tensor, k: Tensor, v: Tensor, block_mask, dtype: torch.dtype, device: str, page_size: int = 128, ) -> tuple[Tensor, Tensor, BlockMask, _score_mod_signature]: assert block_mask is not None, "Must provide block_mask" Q_B, Q_H, Q_S, _ = q.shape KV_B, KV_H, KV_S, QK_D = k.shape _, _, _, V_D = v.shape # test with different batch size max_batch_size = max(Q_B, KV_B) + 3 n_pages = (KV_S + page_size - 1) // page_size * max_batch_size # allocate cache MAX_CACHED_SEQ_LEN = n_pages * page_size k_cache = torch.zeros( 1, KV_H, MAX_CACHED_SEQ_LEN, QK_D, device=device, dtype=dtype, ) v_cache = torch.zeros( 1, KV_H, MAX_CACHED_SEQ_LEN, V_D, device=device, dtype=dtype, ) # For testing purposes, we randomly initialize the page table, which maps # (batch_idx, logical_block_idx) to physical_block_idx. Specifically, PagedAttention # maintains a stack empty_pages of unused physical_block_idx. The `batch_reserve` # function grabs physical_block_idx from the top of empty_pages until there are enough # pages for each batch index (i.e., num pages for batch_idx >= target_seq_len[batch_idx]). # For example, at the first batch_reserve call, physical block indices (1,...,KV_S//4) # are allocated to batch index 0, and physical block indices # (KV_S//4+1, ..., KV_S//4 + KV_S//2) are allocated to batch index 1, etc. # Thus, kv tensors of batch index 1 will be scattered in the kv cache, simulating # a real use case of paged attention. paged_attention = PagedAttention( n_pages, page_size, max_batch_size, device=device ) batch_reserve( paged_attention, torch.tensor([KV_S // 4, KV_S // 2, KV_S // 4, KV_S // 3], device=device), ) batch_reserve( paged_attention, torch.tensor([KV_S // 4, KV_S // 2, KV_S // 2, KV_S // 2], device=device), ) batch_reserve( paged_attention, torch.tensor([KV_S // 2, KV_S, KV_S // 2, KV_S], device=device), ) batch_reserve( paged_attention, torch.tensor([KV_S, KV_S, KV_S, KV_S], device=device) ) # update cache with k and v input_pos = ( torch.arange(KV_S, device=device, dtype=torch.int32) .unsqueeze(0) .expand(KV_B, KV_S) ) batch_idx = torch.arange(KV_B, device=device, dtype=torch.int32) paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache) # convert block mask and score mod kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64) converted_block_mask = paged_attention.convert_logical_block_mask( block_mask, kv_len=kv_len_tensor ) converted_score_mod = paged_attention.get_score_mod( score_mod, kv_len=kv_len_tensor ) return k_cache, v_cache, converted_block_mask, converted_score_mod def run_paged_attention( self, score_mod: Optional[Callable], q: Tensor, k: Tensor, v: Tensor, dtype: torch.dtype, device: str, block_mask: Optional[BlockMask] = None, kernel_options: Optional[dict] = None, ) -> tuple[Tensor, Tensor]: B, Q_H, Q_S, KV_H, KV_S = ( q.shape[0], q.shape[1], q.shape[2], k.shape[1], k.shape[2], ) if block_mask is None: block_mask = create_block_mask(noop_mask, B, 1, Q_S, KV_S, device=device) ( k_cache, v_cache, converted_block_mask, converted_score_mod, ) = self.preprocess_paged_attention( score_mod, q, k, v, block_mask, dtype, device, block_mask.BLOCK_SIZE[1] ) compiled_sdpa = torch.compile(flex_attention) # compute return_lse = True requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if requires_grad: compiled_out, compiled_lse = compiled_sdpa( q, k_cache, v_cache, return_lse=return_lse, block_mask=converted_block_mask, score_mod=converted_score_mod, enable_gqa=(not Q_H == KV_H), kernel_options=kernel_options, ) else: return_lse = False compiled_lse = None compiled_out = compiled_sdpa( q, k_cache, v_cache, return_lse=return_lse, block_mask=converted_block_mask, score_mod=converted_score_mod, enable_gqa=(not Q_H == KV_H), kernel_options=kernel_options, ) return compiled_out, compiled_lse def run_test_with_paged_attention( self, score_mod: Optional[Callable], dtype: torch.dtype, device, Q_B: int = B, Q_H: int = H, Q_S: int = S, QK_D: int = D, KV_B: int = B, KV_H: int = H, KV_S: int = S, V_D: int = D, block_mask: Optional[BlockMask] = None, ): assert Q_H % KV_H == 0 if device == "cpu" and dtype is torch.float16: dtype = torch.float32 q = torch.randn( (Q_B, Q_H, Q_S, QK_D), dtype=dtype, device=device, requires_grad=False ) k = torch.randn( (KV_B, KV_H, KV_S, QK_D), dtype=dtype, device=device, requires_grad=False, ) v = torch.randn( (KV_B, KV_H, KV_S, V_D), dtype=dtype, device=device, requires_grad=False, ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) if block_mask is None: block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device) sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True) ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True) compiled_out, compiled_lse = self.run_paged_attention( score_mod, q, k, v, dtype, device, block_mask ) self._check_out( golden_out, ref_out, compiled_out, is_paged_attention=True, ) requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if requires_grad: self._check_out( golden_lse, ref_lse, compiled_lse, is_paged_attention=True, ) def run_test_with_call( self, sdpa_call: Callable, dtype: torch.dtype, device: str, Q_B: int = B, Q_H: int = H, Q_S: int = S, Q_D: int = D, KV_B: int = B, KV_H: int = H, KV_S: int = S, V_D: int = D, ): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 requires_grad = device in DEVICE_SUPPORTS_BACKWARDS q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device=device, requires_grad=requires_grad, ) k = torch.randn( (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device=device, requires_grad=requires_grad, ) v = torch.randn( (KV_B, KV_H, KV_S, V_D), dtype=dtype, device=device, requires_grad=requires_grad, ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) compiled_sdpa = torch.compile(sdpa_call) golden_out = sdpa_call(q_gold, k_gold, v_gold) ref_out = sdpa_call(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) if not requires_grad: self._check_out( golden_out, ref_out, compiled_out, is_paged_attention=False, ) else: backward_grad = torch.randn( (Q_B, Q_H, Q_S, V_D), dtype=dtype, device=device ) golden_out.backward(backward_grad.to(torch.float64)) ref_out.backward(backward_grad) compiled_out.backward(backward_grad) self._check_out_and_grad( golden_out, ref_out, compiled_out, q_gold, q_ref, q, k_gold, k_ref, k, v_gold, v_ref, v, ) def run_dynamic_test( self, score_mask_mod: tuple[Callable, Callable], dtype: torch.dtype, device, B: int = B, H: int = H, S: int = S, D: int = D, ): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 score_mod, mask_mod = score_mask_mod # First batch with original dimensions (B, H, S, D) block_mask1 = create_block_mask(mask_mod, 1, 1, S, S, device=device) sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) requires_grad = device in DEVICE_SUPPORTS_BACKWARDS q1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) k1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) v1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) ref_out1 = sdpa_partial1(q1_ref, k1_ref, v1_ref) golden_out1 = sdpa_partial1(q1_gold, k1_gold, v1_gold) if requires_grad: backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device=device) golden_out1.backward(backward_grad1.to(torch.float64)) ref_out1.backward(backward_grad1) # Second batch with modified dimensions (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) block_mask2 = create_block_mask(mask_mod, 1, 1, S, S, device=device) sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2) q2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) k2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) v2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) ref_out2 = sdpa_partial2(q2_ref, k2_ref, v2_ref) golden_out2 = sdpa_partial2(q2_gold, k2_gold, v2_gold) if requires_grad: backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device=device) golden_out2.backward(backward_grad2.to(torch.float64)) ref_out2.backward(backward_grad2) # Third batch with modified dimensions (B * 2, H, S / 4, D) S = int(S / 2) block_mask3 = create_block_mask(mask_mod, 1, 1, S, S, device=device) sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3) q3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) k3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) v3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) q3_ref, k3_ref, v3_ref = query_key_value_clones(q3, k3, v3) q3_gold, k3_gold, v3_gold = query_key_value_clones(q3, k3, v3, torch.float64) ref_out3 = sdpa_partial3(q3_ref, k3_ref, v3_ref) golden_out3 = sdpa_partial3(q3_gold, k3_gold, v3_gold) if requires_grad: backward_grad3 = torch.randn((B, H, S, D), dtype=dtype, device=device) golden_out3.backward(backward_grad3.to(torch.float64)) ref_out3.backward(backward_grad3) # Clear dynamo counters torch._dynamo.reset() # First compilation with original dimensions backend = torch._dynamo.testing.CompileCounterWithBackend("inductor") compiled_sdpa1 = torch.compile(sdpa_partial1, backend=backend, dynamic=True) compiled_out1 = compiled_sdpa1(q1, k1, v1) if requires_grad: compiled_out1.backward(backward_grad1) self._check_out_and_grad( golden_out1, ref_out1, compiled_out1, q1_gold, q1_ref, q1, k1_gold, k1_ref, k1, v1_gold, v1_ref, v1, ) else: self._check_out(golden_out1, ref_out1, compiled_out1) self.assertEqual(backend.frame_count, 1) # Second compilation with new dimensions compiled_sdpa2 = torch.compile(sdpa_partial2, backend=backend, dynamic=True) compiled_out2 = compiled_sdpa2(q2, k2, v2) if requires_grad: compiled_out2.backward(backward_grad2) self._check_out_and_grad( golden_out2, ref_out2, compiled_out2, q2_gold, q2_ref, q2, k2_gold, k2_ref, k2, v2_gold, v2_ref, v2, ) else: self._check_out(golden_out2, ref_out2, compiled_out2) self.assertEqual(backend.frame_count, 1) # Third compilation with new dimensions compiled_sdpa3 = torch.compile(sdpa_partial3, backend=backend, dynamic=True) compiled_out3 = compiled_sdpa3(q3, k3, v3) if requires_grad: compiled_out3.backward(backward_grad3) self._check_out_and_grad( golden_out3, ref_out3, compiled_out3, q3_gold, q3_ref, q3, k3_gold, k3_ref, k3, v3_gold, v3_ref, v3, ) else: self._check_out(golden_out3, ref_out3, compiled_out3) self.assertEqual(backend.frame_count, 1) def run_automatic_dynamic_test( self, score_mod: Callable, dtype: torch.dtype, device: str, B: int = B, H: int = H, S: int = S, D: int = D, ): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 block_mask1 = create_block_mask(noop_mask, 1, 1, S, S, device=device) sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) # The first eager batch, shape (B, H, S, D) requires_grad = device in DEVICE_SUPPORTS_BACKWARDS q1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) k1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) v1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) golden_out1 = sdpa_partial1( q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) ) ref_out1 = sdpa_partial1(q1, k1, v1) # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) block_mask2 = create_block_mask(noop_mask, 1, 1, S, S, device=device) sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2) q2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) k2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) v2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) golden_out2 = sdpa_partial2( q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) ) ref_out2 = sdpa_partial2(q2, k2, v2) # The third eager batch, shape (B * 4, H, S / 4, D) B = int(B * 2) S = int(S / 2) block_mask3 = create_block_mask(noop_mask, 1, 1, S, S, device=device) sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3) q3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) k3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) v3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, requires_grad=requires_grad, ) golden_out3 = sdpa_partial3( q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) ) ref_out3 = sdpa_partial3(q3, k3, v3) # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. # We check dynamo counters["frames"]["ok"] to ensure: # 1, the first batch is compiled with static shape # 2, the second batch is compiled with dynamic shape # 3, no re-compilation in the third batch torch._dynamo.reset() # Note, it seems like we really are less accurate than the float32 # computation, likely due to the online softmax if dtype == torch.float32: fudge_factor = 10.0 else: fudge_factor = 1.1 # The first batch. backend = torch._dynamo.testing.CompileCounterWithBackend("inductor") compiled_out1 = torch.compile(sdpa_partial1, backend=backend, fullgraph=True)( q1, k1, v1 ) self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(backend.frame_count, 1) # The second batch (automatic dynamic). compiled_out2 = torch.compile(sdpa_partial2, backend=backend, fullgraph=True)( q2, k2, v2 ) self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(backend.frame_count, 2) # The third batch (no re-compilation). compiled_out3 = torch.compile(sdpa_partial3, backend=backend, fullgraph=True)( q3, k3, v3 ) self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(backend.frame_count, 2) @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods(self, device, dtype, score_mod: Callable): self.run_test(score_mod, dtype, device=device) self.run_test_with_paged_attention(score_mod, dtype, device=device) @running_on_a100_only @common_utils.parametrize("score_mod", test_score_mods) @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( self, device, dtype, score_mod: Callable ): # _DEFAULT_SPARSE_BLOCK_SIZE is 128 attention = functools.partial( flex_attention, score_mod=score_mod, kernel_options={"FORCE_USE_FLEX_ATTENTION": True}, ) self.run_test_with_call(attention, dtype, device, B, H, 64, D, B, H, 64, D) @running_on_a100_only @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size( self, device, dtype: torch.dtype, score_mod: Callable ): def causal_mask(b, h, q, kv): return q >= kv block_mask = create_block_mask( causal_mask, 1, 1, 64, 64, BLOCK_SIZE=256, device=device ) attention = functools.partial( flex_attention, score_mod=score_mod, block_mask=block_mask, kernel_options={"FORCE_USE_FLEX_ATTENTION": True}, ) self.run_test_with_call( attention, dtype, device, B, H, 64, D, B, H, 64, D, ) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items()) def test_builtin_score_mods_dynamic( self, device, dtype: torch.dtype, score_mask_mod: tuple[Callable, Callable] ): self.run_dynamic_test(score_mask_mod, dtype, S=1024, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_automatic_dynamic( self, device, dtype: torch.dtype, score_mod: Callable ): self.run_automatic_dynamic_test(score_mod, dtype, S=1024, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_different_seqlen( self, device, dtype: torch.dtype, score_mod: Callable ): inputs = ( score_mod, dtype, device, B, H, S // 2, # Seqlen of Q is different from seqlen of K/V D, B, H, S, D, ) self.run_test(*inputs) self.run_test_with_paged_attention(*inputs) @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) @common_utils.parametrize("score_mod", test_score_mods) @common_utils.parametrize("BLOCK_SIZE", test_block_size) def test_builtin_score_mods_different_block_size( self, device, dtype: torch.dtype, score_mod: Callable, BLOCK_SIZE: Union[int, tuple[int, int]], ): block_mask = create_block_mask( noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE, device=device ) self.run_test(score_mod, dtype, block_mask=block_mask, device=device) self.run_test_with_paged_attention( score_mod, dtype, block_mask=block_mask, device=device ) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @common_utils.parametrize("score_mod", test_score_mods) def test_kv_batch_broadcast( self, device, dtype: torch.dtype, batch_dims: tuple[int, int], head_dims: tuple[int, int], score_mod: Callable, ): Hq, Hkv = head_dims assert Hq % Hkv == 0 Bq, Bkv = batch_dims assert Bq > 1 and Bkv == 1 block_mask = create_block_mask(noop_mask, Bq, 1, S, S, device=device) self.run_test( score_mod, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D, block_mask ) @supported_platform @skip_on_cpu def test_small_block_mask(self, device): compiled_create_block_mask = torch.compile(create_block_mask) def create_block_mask_from_seqlens( q_batch: torch.Tensor, kv_batch: torch.Tensor, ) -> BlockMask: B, H = None, None Q_LEN = q_batch.size(0) KV_LEN = kv_batch.size(0) def batch_mask_mod( b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, ): q_idx_batch = q_batch[q_idx] kv_idx_batch = kv_batch[kv_idx] batch_mask = ( (q_idx_batch == kv_idx_batch) & (q_idx_batch != -1) & (kv_idx_batch != -1) ) return batch_mask return compiled_create_block_mask( batch_mask_mod, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, ) a = torch.tensor([2, 42, 18, 21, 4, 2, 7, 1, 1], device=device) b = torch.tensor([57, 21, 16, 8], device=device) for seqlen in [a, b]: create_block_mask_from_seqlens(seqlen, seqlen) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @common_utils.parametrize("score_mod", test_score_mods) def test_kv_batch_broadcast_causal_mask( self, device, dtype: torch.dtype, batch_dims: tuple[int, int], head_dims: tuple[int, int], score_mod: Callable, ): Hq, Hkv = head_dims assert Hq % Hkv == 0 Bq, Bkv = batch_dims assert Bq > 1 and Bkv == 1 def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device) attention = functools.partial( flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv) ) self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) @skip_on_rocm # TODO: NaNs on ROCM @skip_on_xpu # TODO: NaNs on XPU like ROCM, need another PR to fix. def test_GQA(self, device, dtype: torch.dtype, score_mod: Callable): inputs = ( score_mod, dtype, device, B, H * 4, # Hq = 4*Hkv. S // 8, D, B, H, S, D, ) self.run_test(*inputs) self.run_test_with_paged_attention(*inputs) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize( "q_s", test_strides[:2] ) # TODO: fix layout for query braodcasting @common_utils.parametrize( "k_s,v_s", [ (test_strides[0], test_strides[0]), (test_strides[0], test_strides[1]), (test_strides[2], test_strides[3]), (test_strides[3], test_strides[1]), # (test_strides[2], test_strides[4]), # TODO: Doesn't work for # broadcasting reasons i think ], ) @common_utils.parametrize("do_s", test_strides[:3]) def test_strided_inputs(self, device, dtype: torch.dtype, q_s, k_s, v_s, do_s): q1 = torch.randn((B * H * S * D * 2), dtype=dtype, device=device) k1 = torch.randn((B * H * S * D * 2), dtype=dtype, device=device) v1 = torch.randn((B * H * S * D * 2), dtype=dtype, device=device) do1 = torch.randn((B * H * S * D * 2), dtype=dtype, device=device) q_shape = (B, H, S // 2, D) k_shape = (B, H, S, D) v_shape = (B, H, S, D) do_shape = (B, H, S // 2, D) requires_grad = device in DEVICE_SUPPORTS_BACKWARDS def coerce_to_strides(val, shape, strides): strides, offset = strides val_max = [x * (y - 1) for x, y in zip(strides, shape)] assert sum(val_max) + offset < B * H * S * D * 2 assert strides[-1] == 1 return torch.as_strided(val, shape, strides, offset).requires_grad_( requires_grad ) q = coerce_to_strides(q1, q_shape, q_s) k = coerce_to_strides(k1, k_shape, k_s) v = coerce_to_strides(v1, v_shape, v_s) do = coerce_to_strides(do1, do_shape, do_s) kernel_options = {"USE_TMA": True} block_mask = _create_empty_block_mask(q, k) score_mod = _generate_alibi_bias(8) sdpa_partial = create_attention( score_mod=score_mod, block_mask=block_mask, kernel_options=kernel_options ) compiled_sdpa = torch.compile(sdpa_partial, fullgraph=True) ref_out = sdpa_partial(q, k, v) compiled_out = compiled_sdpa(q, k, v) tolerance = Tolerances(atol=2e-1, rtol=2e-1) torch.testing.assert_close( ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) if requires_grad: ref_out.backward(do) ref_grads = [q.grad, k.grad, v.grad] q.grad = None k.grad = None v.grad = None compiled_out.backward(do) compiled_grads = [q.grad, k.grad, v.grad] q.grad = None k.grad = None v.grad = None torch.testing.assert_close( compiled_grads[0], ref_grads[0], atol=tolerance.atol, rtol=tolerance.rtol, ) torch.testing.assert_close( compiled_grads[1], ref_grads[1], atol=tolerance.atol, rtol=tolerance.rtol, ) torch.testing.assert_close( compiled_grads[2], ref_grads[2], atol=tolerance.atol, rtol=tolerance.rtol, ) # test paged attention which does not support backward q.requires_grad, k.requires_grad, v.requires_grad = False, False, False paged_compiled_out, _ = self.run_paged_attention( score_mod, q, k, v, dtype, device=device, kernel_options=kernel_options ) torch.testing.assert_close( ref_out, paged_compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) @supported_platform def test_doc_mask_sparse(self, device): document_id = torch.zeros(S, dtype=torch.int, device=device) for i in range(0, S, 256): document_id[i : i + 256] = i // 256 def document_masking_causal(score, b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx document_mask = document_id[q_idx] == document_id[kv_idx] return torch.where(causal_mask & document_mask, score, -float("inf")) self.run_test(document_masking_causal, torch.float16, device=device) self.run_test_with_paged_attention( document_masking_causal, torch.float16, device=device ) @supported_platform def test_index_multiple(self, device): bias = torch.randn(B, S, device=device) def index_multiple(score, b, h, q_idx, kv_idx): return score + bias[b][q_idx] self.run_test(index_multiple, torch.float16, device=device) self.run_test_with_paged_attention(index_multiple, torch.float16, device=device) @supported_platform def test_index_weird1(self, device): bias = torch.randn(4, B, H, S, device=device) def index_weird1(score, b, h, q_idx, kv_idx): return score + bias[0][b, h][q_idx] self.run_test(index_weird1, torch.float16, device=device) self.run_test_with_paged_attention(index_weird1, torch.float16, device=device) @supported_platform def test_index_weird2(self, device): bias = torch.randn(B, H, 4, S, device=device) which_bias = torch.tensor(0, device=device) def index_weird2(score, b, h, q_idx, kv_idx): return score + bias[b][h][which_bias, q_idx] self.run_test(index_weird2, torch.float16, device=device) self.run_test_with_paged_attention(index_weird2, torch.float16, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) def test_skip_odd_keys(self, device, dtype: torch.dtype): def score_mod(score, b, h, q, kv): return torch.where(kv % 2 == 0, score, float("-inf")) self.run_test(score_mod, dtype, device=device) self.run_test_with_paged_attention(score_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) def test_function_composition(self, device, dtype: torch.dtype): def score_mod_1(score, b, h, m, n): return score + (m - n) def score_mod_2(score, b, h, m, n): return torch.where(m <= n, score, float("-inf")) def composed_score_mod(score, b, h, m, n): return score_mod_2(score_mod_1(score, b, h, m, n), b, h, m, n) self.run_test(composed_score_mod, dtype, device=device) self.run_test_with_paged_attention(composed_score_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) def test_captured_buffers_all_dims(self, device, dtype: torch.dtype): head_scale = torch.randn(H, device=device) batch_scale = torch.randn(B, device=device) tok_scale = torch.randn(S, device=device) def all_bias(score, batch, head, token_q, token_kv): score = score + tok_scale[token_q] score = score + batch_scale[batch] score = score + head_scale[head] return score self.run_test(all_bias, dtype, device=device) self.run_test_with_paged_attention(all_bias, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_seq_masking(self, device, dtype): seq_idx = torch.zeros(S, device=device, dtype=torch.bool) seq_idx[S // 2 :] = 1 def seq_mask_mod(score, b, h, q, kv): return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf")) self.run_test(seq_mask_mod, dtype, device=device) self.run_test_with_paged_attention(seq_mask_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_load_from_bias_seq_only(self, device, dtype): bias = torch.randn(S, S, device=device, dtype=dtype) def bias_mod(score, b, h, q, kv): return score + bias[q, kv] self.run_test(bias_mod, dtype, device=device) self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_load_from_bias_seq_batch(self, device, dtype): bias = torch.randn(B, S, S, device=device, dtype=dtype) def bias_mod(score, b, h, q, kv): return score + bias[b, q, kv] self.run_test(bias_mod, dtype, device=device) self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform @skip_on_cpu def test_load_from_view_buffer(self, device): dtype = torch.float16 W = 8 class SimpleAttention(torch.nn.Module): def __init__(self): super().__init__() self.rel_pos_h = torch.randn(2 * H - 1, D, device=device, dtype=dtype) def forward(self, q, k, v): q = q.view(B * H, H * W, -1) score_mod = self.generate_score_mod(q) q = q.view(B, H, H * W, -1) return flex_attention(q, k, v, score_mod=score_mod) def generate_score_mod(self, q): rel_h = self.add_decomposed_rel_pos(q) rel_h = rel_h.view( B, H, rel_h.size(1), rel_h.size(2), rel_h.size(3) ).squeeze(-1) def score_mod(score, batch, head, q_idx, k_idx): h_idx = k_idx // W return score + rel_h[batch, head, q_idx, h_idx] return score_mod @torch.no_grad() def add_decomposed_rel_pos(self, q): q_coords = torch.arange(H, device=self.rel_pos_h.device)[:, None] k_coords = torch.arange(H, device=self.rel_pos_h.device)[None, :] relative_coords = (q_coords - k_coords) + (H - 1) Rh = self.rel_pos_h[relative_coords.long()] r_q = q.reshape(B * H, H, W, D) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) return rel_h.reshape(B * H, H * W, H, 1) m = SimpleAttention().to(device).eval() m = torch.compile(m, mode="max-autotune", fullgraph=True) q = torch.randn(B, H, H * W, D, device=device, dtype=dtype, requires_grad=True) k = torch.randn(B, H, H * W, D, device=device, dtype=dtype, requires_grad=True) v = torch.randn(B, H, H * W, D, device=device, dtype=dtype, requires_grad=True) out = m(q, k, v) out.sum().backward() @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_load_from_bias_head_seq_batch(self, device, dtype): bias = torch.randn(B, H, S, S, device=device, dtype=dtype) def bias_mod(score, b, h, q, kv): return score + bias[b, h, q, kv] self.run_test(bias_mod, dtype, device=device) self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_load_rel_bias(self, device, dtype): rel_bias = torch.randn(2 * S, device=device, dtype=dtype) def bias_mod(score, b, h, q, kv): return score + rel_bias[(q - kv) + S] self.run_test(bias_mod, dtype, device=device) self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_dependent_causal_bidirectional(self, device, dtype): num_bidirectional = torch.randint(0, S, (B,), device=device, dtype=torch.int32) def bias_mod(score, b, h, q, kv): causal_attention = q >= kv cur_num_bidirectional = num_bidirectional[b] bidirectional_attention_on_video = (q <= cur_num_bidirectional) & ( kv <= cur_num_bidirectional ) return torch.where( bidirectional_attention_on_video | causal_attention, score, -float("inf"), ) self.run_test(bias_mod, dtype, device=device) self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_natten_2d(self, device, dtype): H = 32 W = S // H WINDOW = 3 assert W * H == S def get_x_y(idx): # This should be a floor divide, but we don't support that properly return idx / W, idx % W def natten_mask(score, b, h, q, kv): q_x, q_y = get_x_y(q) kv_x, kv_y = get_x_y(kv) return torch.where( ((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW), score, float("-inf"), ) self.run_test(natten_mask, dtype, device=device) self.run_test_with_paged_attention(natten_mask, dtype, device=device) @supported_platform def test_subgraph_respect_decompostion(self, device): from torch._decomp import core_aten_decompositions from torch.fx.experimental.proxy_tensor import make_fx def score_mod_func(score, b, h, q, kv): return score - q // (1 + kv) make_tensor = functools.partial( torch.randn, (2, 2, 128, 4), device=device, dtype=torch.float64, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() # floor_div is not decomposed in decompostion_table is empty attention = functools.partial(flex_attention, score_mod=score_mod_func) gm = make_fx(attention, decomposition_table={})(query, key, value) self.assertExpectedInline( gm.sdpa_score0.code.strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None return sub""", ) # floor_div is decomposed for core_aten_decompositions gm = make_fx(attention, decomposition_table=core_aten_decompositions())( query, key, value ) self.assertExpectedInline( gm.sdpa_score0.code.strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None return sub""", ) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_silu_on_score(self, device, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) self.run_test(silu_score, dtype, device=device) self.run_test_with_paged_attention(silu_score, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_padded_dense_causal(self, device, dtype): seq_len = torch.arange(B, device=device, dtype=torch.int32) + 1 def create_padded_dense_wrapper(orig_score_mod): def njt_score_mod(qk, b, h, q, kv): return torch.where( qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf") ) return njt_score_mod causal_njt = create_padded_dense_wrapper(_causal) self.run_test(causal_njt, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_captured_scale(self, device, dtype): scale = torch.ones((), device=device, dtype=torch.int32) def score_mod_scale(qk, b, h, q, kv): return qk + scale self.run_test(score_mod_scale, dtype, device=device) self.run_test_with_paged_attention(score_mod_scale, dtype, device=device) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_recompile_changed_score_mod(self, device, dtype): scale = torch.ones((), device=device, dtype=torch.int32) ADD = True def score_mod_scale(qk, b, h, q, kv): if ADD: return qk + scale else: return qk * scale self.run_test(score_mod_scale, dtype, device=device) self.run_test_with_paged_attention(score_mod_scale, dtype, device=device) ADD = False self.run_test(score_mod_scale, dtype, device=device) self.run_test_with_paged_attention(score_mod_scale, dtype, device=device) @supported_platform @expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_captured_reduction(self, device, dtype): scale = torch.randn((B, 8), device=device) def score_mod_scale(qk, b, h, q, kv): return qk + scale[b].sum(dim=-1) self.run_test(score_mod_scale, dtype, device=device) @supported_platform @skip_on_cpu @dtypes(torch.float16) @dtypesIfCUDA(torch.float16) def test_dynamic_captured_buffer(self, device, dtype): def run_with_head_count(compiled_fa, head_count): head_scale = torch.randn( head_count, device=device, dtype=dtype, requires_grad=True ) def score_mod(score, batch, head, token_q, token_kv): return score * head_scale[head] q = torch.randn( B, head_count, S, D, device=device, dtype=dtype, requires_grad=True ) k = torch.randn_like(q, requires_grad=True) v = torch.randn_like(q, requires_grad=True) block_mask = create_block_mask(noop_mask, B, 1, S, S, device=device) out = compiled_fa(q, k, v, score_mod=score_mod, block_mask=block_mask) loss = out.sum() loss.backward() return out compiled_fa = torch.compile(flex_attention, fullgraph=True, dynamic=True) head_counts = [4, 8, 4, 16, 4] for head_count in head_counts: run_with_head_count(compiled_fa, head_count) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize( "score_mod", test_score_mods, name_fn=lambda score_mod: score_mod.__name__ ) @skip_on_cpu def test_return_max(self, device, dtype, score_mod): make_tensor = functools.partial( torch.randn, (2, 2, 243, 16), device=device, dtype=dtype, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() out_only = flex_attention(query, key, value, score_mod) out_max, aux_max = flex_attention( query, key, value, score_mod, return_aux=AuxRequest(max_scores=True), ) out_both, aux_both = flex_attention( query, key, value, score_mod, return_aux=AuxRequest(lse=True, max_scores=True), ) flex_compile = torch.compile(flex_attention, fullgraph=True) out_compiled, aux_compiled = flex_compile( query, key, value, score_mod, return_aux=AuxRequest(max_scores=True), ) torch.testing.assert_close(out_only, out_max, atol=1e-6, rtol=1e-6) torch.testing.assert_close(out_only, out_both, atol=1e-6, rtol=1e-6) torch.testing.assert_close( aux_max.max_scores, aux_both.max_scores, atol=1e-6, rtol=1e-6 ) # we are calculating slightly different scores so add a lil fudge # Extra tolerance for squared score_mod with float16 due to limited dynamic range if score_mod.__name__ == "_squared" and dtype == torch.float16: atol, rtol = 2e-2, 2e-2 else: atol, rtol = 5e-3, 5e-3 torch.testing.assert_close(out_max, out_compiled, atol=atol, rtol=rtol) torch.testing.assert_close( aux_max.max_scores, aux_compiled.max_scores, atol=atol, rtol=rtol ) B, H, L = query.shape[:3] self.assertEqual(aux_max.max_scores.shape, (B, H, L)) max_score_tensors = [ aux_max.max_scores, aux_both.max_scores, aux_compiled.max_scores, ] for max_tensor in max_score_tensors: self.assertFalse( max_tensor.requires_grad, "max_scores should not require gradients" ) self.assertEqual( max_tensor.dtype, torch.float32, "max_scores should be kept in fp32" ) # Test gradient computation for both eager and compiled versions test_cases = [ ("eager", out_max, "eager mode"), ("compiled", out_compiled, "compiled mode"), ] for mode_name, output, description in test_cases: loss = output.sum() grads = torch.autograd.grad(loss, (query, key, value)) # Verify gradients are computed for all inputs input_names = ["query", "key", "value"] for grad, input_name in zip(grads, input_names): self.assertIsNotNone( grad, f"{input_name} should receive gradients in {description}" ) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize( "score_mod", test_score_mods, name_fn=lambda score_mod: score_mod.__name__ ) @skip_on_cpu def test_return_aux(self, device, dtype, score_mod): """Test the new return_aux API with AuxRequest/Output""" make_tensor = functools.partial( torch.randn, (2, 2, 243, 16), device=device, dtype=dtype, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() flex_compile = torch.compile(flex_attention, fullgraph=True) flex_compile_partial = torch.compile(flex_attention, fullgraph=False) # Test 1: No auxiliary outputs (default behavior) out_only = flex_compile(query, key, value, score_mod) self.assertIsInstance(out_only, torch.Tensor) # Test 2: Request only LSE out, aux_lse = flex_compile( query, key, value, score_mod, return_aux=AuxRequest(lse=True) ) self.assertIsInstance(aux_lse, AuxOutput) self.assertIsInstance(aux_lse.lse, torch.Tensor) self.assertIsNone(aux_lse.max_scores) self.assertEqual(aux_lse.lse.shape, (2, 2, 243)) self.assertEqual(aux_lse.lse.dtype, torch.float32) # Test 3: Request only max_scores out, aux_max = flex_compile( query, key, value, score_mod, return_aux=AuxRequest(max_scores=True), ) self.assertIsInstance(aux_max, AuxOutput) self.assertIsNone(aux_max.lse) self.assertIsInstance(aux_max.max_scores, torch.Tensor) self.assertEqual(aux_max.max_scores.shape, (2, 2, 243)) self.assertEqual(aux_max.max_scores.dtype, torch.float32) # Test 4: Request both auxiliary outputs out, aux_both = flex_compile( query, key, value, score_mod, return_aux=AuxRequest(lse=True, max_scores=True), ) self.assertIsInstance(aux_both, AuxOutput) self.assertIsInstance(aux_both.lse, torch.Tensor) self.assertIsInstance(aux_both.max_scores, torch.Tensor) self.assertEqual(aux_both.lse.shape, (2, 2, 243)) self.assertEqual(aux_both.max_scores.shape, (2, 2, 243)) # Test 5: Request no auxiliary outputs explicitly out, aux_none = flex_compile( query, key, value, score_mod, return_aux=AuxRequest(), # Default is lse=False, max_scores=False ) self.assertIsInstance(aux_none, AuxOutput) self.assertIsNone(aux_none.lse) self.assertIsNone(aux_none.max_scores) # Test 6: Verify outputs are consistent with legacy API, can't fullgraph through warnings out_legacy, lse_legacy = flex_compile_partial( query, key, value, score_mod, return_lse=True ) torch.testing.assert_close(out_only, out_legacy, atol=1e-6, rtol=1e-6) torch.testing.assert_close(aux_lse.lse, lse_legacy, atol=1e-6, rtol=1e-6) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @skip_on_cpu def test_return_aux_deprecation_warnings(self, device, dtype): """Test that deprecation warnings are issued for legacy parameters""" import warnings make_tensor = functools.partial( torch.randn, (2, 2, 64, 16), device=device, dtype=dtype, ) query, key, value = make_tensor(), make_tensor(), make_tensor() # Clear shown warnings to ensure we can test them original_shown = _WARNINGS_SHOWN.copy() _WARNINGS_SHOWN.clear() try: # Test deprecation warning for return_lse with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") flex_attention(query, key, value, return_lse=True) self.assertTrue( any( "return_lse is deprecated" in str(warning.message) for warning in w ) ) # Clear for next test _WARNINGS_SHOWN.clear() # Test error when both old and new API are used with self.assertRaises(ValueError) as cm: flex_attention( query, key, value, return_lse=True, return_aux=AuxRequest(lse=True), ) self.assertIn( "Cannot specify both return_lse and return_aux", str(cm.exception) ) finally: # Restore original warnings state _WARNINGS_SHOWN.clear() _WARNINGS_SHOWN.update(original_shown) @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @skip_on_cpu def test_dynamic_divisibility_guards(self, device, dtype): """Test guards for divisible/non-divisible shape transitions""" if device == "cpu" and dtype is torch.float16: dtype = torch.float32 def score_mod(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) def test_shape(S, backend): """Test a single shape configuration""" block_mask = create_block_mask(noop_mask, 1, 1, S, S, device=device) sdpa_partial = create_attention(score_mod, block_mask=block_mask) tensors = [ torch.randn( 2, 4, S, 64, dtype=dtype, device=device, requires_grad=False ) for _ in range(3) ] compiled_sdpa = torch.compile(sdpa_partial, backend=backend) out, code = run_and_get_code(compiled_sdpa, *tensors) # Check divisibility flag is_divisible = S % 128 == 0 expected_flag = f"IS_DIVISIBLE : tl.constexpr = {is_divisible}" self.assertIn( expected_flag, str(code), f"S={S} should have {expected_flag}" ) self.assertEqual(out.shape, (2, 4, S, 64)) return out, code torch._dynamo.reset() backend = CompileCounterWithBackend("inductor") # Test divisible and non-divisible shapes test_shapes = [256, 255, 383, 384] _ = [test_shape(S, backend) for S in test_shapes] @supported_platform def test_multiple_score_mod_calls(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) keys = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(2) ] values = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(2) ] def scoremod_1(qk, b, h, q, kv): return qk + (q - kv) def scoremod_2(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) def f(q, k1, k2, v1, v2): q2 = flex_attention(q, k1, v1, score_mod=scoremod_1) return flex_attention(q2, k2, v2, score_mod=scoremod_2) out = f(query, *keys, *values) out2 = torch.compile(f)(query, *keys, *values) tolerance = Tolerances(atol=2e-1, rtol=2e-1) torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) @supported_platform @skip_on_cpu @skip_on_rocm # TODO: Investigate def test_multiple_mask_calls(self, device): make_tensor = functools.partial( torch.randn, (1, 4, 512, 64), dtype=torch.float32, device=device, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() window_size = 32 def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size) mask1 = create_block_mask( causal_mask, 1, None, 512, 512, _compile=False, device=device ) mask2 = create_block_mask( causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False, device=device, ) def f(q, k, v): out1 = flex_attention(q, k, v, block_mask=mask1) out2 = flex_attention(q, k, v, block_mask=mask2) return out1 + out2 f_compiled = torch.compile(f, fullgraph=True) out = f(query, key, value) out_compiled = f_compiled(query, key, value) grads = torch.autograd.grad((out,), (query, key, value), torch.ones_like(out)) grads_compile = torch.autograd.grad( (out_compiled,), (query, key, value), torch.ones_like(out_compiled) ) for grad, grad_compiled in zip(grads, grads_compile): torch.testing.assert_close(grad, grad_compiled, atol=3e-2, rtol=3e-2) @supported_platform def test_multiple_score_mod_calls2(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) keys = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(3) ] values = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(3) ] def scoremod_1(qk, b, h, q, kv): return qk + (q - kv) def scoremod_2(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) attention1 = functools.partial(flex_attention, score_mod=scoremod_1) def f(q, k1, k2, k3, v1, v2, v3): q2 = attention1(q, k1, v1) q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2) return flex_attention(q3, k3, v3, score_mod=scoremod_1) out = f(query, *keys, *values) out2 = torch.compile(f, fullgraph=True)(query, *keys, *values) self.assertTrue((out - out2).abs().mean() < 1e-2) @supported_platform def test_multiple_score_mod_calls_paged_attention(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) keys = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(2) ] values = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(2) ] def scoremod_1(qk, b, h, q, kv): return qk + (q - kv) def scoremod_2(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) def f(q, k1, k2, v1, v2): q2 = flex_attention(q, k1, v1, score_mod=scoremod_1) return flex_attention(q2, k2, v2, score_mod=scoremod_2) eager_out = f(query, *keys, *values) block_mask = create_block_mask(noop_mask, 1, 1, 1024, 1024, device=device) ( k_cache1, v_cache1, converted_block_mask1, converted_score_mod1, ) = self.preprocess_paged_attention( scoremod_1, query, keys[0], values[0], block_mask, torch.float32, device=device, ) ( k_cache2, v_cache2, converted_block_mask2, converted_score_mod2, ) = self.preprocess_paged_attention( scoremod_2, query, keys[1], values[1], block_mask, torch.float32, device=device, ) def paged_f(q, k1, k2, v1, v2): q2 = flex_attention( q, k1, v1, score_mod=converted_score_mod1, block_mask=converted_block_mask1, ) return flex_attention( q2, k2, v2, score_mod=converted_score_mod2, block_mask=converted_block_mask2, ) compiled_out = torch.compile(paged_f, fullgraph=True)( query, k_cache1, k_cache2, v_cache1, v_cache2 ) tolerance = Tolerances(atol=2e-1, rtol=2e-1) torch.testing.assert_close( eager_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) @supported_platform def test_multiple_score_mod_calls2_paged_attention(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) keys = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(3) ] values = [ torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) for _ in range(3) ] def scoremod_1(qk, b, h, q, kv): return qk + (q - kv) def scoremod_2(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) attention1 = functools.partial(flex_attention, score_mod=scoremod_1) def f(q, k1, k2, k3, v1, v2, v3): q2 = attention1(q, k1, v1) q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2) return flex_attention(q3, k3, v3, score_mod=scoremod_1) eager_out = f(query, *keys, *values) block_mask = create_block_mask(noop_mask, 1, 1, 1024, 1024, device=device) ( k_cache1, v_cache1, converted_block_mask1, converted_score_mod1, ) = self.preprocess_paged_attention( scoremod_1, query, keys[0], values[0], block_mask, torch.float32, device=device, ) ( k_cache2, v_cache2, converted_block_mask2, converted_score_mod2, ) = self.preprocess_paged_attention( scoremod_2, query, keys[1], values[1], block_mask, torch.float32, device=device, ) ( k_cache3, v_cache3, converted_block_mask3, converted_score_mod3, ) = self.preprocess_paged_attention( scoremod_1, query, keys[2], values[2], block_mask, torch.float32, device=device, ) paged_attention1 = functools.partial( flex_attention, score_mod=converted_score_mod1, block_mask=converted_block_mask1, ) def paged_f(q, k1, k2, k3, v1, v2, v3): q2 = paged_attention1(q, k1, v1) q3 = flex_attention( q2, k2, v2, score_mod=converted_score_mod2, block_mask=converted_block_mask2, ) return flex_attention( q3, k3, v3, score_mod=converted_score_mod3, block_mask=converted_block_mask3, ) compiled_out = torch.compile(paged_f, fullgraph=True)( query, k_cache1, k_cache2, k_cache3, v_cache1, v_cache2, v_cache3 ) tolerance = Tolerances(atol=2e-1, rtol=2e-1) torch.testing.assert_close( eager_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) @supported_platform @skip_on_cpu def test_inputs_are_realized(self, device): def f(q, k, v): x = torch.randn(1024, device=device) x = x * 2 def func(qk, b, h, q, kv): return qk + x[q] return flex_attention(q.sin(), k, v, score_mod=func).cos() q, k, v = ( torch.randn(1, 8, 1024, 64, device=device, requires_grad=True) for _ in range(3) ) ref = f(q, k, v) out = torch.compile(f)(q, k, v) self.assertTrue((ref - out).abs().mean() < 1e-2) gradOut = torch.randn_like(q) ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut) out_grads = torch.autograd.grad(out, (q, k, v), gradOut) for ref, out in zip(ref_grads, out_grads): self.assertTrue((ref - out).abs().mean() < 1e-2) @supported_platform @skip_on_cpu def test_make_block_mask(self, device): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask_a = torch.compile(create_block_mask, fullgraph=True)( causal_mask, 1, 1, 512, 512, device=device ) block_mask_b = create_block_mask(causal_mask, 1, 1, 512, 512, device=device) self.assertEqual(block_mask_a.kv_num_blocks, block_mask_b.kv_num_blocks) self.assertEqual(block_mask_a.kv_indices, block_mask_b.kv_indices) self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks) @supported_platform def test_mask_mod_combiners(self, device): def causal_mask(b, h, q, kv): return q >= kv def neg_causal_mask(b, h, q, kv): return q < kv def sliding_window(b, h, q, kv): return (q - kv) <= 512 local_s = 2048 block_mask = create_block_mask( and_masks(causal_mask, sliding_window), 1, 1, local_s, local_s, device=device, ) self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""") attention = functools.partial(flex_attention, block_mask=block_mask) self.run_test_with_call( attention, Q_S=local_s, KV_S=local_s, dtype=torch.float16, device=device ) block_mask = create_block_mask( and_masks(causal_mask, neg_causal_mask), 1, 1, local_s, local_s, device=device, ) self.assertEqual(block_mask.kv_num_blocks.sum(), 0) block_mask1 = create_block_mask( or_masks(causal_mask, neg_causal_mask), 1, 1, local_s, local_s, device=device, ) block_mask2 = create_block_mask( noop_mask, 1, 1, local_s, local_s, device=device ) self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity()) @supported_platform @skip_on_cpu def test_epilogue_fused(self, device): # set so that metrics appear torch._logging.set_logs(inductor_metrics=True) @torch.compile def f(q, k, v): out = flex_attention(q, k, v) return out.cos() q, k, v = (torch.randn(1, 8, 1024, 64, device=device) for _ in range(3)) metrics.reset() _, code = run_and_get_code(f, q, k, v) fc = FileCheck() fc.check("triton_tem_fused") # template call fc.check_not("poi_fused_cos") # No cos pointwise operation fc.run(code[0]) accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize num_accesses = 4 # q, k, v reads, one output. # TODO: Get rid of this fudge factor # We need this fudge factor for now as we write the extraneous logsumexp num_accesses += 1 self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) torch._logging.set_logs() @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) def test_njt_causal(self, device, dtype): offsets = torch.tensor( [0, 1024, 1024 + 512, S], device=device, dtype=torch.int32 ) seq_idx = torch.zeros(S, device=device, dtype=torch.int32) for idx in range(len(offsets) - 1): seq_idx[offsets[idx] : offsets[idx + 1]] = idx def create_njt_wrapper(orig_score_mod, offsets, seq_idx): def njt_score_mod(qk, b, h, q, kv): q_nested = q - offsets[seq_idx[q]] kv_nested = kv - offsets[seq_idx[kv]] return orig_score_mod(qk, b, h, q_nested, kv_nested) return njt_score_mod causal_njt = create_njt_wrapper(_causal, offsets, seq_idx) self.run_test(causal_njt, dtype, device=device) self.run_test_with_paged_attention(causal_njt, dtype, device=device) @supported_platform def test_mixed_dtypes_fails(self, device): query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device=device) key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device=device) value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device=device) with self.assertRaisesRegex( ValueError, "Expected query, key, and value to have the same dtype" ): flex_attention(query, key, value, _identity) @supported_platform @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune(self, device): def score_mod(score, b, h, m, n): return score * 2 self.run_test(score_mod, dtype=torch.float16, device=device) self.run_test_with_paged_attention( score_mod, dtype=torch.float16, device=device ) self.run_test_with_paged_attention( score_mod=score_mod, dtype=torch.bfloat16, KV_S=64, device=device, ) @supported_platform @skip("TODO: Figure out why this is erroring") @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune_with_captured(self, device): head_scale = torch.randn(H, device=device) batch_scale = torch.randn(B, device=device) tok_scale = torch.randn(S, device=device) def bias_mod(score, batch, head, token_q, token_kv): score = score + tok_scale[token_q] score = score + batch_scale[batch] score = score + head_scale[head] return score self.run_test(bias_mod, dtype=torch.float32, device=device) @supported_platform @common_utils.parametrize("score_mod", test_score_mods) @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)]) def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims): qk_d, v_d = head_dims self.run_test(score_mod, dtype, device, B, H, S, qk_d, B, H, S, V_D=v_d) self.run_test_with_paged_attention( score_mod, dtype, device, B, H, S, qk_d, B, H, S, V_D=v_d ) @supported_platform @skip_on_cpu def test_autograd_function_in_score_mod(self, device): class ApplyMask(torch.autograd.Function): generate_vmap_rule = True @staticmethod def forward(a, mask): return torch.where(mask, a, -float("inf")) @staticmethod def setup_context(ctx, inputs, output): _, mask = inputs ctx.mark_non_differentiable(mask) @staticmethod def backward(ctx, i): return i, None def score_mod(score, b, h, q, kv): return ApplyMask.apply(score, q <= kv) func = torch.compile(flex_attention, fullgraph=True) q, k, v = ( torch.randn(1, 8, 1024, 64, device=device, requires_grad=True) for _ in range(3) ) # Just checking that it runs func(q, k, v) # expectedFailure # This doesn't work due to vmap + autograd.Function + torch.compile not composing # self.run_test(score_mod) @supported_platform def test_causal_block(self, device): def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, 1, 1, S, S, device=device) attention = functools.partial(flex_attention, block_mask=block_mask) self.run_test_with_call(attention, dtype=torch.float16, device=device) @supported_platform def test_causal_block_paged_attention(self, device): def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, B, 1, S, S, device=device) self.run_test_with_paged_attention( score_mod=_identity, dtype=torch.float16, device=device, block_mask=block_mask, ) @supported_platform def test_new_empty_mask_mod(self, device): S = 128 q, k, v = (torch.randn(4, 1, S, 64, device=device) for _ in range(3)) attn_mask = torch.ones(4, 1, S, S, dtype=torch.bool, device=device).tril() def score_mod(score, b, h, q_idx, kv_idx): h_ = h.new_zeros(h.shape) return score + attn_mask[b, h_, q_idx, kv_idx] def causal(b, h, q_idx, kv_idx): h_ = h.new_zeros(h.shape) return attn_mask[b, h_, q_idx, kv_idx] block_mask = create_block_mask( causal, B=4, H=None, Q_LEN=S, KV_LEN=S, device=device ) torch.compile(flex_attention, fullgraph=True)( q, k, v, score_mod, block_mask=block_mask ) @supported_platform @common_utils.parametrize("head_dim", [17, 24, 94, 121]) @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) def test_non_pow_2_headdim(self, device, dtype, head_dim): self.run_test(_rel_bias, dtype, device, B, H, S, head_dim, B, H, S, head_dim) @supported_platform def test_GQA_causal_mask(self, device): def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, B, 1, S // 8, S // 8, device=device) attention = functools.partial( flex_attention, block_mask=block_mask, enable_gqa=True ) self.run_test_with_call( attention, torch.float16, device, B, H * 4, # Hq = 4*Hkv. S // 8, D, B, H, S // 8, D, ) self.run_test_with_paged_attention( _identity, dtype=torch.float16, device=device, Q_H=H * 4, Q_S=S // 8, KV_H=H, KV_S=S // 8, block_mask=block_mask, ) @supported_platform def test_custom_block_mask_generator(self, device): def mask_mod(b, h, q, kv): return q >= kv auto_mask = create_block_mask(mask_mod, 1, 1, S, S, device=device) BLOCK_SIZE = 128 def causal_constructor(S): num_blocks = torch.arange(S // BLOCK_SIZE, device=device) + 1 indices = torch.arange(S // BLOCK_SIZE, device=device).expand( S // BLOCK_SIZE, S // BLOCK_SIZE ) num_blocks = num_blocks[None, None, :] indices = indices[None, None, :] return BlockMask.from_kv_blocks( num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=mask_mod ) manual_mask = causal_constructor(S) self.assertEqual(auto_mask.to_dense(), manual_mask.to_dense()) @supported_platform @skip_on_cpu @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) @common_utils.parametrize("score_mod", [_identity, _causal]) def test_logsumexp_correctness(self, device, dtype, score_mod): make_tensor = functools.partial( torch.randn, (B, H, S, D), dtype=dtype, device=device, requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() @torch.compile def sdpa_hop(q, k, v, score_mod): return flex_attention(q, k, v, score_mod, return_lse=True) @torch.compile(backend="aot_eager") def eager_sdpa_hop(q, k, v, score_mod): return flex_attention(q, k, v, score_mod, return_lse=True) ref_out, ref_lse = eager_sdpa_hop( q.to(torch.float64), k.to(torch.float64), v.to(torch.float64), score_mod, ) compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod) self.assertTrue(ref_lse.dtype == torch.float64) self.assertTrue(compiled_lse.dtype == torch.float32) tolerance = Tolerances(atol=2e-2, rtol=2e-2) torch.testing.assert_close( ref_out.to(dtype=torch.float32), compiled_out.to(dtype=torch.float32), atol=tolerance.atol, rtol=tolerance.rtol, ) torch.testing.assert_close( ref_lse.to(dtype=torch.float32), compiled_lse.to(dtype=torch.float32), atol=tolerance.atol, rtol=tolerance.rtol, ) @supported_platform @skip_on_cpu def test_logsumexp_only_return(self, device): make_tensor = functools.partial( torch.randn, (B, H, S, D), dtype=torch.float32, device=device, requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() @torch.compile def func(q, k, v, score_mod): _, lse = flex_attention(q, k, v, score_mod, return_lse=True) lse_2 = lse * 2 return lse_2 _, code = run_and_get_code(func, q, k, v, _identity) # Ensure that we're still generating the flexattention kernel FileCheck().check_count(".run(primals_1, primals_2, primals_3", 1, True).run( code[0] ) @supported_platform @skip_on_cpu @common_utils.parametrize( "score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2] ) def test_aot_eager_gradcheck(self, device, score_mod): make_tensor = functools.partial( torch.randn, (2, 2, 11, 4), device=device, dtype=torch.float64, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() func = torch.compile(flex_attention, backend="aot_eager", fullgraph=True) self.assertTrue( torch.autograd.gradcheck( func, (query, key, value, score_mod), raise_exception=True ) ) @supported_platform @skip_on_cpu def test_eager_backward_strides(self, device): class Repro(torch.nn.Module): def __init__(self): super().__init__() self.qkv_proj = torch.nn.Linear(256, 256 * 3) self.n_head = 256 // 64 self.d_attn = 256 def forward(self, x): n_batch, n_ctx, _ = x.shape q, k, v = self.qkv_proj(x).split( [self.d_attn, self.d_attn, self.d_attn], dim=2 ) q = q.reshape(n_batch, n_ctx, self.n_head, -1) k = k.reshape(n_batch, n_ctx, self.n_head, -1) v = v.reshape(n_batch, n_ctx, self.n_head, -1) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) x = torch.nn.attention.flex_attention.flex_attention(q, k, v) return x model = Repro().to(device) x = torch.randn((1, 512, 256), device=device, requires_grad=True) out = torch.compile(model, backend="aot_eager", fullgraph=True)(x) out.backward(torch.ones_like(out)) @supported_platform @skip_on_cpu def test_differentiable_logsumexp_gradcheck(self, device): make_tensor = functools.partial( torch.randn, (2, 2, 11, 4), device=device, dtype=torch.float64, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() def flex_attention_lse_only(q, k, v): return flex_attention(q, k, v, return_lse=True)[1] func = torch.compile(flex_attention_lse_only, backend="aot_eager") self.assertTrue( torch.autograd.gradcheck(func, (query, key, value), raise_exception=True) ) @supported_platform @skip_on_cpu def test_differentiable_logsumexp_compiled(self, device): make_tensor = functools.partial( torch.randn, (2, 2, 128, 64), device=device, dtype=torch.float32, requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() lse_mask = torch.randn(2, 2, 128, device=device) out, lse = flex_attention(q, k, v, return_lse=True) (out.mean() + (lse * lse_mask).sum()).backward() q_grad, k_grad, v_grad = q.grad, k.grad, v.grad q.grad = None k.grad = None v.grad = None out2, lse2 = torch.compile(flex_attention)(q, k, v, return_lse=True) (out2.mean() + (lse2 * lse_mask).sum()).backward() q_grad2, k_grad2, v_grad2 = q.grad, k.grad, v.grad tolerance = Tolerances(atol=1e-1, rtol=1e-1) torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) torch.testing.assert_close(lse, lse2, atol=tolerance.atol, rtol=tolerance.rtol) torch.testing.assert_close( q_grad, q_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) torch.testing.assert_close( k_grad, k_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) torch.testing.assert_close( v_grad, v_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) # Use weird mask to test reusing block_mask does work well. @supported_platform @skip_on_cpu def _test_block_mask_reuse_with_weird_mask(self, device): def mask(b, h, q, kv): return (kv < 256) | (kv >= 2048) make_tensor = functools.partial( torch.randn, (4, 4, 4096, 64), device=device, dtype=torch.float32, requires_grad=True, ) block_mask = create_block_mask(mask, None, None, 4096, 4096, device=device) # Compile 1st version with q/k/v(seqlen=4096) and block_mask(seqlen=4096) torch.compile(flex_attention, dynamic=True, fullgraph=True)( make_tensor(), make_tensor(), make_tensor(), block_mask=block_mask ) make_tensor2 = functools.partial( torch.randn, (4, 4, 2048, 64), device=device, dtype=torch.float32, requires_grad=True, ) q, k, v = make_tensor2(), make_tensor2(), make_tensor2() # Compile 2nd version with q/k/v(seqlen=2048) and block_mask(seqlen=4096), # The graph includes the BlockMask._adjust part. out = torch.compile(flex_attention, dynamic=True, fullgraph=True)( q, k, v, block_mask=block_mask ) out.sum().backward() q_grad, k_grad, v_grad = q.grad, k.grad, v.grad q.grad = None k.grad = None v.grad = None block_mask2 = create_block_mask(mask, None, None, 2048, 2048, device=device) # Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048) out2 = torch.compile(flex_attention, dynamic=True, fullgraph=True)( q, k, v, block_mask=block_mask2 ) out2.sum().backward() q_grad2, k_grad2, v_grad2 = q.grad, k.grad, v.grad tolerance = Tolerances(atol=1e-3, rtol=1e-3) torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) torch.testing.assert_close( q_grad, q_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) torch.testing.assert_close( k_grad, k_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) torch.testing.assert_close( v_grad, v_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) @supported_platform @skip_on_cpu def test_float32_matmul_precision(self, device): make_tensor = functools.partial( torch.zeros, (2, 2, 128, 32), device=device, dtype=torch.float32, requires_grad=False, ) query, key, value = make_tensor(), make_tensor(), make_tensor() query.fill_(0.2) key.fill_(0.3) value.fill_(0.4) query.requires_grad = True key.requires_grad = True value.requires_grad = True def score_mod(score, b, h, q, kv): return score * 2 with temp_float32_matmul_precision("highest"): out_eager = flex_attention(query, key, value, score_mod) flex_compiled = torch.compile(flex_attention, fullgraph=True) out_compiled = flex_compiled(query, key, value, score_mod) grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) torch.testing.assert_close(grads_eager, grads_compile) @supported_platform @skip_on_cpu @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_captured_score_mod_aot_eager_gradcheck( self, device, score_mod_name: str, mode: str ): make_tensor = functools.partial( torch.randn, (2, 2, 11, 4), device=device, dtype=torch.float64, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() func = torch.compile(flex_attention, backend=mode, fullgraph=True) score_mod = captured_buffers_map[score_mod_name](torch.float64, device) self.assertTrue( torch.autograd.gradcheck( func, (query, key, value, score_mod), raise_exception=True ) ) @supported_platform @skip_on_cpu @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_document_masking_edge_case(self, device, mode): requires_grad = device in DEVICE_SUPPORTS_BACKWARDS document_masks = torch.full((2, 128), 0, dtype=torch.int32, device=device) document_masks[:, 64:] = 1 def mask_mod(b, h, q, kv): same_doc = document_masks[b, q] == document_masks[b, kv] return same_doc make_tensor = functools.partial( torch.randn, (2, 1, 128, 4), device=device, dtype=torch.float64, requires_grad=requires_grad, ) query, key, value = make_tensor(), make_tensor(), make_tensor() func = torch.compile(flex_attention, backend=mode, fullgraph=True) block_mask = create_block_mask(mask_mod, 2, 1, 128, 128, device=device) out = func(query, key, value, block_mask=block_mask) if requires_grad: out.sum().backward() @supported_platform @skip_on_cpu def test_strided_backwards(self, device): shape = (1, 2, 4096, 64) Q = torch.randn(shape, requires_grad=True, device=device) K = torch.randn(shape, requires_grad=True, device=device) V = torch.randn(shape, requires_grad=True, device=device) func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128] V_sliced = V[:, :, :-128] out_eager = flex_attention(Q, K_sliced, V_sliced) out_compiled = func(Q, K_sliced, V_sliced) grad = torch.rand_like(out_eager) eager_grads = torch.autograd.grad(out_eager, (Q, K, V), grad) compiled_grads = torch.autograd.grad(out_compiled, (Q, K, V), grad) for eager, compiled in zip(eager_grads, compiled_grads): torch.testing.assert_close(eager, compiled, atol=9e-3, rtol=0) @supported_platform @skip_on_cpu @common_utils.parametrize("mode", ["eager", "inductor", "paged_attention"]) @common_utils.parametrize( "permute_order", [ (0, 1, 2, 3), # Default order (1, 0, 2, 3), # Reverse order (0, 2, 1, 3), # Mixed order (2, 0, 1, 3), # Another mixed order (0, 1, 3, 2), # Non contiguous last dim ], ) @common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)]) def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape): from torch._inductor.ir import get_stride_order if torch.version.hip and mode == "paged_attention": raise self.skipTest( "TODO: figure out why mode_paged_attention_permute_order3_shape0 on MI200 caused mem fault" ) dtype = torch.float32 # Setup requires_grad = device in DEVICE_SUPPORTS_BACKWARDS make_tensor = functools.partial( torch.randn, shape, device=device, dtype=dtype, requires_grad=False if mode == "paged_attention" else requires_grad, ) # Create and permute tensors query, key, value = make_tensor(), make_tensor(), make_tensor() query = query.permute(permute_order) key = key.permute(permute_order) value = value.permute(permute_order) if mode == "inductor": func = torch.compile(flex_attention, backend=mode, fullgraph=True) out = func(query, key, value) elif mode == "paged_attention": out, _ = self.run_paged_attention( _identity, query, key, value, dtype, device=device ) else: func = flex_attention out = func(query, key, value) out_stride_order = get_stride_order(out.stride()) query_stride_order = get_stride_order(query.stride()) self.assertEqual( out_stride_order, query_stride_order, f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}", ) @supported_platform @skip_on_cpu @common_utils.parametrize("mode", ["eager", "inductor"]) @common_utils.parametrize( "permute_order", [(0, 1, 2, 3), (1, 0, 2, 3), (0, 2, 1, 3), (2, 0, 1, 3), (0, 1, 3, 2)], ) @common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)]) def test_flex_attention_backward_stride_ordering( self, device, mode, permute_order, shape ): from torch._inductor.ir import get_stride_order dtype = torch.float32 make_tensor = functools.partial( torch.randn, shape, device=device, dtype=dtype, requires_grad=False ) query, key, value = make_tensor(), make_tensor(), make_tensor() query = query.permute(permute_order) key = key.permute(permute_order) value = value.permute(permute_order) query.requires_grad_() key.requires_grad_() value.requires_grad_() func = ( torch.compile(flex_attention, backend=mode, fullgraph=True) if mode == "inductor" else flex_attention ) out = func(query, key, value) grad_output = torch.randn_like(out) out.backward(grad_output) for leaf, grad, name in [ (query, query.grad, "query"), (key, key.grad, "key"), (value, value.grad, "value"), ]: input_stride_order = get_stride_order(grad.stride()) orig_stride_order = get_stride_order(leaf.stride()) self.assertEqual( input_stride_order, orig_stride_order, f"Mode: {mode}, Stride order mismatch for {name}: grad {input_stride_order}, input {orig_stride_order}.", ) @supported_platform def test_non_contiguous_last_dim(self, device): """Test flex_attention with tensors having non contiguous last dimension.""" B, H, D = 4, 8, 64 dtype = torch.float16 if device in DEVICE_SUPPORTS_BACKWARDS else torch.float32 for S in [16, 64]: def column_major_tensor(): tensor = torch.randn( (B, H, S, D), dtype=dtype, device=device, ) # Column major in last 2 dims return tensor.transpose(-1, -2).contiguous().transpose(-1, -2) q = column_major_tensor() k = column_major_tensor() v = column_major_tensor() requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if requires_grad: q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) self.assertNotEqual(q.stride()[-1], 1) self.assertNotEqual(k.stride()[-1], 1) self.assertNotEqual(v.stride()[-1], 1) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) golden_out = flex_attention(q_gold, k_gold, v_gold) ref_out = flex_attention(q_ref, k_ref, v_ref) flex_compiled = torch.compile(flex_attention, fullgraph=True, dynamic=True) compiled_out = flex_compiled(q, k, v) self._check_out(golden_out, ref_out, compiled_out) if requires_grad: backward_grad = torch.randn_like(ref_out) golden_out.backward(backward_grad.to(torch.float64)) ref_out.backward(backward_grad) compiled_out.backward(backward_grad) self._check_out_and_grad( golden_out, ref_out, compiled_out, q_gold, q_ref, q, k_gold, k_ref, k, v_gold, v_ref, v, ) @supported_platform @common_utils.parametrize("compile", [True, False]) def test_fully_masked_out_rows_0_check(self, device, compile: bool): # Ensure fully masked out rows won't cause NaNs. requires_grad = device in DEVICE_SUPPORTS_BACKWARDS query = torch.randn( (B, H, S, D), dtype=torch.float32, device=device, requires_grad=requires_grad, ) key = torch.randn( (B, H, S, D), dtype=torch.float32, device=device, requires_grad=requires_grad, ) value = torch.randn( (B, H, S, D), dtype=torch.float32, device=device, requires_grad=requires_grad, ) M = S // 2 def mask_mod(b, h, q, kv): return q < M block_mask = create_block_mask(mask_mod, B, 1, S, S, device=device) flex = ( torch.compile(flex_attention, dynamic=False) if compile else flex_attention ) if requires_grad: out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True) self.assertEqual(out[:, :, M:, :].sum(), 0) self.assertTrue((lse[:, :, M:] == -float("inf")).all()) loss = out.sum() + lse.sum() loss.backward() self.assertEqual(query.grad[:, :, M:, :].sum(), 0) else: out = flex(query, key, value, block_mask=block_mask, return_lse=False) self.assertEqual(out[:, :, M:, :].sum(), 0) @supported_platform def test_fully_masked_out_rows(self, device): M = S // 2 def mask_mod(b, h, q, kv): return q < M block_mask = create_block_mask(mask_mod, B, 1, S, S, device=device) def noop_mod(score, b, h, q_idx, kv_idx): return score self.run_test( noop_mod, torch.float32, device, B, H, S, D, B, H, S, D, block_mask ) @supported_platform @skip_on_cpu def test_kernel_options_argument_is_respected(self, device): make_tensor = functools.partial( torch.randn, (2, 2, 128, 64), device=device, dtype=torch.float32, requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() # Ensure we respect user's input kernel options. _, code = run_and_get_code( torch.compile(flex_attention, fullgraph=True), q, k, v, kernel_options={"BLOCK_M": 16}, ) FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0]) @supported_platform def test_block_mask_non_divisible(self, device): seq = torch.arange(1023, device=device) // 128 def mod(b, h, q, kv): return seq[q] == seq[kv] block_mask = create_block_mask(mod, None, None, 1023, 1023, device=device) torch.compile(create_block_mask)(mod, None, None, 1023, 1023, device=device) self.run_test_with_call( lambda q, k, v: flex_attention(q, k, v, block_mask=block_mask), torch.float16, device, Q_S=1023, KV_S=1023, ) @supported_platform def test_causal_block_non_divisible(self, device): def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, B, 1, S - 1, S - 1, device=device) attention = functools.partial(flex_attention, block_mask=block_mask) self.run_test_with_call(attention, torch.float16, device, Q_S=S - 1, KV_S=S - 1) @supported_platform @skip_on_cpu def test_modular_indexing(self, device): B, H, N, D = 100, 12, 128, 64 dtype = torch.bfloat16 device = torch.device(device) class Attention(torch.nn.Module): def __init__(self): super().__init__() self.bias = torch.randn(B, N, N, H, device=device, dtype=dtype) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: score_mod = generate_score_mod(self.bias) o = flex_attention(q, k, v, score_mod=score_mod) return o def generate_score_mod(bias): bias = (2 * bias).view(B, H, N, N).contiguous() def score_mod(score, batch, head, q_idx, k_idx): attn_bias = bias[batch, head, q_idx, k_idx] return score + attn_bias return score_mod m = Attention().to(device).eval().to(dtype) m = torch.compile(m, mode="default", fullgraph=False) q = torch.randn(B, H, N, D, device=device, dtype=dtype) k = torch.randn(B, H, N, D, device=device, dtype=dtype) v = torch.randn(B, H, N, D, device=device, dtype=dtype) m(q, k, v) @supported_platform @skip_on_cpu def test_force_write_lse(self, device): dtype = torch.float32 make_tensor = functools.partial( torch.randn, (2, 2, 128, 16), device=device, dtype=dtype, requires_grad=False, ) query, key, value = make_tensor(), make_tensor(), make_tensor() out_eager, lse_eager = flex_attention(query, key, value, return_lse=True) flex_compile = torch.compile(flex_attention) out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True) out_paged, lse_paged = self.run_paged_attention( score_mod=_identity, q=query, k=key, v=value, dtype=dtype, device=device ) torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if requires_grad: torch.testing.assert_close(lse_eager, lse_paged, atol=3e-3, rtol=0) @supported_platform @skip_on_cpu @common_utils.parametrize("backend", ["flex_attention", "flex_decode", "eager"]) def test_lse_masked_output(self, device, backend): if backend == "flex_decode": kernel_options = {"FORCE_USE_FLEX_ATTENTION": False} flex_call = torch.compile(flex_attention, fullgraph=True) N_CTX = 96 elif backend == "flex_attention": kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} flex_call = torch.compile(flex_attention, fullgraph=True) N_CTX = 196 else: kernel_options = {} flex_call = flex_attention N_CTX = 196 SLIDING_WINDOW = 64 make_tensor = functools.partial( torch.randn, (2, 2, N_CTX, 64), device=device, dtype=torch.float32, requires_grad=True, ) def sliding_window_causal(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx window_mask = q_idx - kv_idx <= SLIDING_WINDOW return causal_mask & window_mask def global_causal(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx window_mask = q_idx - kv_idx > SLIDING_WINDOW return causal_mask & window_mask sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask( sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX, device=device, ) global_causal = torch.nn.attention.flex_attention.create_block_mask( global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX, device=device ) local_attn = functools.partial( flex_call, block_mask=sliding_window_causal, return_lse=True, kernel_options=kernel_options, ) global_attn = functools.partial( flex_call, block_mask=global_causal, return_lse=True, kernel_options=kernel_options, ) q, k, v = make_tensor(), make_tensor(), make_tensor() gradOut = make_tensor(requires_grad=False) x_local, lse_local = local_attn(q, k, v) x_global, lse_global = global_attn(q, k, v) max_lse = torch.maximum(lse_local, lse_global) lse_global = lse_global - max_lse lse_local = lse_local - max_lse lse_global = torch.exp(lse_global) lse_local = torch.exp(lse_local) x = ((x_local * lse_local[..., None]) + (x_global * lse_global[..., None])) / ( lse_global[..., None] + lse_local[..., None] ) x.backward(gradOut) flex_q_grad, flex_k_grad, flex_v_grad = q.grad, k.grad, v.grad q.grad = None k.grad = None v.grad = None out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) out.backward(gradOut) torch.testing.assert_close(x, out, atol=3e-3, rtol=2e-3) torch.testing.assert_close(flex_q_grad, q.grad, atol=3e-3, rtol=2e-3) torch.testing.assert_close(flex_k_grad, k.grad, atol=3e-3, rtol=2e-3) torch.testing.assert_close(flex_v_grad, v.grad, atol=3e-3, rtol=2e-3) @supported_platform @skip_on_cpu def test_mixed_device_error_message(self, device): # Create tensors on different devices cpu_tensor = torch.randn(2, 2, 128, 16, device="cpu") gpu_tensor = torch.randn(2, 2, 128, 16, device=device) # Use different devices for query, key, and value query, key, value = cpu_tensor, gpu_tensor, cpu_tensor expected_error_message = ( "Expected query, key, and value to have the same device type, " f"but got query.device: {query.device}, key.device: {key.device}, " f"and value.device: {value.device} instead." ) with self.assertRaisesRegex(ValueError, expected_error_message): flex_attention(query, key, value) @supported_platform @skip_on_cpu def test_captured_wrong_device_error_message(self, device): means = torch.randn(64, 3, device=device) length_scales = torch.logspace(0.001, 0.1, 8, device="cpu") def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] k_pos = means[k_idx] dist = (q_pos - k_pos).pow(2).sum(-1).sqrt() scale = length_scales[h] inv_dist = torch.exp(-dist / scale) return inv_dist * score expected_error_message = "Buffers cannot be created" q, k, v = (torch.randn(1, 8, 64, 64, device=device) for _ in range(3)) with self.assertRaisesRegex(RuntimeError, expected_error_message): torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed) @supported_platform @skip_on_cpu def test_cant_lower_error_message(self, device): # We can't lower a 256-element reduction inside a pointwise reduction means = torch.randn(64, 256, device=device) length_scales = torch.logspace(0.001, 0.1, 8, device=device) def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] k_pos = means[k_idx] dist = (q_pos - k_pos).pow(2).sum(-1).sqrt() scale = length_scales[h] inv_dist = torch.exp(-dist / scale) return inv_dist * score expected_error_message = "Buffers cannot be created" q, k, v = (torch.randn(1, 8, 64, 64, device=device) for _ in range(3)) with self.assertRaisesRegex(RuntimeError, expected_error_message): torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed) @supported_platform @skip_on_cpu def test_reduction_unrolled(self, device): # We can't lower a 256-element reduction inside a pointwise reduction means = torch.randn(S, 3, device=device) length_scales = torch.logspace(0.001, 0.1, H, device=device) def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] k_pos = means[k_idx] dist = (q_pos - k_pos).pow(2).sum(-1).sqrt() scale = length_scales[h] inv_dist = torch.exp(-dist / scale) return inv_dist * score self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device=device) @supported_platform @skip_on_cpu def test_invalid_block_size(self, device): # Create tensors on different devices q, k, v = (torch.randn(1, 8, 128, 64, device=device) for _ in range(3)) expected_error_message = ( "ValueError: Q and KV block size must be divisible by BLOCK_M and BLOCK_N." ) block_mask = create_block_mask( noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96, device=device ) with self.assertRaisesRegex(RuntimeError, expected_error_message): torch.compile(flex_attention)(q, k, v, block_mask=block_mask) @supported_platform @skip_on_cpu def test_small_q_kv_len(self, device): make_tensor = functools.partial( torch.ones, (1, 1, 1, 16), device=device, dtype=torch.float32, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} out_eager, lse_eager = flex_attention( query, key, value, return_lse=True, kernel_options=kernel_options ) flex_compile = torch.compile(flex_attention, fullgraph=True) out_compiled, lse_compiled = flex_compile( query, key, value, return_lse=True, kernel_options=kernel_options ) assert torch.equal(out_eager, out_compiled) assert torch.equal(lse_eager, lse_compiled) grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) torch.testing.assert_close(grads_eager, grads_compile) @supported_platform @skip_on_cpu def test_dynamic_shapes_bug_dynamic_batch(self, device): def _flex_attention_mask(b, h, q_idx, kv_idx, input_lengths): padding_condition = (q_idx < input_lengths[b]) & (kv_idx < input_lengths[b]) return padding_condition counter = CompileCounterWithBackend("inductor") class Model(torch.nn.Module): def __init__(self, dim=1024): super().__init__() self.subsampler = torch.nn.Conv1d(256, 256, 5) self.projector = torch.nn.Linear(256, dim) self.num_heads = 4 def forward(self, x, input_lengths): x = self.subsampler(x.transpose(-1, -2)).transpose(-1, -2) x = self.projector(x).transpose(0, 1) head_dim = x.size(-1) // self.num_heads x = x.view(-1, x.size(1), self.num_heads, head_dim) x = x.permute(1, 2, 0, 3) max_time = x.size(-2) mask = torch.compile(create_block_mask, dynamic=True, fullgraph=False)( functools.partial( _flex_attention_mask, input_lengths=input_lengths, ), B=input_lengths.size(0), H=None, Q_LEN=max_time, KV_LEN=max_time, device=device, ) x = torch.compile( flex_attention, dynamic=True, fullgraph=True, backend=counter )( query=x, key=x, value=x, block_mask=mask, ) return x model = Model(128).to(device) B, F, T = 16, 256, 12 for _ in range(5): x = torch.randn(B, T, F, device=device) l = torch.randint(0, T, (B,), device=device) model(x, l) assert counter.frame_count == 1, ( f"Expected 1 graph, but got {counter.frame_count} graphs" ) @supported_platform @skip_on_cpu def test_dynamic_shapes_with_custom_kernel_options(self, device): make_tensor = functools.partial( torch.ones, (8, 8, 1024, 64), device=device, dtype=torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() kernel_options = {"BLOCK_M": 64, "BLOCK_N": 64} out_eager = flex_attention(query, key, value, kernel_options=kernel_options) flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True) out_compiled = flex_compile(query, key, value, kernel_options=kernel_options) torch.testing.assert_close(out_eager, out_compiled, atol=3e-3, rtol=2e-3) @supported_platform def test_dynamic_shapes_with_max_autotune(self, device): make_tensor = functools.partial( torch.ones, (8, 8, 1024, 64), device=device, dtype=torch.float if device == "cpu" else torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() block_mask = create_block_mask( _causal_mask, None, None, 1024, 1024, device=device ) out_eager = flex_attention(query, key, value, block_mask=block_mask) flex_compile = torch.compile( flex_attention, fullgraph=True, dynamic=True, mode="max-autotune" ) out_compiled = flex_compile(query, key, value, block_mask=block_mask) torch.testing.assert_close(out_eager, out_compiled, atol=3e-3, rtol=2e-3) @supported_platform @skip_on_cpu def test_zero_length_sequence_error(self, device): make_tensor = functools.partial( torch.ones, (8, 8, 0, 64), # Zero in sequence dimension device=device, dtype=torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() # Test compiled mode - should also raise assertion error flex_compile = torch.compile(flex_attention, fullgraph=True) with self.assertRaisesRegex( torch._inductor.exc.InductorError, "Query length must be greater than 0" ): flex_compile(query, key, value) @supported_platform def test_causal_block_non_divisible_with_captured_buffer( self, device, ): Q_S = S - 3 KV_S = S - 3 offset_q = torch.randn(Q_S, device=device, dtype=torch.bfloat16) offset_kv = torch.randn(KV_S, device=device, dtype=torch.bfloat16) def score_mod(score, b, h, q, kv): return score + offset_q[q] + offset_kv[kv] def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device) attention = functools.partial(flex_attention, block_mask=block_mask) self.run_test_with_call( attention, Q_S=Q_S, KV_S=KV_S, dtype=torch.bfloat16, device=device ) @supported_platform def test_non_divisible_with_captured_buffer(self, device): Q_S = S + 3 KV_S = S + 3 multiplier = torch.randn(Q_S, device=device, dtype=torch.bfloat16) def apply_multiplicative_bias(score, b, h, q_idx, kv_idx): return score * multiplier[q_idx] attention = functools.partial( flex_attention, score_mod=apply_multiplicative_bias ) self.run_test_with_call( attention, Q_S=Q_S, KV_S=KV_S, dtype=torch.bfloat16, device=device ) @supported_platform def test_num_warps_8_error(self, device): attention = functools.partial(flex_attention, score_mod=_identity) self.run_test_with_call( attention, dtype=torch.float16, device=device, Q_S=128, KV_S=128, Q_D=128, V_D=128, ) @supported_platform @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_qkv_and_block_mask_on_the_same_device(self, device): make_tensor = functools.partial( torch.ones, (2, 2, 256, 32), device="cuda:0", dtype=torch.float32, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() def mask_mod(b, h, q, kv): return q >= kv block_mask = create_block_mask(mask_mod, 1, 1, 256, 256, device="cuda:1") with self.assertRaisesRegex( RuntimeError, "Expect q/k/v and block_mask to be on the same device" ): torch.compile(flex_attention)(query, key, value, block_mask=block_mask) @supported_platform @skip_on_cpu @unittest.skipIf(config.triton.native_matmul, "different dynamo counters") def test_free_symbol_dynamic(self, device): def batch_flip_causal(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (b % 2 == 0) class SimpleAttention(torch.nn.Module): def __init__(self, dim=512, n_head=8): super().__init__() self.qkv = torch.nn.Linear(dim, 3 * dim) self.n_head = n_head self.head_dim = dim // n_head def forward(self, x, block_mask=None): B, T, C = x.size() qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv y = flex_attention( q, k, v, block_mask=block_mask, ) return y.transpose(1, 2).contiguous().view(B, T, C) model = SimpleAttention().to(device) model.compile(mode="default", dynamic=True) sequence_len = 256 # Test different batch shapes with dense masks torch._dynamo.reset() for batch_shape in [4, 16, 32]: # Create dense mask rand_mask = torch.randint( 0, 2, (batch_shape, sequence_len), device=device ).bool() block_mask = torch.compile(create_block_mask, dynamic=True)( B=batch_shape, BLOCK_SIZE=128, mask_mod=lambda b, h, q_idx, kv_idx: ~rand_mask[b, q_idx], H=None, Q_LEN=sequence_len, KV_LEN=sequence_len, device=device, ) # Run forward pass x = torch.randn(batch_shape, sequence_len, 512, device=device) model(x, block_mask=block_mask) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) @supported_platform @skip_on_cpu def test_symbol_closure_in_score_mod(self, device): class SimpleAttention(torch.nn.Module): def __init__(self, dim=512, n_head=8): super().__init__() self.qkv = torch.nn.Linear(dim, 3 * dim) self.n_head = n_head self.head_dim = dim // n_head def forward(self, x, block_mask=None): B, T, C = x.size() qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv return flex_attention( q, k, v, score_mod=lambda s, b, h, q, k: s + B, block_mask=block_mask, ) model = SimpleAttention().to(device) from torch._dynamo.testing import EagerAndRecordGraphs backend = EagerAndRecordGraphs() model.compile(mode="default", dynamic=True, backend=backend) sequence_len = 256 torch._dynamo.reset() for batch_shape in [4, 16, 32]: x = torch.randn(batch_shape, sequence_len, 512, device=device) model(x) self.assertEqual(len(backend.graphs), 1) self.assertExpectedInline( backend.graphs[0].score_mod_0.code.strip(), """\ def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch.Tensor, child_3 : torch.Tensor, child_4 : torch.Tensor, getitem : torch.SymInt): add = child + getitem; child = getitem = None return add""", ) @supported_platform @skip_on_cpu def test_fw_bw_graph_correctness(self, device): cnt = CompileCounterWithBackend("aot_eager") make_tensor = functools.partial( torch.randn, (2, 2, 128, 4), device=device, dtype=torch.float64, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 128, 128, device=device) func = torch.compile(flex_attention, backend=cnt, fullgraph=True) out = func(query, key, value, _squared, block_mask=block_mask) out.sum().backward() self.assertEqual(cnt.frame_count, 1) self.assertEqual(len(cnt.graphs), 1) graph = cnt.graphs[0] norm_graph = normalize_gm(graph.print_readable(print_output=False)) self.assertExpectedInline( norm_graph, """\ class GraphModule(torch.nn.Module): def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"): l_query_ = L_query_ l_key_ = L_key_ l_value_ = L_value_ l_block_mask_kv_indices = L_block_mask_kv_indices l_block_mask_kv_num_blocks = L_block_mask_kv_num_blocks l_block_mask_full_kv_num_blocks = L_block_mask_full_kv_num_blocks l_block_mask_full_kv_indices = L_block_mask_full_kv_indices l_block_mask_q_num_blocks = L_block_mask_q_num_blocks l_block_mask_q_indices = L_block_mask_q_indices l_block_mask_full_q_num_blocks = L_block_mask_full_q_num_blocks l_block_mask_full_q_indices = L_block_mask_full_q_indices score_mod_0 = self.score_mod_0 mask_fn_0 = self.mask_fn_0 flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None return (out,) class score_mod_0(torch.nn.Module): def forward(self, child: "f64[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]", child_4: "i32[]"): mul: "f64[]" = child * child; child = None return mul class mask_fn_0(torch.nn.Module): def forward(self, child: "i32[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]"): ge: "b8[]" = child_2 >= child_3; child_2 = child_3 = None return ge """, # noqa: B950 ) # Save the AOT graphs aot_graphs = [] from torch._inductor import compile_fx def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): aot_graphs.append(graph) return graph backend = functools.partial( compile_fx.compile_fx, inner_compile=debug_compile_fx_inner ) func = torch.compile(func, backend=backend, fullgraph=True) out = func(query, key, value, _squared) out.sum().backward() joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False)) self.assertExpectedInline( joint_graph, """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"): full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) fw_graph0 = self.fw_graph0 joint_graph0 = self.joint_graph0 mask_graph0 = self.mask_graph0 flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0] getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1] getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None return (getitem_5, getitem_6, getitem_7) class fw_graph0(torch.nn.Module): def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"): mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul class joint_graph0(torch.nn.Module): def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"): mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1) mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None return [add, None, None, None, None] class mask_graph0(torch.nn.Module): def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) return full_default """.replace( # noqa: B950 "GPU_TYPE", torch.device(device).type ), ) @supported_platform def test_tensor_subclass_dispatch_order(self, device): """Test that tensor subclasses get proper dispatch priority over modes. This test verifies the fix that allows tensor subclasses' pyimpl to run before FakeTensorMode/FunctionalTensorMode implementations, preventing issues where subclasses that error on as_strided would fail in flex_attention. """ import torch.utils._pytree as pytree from torch.utils._python_dispatch import return_and_correct_aliasing class AsStridedErrorTensor(torch.Tensor): @staticmethod def __new__(cls, elem): assert isinstance(elem, torch.Tensor) return torch.Tensor._make_wrapper_subclass( cls, elem.shape, strides=elem.stride(), storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, device=elem.device, requires_grad=elem.requires_grad, ) def __init__(self, elem): self.elem = elem def __repr__(self): return f"AsStridedErrorTensor({self.elem})" def __tensor_flatten__(self): return ["elem"], None @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert meta is None elem = inner_tensors["elem"] return AsStridedErrorTensor(elem) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # Error if as_strided is called if func is torch.ops.aten.as_strided.default: raise RuntimeError("as_strided was called on AsStridedErrorTensor!") if kwargs is None: kwargs = {} args_elem = pytree.tree_map_only( AsStridedErrorTensor, lambda x: x.elem, args ) kwargs_elem = pytree.tree_map_only( AsStridedErrorTensor, lambda x: x.elem, kwargs ) out = func(*args_elem, **kwargs_elem) def wrap_output(x): if isinstance(x, torch.Tensor): return AsStridedErrorTensor(x) return x out_wrapped = pytree.tree_map(wrap_output, out) return return_and_correct_aliasing(func, args, kwargs, out_wrapped) from torch._higher_order_ops.flex_attention import ( flex_attention as flex_attention_hop, ) @flex_attention_hop.py_impl(AsStridedErrorTensor) def flex_attention_as_strided_error_tensor( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, score_mod, block_mask, scale, kernel_options, score_mod_other_buffers=(), mask_mod_other_buffers=(), ): inner_q, inner_k, inner_v = query.elem, key.elem, value.elem out, lse, max_scores = flex_attention_hop( inner_q, inner_k, inner_v, score_mod, block_mask, scale, kernel_options, score_mod_other_buffers, mask_mod_other_buffers, ) return ( AsStridedErrorTensor(out), AsStridedErrorTensor(lse), AsStridedErrorTensor(max_scores), ) # Test setup B, H, S, D = 2, 1, 128, 16 dtype = torch.float32 # Create regular tensors query_elem = torch.randn(B, H, S, D, device=device, dtype=dtype) key_elem = torch.randn(B, H, S, D, device=device, dtype=dtype) value_elem = torch.randn(B, H, S, D, device=device, dtype=dtype) # Test 1: Verify as_strided raises error when called directly on AsStridedErrorTensor test_tensor = AsStridedErrorTensor(query_elem) with self.assertRaisesRegex( RuntimeError, "as_strided was called on AsStridedErrorTensor!" ): torch.as_strided( test_tensor, size=(B, H, S, D), stride=test_tensor.stride() ) # Test 2: Run flex_attention with normal tensors first compiled_fn = torch.compile(flex_attention, backend="aot_eager") normal_out, normal_lse = compiled_fn( query_elem, key_elem, value_elem, return_lse=True ) # Test 3: Wrap in our subclass query = AsStridedErrorTensor(query_elem) key = AsStridedErrorTensor(key_elem) value = AsStridedErrorTensor(value_elem) # This should NOT error with as_strided after the fix # Before the fix, it would error because FakeTensorMode would directly # call flex_attention_fake_impl which uses as_strided out, lse = compiled_fn(query, key, value, return_lse=True) # Verify we got valid output self.assertIsInstance(out, AsStridedErrorTensor) self.assertIsInstance(lse, AsStridedErrorTensor) self.assertEqual(out.shape, (B, H, S, D)) self.assertEqual(lse.shape, (B, H, S)) # Test 4: Compare outputs between normal tensors and subclassed tensors torch.testing.assert_close(out.elem, normal_out, rtol=1e-5, atol=1e-5) torch.testing.assert_close(lse.elem, normal_lse, rtol=1e-5, atol=1e-5) @supported_platform @skip_on_cuda def test_cpu_error_message_return_lse(self, device): make_tensor = functools.partial( torch.randn, (2, 2, 128, 16), device="cpu", dtype=torch.float32, requires_grad=False, ) query, key, value = make_tensor(), make_tensor(), make_tensor() attention = torch.compile(flex_attention) with self.assertRaisesRegex( torch._inductor.exc.InductorError, r"NotImplementedError: torch.compile on CPU only supports inference and `return_lse` is not supported yet.", ): attention(query, key, value, return_lse=True) @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_device_cuda_1(self, device): class TestModule(torch.nn.Module): def forward(self, q, k, v, block_mask): return flex_attention(q, k, v, block_mask=block_mask) q = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16) k = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16) v = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16) mask = create_block_mask( lambda b, h, q_idx, kv_idx: q_idx >= kv_idx, B=None, H=None, Q_LEN=256, KV_LEN=256, device="cuda:1", ) mod = torch.compile(TestModule()) attn_output = mod(q, k, v, mask) self.assertEqual(attn_output.device, torch.device("cuda:1")) @supported_platform @skip_on_cpu def test_custom_score_mod_layout_freeze(self, device): torch.manual_seed(0) class FlexAttentionCPB(nn.Module): def __init__(self, N: int, R: int, H: int = 4, hidden: int = 32): super().__init__() self.mlp = nn.Sequential( nn.Linear(2, hidden), nn.GELU(), nn.Linear(hidden, H, bias=False), ) self.gamma = nn.Parameter(torch.zeros(H)) self.H = H self._init_tables(N, R) self.register_buffer( "r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False ) def _init_tables(self, N: int, R: int) -> None: P = N - R S = int(P**0.5) assert S * S == P rng = torch.arange(-(S - 1), S, dtype=torch.float32) dY, dX = torch.meshgrid(rng, rng, indexing="ij") rel = torch.stack( [dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1 ).reshape(-1, 2) rel_table = torch.sign(rel) * torch.log1p(rel.abs()) self.register_buffer("rel_table", rel_table, persistent=False) yy, xx = torch.arange(S), torch.arange(S) Y, X = torch.meshgrid(yy, xx, indexing="ij") flat = torch.stack([Y, X], 0).flatten(1) d = flat[:, :, None] - flat[:, None, :] d = d.permute(1, 2, 0).contiguous() d[:, :, 0] += S - 1 d[:, :, 1] += S - 1 d[:, :, 0] *= 2 * S - 1 l_idx = d.sum(-1).to(torch.long) idx = torch.full((N, N), 0, dtype=torch.long) idx[R:, R:] = l_idx self.register_buffer("idx_table", idx, persistent=False) def _score_mod(self, mu: torch.Tensor): bt = self.mlp(self.rel_table) idx = self.idx_table mu_q, mu_k = mu.unbind(2) gam_sig = torch.sigmoid(self.gamma) def score_mod(score, b, h, q, kv): has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff) l2 = idx[q, kv] bias = bt[l2, h] w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv]) return score + has_bias.to(score.dtype) * w_gate * bias return score_mod def forward(self, q, k, v, mu): return flex_attention(q, k, v, score_mod=self._score_mod(mu)) dtype = torch.bfloat16 if PLATFORM_SUPPORTS_BF16 else torch.float16 device_obj = torch.device(device) module = FlexAttentionCPB(N=18, R=2).to(device_obj) compiled_module = torch.compile(module, backend="inductor", dynamic=False) q = torch.randn(2, 4, 18, 32, device=device_obj, dtype=dtype) k = torch.randn_like(q) v = torch.randn_like(q) mu = torch.randn(2, 4, 2, 18, device=device_obj) with torch.no_grad(): with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): eager_out = module(q, k, v, mu) compiled_out = compiled_module(q, k, v, mu) self.assertEqual(compiled_out.shape, eager_out.shape) torch.testing.assert_close( compiled_out.float(), eager_out.float(), atol=2e-2, rtol=2e-2 ) @supported_platform @skip_on_cpu @common_utils.parametrize( "ops_to_save", [ [ torch.ops.aten.mm.default, ], [ flex_attention_hop, ], [torch.ops.aten.mm.default, flex_attention_hop], ], ) def test_selective_ac(self, device, ops_to_save): class FlexAttentionModule(nn.Module): def __init__(self, hidden_size, num_heads): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads # In-projections (query, key, value) self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(hidden_size, hidden_size) self.v_proj = nn.Linear(hidden_size, hidden_size) # Out-projection self.out_proj = nn.Linear(hidden_size, hidden_size) def forward(self, x): batch_size, seq_len, _ = x.size() # Project queries, keys, and values q = ( self.q_proj(x) .view(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) k = ( self.k_proj(x) .view(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) v = ( self.v_proj(x) .view(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) # Apply flex attention attn_output = flex_attention( q, k, v, ) # Reshape output attn_output = ( attn_output.transpose(1, 2) .contiguous() .view(batch_size, seq_len, self.hidden_size) ) # Out projection output = self.out_proj(attn_output) return output from torch.utils.checkpoint import ( checkpoint, create_selective_checkpoint_contexts, ) context_fn = functools.partial( create_selective_checkpoint_contexts, ops_to_save ) # Define a model that uses FlexAttention with selective activation checkpointing class SacModule(nn.Module): def __init__(self, hidden_size, num_heads, context_fn): super().__init__() self.flex_attn = FlexAttentionModule(hidden_size, num_heads) self.context_fn = context_fn def forward(self, x): def flex_attn_fn(x): return self.flex_attn(x) output = checkpoint( flex_attn_fn, x, use_reentrant=False, context_fn=self.context_fn, ) return output flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to( device, dtype=torch.bfloat16 ) x = torch.ones(8, 1024, 512, device=device, dtype=torch.bfloat16) # Run without compilation output_module = flex_module(x) compiled_module = torch.compile(flex_module) output_compiled = compiled_module(x) torch.testing.assert_close(output_module, output_compiled, rtol=1e-2, atol=1e-2) # Calculate gradients and compare them x.requires_grad_(True) output_module = flex_module(x) output_compiled = compiled_module(x) grad_output = torch.ones_like(output_module) grad_module = torch.autograd.grad( outputs=output_module, inputs=x, grad_outputs=grad_output, retain_graph=True )[0] grad_compiled = torch.autograd.grad( outputs=output_compiled, inputs=x, grad_outputs=grad_output )[0] torch.testing.assert_close(grad_module, grad_compiled, rtol=1e-2, atol=1e-2) @supported_platform @skip_on_cpu def test_selective_ac_with_max_autotune_short_query(self, device): from functools import partial from torch.utils.checkpoint import ( checkpoint, CheckpointPolicy, create_selective_checkpoint_contexts, ) compute_intensive_ops = [ torch.ops.aten.mm, torch.ops.aten.bmm, ] def policy_fn(ctx, op, *args, **kwargs): if op in compute_intensive_ops: return CheckpointPolicy.MUST_SAVE else: return CheckpointPolicy.PREFER_RECOMPUTE def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx class DummyAttentionModule(nn.Module): def __init__(self, dim=64, num_heads=4): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) self.v_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim) self._activation_checkpoint_context_fn = partial( create_selective_checkpoint_contexts, policy_fn ) self._flex_attention = torch.compile( partial( checkpoint, flex_attention, use_reentrant=False, context_fn=self._activation_checkpoint_context_fn, ), mode="max-autotune-no-cudagraphs", ) def forward(self, x, block_mask): batch_size, seq_len, _ = x.shape q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) q = q.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) k = k.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) v = v.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) attn_out = self._flex_attention(q, k, v, block_mask=block_mask) attn_out = ( attn_out.transpose(1, 2) .contiguous() .view(batch_size, seq_len, self.dim) ) out = self.out_proj(attn_out) return out batch_size = 2 seq_len = 64 dim = 64 num_heads = 4 model = DummyAttentionModule(dim=dim, num_heads=num_heads).to(device) x = torch.randn(batch_size, seq_len, dim, device=device, requires_grad=True) block_mask = create_block_mask( causal_mask, B=batch_size, H=num_heads, Q_LEN=seq_len, KV_LEN=seq_len, device=device, ) out = model(x, block_mask) loss = out.sum() loss.backward() self.assertIsNotNone(x.grad) @supported_platform @skip_on_cpu def test_validate_small_embedding_size_error_message(self, device): # eager support for small embedding size q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)] flex_attention(q, k, v) # compiled cpu support for small embedding size q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)] flex_attention(q, k, v) # compiled gpu kernel does not support small embedding size q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)] compiled_fa = torch.compile(flex_attention) with self.assertRaisesRegex( torch._inductor.exc.InductorError, "NYI: embedding dimension of the query, key, and value must be " "at least 16 but got E=8 and Ev=8", ): compiled_fa(q, k, v) # compiled gpu kernel supports large embedding size q, k, v = [torch.randn(2, 2, 128, 16, device=device) for _ in range(3)] compiled_fa = torch.compile(flex_attention) @unittest.skipIf( not has_triton() or not HAS_WARP_SPEC, reason="FBCODE Triton is required for this test", ) def test_triton_template_warp_specialization(self, device): def make_tensor(): return torch.rand(4, 16, 4096, 64, device=device, dtype=torch.bfloat16) q, k, v = make_tensor(), make_tensor(), make_tensor() flex_compiled = torch.compile(flex_attention, fullgraph=True) positional_args = (q, k, v) keyword_args = { "kernel_options": { "num_warps": 4, "num_consumer_groups": 2, "num_buffers_warp_spec": 3, } } # Check if kernel code contains warp specialization parameters _, kernel_code = run_and_get_code( flex_compiled, *positional_args, **keyword_args, ) assert kernel_code is not None, "Failed to retrieve compiled kernel code" assert "num_consumer_groups" in kernel_code[0], ( "num_consumer_groups missing in kernel definition" ) assert "num_buffers_warp_spec" in kernel_code[0], ( "num_buffers_warp_spec missing in kernel definition" ) # Validate correctness C1 = flex_compiled(q, k, v) C2 = flex_attention(q, k, v) assert torch.allclose(C1, C2, atol=1e-2, rtol=1e-2), ( "Warp specialized kernel result differs from reference" ) @supported_platform @skip_on_cpu @skipCUDAIf(not has_triton_tma_device(), "Requires TMA enabled CUDA device") def test_tma_with_customer_kernel_options(self, device): make_tensor = functools.partial( torch.ones, (1, 1, 256, 128), device=device, dtype=torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() kernel_options_1 = { "BLOCK_M": 128, "BLOCK_N": 128, "USE_TMA": False, } kernel_options_2 = {"BLOCK_M": 128, "BLOCK_N": 128, "USE_TMA": True} flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True) out_compiled = flex_compile(query, key, value, kernel_options=kernel_options_1) out_tma_compiled = flex_compile( query, key, value, kernel_options=kernel_options_2 ) # vanilla compiled vs TMA compiled torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1) @supported_platform @skip_on_cpu def test_large_batch_heads_grid_dimension(self, device): B, H, S, D = 22720, 3, 64, 32 make_tensor = functools.partial( torch.randn, (B, H, S, D), device=device, dtype=torch.float16, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True) out_compiled = flex_compile(query, key, value) self.assertEqual(out_compiled.shape, (B, H, S, D)) grad_output = torch.randn_like(out_compiled) out_compiled.backward(grad_output) self.assertIsNotNone(query.grad) self.assertIsNotNone(key.grad) self.assertIsNotNone(value.grad) self.assertEqual(query.grad.shape, query.shape) self.assertEqual(key.grad.shape, key.shape) self.assertEqual(value.grad.shape, value.shape) @supported_platform def test_debug_flag_disables_internal_compilation(self, device): """Test that _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG flag bypasses internal compilation.""" import torch.nn.attention.flex_attention as fa original_flag = fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG original_warnings_shown = fa._WARNINGS_SHOWN.copy() try: B, H, S, D = 1, 1, 128, 64 query = torch.randn(B, H, S, D, device=device, dtype=torch.float32) key = torch.randn(B, H, S, D, device=device, dtype=torch.float32) value = torch.randn(B, H, S, D, device=device, dtype=torch.float32) def simple_score_mod(score, b, h, q_idx, kv_idx): return score # Test with debug flag False - should warn fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False fa._WARNINGS_SHOWN.clear() with self.assertWarns(UserWarning) as cm: out_compiled = fa.flex_attention( query, key, value, score_mod=simple_score_mod ) self.assertIn( "flex_attention called without torch.compile", str(cm.warning) ) # Test with debug flag True - should NOT warn fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True # Should not error with warnings.catch_warnings(): warnings.simplefilter("error") out_debug = fa.flex_attention( query, key, value, score_mod=simple_score_mod ) torch.testing.assert_close(out_compiled, out_debug, rtol=1e-4, atol=1e-4) finally: fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag fa._WARNINGS_SHOWN = original_warnings_shown class TestBlockMask(InductorTestCase): def setUp(self): super().setUp() @supported_platform def test_block_mask_attributes(self, device): offset = torch.zeros(8, device=device) def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device) self.assertEqual(block_mask.shape, (4, 2, 2048, 2048)) self.assertEqual(block_mask[0].shape, (1, 2, 2048, 2048)) self.assertEqual(block_mask[0, 0].shape, (1, 1, 2048, 2048)) self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048) self.assertEqual(block_mask.sparsity(), 46.875) self.assertEqual(block_mask[0].sparsity(), 46.875) self.assertEqual(block_mask[1, 0].sparsity(), 46.875) self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity()) offset = torch.arange(8, device=device) block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048, device=device) self.assertEqual(block_mask.sparsity(), 29.1015625) self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity()) self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity()) @supported_platform @common_utils.parametrize("BLOCK_SIZE", [32, 64, 128, 256, (32, 64), (64, 32)]) def test_block_size_changes(self, device, BLOCK_SIZE: Union[int, tuple[int, int]]): B, H, Q_LEN, KV_LEN = 4, 2, 2048, 2048 if isinstance(BLOCK_SIZE, int): Q_BLOCK_SIZE = BLOCK_SIZE KV_BLOCK_SIZE = BLOCK_SIZE else: Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE block_mask = create_block_mask( noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE, device=device ) self.assertEqual(block_mask.BLOCK_SIZE, (Q_BLOCK_SIZE, KV_BLOCK_SIZE)) self.assertEqual(block_mask.shape, (B, H, Q_LEN, KV_LEN)) @supported_platform def test_getitem(self, device): offset = torch.zeros(8, device=device) def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device=device) assert block_mask.kv_num_blocks.shape == (4, 2, 4) assert block_mask.kv_indices.shape == (4, 2, 4, 4) # Index on batch dimension new_block_mask = block_mask[0] assert new_block_mask.kv_num_blocks.shape == (1, 2, 4) assert new_block_mask.kv_indices.shape == (1, 2, 4, 4) # Index on batch and head dimension new_block_mask = block_mask[0, 1] assert new_block_mask.kv_num_blocks.shape == ( 1, 1, 4, ) assert new_block_mask.kv_indices.shape == (1, 1, 4, 4) # Index on batch and head dimension with -1 semantics new_block_mask = block_mask[-1, -2] assert new_block_mask.kv_num_blocks.shape == ( 1, 1, 4, ) assert new_block_mask.kv_indices.shape == (1, 1, 4, 4) # slicing on batch and head dimension new_block_mask = block_mask[0:2, 1:2] assert new_block_mask.kv_num_blocks.shape == (2, 1, 4) assert new_block_mask.kv_indices.shape == (2, 1, 4, 4) # slicing on batch, head, and query dimension new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) # slicing on batch, head, and query dimension q_index = torch.tensor([0], dtype=torch.int32) new_block_mask = block_mask[:, :, q_index] self.assertEqual(new_block_mask.kv_num_blocks.ndim, 3) self.assertEqual(new_block_mask.kv_indices.ndim, 4) torch.testing.assert_close( new_block_mask.kv_num_blocks, block_mask.kv_num_blocks[:, :, q_index], ) torch.testing.assert_close( new_block_mask.kv_indices, block_mask.kv_indices[:, :, q_index, :] ) if block_mask.full_kv_num_blocks is not None: assert new_block_mask.full_kv_num_blocks is not None assert new_block_mask.full_kv_indices is not None torch.testing.assert_close( new_block_mask.full_kv_num_blocks, block_mask.full_kv_num_blocks[:, :, q_index], ) torch.testing.assert_close( new_block_mask.full_kv_indices, block_mask.full_kv_indices[:, :, q_index, :], ) @supported_platform def test_block_mask_device_change(self, device): device = torch.device(device) offset = torch.zeros(8, device=device) def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv block_mask = create_block_mask(causal_mask, 1, 1, 512, 512, device=device) assert block_mask.kv_indices.device.type == device.type assert block_mask.kv_num_blocks.device.type == device.type assert block_mask.q_indices.device.type == device.type assert block_mask.q_num_blocks.device.type == device.type block_mask = block_mask.to("cpu") assert block_mask.kv_indices.is_cpu assert block_mask.kv_num_blocks.is_cpu assert block_mask.q_indices.is_cpu assert block_mask.q_num_blocks.is_cpu block_mask = block_mask.to(device) assert block_mask.kv_indices.device.type == device.type assert block_mask.kv_num_blocks.device.type == device.type assert block_mask.q_indices.device.type == device.type assert block_mask.q_num_blocks.device.type == device.type @supported_platform def test_compiling_create_block_mask(self, device): seq = torch.arange(512, device=device) // 127 def mask_mod(b, h, q, kv): return (q >= kv) & (seq[q] == seq[kv]) block_mask = torch.compile(create_block_mask, fullgraph=True)( mask_mod, 1, 1, 512, 512, device=device ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((1, 1, 4, 4))) @supported_platform def test_compiling_create_block_mask_no_recompile(self, device): def mask_mod(b, h, q, kv): return q >= kv torch._dynamo.reset() block_mask = torch.compile(create_block_mask)( mask_mod, 2, 4, 1024, 1024, device=device ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((2, 4, 8))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((2, 4, 8, 8))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 1) # automatic dynamic shapes triggered and recompilation. block_mask = torch.compile(create_block_mask)( mask_mod, 4, 8, 2048, 2048, device=device ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((4, 8, 16))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((4, 8, 16, 16))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) # no recompilation. block_mask = torch.compile(create_block_mask)( mask_mod, 6, 16, 3072, 3072, device=device ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((6, 16, 24))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((6, 16, 24, 24))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) @supported_platform def test_block_mask_viz(self, device): def causal_mask(b, h, q, kv): return q >= kv block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device) def replace_non_printable(s): def replace(c): if c not in string.printable: return "@" elif c == " ": return "s" return c return "".join(replace(c) for c in s) self.assertExpectedInline( replace_non_printable(str(block_mask)), """\ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s (0,s0) @@ssssssssssssssssssssssssssssss @@@@ssssssssssssssssssssssssssss @@@@@@ssssssssssssssssssssssssss @@@@@@@@ssssssssssssssssssssssss @@@@@@@@@@ssssssssssssssssssssss @@@@@@@@@@@@ssssssssssssssssssss @@@@@@@@@@@@@@ssssssssssssssssss @@@@@@@@@@@@@@@@ssssssssssssssss @@@@@@@@@@@@@@@@@@ssssssssssssss @@@@@@@@@@@@@@@@@@@@ssssssssssss @@@@@@@@@@@@@@@@@@@@@@ssssssssss @@@@@@@@@@@@@@@@@@@@@@@@ssssssss @@@@@@@@@@@@@@@@@@@@@@@@@@ssssss @@@@@@@@@@@@@@@@@@@@@@@@@@@@ssss @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ss @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ )""", ) offset = torch.arange(8, device=device) def causal_offset_mask(b, h, q, kv): return (q + offset[b] * 128) >= kv block_mask = create_block_mask( causal_offset_mask, 8, 1, 2048, 2048, device=device ) str_block_mask = str(block_mask) self.assertTrue("sparsity=29.10" in str_block_mask) def generate_test_inputs(self, full_seq_len: bool, device): if full_seq_len: kv_num_blocks = torch.tensor([1], dtype=torch.int32, device=device).view( 1, 1, 1 ) kv_indices = torch.tensor([1, -1], dtype=torch.int32, device=device).view( 1, 1, 1, 2 ) full_kv_num_blocks = torch.tensor( [1], dtype=torch.int32, device=device ).view(1, 1, 1) full_kv_indices = torch.tensor( [0, -1], dtype=torch.int32, device=device ).view(1, 1, 1, 2) else: kv_num_blocks = torch.tensor([2], dtype=torch.int32, device=device).view( 1, 1, 1 ) kv_indices = torch.tensor([0, 1], dtype=torch.int32, device=device).view( 1, 1, 1, 2 ) full_kv_indices = None full_kv_num_blocks = None return kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices @supported_platform @common_utils.parametrize("full_indices", [False, True]) def test_from_kv_blocks(self, device, full_indices: bool): ( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, ) = self.generate_test_inputs(full_indices, device=device) block_mask = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices ) self.assertIsInstance(block_mask, BlockMask) torch.testing.assert_close(block_mask.kv_num_blocks, kv_num_blocks) torch.testing.assert_close(block_mask.kv_indices, kv_indices) if full_indices: torch.testing.assert_close( block_mask.full_kv_num_blocks, full_kv_num_blocks ) torch.testing.assert_close(block_mask.full_kv_indices, full_kv_indices) torch.testing.assert_close( block_mask.q_num_blocks, torch.tensor([0, 1], dtype=torch.int32, device=device).view(1, 1, 2), ) torch.testing.assert_close( block_mask.q_indices, torch.tensor([0, 0], dtype=torch.int32, device=device).view(1, 1, 2, 1), ) torch.testing.assert_close( block_mask.full_q_num_blocks, torch.tensor([1, 0], dtype=torch.int32, device=device).view(1, 1, 2), ) torch.testing.assert_close( block_mask.full_q_indices, torch.tensor([0, 0], dtype=torch.int32, device=device).view(1, 1, 2, 1), ) else: torch.testing.assert_close( block_mask.q_num_blocks, torch.tensor([1, 1], dtype=torch.int32, device=device).view(1, 1, 2), ) torch.testing.assert_close( block_mask.q_indices, torch.tensor([0, 0], dtype=torch.int32, device=device).view(1, 1, 2, 1), ) self.assertIsNone(block_mask.full_kv_num_blocks) self.assertIsNone(block_mask.full_kv_indices) self.assertIsNone(block_mask.full_q_num_blocks) self.assertIsNone(block_mask.full_q_indices) @supported_platform def test_block_size(self, device): kv_num_blocks, kv_indices, _, _ = self.generate_test_inputs(False, device) block_mask = BlockMask.from_kv_blocks(kv_num_blocks, kv_indices) self.assertEqual( block_mask.BLOCK_SIZE, (_DEFAULT_SPARSE_BLOCK_SIZE, _DEFAULT_SPARSE_BLOCK_SIZE), ) custom_block_size = (64, 64) block_mask_custom = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, BLOCK_SIZE=custom_block_size ) self.assertEqual(block_mask_custom.BLOCK_SIZE, custom_block_size) @supported_platform def test_upcast_appropriately(self, device): q = torch.randn((1, 1, 128, 16), dtype=torch.float16, device=device) k = torch.randn((1, 1, 128, 16), dtype=torch.float16, device=device) v = torch.randn((1, 1, 128, 16), dtype=torch.float16, device=device) mass = torch.ones((1), dtype=torch.float16, device=device) def score_mod(score, b, h, q_idx, kv_idx): return score + torch.log(mass[0]) torch.compile(flex_attention)(q, k, v, score_mod=score_mod) @supported_platform def test_init_mismatched_full_kv(self, device): kv_num_blocks, kv_indices, full_kv_num_blocks, _ = self.generate_test_inputs( True, device ) with self.assertRaises(AssertionError): BlockMask( kv_num_blocks=kv_num_blocks, kv_indices=kv_indices, full_kv_num_blocks=full_kv_num_blocks, full_kv_indices=None, # Mismatched, should raise error q_num_blocks=kv_num_blocks, q_indices=kv_indices, full_q_num_blocks=None, full_q_indices=None, BLOCK_SIZE=(64, 64), mask_mod=noop_mask, seq_lengths=(1, 1), ) @supported_platform def test_init_mismatched_full_q(self, device): kv_num_blocks, kv_indices, _, _ = self.generate_test_inputs(False, device) with self.assertRaises(AssertionError): BlockMask( kv_num_blocks=kv_num_blocks, kv_indices=kv_indices, full_kv_num_blocks=None, full_kv_indices=None, q_num_blocks=kv_num_blocks, q_indices=kv_indices, full_q_num_blocks=kv_num_blocks, full_q_indices=None, # Mismatched, should raise error BLOCK_SIZE=(64, 64), mask_mod=noop_mask, seq_lengths=(1, 1), ) @supported_platform def test_doc_mask_clamped_repro(self, device): def _offsets_to_doc_ids_tensor(offsets): device = offsets.device counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave( torch.arange(len(counts), device=device, dtype=torch.int32), counts ) def length_to_offsets( lengths: list[int], device: Union[str, torch.device] ) -> Tensor: offsets = [0] offsets.extend(lengths) offsets = torch.tensor(offsets, device=device, dtype=torch.int32) offsets = torch.cumsum(offsets, dim=-1) return offsets def generate_doc_mask_mod(offsets: Tensor) -> _mask_mod_signature: document_id = _offsets_to_doc_ids_tensor(offsets) def doc_mask_mod(b, h, q_idx, kv_idx): same_doc = document_id[q_idx] == document_id[kv_idx] return same_doc return doc_mask_mod random.seed(0) def generate_random_lengths(total_length, num_documents): lengths = [1] * num_documents remaining_length = total_length - num_documents for _ in range(remaining_length): index = random.randint(0, num_documents - 1) lengths[index] += 1 return lengths max_seq_len, doc_count = 128, 4 SEQ_LEN = max_seq_len lengths = generate_random_lengths(max_seq_len, doc_count) offsets = length_to_offsets(lengths, device) document_causal_mask = generate_doc_mask_mod(offsets) block_mask_compiled = torch.compile(create_block_mask)( document_causal_mask, 1, 1, SEQ_LEN, SEQ_LEN, device=device, ) block_mask = torch.compile(create_block_mask)( document_causal_mask, 1, 1, SEQ_LEN, SEQ_LEN, device=device, ) self.assertEqual(block_mask_compiled.kv_indices, block_mask.kv_indices) self.assertEqual( block_mask_compiled.full_kv_indices, block_mask.full_kv_indices ) for i in range(5): lengths = generate_random_lengths(1024 + i, 5) offsets = length_to_offsets(lengths, device) doc_ids = _offsets_to_doc_ids_tensor(offsets) def doc_mask_mod(b, h, q_idx, kv_idx): return ( doc_ids[q_idx.clamp(0, doc_ids.shape[0] - 1)] == doc_ids[kv_idx.clamp(0, doc_ids.shape[0] - 1)] ) q, k, v = ( torch.randn(1, 12, 1024 + i, 64, device=device) for _ in range(3) ) block_mask = create_block_mask( doc_mask_mod, None, None, 1024 + i, 1024 + i, device=device ) torch.compile(flex_attention)(q, k, v, block_mask=block_mask) @supported_platform def test_eager_tracing_correctness(self, device): qk_dims = 64 v_dims = 128 q_heads = 4 kv_heads = 2 seq_len = 256 batch_size = 1 make_tensor = functools.partial(torch.randn, device=device, dtype=torch.float16) q = make_tensor(*(batch_size, q_heads, seq_len, qk_dims)) k = make_tensor(*(batch_size, kv_heads, seq_len, qk_dims)) v = make_tensor(*(batch_size, kv_heads, seq_len, v_dims)) def flex_attention_fn(): out = flex_attention(q, k, v, enable_gqa=True) return out.view(batch_size, q_heads, seq_len, 2, 64) # Run with compilation compiled_fn = torch.compile(flex_attention_fn, fullgraph=True) result = compiled_fn() # Assert expected output shape expected_shape = (batch_size, q_heads, seq_len, 2, 64) self.assertEqual( result.shape, expected_shape, f"Expected output shape {expected_shape}, but got {result.shape}", ) @supported_platform @skip_on_xpu def test_create_is_cuda_graphable(self, device): def mask_mod(b, h, q, kv): return q >= kv g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): create_block_mask(mask_mod, None, None, 256, 256) g.replay() @common_utils.parametrize("compile", [False, True]) @supported_platform def test_block_mask_vs_sequence_lengths(self, device, compile): if compile: flex_attention_call = torch.compile(flex_attention) else: flex_attention_call = flex_attention def mask_mod(b, h, q_idx, kv_idx): return q_idx >= kv_idx def create_inputs(S): q, k, v = ( torch.randn( 1, 8, S, 64, dtype=torch.float16, requires_grad=True, device=device ) for _ in range(3) ) return q, k, v block_mask = create_block_mask(mask_mod, None, None, 1024, 1024, device=device) flex_attention_call(*create_inputs(1024), block_mask=block_mask) with self.assertRaisesRegex(ValueError, "block_mask was created for"): flex_attention_call(*create_inputs(2048), block_mask=block_mask) block_mask = create_block_mask(mask_mod, None, None, 1023, 1023, device=device) with self.assertRaisesRegex(ValueError, "block_mask was created for"): flex_attention_call(*create_inputs(1024), block_mask=block_mask) @supported_platform @common_utils.parametrize("full_indices", [False, True]) def test_from_kv_blocks_without_q_computation(self, device, full_indices: bool): ( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, ) = self.generate_test_inputs(full_indices, device=device) block_mask = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, compute_q_blocks=False, ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks, kv_num_blocks) self.assertEqual(block_mask.kv_indices, kv_indices) self.assertIsNone(block_mask.q_num_blocks) self.assertIsNone(block_mask.q_indices) self.assertIsNone(block_mask.full_q_num_blocks) self.assertIsNone(block_mask.full_q_indices) if full_indices: self.assertEqual(block_mask.full_kv_num_blocks, full_kv_num_blocks) self.assertEqual(block_mask.full_kv_indices, full_kv_indices) else: self.assertIsNone(block_mask.full_kv_num_blocks) self.assertIsNone(block_mask.full_kv_indices) @supported_platform @skip_on_cpu def test_backward_error_with_none_q_indices(self, device): N_BLOCKS = 4 B, H, S, D = 1, 1, 128, 64 S_KV = N_BLOCKS * S kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device) kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device) block_mask = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, compute_q_blocks=False ) q = torch.randn( B, H, S, D, dtype=torch.float16, device=device, requires_grad=True ) k = torch.randn( B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True ) v = torch.randn( B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True ) flex_compile = torch.compile(flex_attention, fullgraph=True) with torch.no_grad(): out_no_grad = flex_compile(q, k, v, block_mask=block_mask) self.assertEqual(out_no_grad.shape, (B, H, S, D)) # Forward pass with grad enabled should error immediately with self.assertRaisesRegex( RuntimeError, "BlockMask q_indices is None. Backward pass requires q_indices to be computed. " "Please create the BlockMask with compute_q_blocks=True", ): flex_compile(q, k, v, block_mask=block_mask) @supported_platform @skip_on_cpu def test_flex_attention_poisoned_rel_logits(self, device): B = 1 H = 1 S = 1025 D = 64 q, k, v = [ torch.randn(B, H, S, D, requires_grad=True, device=device) for _ in range(3) ] rel_logits = torch.randn(2 * B, H, S, S, device=device) rel_logits[B:] = float("nan") def score_mod(score, b, h, q, kv): return score + rel_logits[b, h, q, kv] def causal( b: torch.Tensor, h: torch.Tensor, q: torch.Tensor, kv: torch.Tensor ) -> torch.Tensor: return q >= kv block_mask = create_block_mask(causal, B, H, S, S, device=device) out = torch.compile(flex_attention)( q, k, v, score_mod=score_mod, block_mask=block_mask ) out.sum().backward() assert out.isfinite().all().item() assert q.grad.isfinite().all().item() assert k.grad.isfinite().all().item() assert v.grad.isfinite().all().item() @supported_platform @skip_on_cpu def test_flex_attention_poison_mod_fwd(self, device): """Div by score should cause our edge case handiling to NaN""" B = 1 H = 1 S = 257 D = 16 q, k, v = [ torch.randn(B, H, S, D, requires_grad=True, device=device) for _ in range(3) ] def score_mod(score, b, h, q, kv): return 1 / score def causal( b: torch.Tensor, h: torch.Tensor, q: torch.Tensor, kv: torch.Tensor ) -> torch.Tensor: return q >= kv block_mask = create_block_mask(causal, B, H, S, S, device=device) out = torch.compile(flex_attention, backend="inductor")( q, k, v, score_mod=score_mod, block_mask=block_mask ) out.sum().backward() assert out.isfinite().all().item() assert q.grad.isfinite().all().item() # assert k.grad.isfinite().all().item() assert v.grad.isfinite().all().item() @supported_platform @skip_on_cpu def test_flex_attention_poison_mod_bwd(self, device): """log score should cause our edge case handiling for NaN in grad score""" B = 1 H = 1 S = 257 D = 16 q, k, v = [ torch.randn(B, H, S, D, requires_grad=True, device=device) for _ in range(3) ] def score_mod(score, b, h, q, kv): return torch.where(score > 0, torch.log(score), score) def causal( b: torch.Tensor, h: torch.Tensor, q: torch.Tensor, kv: torch.Tensor ) -> torch.Tensor: return q >= kv block_mask = create_block_mask(causal, B, H, S, S, device=device) out = torch.compile(flex_attention, backend="inductor")( q, k, v, score_mod=score_mod, block_mask=block_mask ) out.sum().backward() assert out.isfinite().all().item() assert q.grad.isfinite().all().item() # assert k.grad.isfinite().all().item() assert v.grad.isfinite().all().item() @supported_platform @skip_on_cpu def test_forward_pass_with_none_q_indices(self, device): N_BLOCKS = 4 B, H, S, D = 1, 1, 128, 64 S_KV = N_BLOCKS * S kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device) kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device) block_mask = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, compute_q_blocks=False ) q = torch.randn( B, H, S, D, dtype=torch.float16, device=device, ) k = torch.randn( B, H, S_KV, D, dtype=torch.float16, device=device, ) v = torch.randn( B, H, S_KV, D, dtype=torch.float16, device=device, ) flex_compile = torch.compile(flex_attention, fullgraph=True) out = flex_compile(q, k, v, block_mask=block_mask) self.assertEqual(out.shape, (B, H, S, D)) self.assertIsInstance(out, torch.Tensor) self.assertEqual(out.dtype, torch.float16) @supported_platform def test_block_mask_operations_with_none_q_indices(self, device): kv_num_blocks = torch.tensor([[[4]]], dtype=torch.int32, device=device) kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device) block_mask = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, compute_q_blocks=False ) self.assertEqual(block_mask.shape, (1, 1, 128, 512)) self.assertEqual(block_mask.BLOCK_SIZE, (128, 128)) sliced_mask = block_mask[0] self.assertEqual(sliced_mask.shape, (1, 1, 128, 512)) self.assertIsNone(sliced_mask.q_indices) self.assertIsNone(sliced_mask.q_num_blocks) # Test device movement if device != "cpu": cpu_mask = block_mask.to("cpu") self.assertEqual(cpu_mask.kv_num_blocks.device.type, "cpu") self.assertIsNone(cpu_mask.q_indices) @supported_platform @skip_on_cpu def test_broadcasted_head_block_mask(self, device): torch.manual_seed(42) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def get_mask_mod_with_offset(mask_mod, offset_tensor): def _mask_mod(b, h, q, kv): return mask_mod(b, h, q + offset_tensor, kv) return _mask_mod B, T, H, D, current_pos = 4, 512, 8, 64, 128 dtype = torch.float32 q = torch.randn(B, H, 1, D, device=device, dtype=dtype) k_cache = torch.randn(B, H, T, D, device=device, dtype=dtype) v_cache = torch.randn(B, H, T, D, device=device, dtype=dtype) # Keep future tokens tiny to avoid numerical issues when using full caches k_cache[:, :, current_pos + 1 :, :] = ( torch.randn_like(k_cache[:, :, current_pos + 1 :, :]) * 1e-10 ) v_cache[:, :, current_pos + 1 :, :] = ( torch.randn_like(v_cache[:, :, current_pos + 1 :, :]) * 1e-10 ) k_cropped = k_cache[:, :, : current_pos + 1, :] v_cropped = v_cache[:, :, : current_pos + 1, :] sdpa_output = torch.nn.functional.scaled_dot_product_attention( q, k_cropped, v_cropped, attn_mask=None ) base_mask = create_block_mask( causal_mask, B=B, H=None, # broadcast across heads Q_LEN=T, KV_LEN=T, device=device, _compile=True, ) q_block_size = base_mask.BLOCK_SIZE[0] block_offset = current_pos // q_block_size mask_slice = base_mask[:, :, block_offset] offset_tensor = torch.tensor(current_pos, device=device) mask_slice.mask_mod = get_mask_mod_with_offset( base_mask.mask_mod, offset_tensor ) mask_slice.seq_lengths = (1, mask_slice.seq_lengths[1]) fa = torch.compile(flex_attention, dynamic=True) flex_output = fa(q, k_cache, v_cache, block_mask=mask_slice) self.assertEqual(flex_output, sdpa_output, atol=1e-3, rtol=1e-3) @large_tensor_test_class("2GB", device=test_device[0]) class TestPagedAttention(InductorTestCase): def setUp(self): super().setUp() skipCPUIf( LONG_COMPILATION_ON_CPU, "skip UT for CPU due to long compilation time found in CI", ) def _check_equal( self, golden_out: torch.Tensor, ref_out: torch.Tensor, compiled_out: torch.Tensor, fudge_factor: float, tensor_name: Optional[str] = None, ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any(): self.assertTrue(False, "Output/Grad with NaN") if compiled_error > ref_error * fudge_factor: name = tensor_name if tensor_name is not None else "" msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) def allocate_page_cache(self, n_pages: int, page_size: int, device: str): max_batch_size = 3 paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device) return paged_cache def cdiv(self, x, y): return (x + y - 1) // y def roundup(self, x, y): return (x + y - 1) // y * y @supported_platform def test_page_allocation(self, device): n_pages, page_size = 12, 4 paged_cache = self.allocate_page_cache(n_pages, page_size, device=device) batch_reserve(paged_cache, torch.tensor([8, 24, 16])) with self.assertRaisesRegex( AssertionError, "requested 2 pages but there are only 0 empty pages" ): paged_cache.reserve( torch.tensor([0], device=device), torch.tensor([16], device=device), ) paged_cache.erase(torch.tensor([1], device=device)) paged_cache.reserve( torch.tensor([0], device=device), torch.tensor([16], device=device), ) @supported_platform def test_allocate(self, device): n_pages, page_size = 12, 4 paged_cache = self.allocate_page_cache(n_pages, page_size, device=device) target_seq_len = torch.tensor([3, 11, 8]) batch_reserve(paged_cache, target_seq_len) expected_allocated_pages = self.cdiv(target_seq_len, page_size).sum() self.assertEqual(paged_cache.capacity, self.roundup(target_seq_len, page_size)) self.assertEqual( len(paged_cache.empty_pages), n_pages - expected_allocated_pages ) # deallocate batch 1 paged_cache.erase(torch.tensor([1], device=device)) target_seq_len = torch.tensor([3, 0, 8]) expected_allocated_pages = self.cdiv(target_seq_len, page_size).sum() self.assertEqual(paged_cache.capacity, self.roundup(target_seq_len, page_size)) self.assertEqual( len(paged_cache.empty_pages), n_pages - expected_allocated_pages ) # re-allocate target_seq_len = torch.tensor([7, 2, 10]) batch_reserve(paged_cache, target_seq_len) expected_allocated_pages = self.cdiv(target_seq_len, page_size).sum() self.assertEqual(paged_cache.capacity, self.roundup(target_seq_len, page_size)) self.assertEqual( len(paged_cache.empty_pages), n_pages - expected_allocated_pages ) # deallocate all batches paged_cache.erase(torch.tensor([0, 1, 2])) self.assertEqual(paged_cache.capacity, torch.tensor([0, 0, 0])) self.assertEqual(len(paged_cache.empty_pages), n_pages) @supported_platform def test_convert_logical_block_mask(self, device): n_pages, page_size, max_batch_size, max_seq_len = 8, 128, 2, 512 paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device) batch_reserve(paged_cache, torch.tensor([100, 200], device=device)) batch_reserve(paged_cache, torch.tensor([150, 300], device=device)) batch_reserve(paged_cache, torch.tensor([300, 512], device=device)) batch_reserve(paged_cache, torch.tensor([512, 512], device=device)) expected_page_table = torch.tensor( [[0, 3, 5, 7, -1, -1, -1, -1], [2, 1, 4, 6, -1, -1, -1, -1]], device=device, ) self.assertEqual( paged_cache.capacity, torch.tensor([512, 512], device=device), ) self.assertEqual(paged_cache.page_table, expected_page_table) # Get a block mask def causal_mask(b, h, q, kv): return q >= kv block_mask = create_block_mask( causal_mask, max_batch_size, 1, max_seq_len, max_seq_len, device=device ) kv_len_tensor = torch.full( (max_batch_size,), max_seq_len, device=device, dtype=torch.int64 ) new_block_mask = paged_cache.convert_logical_block_mask( block_mask, kv_len=kv_len_tensor ) zeros = [0, 0, 0, 0] # Check that the new block mask is correct expected_kv_num_blocks = torch.tensor( [[[1, 1, 1, 1]], [[1, 1, 1, 1]]], device=device, dtype=torch.int32 ) expected_kv_indices = torch.tensor( [ [ [ [0, 3, 5, 7, *zeros], [3, 0, 5, 7, *zeros], [5, 0, 3, 7, *zeros], [7, 0, 3, 5, *zeros], ] ], [ [ [2, 1, 4, 6, *zeros], [1, 2, 4, 6, *zeros], [4, 2, 1, 6, *zeros], [6, 2, 1, 4, *zeros], ] ], ], device=device, dtype=torch.int32, ) expected_full_kv_num_blocks = torch.tensor( [[[0, 1, 2, 3]], [[0, 1, 2, 3]]], device=device, dtype=torch.int32 ) expected_full_kv_indices = torch.tensor( [ [ [ [0, 3, 5, 7, *zeros], [0, 3, 5, 7, *zeros], [0, 3, 5, 7, *zeros], [0, 3, 5, 7, *zeros], ] ], [ [ [2, 1, 4, 6, *zeros], [2, 1, 4, 6, *zeros], [2, 1, 4, 6, *zeros], [2, 1, 4, 6, *zeros], ] ], ], device=device, dtype=torch.int32, ) self.assertEqual(new_block_mask.kv_num_blocks, expected_kv_num_blocks) self.assertEqual(new_block_mask.kv_indices, expected_kv_indices) self.assertEqual(new_block_mask.full_kv_num_blocks, expected_full_kv_num_blocks) self.assertEqual(new_block_mask.full_kv_indices, expected_full_kv_indices) @supported_platform def test_convert_mask_mod(self, device): n_pages, page_size, max_batch_size = 8, 128, 2 paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device) batch_reserve(paged_cache, torch.tensor([100, 200], device=device)) batch_reserve(paged_cache, torch.tensor([150, 300], device=device)) batch_reserve(paged_cache, torch.tensor([300, 512], device=device)) batch_reserve(paged_cache, torch.tensor([512, 512], device=device)) expected_page_table = torch.tensor( [[0, 3, 5, 7, -1, -1, -1, -1], [2, 1, 4, 6, -1, -1, -1, -1]], device=device, ) self.assertEqual( paged_cache.capacity, torch.tensor([512, 512], device=device), ) self.assertEqual(paged_cache.page_table, expected_page_table) expected_physical_to_logical = torch.tensor( [[0, -1, -1, 1, -1, 2, -1, 3], [-1, 1, 0, -1, 2, -1, 3, -1]], device=device, ) self.assertEqual(paged_cache.physical_to_logical, expected_physical_to_logical) # Get a block mask def causal_mask(b, h, q, kv): return q >= kv converted_causal_mask = paged_cache.get_mask_mod(causal_mask) # Equivalent to: causal_mask(0, 0, 256, 128) self.assertEqual(converted_causal_mask(0, 0, 256, 384), True) # Equivalent to: causal_mask(0, 1, 256, 128) self.assertEqual(converted_causal_mask(0, 1, 256, 384), True) # Not found corresponding logical block self.assertEqual(converted_causal_mask(1, 0, 256, 384), False) # Equivalent to: causal_mask(1, 0, 64, 14) self.assertEqual(converted_causal_mask(1, 0, 64, 270), True) @supported_platform def test_update(self, device): dtype = torch.float32 n_pages, page_size, max_batch_size, max_seq_len = 6, 2, 2, 6 paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device) n_heads, head_dim = 2, 3 cache_shape = (1, n_heads, n_pages * page_size, head_dim) k_cache = torch.zeros(cache_shape, dtype=dtype, device=device) batch_reserve(paged_cache, torch.tensor([1, 3], device=device)) batch_reserve(paged_cache, torch.tensor([4, 5], device=device)) batch_reserve(paged_cache, torch.tensor([6, 6], device=device)) expected_page_table = torch.tensor( [[0, 3, 5, -1, -1, -1], [2, 1, 4, -1, -1, -1]], device=device, ) self.assertEqual(paged_cache.page_table, expected_page_table) batch_idx = torch.arange(max_batch_size, device=device, dtype=torch.int32) input_pos = ( torch.arange(max_seq_len, device=device, dtype=torch.int32) .unsqueeze(0) .expand(max_batch_size, max_seq_len) ) k = torch.arange( max_batch_size * n_heads * max_seq_len * head_dim, device=device, dtype=dtype, ).view(max_batch_size, n_heads, max_seq_len, head_dim) v = k.detach().clone() v_cache = k_cache.detach().clone() paged_cache.assign(batch_idx, input_pos, k, v, k_cache, v_cache) expected_cache = torch.tensor( [ [ # h = 0 [ # page = 0 [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], # page = 1 [42.0, 43.0, 44.0], [45.0, 46.0, 47.0], # page = 2 [36.0, 37.0, 38.0], [39.0, 40.0, 41.0], # page = 3 [6.0, 7.0, 8.0], [9.0, 10.0, 11.0], # page = 4 [48.0, 49.0, 50.0], [51.0, 52.0, 53.0], # page = 5 [12.0, 13.0, 14.0], [15.0, 16.0, 17.0], ], # h = 1 [ # page = 0 [18.0, 19.0, 20.0], [21.0, 22.0, 23.0], # page = 1 [60.0, 61.0, 62.0], [63.0, 64.0, 65.0], # page = 2 [54.0, 55.0, 56.0], [57.0, 58.0, 59.0], # page = 3 [24.0, 25.0, 26.0], [27.0, 28.0, 29.0], # page = 4 [66.0, 67.0, 68.0], [69.0, 70.0, 71.0], # page = 5 [30.0, 31.0, 32.0], [33.0, 34.0, 35.0], ], ] ], device=device, dtype=dtype, ) self.assertEqual(k_cache, expected_cache) @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) @dtypesIfXPU(*device_configs["xpu"].dtypes) @common_utils.parametrize("score_mod", test_score_mods) def test_paged_builtin_score_mods( self, device, dtype: torch.dtype, score_mod: Callable ): n_pages, page_size, max_batch_size, max_seq_len = 32, 128, 4, 512 n_heads, head_dim = 4, 16 def causal_mask(b, h, q, kv): return q >= kv block_mask = create_block_mask( causal_mask, max_batch_size, 1, max_seq_len, max_seq_len, device=device ) q = torch.randn( max_batch_size, n_heads, max_seq_len, head_dim, device=device, dtype=dtype, requires_grad=False, ) k = torch.randn( max_batch_size, n_heads, max_seq_len, head_dim, device=device, dtype=dtype, requires_grad=False, ) v = torch.randn( max_batch_size, n_heads, max_seq_len, head_dim, device=device, dtype=dtype, requires_grad=False, ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=False) golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) MAX_CACHED_SEQ_LEN = n_pages * page_size k_cache = torch.zeros( 1, n_heads, MAX_CACHED_SEQ_LEN, head_dim, device=device, dtype=dtype, ) v_cache = torch.zeros( 1, n_heads, MAX_CACHED_SEQ_LEN, head_dim, device=device, dtype=dtype, ) paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device) batch_reserve(paged_cache, torch.tensor([100, 200, 50, 300], device=device)) batch_reserve(paged_cache, torch.tensor([100, 512, 300, 300], device=device)) batch_reserve(paged_cache, torch.tensor([512, 512, 300, 300], device=device)) batch_reserve(paged_cache, torch.tensor([512, 512, 512, 300], device=device)) batch_reserve(paged_cache, torch.tensor([512, 512, 512, 512], device=device)) batch_idx = torch.arange(max_batch_size, device=device, dtype=torch.int32) input_pos = ( torch.arange(max_seq_len, device=device, dtype=torch.int32) .unsqueeze(0) .expand(max_batch_size, max_seq_len) ) paged_cache.assign(batch_idx, input_pos, k, v, k_cache, v_cache) kv_len_tensor = torch.full( (max_batch_size,), max_seq_len, device=device, dtype=torch.int64 ) new_block_mask = paged_cache.convert_logical_block_mask( block_mask, kv_len=kv_len_tensor ) compiled_sdpa = torch.compile( create_attention( paged_cache.get_score_mod(score_mod, kv_len=kv_len_tensor), block_mask, enable_gqa=False, ) ) paged_out = compiled_sdpa(q, k_cache, v_cache, block_mask=new_block_mask) with torch.no_grad(): dtype = ref_out.dtype if dtype == torch.float32: fudge_factor = 10.0 else: fudge_factor = 1.1 # Checkout output self._check_equal(golden_out, ref_out, paged_out, fudge_factor, "Out") @dataclass class Params: batch_size: int num_heads: int seq_length: int head_dim: int dtype: torch.dtype config_str: Optional[str] = None def __str__(self): return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}" def get_params(dtypes: list[torch.dtype]) -> list[Params]: params = [] seq_lengths = [37, 256, 277] for seq_len, dtype in product(seq_lengths, dtypes): params.append( Params( batch_size=2, num_heads=4, seq_length=seq_len, head_dim=16, dtype=dtype ) ) return params supports_learnable_bias = unittest.skipUnless( ( (torch.cuda.is_available() and has_triton()) and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip) ), "Requires Triton + A100 or Triton + ROCm", ) @supports_learnable_bias @large_tensor_test_class("2GB", device=test_device[0]) class TestLearnableBiases(InductorTestCase): def setUp(self): super().setUp() skipCPUIf( LONG_COMPILATION_ON_CPU, "skip UT for CPU due to long compilation time found in CI", ) self.dtype = torch.float32 self.atol = 3e-2 self.rtol = 3e-2 def _init_tensors(self, params: Params, device: str): make_tensor = functools.partial( torch.randn, (params.batch_size, params.num_heads, params.seq_length, params.head_dim), device=device, dtype=params.dtype, requires_grad=True, ) return (make_tensor(), make_tensor(), make_tensor()) @torch.no_grad() def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35): ref_error = rmse(eager, gold) comp_error = rmse(compiled, gold) # Note: This has been carefully tested that FlexAttention is within # 20% of the average error of SDPA! Do not bump this tolerance # unless you are absolutely sure you are not worsening the accuracy # of FlexAttention! if eager.dtype == torch.float32: fudge_factor = 10.0 * fudge_factor comp_error = comp_error.item() ref_error = ref_error.item() * fudge_factor if ( tensor_name == "out" and eager.dtype == torch.float32 and comp_error > ref_error ): self.skipTest("Compiled FlexAttention is less accurate than eager in fp32") self.assertLessEqual( comp_error, (ref_error * fudge_factor), f"\nTensor: {tensor_name}\nCompiled error ({comp_error:.8f}) exceeds " f"reference error ({ref_error:.8f}) * fudge_factor ({fudge_factor})", ) def _check_outputs_and_grads( self, out_eager, out_compiled, out_gold, tensors, names=None ): backwards_grad = torch.randn_like(out_eager, device="cpu").to(out_eager.device) grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad) grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad) grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad) tensor_names = ( ["out", "grad_query", "grad_key", "grad_value", "grad_bias"] if names is None else names ) eager_tensors = (out_eager, *grads_eager) compiled_tensors = (out_compiled, *grads_compiled) gold_tensors = (out_gold, *grads_gold) for eager, compiled, gold, name in zip( eager_tensors, compiled_tensors, gold_tensors, tensor_names, strict=True ): self._gold_check(eager, compiled, gold, name) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) def test_relative_1d_bias(self, device, params, mode: str): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( 2 * params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[torch.abs(q_idx - kv_idx)] flex_compiled = torch.compile(flex_attention, mode=mode) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_absolute_2d_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.seq_length, params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[q_idx, kv_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_head_specific_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.num_heads, params.seq_length, params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[h, q_idx, kv_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_batch_head_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.batch_size, params.num_heads, params.seq_length, params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[b, h, q_idx, kv_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_multiplicative_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score * bias[q_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_local_window_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) window_size = 8 bias = torch.randn( 2 * window_size + 1, device=device, dtype=torch.float32, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): window_idx = torch.clamp(q_idx - kv_idx + window_size, 0, 2 * window_size) return score + bias[window_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_global_tokens_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.seq_length, device=device, dtype=torch.float32, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[kv_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_weird_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.batch_size, params.num_heads, 4, params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) which_bias = torch.tensor(0, device=device) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[b, h, which_bias, q_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_indirect_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) offset = torch.randint( 0, params.seq_length, (params.seq_length,), device=device, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[offset[q_idx]] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) def test_symmetric_bias(self, device, params, mode: str): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[q_idx] + bias[kv_idx] flex_compiled = torch.compile(flex_attention, mode=mode) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) # Error in backwards with self.assertRaisesRegex( torch._inductor.exc.LoweringException, "Using multiple indexing operations on the same tensor that requires gradients", ): self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_flipped_indexed_bias(self, device, params): query, key, value = self._init_tensors(params, device=device) bias = torch.randn( params.seq_length, params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[kv_idx, q_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) def test_head_specific_gate(self, device, params, mode: str): query, key, value = self._init_tensors(params, device=device) gate_score = torch.randn( params.num_heads, device=device, dtype=torch.float32, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score * torch.sigmoid(gate_score[h]) flex_compiled = torch.compile(flex_attention, mode=mode) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, gate_score), ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_distinct_biases(self, device, params): query, key, value = self._init_tensors(params, device=device) # Create two separate bias tensors bias1 = torch.randn( params.seq_length, device=device, dtype=params.dtype, requires_grad=True, ) bias2 = torch.randn( params.seq_length, device=device, dtype=torch.float32, requires_grad=True, ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias1[q_idx] + bias2[kv_idx] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) # Include both bias tensors in the tuple for gradient checking self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (query, key, value, bias1, bias2), names=[ "out", "grad_query", "grad_key", "grad_value", "grad_bias1", "grad_bias2", ], ) @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @torch.compile def test_learnable_bias_global_compiled(self, device, params): batch_size = 1 num_heads = 1 seq_len = 128 head_dim = 16 d_model = num_heads * head_dim query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device) key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device) value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device) out_proj = nn.Linear(d_model, d_model, device=device) query.requires_grad = True key.requires_grad = True value.requires_grad = True bias = torch.randn( batch_size, num_heads, seq_len, seq_len, device=device, requires_grad=True, ) def bias_mod(score, b, h, q_idx, kv_idx): return score + bias[b, h, q_idx, kv_idx] out = flex_attention( query=query, key=key, value=value, score_mod=bias_mod, ) out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model) attn_output = out_proj(out) random_target = torch.randn(batch_size, seq_len, d_model, device=device) loss = torch.nn.functional.mse_loss(attn_output, random_target) loss.backward() assert bias.grad, "No gradient computed for bias" assert torch.any(bias.grad != 0), "Gradient for bias is 0" @skip_on_cpu def test_backprop_error_case(self, device): @torch.compile() def test(x, y): # Materialize a bias matrix B, L, device = x.shape[0], x.shape[1], x.device b = torch.arange(B, device=device, dtype=torch.long).view(B, 1, 1) q_idx = torch.arange(L, device=device, dtype=torch.long).view(1, L, 1) kv_idx = torch.arange(L, device=device, dtype=torch.long).view(1, 1, L) bias_mat = y[b, q_idx] + y[b, kv_idx] # (B, L, L) # Dummy score_mod retrieving bias values def score_mod(score, b, h, q_idx, kv_idx): return score + bias_mat[b, q_idx, kv_idx] x_ = x[:, :, None].repeat(1, 1, 16, 1) # torch._dynamo.graph_break() return flex_attention(x_, x_, x_, score_mod=score_mod) B, L, D = 2, 16, 64 x = torch.randn(B, L, D, device=device, requires_grad=True) y = torch.randn(B, L, device=device, requires_grad=True) _ = test(x, y).mean().backward() assert x.grad.norm() > 0 assert y.grad.norm() > 0 @skip_on_cpu @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) def test_relative_1d_bias_only_grad(self, device, params): query, key, value = self._init_tensors(params, device=device) query = query.detach().requires_grad_(False) key = key.detach().requires_grad_(False) value = value.detach().requires_grad_(False) # Only bias requires gradients bias = torch.randn( 2 * params.seq_length, device=device, dtype=params.dtype, requires_grad=True, # Only bias needs gradients ) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[torch.abs(q_idx - kv_idx)] flex_compiled = torch.compile(flex_attention) out_eager = flex_attention(query, key, value, score_mod=bias_func) out_compiled = flex_compiled(query, key, value, score_mod=bias_func) out_gold = flex_attention( query.to(torch.float64), key.to(torch.float64), value.to(torch.float64), score_mod=bias_func, ) # For gradient checking, we only pass the bias tensor since it's the only one requiring gradients self._check_outputs_and_grads( out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"] ) def _test_flex_attention_with_dynamic_max_autotune(self, device): query = torch.randn(2, 16, 512, 64, device=device) key = torch.randn(2, 16, 512, 64, device=device) value = torch.randn(2, 16, 512, 64, device=device) query.requires_grad = True key.requires_grad = True value.requires_grad = True shape = (2, 16, 512, 16, 512, 64) B, Hq, M, Hkv, N, D = shape score_mod = _generate_alibi_bias(8) def causal(b, h, m, n): return m >= n mask_shape = (1, 1, M, N) block_mask = torch.compile(create_block_mask)( causal, *mask_shape, device=device ) compiled_sdpa = torch.compile( flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs" ) out = compiled_sdpa( query=query, key=key, value=value, score_mod=score_mod, block_mask=block_mask, enable_gqa=True, kernel_options=None, ) out.sum().backward() self.assertEqual( out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}" ) @skip_on_cpu def test_flex_attention_with_dynamic_max_autotune(self, device): self._test_flex_attention_with_dynamic_max_autotune(device) @skip_on_cpu @torch._inductor.config.patch("graph_partition", True) def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device): self._test_flex_attention_with_dynamic_max_autotune(device) @skip_on_cpu def test_inspect_bug(self, device): # https://github.com/pytorch/pytorch/issues/139374 def sliding_window(b, h, q_idx, kv_idx, val): return (q_idx - kv_idx).abs() < val sliding_window2 = functools.partial( sliding_window, val=torch.randn((), device=device) ) opt_fn = torch.compile(create_block_mask, fullgraph=True) create_block_mask(sliding_window2, None, None, 1024, 1024, device=device) # checks that the compile is working opt_fn(sliding_window2, None, None, 1024, 1024, device=device) @supported_platform @skip_on_cpu def test_head_bias_req_grad(self, device): B, H, S, D = 1, 4, 256, 64 bias = torch.randn(H, device=device, dtype=torch.float16, requires_grad=True) bias_flex = bias.detach().clone().requires_grad_(True) def head_bias(score, b, h, q_idx, kv_idx): return score + bias_flex[h] bias_sdpa_ref = bias.detach().clone().requires_grad_(True) implicit_bias_sdpa_ref = bias_sdpa_ref implicit_bias_sdpa_ref = implicit_bias_sdpa_ref.view(H, 1, 1).expand(H, S, S) bias_sdpa_gold = ( bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) ) implicit_bias_sdpa_gold = bias_sdpa_gold implicit_bias_sdpa_gold = implicit_bias_sdpa_gold.view(H, 1, 1).expand(H, S, S) self._test_learnable_bias_inner( B, H, S, D, head_bias, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ) @supported_platform @skip_on_cpu def test_comparison_vs_sdpa_with_learnable_bias(self, device): # 1-dimensional bias: B, H, S, D = 1, 1, 256, 64 bias = torch.randn( 2 * S, device=device, dtype=torch.float16, requires_grad=True ) bias_flex = bias.detach().clone().requires_grad_(True) def rel_pos_1d(score, b, h, q_idx, kv_idx): return score + bias_flex[q_idx + kv_idx] bias_indices = torch.arange(S)[:, None] + torch.arange(S) bias_sdpa_ref = bias.detach().clone().requires_grad_(True) implicit_bias_sdpa_ref = bias_sdpa_ref[bias_indices] bias_sdpa_gold = ( bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) ) implicit_bias_sdpa_gold = bias_sdpa_gold[bias_indices] self._test_learnable_bias_inner( B, H, S, D, rel_pos_1d, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ) # 2-dimensional bias: B, H, S, D = 1, 1, 256, 64 bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True) bias_flex = bias.detach().clone().requires_grad_(True) def rel_pos_2d(score, b, h, q_idx, kv_idx): return score + bias_flex[q_idx, kv_idx] bias_sdpa_ref = bias.detach().clone().requires_grad_(True) implicit_bias_sdpa_ref = bias_sdpa_ref bias_sdpa_gold = ( bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) ) implicit_bias_sdpa_gold = bias_sdpa_gold self._test_learnable_bias_inner( B, H, S, D, rel_pos_2d, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ) # 2-dimensional bias + index multiple B, H, S, D = 1, 1, 256, 64 bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True) bias_flex = bias.detach().clone().requires_grad_(True) def rel_pos_2d(score, b, h, q_idx, kv_idx): return score + bias_flex[q_idx][kv_idx] bias_sdpa_ref = bias.detach().clone().requires_grad_(True) implicit_bias_sdpa_ref = bias_sdpa_ref bias_sdpa_gold = ( bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) ) implicit_bias_sdpa_gold = bias_sdpa_gold self._test_learnable_bias_inner( B, H, S, D, rel_pos_2d, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ) # 2-dimensional bias + transposed: B, H, S, D = 1, 1, 256, 64 bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True) bias_flex = bias.detach().clone().requires_grad_(True) def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx): return score + bias_flex[kv_idx, q_idx] bias_sdpa_ref = bias.detach().clone().requires_grad_(True) implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) bias_sdpa_gold = ( bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) ) implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) self._test_learnable_bias_inner( B, H, S, D, rel_pos_2d_transposed, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ) # 3-dimensional bias + transposed B, H, S, D = 4, 8, 256, 64 bias = torch.randn( H, S, S, device=device, dtype=torch.float16, requires_grad=True ) bias_flex = bias.detach().clone().requires_grad_(True) def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx): return score + bias_flex[h, kv_idx, q_idx] bias_sdpa_ref = bias.detach().clone().requires_grad_(True) implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) bias_sdpa_gold = ( bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) ) implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) self._test_learnable_bias_inner( B, H, S, D, rel_pos_3d_transposed, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ) def _test_learnable_bias_inner( self, B, H, S, D, score_mod, bias_flex, implicit_bias_sdpa_ref, bias_sdpa_ref, implicit_bias_sdpa_gold, bias_sdpa_gold, device, ): make_tensor = functools.partial( torch.ones, (B, H, S, D), device=device, dtype=torch.float16, requires_grad=True, ) q_ref, k_ref, v_ref = make_tensor(), make_tensor(), make_tensor() q_gold, k_gold, v_gold = query_key_value_clones( q_ref, k_ref, v_ref, torch.float64 ) q_flex, k_flex, v_flex = query_key_value_clones(q_ref, k_ref, v_ref) out_ref = torch.nn.functional.scaled_dot_product_attention( q_ref, k_ref, v_ref, attn_mask=implicit_bias_sdpa_ref ) out_ref.sum().backward() out_gold = torch.nn.functional.scaled_dot_product_attention( q_gold, k_gold, v_gold, attn_mask=implicit_bias_sdpa_gold ) out_gold.sum().backward() out_flex = flex_attention(q_flex, k_flex, v_flex, score_mod=score_mod) out_flex.sum().backward() name = score_mod.__name__ for ref, flex, gold in [ (out_ref, out_flex, out_gold), (q_ref.grad, q_flex.grad, q_gold.grad), (k_ref.grad, k_flex.grad, k_gold.grad), (v_ref.grad, v_flex.grad, v_gold.grad), (bias_sdpa_ref.grad, bias_flex.grad, bias_sdpa_gold.grad), ]: ref_error = rmse(ref, gold) flex_error = rmse(flex, gold) self.assertTrue( ref_error * 1.2 >= flex_error, f"{name} -> Ref error: {ref_error}, Flex eager Error: {flex_error}", ) instantiate_device_type_tests( TestFlexAttention, globals(), only_for=test_device, allow_xpu=True ) instantiate_device_type_tests( TestPagedAttention, globals(), only_for=test_device, allow_xpu=True ) instantiate_device_type_tests( TestBlockMask, globals(), only_for=(test_device[0] if HAS_GPU else "cuda",), allow_xpu=True, ) instantiate_device_type_tests( TestLearnableBiases, globals(), only_for=test_device, allow_xpu=True ) if __name__ == "__main__": from torch._inductor.test_case import run_tests run_tests()