[Bugfix] Add contiguous call inside rope kernel wrapper (#17091)

Signed-off-by: 苏政渊 <suzhengyuan@moonshot.cn>
Co-authored-by: 苏政渊 <suzhengyuan@moonshot.cn>
This commit is contained in:
Zhengyuan Su (苏政渊)
2025-04-29 10:24:07 +08:00
committed by GitHub
parent 165cb56329
commit 17eb306fcc
2 changed files with 17 additions and 7 deletions

View File

@ -158,8 +158,13 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
head_size, cos_sin_cache, is_neox)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
key_contiguous, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
# layer norm ops

View File

@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe)
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_k_pe)
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0: