[V1][Mamba1] - FP32 SSM Kernel Support (#23506)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-09-02 06:53:00 +03:00
committed by GitHub
parent 0235103cbb
commit 2b41cbbf03
3 changed files with 65 additions and 32 deletions

View File

@ -27,11 +27,12 @@
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
bool kIsVariableB_, bool kIsVariableC_, bool kIsVariableB_, bool kIsVariableC_,
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_> bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
struct Selective_Scan_fwd_kernel_traits { struct Selective_Scan_fwd_kernel_traits {
static_assert(kNItems_ % 4 == 0); static_assert(kNItems_ % 4 == 0);
using input_t = input_t_; using input_t = input_t_;
using weight_t = weight_t_; using weight_t = weight_t_;
using state_t = state_t_;
static constexpr int kNThreads = kNThreads_; static constexpr int kNThreads = kNThreads_;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride; weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride + cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride; dim_id * kNRows * params.ssm_states_dim_stride;
@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix; smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) { if (chunk == n_chunks - 1) {
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y); ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
} }
} }
#pragma unroll #pragma unroll
@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
} }
} }
template<int kNThreads, int kNItems, typename input_t, typename weight_t> template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) { void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row. // processing 1 row.
@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>; using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t, state_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows); dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>; auto kernel = &selective_scan_fwd_kernel<Ktraits>;
@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
}); });
} }
template<typename input_t, typename weight_t> template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) { void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
#ifndef USE_ROCM #ifndef USE_ROCM
if (params.seqlen <= 128) { if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) { } else if (params.seqlen <= 256) {
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) { } else if (params.seqlen <= 512) {
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 1024) { } else if (params.seqlen <= 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else { } else {
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
} }
#else #else
if (params.seqlen <= 256) { if (params.seqlen <= 256) {
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) { } else if (params.seqlen <= 512) {
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 1024) { } else if (params.seqlen <= 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else { } else {
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
} }
#endif #endif
} }
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream); template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream); template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream); template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase &params, cudaStream_t stream);
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \ if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \ using input_t = at::Half; \
using weight_t = float; \ using weight_t = float; \
__VA_ARGS__(); \ if (STYPE == at::ScalarType::Half) { \
using state_t = at::Half; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
} \
} else if (ITYPE == at::ScalarType::BFloat16) { \ } else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \ using input_t = at::BFloat16; \
using weight_t = float; \ using weight_t = float; \
__VA_ARGS__(); \ if (STYPE == at::ScalarType::BFloat16) { \
using state_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
} \
} else if (ITYPE == at::ScalarType::Float) { \ } else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \ using input_t = float; \
using weight_t = float; \ using weight_t = float; \
using state_t = float; \
__VA_ARGS__(); \ __VA_ARGS__(); \
} else { \ } else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
} }
template<typename input_t, typename weight_t> template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream); void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
void set_ssm_params_fwd(SSMParamsBase &params, void set_ssm_params_fwd(SSMParamsBase &params,
@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta; at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type); // ssm_states can now be either the same as input_type or float32
auto state_type = ssm_states.scalar_type();
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1); TORCH_CHECK(ssm_states.stride(-1) == 1);
@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream); selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
}); });
} }

View File

@ -65,6 +65,11 @@ V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B", "LiquidAI/LFM2-1.2B",
] ]
FP32_STATE_MODELS = [
"state-spaces/mamba-130m-hf",
"Zyphra/Zamba2-1.2B-instruct",
]
# Avoid OOM # Avoid OOM
MAX_NUM_SEQS = 4 MAX_NUM_SEQS = 4
@ -434,7 +439,7 @@ def test_full_cuda_graph(
) )
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) @pytest.mark.parametrize("model", FP32_STATE_MODELS)
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_fp32_state( def test_fp32_state(

View File

@ -30,12 +30,8 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype: MambaDType, mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]: ) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires kernel changes return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32": mamba_ssm_cache_dtype)
raise ValueError("fp32 state for mamba1 is not yet supported")
else:
return MambaStateDtypeCalculator.mamba2_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
@classmethod @classmethod
def mamba2_state_dtype( def mamba2_state_dtype(
@ -43,6 +39,16 @@ class MambaStateDtypeCalculator:
model_dtype: Union[ModelDType, torch.dtype], model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType, mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
@classmethod
def _mamba_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]: ) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype) model_dtype)