mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
@ -1,16 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata)
|
MLACommonMetadata,
|
||||||
|
MLACommonMetadataBuilder)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -23,6 +25,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHINFER_MLA"
|
return "FLASHINFER_MLA"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||||
|
return FlashInferMLAMetadataBuilder
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||||
return FlashInferMLAImpl
|
return FlashInferMLAImpl
|
||||||
@ -34,6 +40,11 @@ g_fi_workspace = torch.zeros(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
|
cudagraph_support: ClassVar[
|
||||||
|
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -13,9 +13,11 @@ from vllm.attention.ops.triton_flash_attention import triton_attention
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata)
|
MLACommonMetadata,
|
||||||
|
MLACommonMetadataBuilder)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,12 +26,21 @@ class TritonMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_MLA"
|
return "TRITON_MLA_VLLM_V1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["TritonMLAMetadataBuilder"]:
|
||||||
|
return TritonMLAMetadataBuilder
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||||
return TritonMLAImpl
|
return TritonMLAImpl
|
||||||
|
|
||||||
|
|
||||||
|
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
pass
|
||||||
|
|
||||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||||
can_return_lse_for_decode: bool = True
|
can_return_lse_for_decode: bool = True
|
||||||
|
Reference in New Issue
Block a user