[Qwen][ROCm] Flash Attention Rotary Embeddings (#24642)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-10-02 23:26:08 +08:00
committed by GitHub
parent e51de388a2
commit 5e4a8223c6
2 changed files with 28 additions and 5 deletions

View File

@ -2,15 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from functools import cache
from importlib.util import find_spec
from typing import Callable
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
logger = init_logger(__name__)
# common functions
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@cache
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb
if current_platform.is_rocm():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
return apply_rotary
else:
logger.warning(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings.")
return apply_rotary_emb_torch
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations: int,

View File

@ -50,6 +50,8 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -63,7 +65,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function()
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output