mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Compare commits
1 Commits
gemma3n-mm
...
v0.9.0.1
Author | SHA1 | Date | |
---|---|---|---|
5fbbfe9a4c |
@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
|||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
TORCH_CHECK(head_size % pack_size == 0,
|
TORCH_CHECK(head_size % pack_size == 0,
|
||||||
"headsize must be multiple of pack_size:", pack_size);
|
"headsize must be multiple of pack_size:", pack_size);
|
||||||
|
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
||||||
|
"output heads must be contiguous in memory");
|
||||||
|
TORCH_CHECK(
|
||||||
|
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
||||||
|
"prefix_output heads must be contiguous in memory");
|
||||||
|
TORCH_CHECK(
|
||||||
|
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
||||||
|
"suffix_output heads must be contiguous in memory");
|
||||||
float* output_lse_ptr = nullptr;
|
float* output_lse_ptr = nullptr;
|
||||||
if (output_lse.has_value()) {
|
if (output_lse.has_value()) {
|
||||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||||
|
@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
if isinstance(attn_out, tuple):
|
if isinstance(attn_out, tuple):
|
||||||
attn_out, *rest = attn_out
|
attn_out, *rest = attn_out
|
||||||
|
|
||||||
# unpad if necessary
|
|
||||||
if self._pad_v:
|
|
||||||
attn_out = attn_out[..., :v.shape[-1]]
|
|
||||||
|
|
||||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||||
# is only one output tensor if `return_softmax_lse` is False.
|
# is only one output tensor if `return_softmax_lse` is False.
|
||||||
if return_softmax_lse:
|
if return_softmax_lse:
|
||||||
@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# unpad if necessary
|
||||||
|
if self._pad_v:
|
||||||
|
output = output[..., :v.shape[-1]]
|
||||||
|
|
||||||
return output.flatten(start_dim=-2)
|
return output.flatten(start_dim=-2)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
if isinstance(attn_out, tuple):
|
if isinstance(attn_out, tuple):
|
||||||
attn_out, lse = attn_out[0], attn_out[1]
|
attn_out, lse = attn_out[0], attn_out[1]
|
||||||
|
|
||||||
# unpad if necessary
|
|
||||||
if self._pad_v:
|
|
||||||
attn_out = attn_out[..., :v.shape[-1]]
|
|
||||||
|
|
||||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||||
# is only one output tensor if `return_softmax_lse` is False.
|
# is only one output tensor if `return_softmax_lse` is False.
|
||||||
if return_softmax_lse:
|
if return_softmax_lse:
|
||||||
@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# unpad if necessary
|
||||||
|
if self._pad_v:
|
||||||
|
output = output[..., :v.shape[-1]]
|
||||||
|
|
||||||
return output.flatten(start_dim=-2)
|
return output.flatten(start_dim=-2)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
Reference in New Issue
Block a user