[TPU] Temporary fix vmem oom for long model len by reducing page size (#20278)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang
2025-07-07 22:16:16 -07:00
committed by GitHub
parent 7721ef1786
commit e34d130c16

View File

@ -86,6 +86,12 @@ class PallasAttentionBackend(AttentionBackend):
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
# TODO: This is a temporary fix for vmem OOM.
# For long model length, we use 16 page-size to avoid too much
# VMEM spill. A more robust solution should be implemented to
# handle VREG spills.
if vllm_config.model_config.max_model_len > 8192:
return 16
page_size = next_power_of_2(
vllm_config.model_config.max_model_len) // 16
if page_size <= 16: