Files
pytorch/test/inductor/test_flex_flash.py
drisspg 0747d95994 Add Loads from fixed inputs (#162031)
## TODO
Check on multi indices
```Python

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
        in_ptr4 = buffers[0]
        tmp0 = tSrS_ssa
        tmp1 = b_idx
        tmp2 = h_idx
        tmp3 = cute.make_fragment(1, cutlass.Int32)
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6])
        tmp8 = (tmp5.load()).to(cutlass.Float32)
        tmp9 = (tmp0 + tmp8)
        tSrS_ssa = tmp9

        return tSrS_ssa

 ```

I dont think that
```
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6]

```

 is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162031
Approved by: https://github.com/v0i0
ghstack dependencies: #161118
2025-10-10 01:23:37 +00:00

287 lines
11 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
from contextlib import contextmanager
import torch
from torch._inductor.kernel.flex.flex_flash_attention import ensure_flash_available
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_utils import parametrize
def _times_two(score, _b, _h, _m, _n):
return score * 2
def _causal(score, _b, _h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))
def _rel_bias(score, _b, _h, token_q, token_kv):
return score + (token_q - token_kv)
def create_alibi_learned(num_heads=4, dtype=torch.float16):
"""ALiBi with learned per-head slopes (tests tensor loading)."""
slopes = torch.exp2(-torch.linspace(1, 8, num_heads, device="cuda", dtype=dtype))
def alibi_score_mod(score, b, h, q_idx, kv_idx):
bias = (kv_idx - q_idx) * slopes[h]
return score + bias
return alibi_score_mod
def create_pos_bias_table(seq_len=512, dtype=torch.float16):
"""Relative position bias table (tests computed indexing)."""
max_len = seq_len
table = torch.randn(2 * max_len - 1, device="cuda", dtype=dtype) * 0.1
def pos_bias_mod(score, b, h, q_idx, kv_idx):
rel_pos = kv_idx - q_idx + max_len - 1
bias = table[rel_pos]
return score + bias
return pos_bias_mod
def create_head_scale(num_heads=4, dtype=torch.float16):
"""Per-head scaling factors (tests multiplication with tensor loading)."""
scales = torch.rand(num_heads, device="cuda", dtype=dtype) + 0.5
def head_scale_mod(score, b, h, q_idx, kv_idx):
return score * scales[h]
return head_scale_mod
def create_batch_bias(batch_size=2, dtype=torch.float16):
"""Per-batch bias (tests batch indexing)."""
bias = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1
def batch_bias_mod(score, b, h, q_idx, kv_idx):
return score + bias[b]
return batch_bias_mod
def create_batch_head_bias(batch_size=2, num_heads=4, dtype=torch.float16):
"""Per-batch-head bias matrix (tests 2D indexing with batch + head)."""
bias_matrix = torch.randn(batch_size, num_heads, device="cuda", dtype=dtype) * 0.5
def batch_head_mod(score, b, h, q_idx, kv_idx):
bias = bias_matrix[b, h]
return score + bias
return batch_head_mod
def create_dual_buffer_bias(num_heads=4, seq_len=512, dtype=torch.float16):
"""Dual buffer loading (tests loading from 2 separate tensors)."""
head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2
pos_scale = torch.arange(seq_len, device="cuda", dtype=dtype)
def dual_buffer_mod(score, b, h, q_idx, kv_idx):
head_component = head_bias[h]
pos_component = pos_scale[q_idx] * 0.01
return score + head_component + pos_component
return dual_buffer_mod
def create_test_tensors(
batch_size=2, num_heads=4, seq_len=512, dim=64, dtype=torch.float16, device="cuda"
):
shape = (batch_size, num_heads, seq_len, dim)
q = torch.randn(shape, device=device, dtype=dtype, requires_grad=False)
k = torch.randn(shape, device=device, dtype=dtype, requires_grad=False)
v = torch.randn(shape, device=device, dtype=dtype, requires_grad=False)
return q, k, v
@contextmanager
def cuda_kernel_profiler(kernel_pattern="flash_attncute"):
"""Context manager for profiling CUDA kernels."""
result = {"found": False, "kernel_names": []}
with profile(activities=[ProfilerActivity.CUDA]) as prof:
yield result
kernel_names = [
evt.name
for evt in prof.events()
if evt.device_type == torch.autograd.DeviceType.CUDA and evt.name
]
result["kernel_names"] = kernel_names
result["found"] = any(kernel_pattern in name for name in kernel_names)
def flash_vs_triton(q, k, v, score_mod=None, rtol=5e-3, atol=5e-3):
compiled_fn = torch.compile(flex_attention)
out_flash = compiled_fn(
q, k, v, score_mod=score_mod, kernel_options={"force_flash": True}
)
out_no_flash = compiled_fn(
q, k, v, score_mod=score_mod, kernel_options={"force_flash": False}
)
torch.testing.assert_close(out_flash, out_no_flash, rtol=rtol, atol=atol)
return out_flash, out_no_flash
def name_fn(score_mod):
return score_mod.__name__.lstrip("_")
@unittest.skipIf(
not ensure_flash_available(), "Flash attention (CUTE) library is not available"
)
class TestFlexFlash(InductorTestCase):
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_basic(self, device, dtype):
q, k, v = create_test_tensors(dtype=dtype, device=device)
flash_vs_triton(q, k, v)
@dtypes(torch.float16, torch.bfloat16)
@parametrize("score_mod", [_times_two, _causal, _rel_bias], name_fn=name_fn)
def test_flash_attention_with_score_mod(self, device, dtype, score_mod):
q, k, v = create_test_tensors(dtype=dtype, device=device)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
@parametrize("seq_len", [127, 255, 383, 511])
def test_flash_attention_unfriendly_seqlen_with_causal(
self, device, dtype, seq_len
):
"""Test flash attention with unfriendly sequence lengths and causal masking."""
q, k, v = create_test_tensors(seq_len=seq_len, dtype=dtype, device=device)
flash_vs_triton(q, k, v, score_mod=_causal)
@dtypes(torch.float16, torch.bfloat16)
def test_force_flash_error_with_block_mask(self, device, dtype):
"""Test that force_flash=True raises error when BlockMask is provided."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
# Create a causal block mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
compiled_fn = torch.compile(flex_attention)
with self.assertRaisesRegex(
RuntimeError,
r"force_flash=True but flash attention cannot be used.*BlockMask.*not supported",
):
compiled_fn(
q, k, v, block_mask=block_mask, kernel_options={"force_flash": True}
)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_kernel_called(self, device, dtype):
"""Test that flash attention kernel is actually called when force_flash=True."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
compiled_fn = torch.compile(flex_attention)
# Test that flash kernel is called with force_flash=True
with cuda_kernel_profiler("flash_attncute") as prof_result:
compiled_fn(
q, k, v, score_mod=_causal, kernel_options={"force_flash": True}
)
self.assertTrue(
prof_result["found"],
f"Flash attention kernel not found. Available kernels: {prof_result['kernel_names']}",
)
# Test that flash kernel is NOT called with force_flash=False
with cuda_kernel_profiler("flash_attncute") as prof_result:
compiled_fn(
q, k, v, score_mod=_causal, kernel_options={"force_flash": False}
)
self.assertFalse(
prof_result["found"],
f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}",
)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_with_alibi_learned(self, device, dtype):
"""Test flash attention with ALiBi learned slopes (tensor loading)."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
score_mod = create_alibi_learned(num_heads=4, dtype=dtype)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_with_pos_bias_table(self, device, dtype):
"""Test flash attention with position bias table (tensor loading)."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
score_mod = create_pos_bias_table(seq_len=512, dtype=dtype)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_with_head_scale(self, device, dtype):
"""Test flash attention with head scaling (tensor loading)."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
score_mod = create_head_scale(num_heads=4, dtype=dtype)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_with_batch_bias(self, device, dtype):
"""Test flash attention with batch bias (tensor loading)."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
score_mod = create_batch_bias(batch_size=2, dtype=dtype)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_with_batch_head_bias(self, device, dtype):
"""Test flash attention with batch-head bias matrix (tensor loading)."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
score_mod = create_batch_head_bias(batch_size=2, num_heads=4, dtype=dtype)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
def test_flash_attention_with_dual_buffer_bias(self, device, dtype):
"""Test flash attention with dual buffer loading (tensor loading)."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
score_mod = create_dual_buffer_bias(num_heads=4, seq_len=512, dtype=dtype)
flash_vs_triton(q, k, v, score_mod=score_mod)
@dtypes(torch.float16, torch.bfloat16)
def test_force_flash_error_with_requires_grad(self, device, dtype):
"""Test that force_flash=True raises error when tensor requires gradients."""
q, k, v = create_test_tensors(dtype=dtype, device=device)
# Create a score mod with requires_grad tensor
bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True)
def score_mod_with_grad(score, b, h, q_idx, kv_idx):
return score + bias[h]
compiled_fn = torch.compile(flex_attention)
with self.assertRaisesRegex(
RuntimeError,
r"force_flash=True but flash attention cannot be used.*require gradients",
):
compiled_fn(
q,
k,
v,
score_mod=score_mod_with_grad,
kernel_options={"force_flash": True},
)
instantiate_device_type_tests(TestFlexFlash, globals(), only_for="cuda")
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()