mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Qwen][ROCm] Flash Attention Rotary Embeddings (#24642)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@ -2,15 +2,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from functools import cache
|
||||
from importlib.util import find_spec
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# common functions
|
||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
|
||||
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||
|
||||
|
||||
@cache
|
||||
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
|
||||
if current_platform.is_cuda():
|
||||
return apply_rotary_emb
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if find_spec("flash_attn") is not None:
|
||||
from flash_attn.ops.triton.rotary import apply_rotary
|
||||
return apply_rotary
|
||||
else:
|
||||
logger.warning(
|
||||
"flash_attn is not installed. Falling back to PyTorch "
|
||||
"implementation for rotary embeddings.")
|
||||
|
||||
return apply_rotary_emb_torch
|
||||
|
||||
|
||||
# yarn functions
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations: int,
|
||||
|
@ -50,6 +50,8 @@ from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
dispatch_rotary_emb_function)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -63,7 +65,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||
freqs: torch.Tensor) -> torch.Tensor:
|
||||
rotary_emb_function = dispatch_rotary_emb_function()
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
apply_rotary_emb = apply_rotary_emb_torch
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||
output = rotary_emb_function(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user