diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 14e5edd7e2..6bee9e4ce1 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, + "output heads must be contiguous in memory"); + TORCH_CHECK( + prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, + "prefix_output heads must be contiguous in memory"); + TORCH_CHECK( + suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, + "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index d484626849..1007140ef3 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): if isinstance(attn_out, tuple): attn_out, *rest = attn_out - # unpad if necessary - if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] - # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: @@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): suffix_lse=suffix_lse, ) + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + return output.flatten(start_dim=-2) @abstractmethod diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 83e1811165..1edfab26b6 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if isinstance(attn_out, tuple): attn_out, lse = attn_out[0], attn_out[1] - # unpad if necessary - if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] - # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: @@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): suffix_lse=suffix_lse, ) + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + return output.flatten(start_dim=-2) @abstractmethod