Fix warning msg print (#3421)

### What this PR does / why we need it?
Avoid printing some warning msg as below :
UserWarning: To copy construct from a tensor, it is recommended to use
sourceTensor.clone().detach ...

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2025-10-15 11:30:30 +08:00
committed by GitHub
parent 16cb3cc45d
commit 4e720936d8
2 changed files with 6 additions and 10 deletions

View File

@ -79,7 +79,7 @@ class AscendMLAPrefillMetadata:
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int]
query_lens: torch.Tensor
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
@ -380,7 +380,7 @@ class AscendMLAMetadataBuilder:
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[reqs_start:],
query_lens=query_lens[reqs_start:].to(torch.int32),
seq_lens=seq_lens,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
@ -837,9 +837,7 @@ class AscendMLAImpl(MLAAttentionImpl):
k_rope=k_pe,
value=value,
mask=self.prefill_mask,
seqlen=torch.tensor(
attn_metadata.prefill.query_lens,
dtype=torch.int32),
seqlen=attn_metadata.prefill.query_lens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,

View File

@ -74,7 +74,7 @@ class AscendMLATorchairPrefillMetadata:
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int]
query_lens: torch.Tensor
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
@ -473,7 +473,7 @@ class AscendMLATorchairMetadataBuilder:
1).unsqueeze(2)
prefill_metadata = AscendMLATorchairPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[tokens_start:],
query_lens=query_lens[tokens_start:].to(torch.int32),
seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:],
input_positions=prefill_input_positions,
@ -880,9 +880,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
k_rope=k_pe,
value=value,
mask=self.prefill_mask,
seqlen=torch.tensor(
attn_metadata.prefill.query_lens,
dtype=torch.int32),
seqlen=attn_metadata.prefill.query_lens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,