[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-08-02 04:59:34 -04:00
committed by GitHub
parent 25373b6c6c
commit b690e34824
9 changed files with 144 additions and 118 deletions

View File

@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype):
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
],
dim=0)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].clone()
state_before = state.clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out)
out_ref = selective_state_update_ref(state_ref,
x[:batch_size],
dt[:batch_size],
@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
dtype=torch.int32, device=device)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
out = torch.empty_like(x)
if not tie_hdim:
dt = torch.randn(batch_size,
nheads,
@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices(
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID)
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out)
out_ref = selective_state_update_ref(state_ref,
x,
dt,

View File

@ -212,15 +212,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)
Y, final_state = mamba_chunk_scan_combined(X,
dt,
A,
B,
C,
chunk_size,
D=None,
return_final_states=True)
Y = torch.empty_like(X)
final_state = mamba_chunk_scan_combined(X,
dt,
A,
B,
C,
chunk_size,
D=None,
return_final_states=True,
out=Y)
# just test the last in sequence
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
@ -292,7 +293,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y, new_states = mamba_chunk_scan_combined(
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined(
X,
dt,
A,
@ -306,6 +308,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=states,
out=Y,
)
# just test the last in sequence

View File

@ -220,7 +220,8 @@ class MambaMixer(CustomOp):
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
@ -231,7 +232,8 @@ class MambaMixer(CustomOp):
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection

View File

@ -541,7 +541,6 @@ class MambaMixer2(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
@ -583,7 +582,28 @@ class MambaMixer2(MambaBase, CustomOp):
1]
if has_prefill else None)
ssd_output_list = []
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if has_prefill:
@ -623,7 +643,8 @@ class MambaMixer2(MambaBase, CustomOp):
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
@ -646,15 +667,14 @@ class MambaMixer2(MambaBase, CustomOp):
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
@ -684,8 +704,8 @@ class MambaMixer2(MambaBase, CustomOp):
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state,
hidden_states_d,
dt_d,
@ -697,26 +717,16 @@ class MambaMixer2(MambaBase, CustomOp):
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
if envs.VLLM_USE_V1:
ssd_output_list.insert(
0,
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
else:
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
hidden_states = self.norm(preallocated_ssm_out,
gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)

View File

@ -205,7 +205,8 @@ def selective_state_update(state,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID):
pad_slot_id=PAD_SLOT_ID,
out=None):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@ -223,10 +224,9 @@ def selective_state_update(state,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
out: (batch, dim) or (batch, nheads, dim)
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
@ -245,6 +245,8 @@ def selective_state_update(state,
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
@ -264,7 +266,8 @@ def selective_state_update(state,
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
out = torch.empty_like(x)
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
(0, 0, 0))
@ -328,9 +331,6 @@ def selective_state_update(state,
BLOCK_SIZE_M,
num_warps=num_warps,
)
if not has_heads:
out = out.squeeze(1)
return out
def selective_scan_fn(u,

View File

@ -454,6 +454,7 @@ def _chunk_scan_fwd(
chunk_indices=None,
chunk_offsets=None,
initial_states=None,
out=None,
):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
@ -483,20 +484,10 @@ def _chunk_scan_fwd(
else:
chunk_indices, chunk_offsets = None, None
# Allocates output.
out = torch.empty(batch,
seqlen,
nheads,
headdim,
device=x.device,
dtype=x.dtype)
assert out.shape == x.shape
if z is not None:
out_x = torch.empty(batch,
seqlen,
nheads,
headdim,
device=x.device,
dtype=x.dtype)
out_x = torch.empty_like(x)
assert out_x.stride() == out.stride()
else:
out_x = None
@ -579,4 +570,4 @@ def _chunk_scan_fwd(
IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return out, out_x
return out_x

View File

@ -36,7 +36,8 @@ def _mamba_chunk_scan_combined_fwd(x,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
dt_limit=(0.0, float("inf")),
out=None):
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
@ -134,7 +135,7 @@ def _mamba_chunk_scan_combined_fwd(x,
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out, out_x = _chunk_scan_fwd(
out_x = _chunk_scan_fwd(
CB,
x,
dt,
@ -147,9 +148,10 @@ def _mamba_chunk_scan_combined_fwd(x,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
out=out,
)
if cu_seqlens is None:
return out, out_x, dt, dA_cumsum, states, final_states
return out_x, dt, dA_cumsum, states, final_states
else:
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
varlen_states = chunk_state_varlen(
@ -161,7 +163,7 @@ def _mamba_chunk_scan_combined_fwd(x,
states.squeeze(0),
initial_states=initial_states,
)
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
return out_x, dt, dA_cumsum, states, final_states, varlen_states
def mamba_chunk_scan_combined(x,
@ -180,6 +182,7 @@ def mamba_chunk_scan_combined(x,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False):
"""
@ -197,15 +200,14 @@ def mamba_chunk_scan_combined(x,
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
Return:
out: (batch, seqlen, nheads, headdim)
out: Preallocated output tensor
"""
if not return_varlen_states:
cu_seqlens = None
else:
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
@ -221,12 +223,14 @@ def mamba_chunk_scan_combined(x,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
dt_limit=dt_limit,
out=out)
if not return_varlen_states:
return out if not return_final_states else (out, final_states)
if not return_final_states:
return
else:
return final_states
else:
varlen_states = rest[0]
return (out,
varlen_states) if not return_final_states else (out,
final_states,
return (varlen_states) if not return_final_states else (final_states,
varlen_states)

View File

@ -387,7 +387,8 @@ class Phi4Mamba(nn.Module):
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
@ -400,7 +401,8 @@ class Phi4Mamba(nn.Module):
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection

View File

@ -257,7 +257,21 @@ class Plamo2MambaMixer(nn.Module):
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
ssd_output_list = []
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if has_prefill:
@ -290,7 +304,7 @@ class Plamo2MambaMixer(nn.Module):
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
@ -312,15 +326,14 @@ class Plamo2MambaMixer(nn.Module):
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
@ -349,8 +362,7 @@ class Plamo2MambaMixer(nn.Module):
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states_d,
dt,
@ -362,17 +374,13 @@ class Plamo2MambaMixer(nn.Module):
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
assert self.num_heads % self.tp_size == 0
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to MLP
hidden_states = torch.vstack(ssd_output_list)
# 4. Final linear projection
out = self.out_proj(hidden_states)
out = self.out_proj(preallocated_ssm_out)
return out