Revert "[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of U… (#14848)

This commit is contained in:
Tyler Michael Smith
2025-03-14 23:45:42 -04:00
committed by GitHub
parent acaea3bb07
commit ccf02fcbae

View File

@ -466,17 +466,10 @@ class MambaMixer2(CustomOp):
if has_prefill:
initial_states = None
if has_initial_states is not None and torch.any(
has_initial_states):
# vectorized ssm_state zero init
batched_zero_init_func = torch.vmap(
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
batched_zero_init_func(
mamba_cache_params.
state_indices_tensor[~has_initial_states].unsqueeze(
dim=-1), )
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[
~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
@ -500,17 +493,10 @@ class MambaMixer2(CustomOp):
dt_limit=(0.0, float("inf")),
)
# vectorized ssm state update using vmap
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
# limitation which doesn't allow use of `item()`
# Note: the lambda capture can happen where ssm_state is initialized
# instead of here
batched_copy = torch.vmap(
lambda idx, source_state: mamba_cache_params.ssm_state[
idx].copy_(source_state))
batched_copy(
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
varlen_state)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
# - reshape
hidden_states = scan_output.view(seq_len, -1)