[FlashInfer] Cache hyper params in metadata builder (#23732)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# TODO: discard this for trtllm-gen backend
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||
self.sm_scale = self.global_hyperparameters.sm_scale
|
||||
self.window_left = self.global_hyperparameters.window_left
|
||||
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
# Preparing persistent buffers (device-side)
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_prefill_tokens,
|
||||
@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=True,
|
||||
has_sinks=has_sinks)
|
||||
has_sinks=self.has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_decode_tokens,
|
||||
@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=False,
|
||||
has_sinks=has_sinks)
|
||||
has_sinks=self.has_sinks)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
|
Reference in New Issue
Block a user