Compare commits

...

1 Commits

Author SHA1 Message Date
ebfce922f9 full cg support
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-09-26 12:51:46 -07:00
2 changed files with 28 additions and 6 deletions

View File

@ -1,16 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import ClassVar, Optional, Union
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
MLACommonMetadata,
MLACommonMetadataBuilder)
logger = init_logger(__name__)
@ -23,6 +25,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_name() -> str:
return "FLASHINFER_MLA"
@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl
@ -34,6 +40,11 @@ g_fi_workspace = torch.zeros(
device="cuda",
)
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
pass
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import ClassVar, Optional, Union
import torch
@ -13,9 +13,11 @@ from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
MLACommonMetadata,
MLACommonMetadataBuilder)
logger = init_logger(__name__)
@ -24,12 +26,21 @@ class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_builder_cls() -> type["TritonMLAMetadataBuilder"]:
return TritonMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
pass
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True