Fixes IMA for TP w/ flex-attention (#19712)

Signed-off-by: drisspg <drisspguessous@gmail.com>
This commit is contained in:
Driss Guessous
2025-06-16 21:01:50 -07:00
committed by GitHub
parent 5b3ad5ecf2
commit ddfed314f9
2 changed files with 2 additions and 10 deletions

View File

@ -51,7 +51,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
set_seed(seed)
@ -66,7 +65,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
set_seed(seed)
llm_default = LLM(
model_name,

View File

@ -13,7 +13,6 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@ -237,17 +236,13 @@ class FlexAttentionMetadata:
def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
create_block_mask_fn = (create_block_mask_compiled
if get_tensor_model_parallel_world_size() == 1
else create_block_mask)
return create_block_mask_fn(
return create_block_mask_compiled(
self.mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
device=self.block_table.device,
)
def __post_init__(self):
@ -429,7 +424,6 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"