mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Add contiguous call inside rope kernel wrapper (#17091)
Signed-off-by: 苏政渊 <suzhengyuan@moonshot.cn> Co-authored-by: 苏政渊 <suzhengyuan@moonshot.cn>
This commit is contained in:
committed by
GitHub
parent
165cb56329
commit
17eb306fcc
@ -158,8 +158,13 @@ def rotary_embedding(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
torch.ops._C.rotary_embedding(positions, query, key, head_size,
|
||||
cos_sin_cache, is_neox)
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
||||
query_contiguous = query.contiguous()
|
||||
key_contiguous = key.contiguous()
|
||||
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
|
||||
head_size, cos_sin_cache, is_neox)
|
||||
query.copy_(query_contiguous)
|
||||
key.copy_(key_contiguous)
|
||||
|
||||
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
||||
query_contiguous = query.contiguous()
|
||||
key_contiguous = key.contiguous()
|
||||
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
|
||||
key_contiguous, head_size,
|
||||
cos_sin_cache, is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
query.copy_(query_contiguous)
|
||||
key.copy_(key_contiguous)
|
||||
|
||||
|
||||
# layer norm ops
|
||||
|
@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
||||
decode_k_pe)
|
||||
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions,
|
||||
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
|
Reference in New Issue
Block a user