[FlashInfer] Cache hyper params in metadata builder (#23732)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-27 03:45:04 -07:00
committed by GitHub
parent 04ff1e43fb
commit 11eddf02f0

View File

@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# TODO: discard this for trtllm-gen backend
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
self.sm_scale = self.global_hyperparameters.sm_scale
self.window_left = self.global_hyperparameters.window_left
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
self.has_sinks = self.global_hyperparameters.has_sinks
# Preparing persistent buffers (device-side)
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.cache_dtype,
self.q_data_type,
is_prefill=True,
has_sinks=has_sinks)
has_sinks=self.has_sinks)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=has_sinks)
has_sinks=self.has_sinks)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)