mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
@ -1,16 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -23,6 +25,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
@ -34,6 +40,11 @@ g_fi_workspace = torch.zeros(
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
cudagraph_support: ClassVar[
|
||||
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
pass
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,9 +13,11 @@ from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,12 +26,21 @@ class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonMLAMetadataBuilder"]:
|
||||
return TritonMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
|
||||
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
pass
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
Reference in New Issue
Block a user