mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
],
|
||||
dim=0)
|
||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].clone()
|
||||
state_before = state.clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
if not tie_hdim:
|
||||
dt = torch.randn(batch_size,
|
||||
nheads,
|
||||
@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
|
@ -212,15 +212,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||
B, C, chunk_size)
|
||||
|
||||
Y, final_state = mamba_chunk_scan_combined(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
return_final_states=True)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
return_final_states=True,
|
||||
out=Y)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||
@ -292,7 +293,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
Y, new_states = mamba_chunk_scan_combined(
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
@ -306,6 +308,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=states,
|
||||
out=Y,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
|
@ -220,7 +220,8 @@ class MambaMixer(CustomOp):
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
@ -231,7 +232,8 @@ class MambaMixer(CustomOp):
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
|
@ -541,7 +541,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
@ -583,7 +582,28 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
1]
|
||||
if has_prefill else None)
|
||||
|
||||
ssd_output_list = []
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
num_prefill_tokens + num_decodes,
|
||||
(self.num_heads // self.tp_size) * self.head_dim
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
if envs.VLLM_USE_V1:
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
@ -623,7 +643,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
has_initial_states_p[:num_prefills, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
@ -646,15 +667,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
@ -684,8 +704,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states_d = selective_state_update(
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
selective_state_update(
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
@ -697,26 +717,16 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
ssd_output_list.insert(
|
||||
0,
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
else:
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# Merge prefill and decode outputs before passing to gated MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
|
||||
hidden_states = self.norm(preallocated_ssm_out,
|
||||
gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
@ -205,7 +205,8 @@ def selective_state_update(state,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID):
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
@ -223,10 +224,9 @@ def selective_state_update(state,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
@ -245,6 +245,8 @@ def selective_state_update(state,
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
@ -264,7 +266,8 @@ def selective_state_update(state,
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
out = torch.empty_like(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
@ -328,9 +331,6 @@ def selective_state_update(state,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_fn(u,
|
||||
|
@ -454,6 +454,7 @@ def _chunk_scan_fwd(
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
out=None,
|
||||
):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
@ -483,20 +484,10 @@ def _chunk_scan_fwd(
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
# Allocates output.
|
||||
out = torch.empty(batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
headdim,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
assert out.shape == x.shape
|
||||
|
||||
if z is not None:
|
||||
out_x = torch.empty(batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
headdim,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
out_x = torch.empty_like(x)
|
||||
assert out_x.stride() == out.stride()
|
||||
else:
|
||||
out_x = None
|
||||
@ -579,4 +570,4 @@ def _chunk_scan_fwd(
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out, out_x
|
||||
return out_x
|
||||
|
@ -36,7 +36,8 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf"))):
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
@ -134,7 +135,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
out, out_x = _chunk_scan_fwd(
|
||||
out_x = _chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
@ -147,9 +148,10 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
out=out,
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
return out, out_x, dt, dA_cumsum, states, final_states
|
||||
return out_x, dt, dA_cumsum, states, final_states
|
||||
else:
|
||||
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
||||
varlen_states = chunk_state_varlen(
|
||||
@ -161,7 +163,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
states.squeeze(0),
|
||||
initial_states=initial_states,
|
||||
)
|
||||
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined(x,
|
||||
@ -180,6 +182,7 @@ def mamba_chunk_scan_combined(x,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None,
|
||||
return_final_states=False,
|
||||
return_varlen_states=False):
|
||||
"""
|
||||
@ -197,15 +200,14 @@ def mamba_chunk_scan_combined(x,
|
||||
seq_idx: (batch, seqlen)
|
||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
out: Preallocated output tensor
|
||||
"""
|
||||
|
||||
if not return_varlen_states:
|
||||
cu_seqlens = None
|
||||
else:
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
||||
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
||||
out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
@ -221,12 +223,14 @@ def mamba_chunk_scan_combined(x,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit)
|
||||
dt_limit=dt_limit,
|
||||
out=out)
|
||||
if not return_varlen_states:
|
||||
return out if not return_final_states else (out, final_states)
|
||||
if not return_final_states:
|
||||
return
|
||||
else:
|
||||
return final_states
|
||||
else:
|
||||
varlen_states = rest[0]
|
||||
return (out,
|
||||
varlen_states) if not return_final_states else (out,
|
||||
final_states,
|
||||
return (varlen_states) if not return_final_states else (final_states,
|
||||
varlen_states)
|
||||
|
@ -387,7 +387,8 @@ class Phi4Mamba(nn.Module):
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
@ -400,7 +401,8 @@ class Phi4Mamba(nn.Module):
|
||||
None if self.yoco_kv else gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
|
@ -257,7 +257,21 @@ class Plamo2MambaMixer(nn.Module):
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
if has_prefill else None)
|
||||
|
||||
ssd_output_list = []
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
num_prefill_tokens + num_decodes,
|
||||
(self.num_heads // self.tp_size) * self.head_dim
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
@ -290,7 +304,7 @@ class Plamo2MambaMixer(nn.Module):
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
@ -312,15 +326,14 @@ class Plamo2MambaMixer(nn.Module):
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
@ -349,8 +362,7 @@ class Plamo2MambaMixer(nn.Module):
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states_d = selective_state_update(
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states_d,
|
||||
dt,
|
||||
@ -362,17 +374,13 @@ class Plamo2MambaMixer(nn.Module):
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
assert self.num_heads % self.tp_size == 0
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# Merge prefill and decode outputs before passing to MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
|
||||
# 4. Final linear projection
|
||||
out = self.out_proj(hidden_states)
|
||||
out = self.out_proj(preallocated_ssm_out)
|
||||
return out
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user