[CI Failure] fix_test_auto_prefix_cache_support (#26053)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li
2025-10-04 02:44:49 -07:00
committed by GitHub
parent 7c2e91c4e0
commit 7d6b03381e
2 changed files with 14 additions and 7 deletions

View File

@ -1917,7 +1917,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
def test_chunked_prefill_disabled_for_encoder_decoder(
enable_chunked_prefill: bool, is_encoder_decoder: bool,
expect_enabled: bool) -> None:
"""Validate that chunked prefill is appropriately disabled for
"""Validate that chunked prefill is appropriately disabled for
encoder-decoder models."""
scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill,
@ -1942,7 +1942,7 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
def _validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config: SchedulerConfig, is_encoder_decoder: bool,
expect_enabled: bool) -> None:
"""Validate chunked prefill settings in the scheduler config for
"""Validate chunked prefill settings in the scheduler config for
encoder-decoder models."""
assert scheduler_config.chunked_prefill_enabled is expect_enabled
assert scheduler_config.enable_chunked_prefill is expect_enabled

View File

@ -396,10 +396,17 @@ class VllmConfig:
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
# Disable prefix caching only if chunked prefill is explicitly disabled
# (and not merely unset)
if (self.scheduler_config.chunked_prefill_enabled is False
or disable_chunked_prefill_reasons):
# Final off-switch for CP/APC:
# Disable for (a) collected blockers, (b) encoderdecoder, or
# (c) explicit CP=False when APC wasn't requested.
# Do NOT disable merely because the resolved CP flag is False.
apc_requested = (self.cache_config is not None
and self.cache_config.enable_prefix_caching)
if (disable_chunked_prefill_reasons
or (self.model_config is not None
and self.model_config.is_encoder_decoder)
or (self.scheduler_config.enable_chunked_prefill is False
and not apc_requested)):
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
@ -668,7 +675,7 @@ class VllmConfig:
f"Model: {self.model_config.model}")
def compile_debug_dump_path(self) -> Optional[Path]:
"""Returns a rank-aware path for dumping
"""Returns a rank-aware path for dumping
torch.compile debug information.
"""
if self.compilation_config.debug_dump_path is None: