fix cuda graph (#22721)

Signed-off-by: fsx950223 <fsx950223@outlook.com>
This commit is contained in:
who who who
2025-08-20 14:24:37 +08:00
committed by GitHub
parent 8fd920924c
commit d983769c41

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -231,7 +232,7 @@ class AiterFlashAttentionMetadata:
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
cudagraph_support = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):