mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCm] Split AITER unified attention into its own backend (#25507)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
2161efe978
commit
f231e5bc21
@ -7,9 +7,7 @@ import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
@ -31,7 +29,6 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
@ -48,132 +45,6 @@ backend: Optional[TestBackend] = None
|
||||
backend_unfused: Optional[TestBackend] = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
|
||||
)
|
||||
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
|
||||
)
|
||||
def test_attention_fusion_v0(
|
||||
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
|
||||
):
|
||||
# Clean Dynamo cache to avoid reusing other test cases
|
||||
# (for some reason the reset at the end is not enough)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Use global backends
|
||||
global backend, backend_unfused
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
|
||||
|
||||
# Prompt 4 seems too open-ended, differs between fused and unfused
|
||||
# (both outputs look reasonable though)
|
||||
prompts = example_prompts[:4] + example_prompts[5:]
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
|
||||
llm = LLM(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
|
||||
|
||||
unfused_output = llm.generate(prompts, sampling_params)
|
||||
backend_unfused = None # Reset backend to make sure llm gets released
|
||||
del llm
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
)
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||
llm2 = LLM(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in compile_config.static_forward_context.items()
|
||||
]
|
||||
|
||||
print(f"{attn_fusion_supported=}")
|
||||
if any(attn_fusion_supported):
|
||||
# Check quant ops
|
||||
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
|
||||
# attention ops present in both, just output_scale param changes
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post)
|
||||
|
||||
for i in range(len(attn_nodes_pre)):
|
||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||
assert fused == attn_fusion_supported[i], (
|
||||
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
|
||||
)
|
||||
|
||||
# check outputs
|
||||
fused_output = llm2.generate(prompts, sampling_params)
|
||||
|
||||
# transform outputs to format expected by check_outputs_equal
|
||||
sample_outs = lambda s: (list(s.token_ids), s.text)
|
||||
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outs_lst(unfused_output),
|
||||
outputs_1_lst=outs_lst(fused_output),
|
||||
name_0="unfused",
|
||||
name_1="fused",
|
||||
)
|
||||
|
||||
# Clean Dynamo cache to avoid polluting other case(s)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Reset backend to make sure llm2 gets released
|
||||
backend = None
|
||||
|
||||
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
|
||||
@ -221,7 +92,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
|
||||
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# Create common attn metadata
|
||||
@ -232,30 +103,57 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
backend = self.attn.backend
|
||||
|
||||
# Create dummy KV cache for FlashInfer TRTLLM
|
||||
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == _Backend.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
if use_hnd:
|
||||
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
|
||||
else:
|
||||
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
|
||||
else:
|
||||
# HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
2,
|
||||
num_blocks,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
2,
|
||||
num_blocks,
|
||||
self.block_size,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.TRITON_ATTN:
|
||||
# k/v as 2nd dimention
|
||||
# Create kv_cache in HND layout and permute to NHD layout
|
||||
# (later will be permuted back to HND layout in forward pass)
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
).permute(0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
|
||||
# Build attn metadata
|
||||
@ -375,10 +273,9 @@ else:
|
||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"split_attention", [False, True] if current_platform.is_rocm() else [False]
|
||||
[_Backend.FLASHINFER]
|
||||
if current_platform.is_cuda()
|
||||
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
|
||||
)
|
||||
# TODO(boyuan): test inductor graph partition on rocm
|
||||
@pytest.mark.parametrize(
|
||||
@ -405,7 +302,6 @@ def test_attention_quant_pattern(
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend,
|
||||
split_attention: bool,
|
||||
use_inductor_graph_partition: bool,
|
||||
monkeypatch,
|
||||
dist_init,
|
||||
@ -417,8 +313,6 @@ def test_attention_quant_pattern(
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
if split_attention:
|
||||
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.manual_seed(42)
|
||||
@ -466,9 +360,7 @@ def test_attention_quant_pattern(
|
||||
model_unfused = model_unfused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
||||
batch_size, use_hnd=split_attention
|
||||
)
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
@ -494,9 +386,7 @@ def test_attention_quant_pattern(
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
||||
batch_size, use_hnd=split_attention
|
||||
)
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
@ -25,3 +25,4 @@ class _Backend(enum.Enum):
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
ROCM_ATTN = enum.auto()
|
||||
ROCM_AITER_UNIFIED_ATTN = enum.auto()
|
||||
|
@ -254,3 +254,4 @@ def global_force_attn_backend_context_manager(
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
@ -1623,6 +1623,7 @@ class EngineArgs:
|
||||
"TREE_ATTN",
|
||||
"XFORMERS",
|
||||
"ROCM_ATTN",
|
||||
"ROCM_AITER_UNIFIED_ATTN",
|
||||
]
|
||||
if (
|
||||
envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
|
13
vllm/envs.py
13
vllm/envs.py
@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
LD_LIBRARY_PATH: Optional[str] = None
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
||||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||
LOCAL_RANK: int = 0
|
||||
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||
@ -109,6 +108,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||
VLLM_ROCM_USE_TRITON_ROPE: bool = False
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
@ -475,10 +475,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
# Use AITER triton unified attention for V1 attention
|
||||
"VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: (
|
||||
os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
||||
# when using the flash-attention backend.
|
||||
"VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int(
|
||||
@ -896,6 +892,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Use AITER triton unified attention for V1 attention
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
# use rocm skinny gemms
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
|
||||
@ -1434,7 +1435,6 @@ def compute_hash() -> str:
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||
"VLLM_FLASHINFER_MOE_BACKEND",
|
||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
|
||||
"VLLM_USE_AITER_UNIFIED_ATTENTION",
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||
"VLLM_DISABLED_KERNELS",
|
||||
@ -1462,6 +1462,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
||||
"VLLM_ROCM_USE_TRITON_ROPE",
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM",
|
||||
"VLLM_ROCM_FP8_PADDING",
|
||||
"VLLM_ROCM_MOE_PADDING",
|
||||
|
@ -276,25 +276,33 @@ class RocmPlatform(Platform):
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
|
||||
) or selected_backend == _Backend.ROCM_AITER_FA:
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return (
|
||||
"vllm.v1.attention.backends."
|
||||
"rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
elif (
|
||||
(envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION)
|
||||
or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
return (
|
||||
"vllm.v1.attention.backends."
|
||||
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
if (
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
or selected_backend == _Backend.ROCM_ATTN
|
||||
):
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
else:
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend."
|
||||
|
203
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
203
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
RocmAttentionImpl,
|
||||
RocmAttentionMetadata,
|
||||
RocmAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_UNIFIED_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
||||
return RocmAiterUnifiedAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return RocmAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
sinks,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
return output
|
@ -3,13 +3,10 @@
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@ -96,12 +93,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
|
||||
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
|
||||
# Here we set the query start locs to 0. This is to
|
||||
# cover up an invalid memory access in the prefix_prefil kernel
|
||||
# that we run into during graph capture (#25985)
|
||||
common_attn_metadata.query_start_loc.zero_()
|
||||
common_attn_metadata.query_start_loc_cpu.zero_()
|
||||
# Here we set the query start locs to 0. This is to
|
||||
# cover up an invalid memory access in the prefix_prefil kernel
|
||||
# that we run into during graph capture (#25985)
|
||||
common_attn_metadata.query_start_loc.zero_()
|
||||
common_attn_metadata.query_start_loc_cpu.zero_()
|
||||
|
||||
return attn_metadata
|
||||
|
||||
@ -211,14 +207,6 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
@cache
|
||||
def use_aiter_unified_attention() -> bool:
|
||||
"""Check if aiter unified attention should be used."""
|
||||
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
|
||||
# to 1 as default
|
||||
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
|
||||
|
||||
|
||||
class RocmAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
@ -268,23 +256,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
if not self.force_prefill_decode_attn:
|
||||
# If not using prefill decode attention, we use the Triton
|
||||
# unified attention implementation.
|
||||
if use_aiter_unified_attention():
|
||||
logger.info_once("Using aiter unified attention for RocmAttentionImpl")
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
else:
|
||||
logger.info_once("Using vllm unified attention for RocmAttentionImpl")
|
||||
from vllm.attention.ops.triton_unified_attention import (
|
||||
unified_attention,
|
||||
)
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
@ -341,58 +312,32 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
use_prefill_decode_attn = self.force_prefill_decode_attn
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if use_prefill_decode_attn:
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
||||
# only, since dequantizing back to f32 in the attention kernel
|
||||
# is not supported.
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale,
|
||||
)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
@ -400,53 +345,27 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
Reference in New Issue
Block a user