[Bugfix] Mamba2 remove bugged initial state condition in chunk scan (#22034)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@ -476,15 +476,8 @@ def _chunk_scan_fwd(
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
||||
|
||||
if initial_states.shape[0] == 1:
|
||||
# no in this case no point to use initial states
|
||||
initial_states = None
|
||||
else:
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
(
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
)
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
else:
|
||||
|
Reference in New Issue
Block a user