mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[torch.compile] hide slicing under custom op for inductor (#8384)
This commit is contained in:
@ -16,5 +16,7 @@ def test_full_graph(model):
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
|
||||
enforce_eager=True,
|
||||
load_format="dummy")
|
||||
llm.generate(prompts, sampling_params)
|
||||
|
@ -122,6 +122,40 @@ def _(
|
||||
return torch.empty_like(decode_query)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::reshape_and_cache_flash",
|
||||
mutates_args=["kv_cache"])
|
||||
def reshape_and_cache_flash(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
"""Inductor cannot deal with inplace operations on views.
|
||||
See https://github.com/pytorch/pytorch/issues/131192
|
||||
and https://github.com/pytorch/pytorch/issues/130174
|
||||
This is a workaround to hide the view operation from the inductor.
|
||||
"""
|
||||
return torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
|
||||
|
||||
@reshape_and_cache_flash.register_fake # type: ignore
|
||||
def _(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
ops.reshape_and_cache_flash(
|
||||
torch.ops.vllm.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache is None or prefill_meta.block_tables is None
|
||||
@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
out = torch.ops.vllm.flash_attn_varlen_func(
|
||||
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@ -701,42 +736,44 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
output[:
|
||||
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output[
|
||||
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query.unsqueeze(1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=self.logits_soft_cap,
|
||||
).squeeze(1)
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query.unsqueeze(1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
Reference in New Issue
Block a user