[ROCm] Split AITER unified attention into its own backend (#25507)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-10-06 18:49:23 -04:00
committed by GitHub
parent 2161efe978
commit f231e5bc21
8 changed files with 325 additions and 301 deletions

View File

@ -7,9 +7,7 @@ import pytest
import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
@ -31,7 +29,6 @@ from vllm.config import (
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
)
@ -48,132 +45,6 @@ backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
)
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
)
def test_attention_fusion_v0(
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], (
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
)
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
@ -221,7 +92,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device,
)
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize attention metadata."""
# Create common attn metadata
@ -232,30 +103,57 @@ class AttentionQuantPatternModel(torch.nn.Module):
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
backend = self.attn.backend
# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
if current_platform.is_rocm():
# Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN:
# k/v as 1st dimention
if use_hnd:
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
else:
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
else:
# HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.block_size,
self.num_kv_heads,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.TRITON_ATTN:
# k/v as 2nd dimention
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
).permute(0, 1, 3, 2, 4)
else:
raise ValueError(f"Unsupported backend: {backend}")
self.attn.kv_cache = [kv_cache]
# Build attn metadata
@ -375,10 +273,9 @@ else:
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize(
"backend",
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
)
@pytest.mark.parametrize(
"split_attention", [False, True] if current_platform.is_rocm() else [False]
[_Backend.FLASHINFER]
if current_platform.is_cuda()
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
)
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
@ -405,7 +302,6 @@ def test_attention_quant_pattern(
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
split_attention: bool,
use_inductor_graph_partition: bool,
monkeypatch,
dist_init,
@ -417,8 +313,6 @@ def test_attention_quant_pattern(
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
device = torch.device("cuda:0")
torch.manual_seed(42)
@ -466,9 +360,7 @@ def test_attention_quant_pattern(
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size, use_hnd=split_attention
)
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
@ -494,9 +386,7 @@ def test_attention_quant_pattern(
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention
)
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)

View File

@ -25,3 +25,4 @@ class _Backend(enum.Enum):
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()

View File

@ -254,3 +254,4 @@ def global_force_attn_backend_context_manager(
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
_cached_get_attn_backend.cache_clear()

View File

@ -1623,6 +1623,7 @@ class EngineArgs:
"TREE_ATTN",
"XFORMERS",
"ROCM_ATTN",
"ROCM_AITER_UNIFIED_ATTN",
]
if (
envs.is_set("VLLM_ATTENTION_BACKEND")

View File

@ -18,7 +18,6 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
@ -109,6 +108,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
@ -475,10 +475,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower()
in ("true", "1")
),
# Use AITER triton unified attention for V1 attention
"VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: (
os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1")
),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int(
@ -896,6 +892,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
),
# Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
in ("true", "1")
),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
@ -1434,7 +1435,6 @@ def compute_hash() -> str:
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_FLASHINFER_MOE_BACKEND",
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
"VLLM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_DISABLED_KERNELS",
@ -1462,6 +1462,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING",

View File

@ -276,25 +276,33 @@ class RocmPlatform(Platform):
)
if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
logger.info("Using Flash Attention backend on V1 engine.")
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend on V1 engine.")
return (
"vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend"
)
elif (
(envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION)
or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend on V1 engine.")
return (
"vllm.v1.attention.backends."
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
)
if (
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == _Backend.ROCM_ATTN
):
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
logger.info("Using Rocm Attention backend on V1 engine.")
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
else:
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."

View File

@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
RocmAttentionImpl,
RocmAttentionMetadata,
RocmAttentionMetadataBuilder,
)
logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ROCM_AITER_UNIFIED_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
return RocmAiterUnifiedAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return RocmAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
sinks,
)
logger.info_once(
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
)
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
return output

View File

@ -3,13 +3,10 @@
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@ -96,12 +93,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
# Here we set the query start locs to 0. This is to
# cover up an invalid memory access in the prefix_prefil kernel
# that we run into during graph capture (#25985)
common_attn_metadata.query_start_loc.zero_()
common_attn_metadata.query_start_loc_cpu.zero_()
# Here we set the query start locs to 0. This is to
# cover up an invalid memory access in the prefix_prefil kernel
# that we run into during graph capture (#25985)
common_attn_metadata.query_start_loc.zero_()
common_attn_metadata.query_start_loc_cpu.zero_()
return attn_metadata
@ -211,14 +207,6 @@ class RocmAttentionBackend(AttentionBackend):
return RocmAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class RocmAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
@ -268,23 +256,6 @@ class RocmAttentionImpl(AttentionImpl):
)
self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once("Using aiter unified attention for RocmAttentionImpl")
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
else:
logger.info_once("Using vllm unified attention for RocmAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention,
)
self.unified_attention = unified_attention
self.sinks = sinks
if sinks is not None:
@ -341,58 +312,32 @@ class RocmAttentionImpl(AttentionImpl):
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
use_prefill_decode_attn = self.force_prefill_decode_attn
num_actual_tokens = attn_metadata.num_actual_tokens
if use_prefill_decode_attn:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
else:
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if use_prefill_decode_attn:
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
@ -400,53 +345,27 @@ class RocmAttentionImpl(AttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
return output