mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.10.2rc3
...
v0.9.0.1
Author | SHA1 | Date | |
---|---|---|---|
5fbbfe9a4c |
@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
||||
"output heads must be contiguous in memory");
|
||||
TORCH_CHECK(
|
||||
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
||||
"prefix_output heads must be contiguous in memory");
|
||||
TORCH_CHECK(
|
||||
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
||||
"suffix_output heads must be contiguous in memory");
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
|
@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
if isinstance(attn_out, tuple):
|
||||
attn_out, *rest = attn_out
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
attn_out = attn_out[..., :v.shape[-1]]
|
||||
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
output = output[..., :v.shape[-1]]
|
||||
|
||||
return output.flatten(start_dim=-2)
|
||||
|
||||
@abstractmethod
|
||||
|
@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
if isinstance(attn_out, tuple):
|
||||
attn_out, lse = attn_out[0], attn_out[1]
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
attn_out = attn_out[..., :v.shape[-1]]
|
||||
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
output = output[..., :v.shape[-1]]
|
||||
|
||||
return output.flatten(start_dim=-2)
|
||||
|
||||
@abstractmethod
|
||||
|
Reference in New Issue
Block a user