Revert "[v1] Add fp32 support to v1 engine through flex attn" (#19404)

This commit is contained in:
Isotr0py
2025-06-10 16:30:20 +08:00
committed by GitHub
parent 9368cc90b2
commit 5f1ac1e1d1
4 changed files with 7 additions and 38 deletions

View File

@ -183,34 +183,6 @@ def test_env(
assert backend.get_name() == expected
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
device: str,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test attention backend selection with fp32."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to

View File

@ -4,7 +4,6 @@
import random
import pytest
import torch
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@ -400,7 +399,6 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
@ -429,7 +427,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
@ -458,7 +455,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
@ -487,7 +483,6 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_without_kv_sharing():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
@ -555,7 +550,6 @@ def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()

View File

@ -1337,6 +1337,13 @@ class EngineArgs:
recommend_to_remove=False)
return False
# Only Fp16 and Bf16 dtypes since we only support FA.
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
if model_config.dtype not in V1_SUPPORTED_DTYPES:
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
recommend_to_remove=False)
return False
# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",

View File

@ -233,10 +233,6 @@ class CudaPlatformBase(Platform):
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try: