mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[V1][Mamba1] - FP32 SSM Kernel Support (#23506)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0235103cbb
commit
2b41cbbf03
@ -27,11 +27,12 @@
|
|||||||
|
|
||||||
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||||
bool kIsVariableB_, bool kIsVariableC_,
|
bool kIsVariableB_, bool kIsVariableC_,
|
||||||
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_>
|
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
|
||||||
struct Selective_Scan_fwd_kernel_traits {
|
struct Selective_Scan_fwd_kernel_traits {
|
||||||
static_assert(kNItems_ % 4 == 0);
|
static_assert(kNItems_ % 4 == 0);
|
||||||
using input_t = input_t_;
|
using input_t = input_t_;
|
||||||
using weight_t = weight_t_;
|
using weight_t = weight_t_;
|
||||||
|
using state_t = state_t_;
|
||||||
static constexpr int kNThreads = kNThreads_;
|
static constexpr int kNThreads = kNThreads_;
|
||||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||||
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||||
@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
||||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
||||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
|
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
|
||||||
cache_index * params.ssm_states_batch_stride +
|
cache_index * params.ssm_states_batch_stride +
|
||||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||||
|
|
||||||
@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||||
if (chunk == n_chunks - 1) {
|
if (chunk == n_chunks - 1) {
|
||||||
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
|
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
|
||||||
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||||
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||||
// processing 1 row.
|
// processing 1 row.
|
||||||
@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t, state_t>;
|
||||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||||
dim3 grid(params.batch, params.dim / kNRows);
|
dim3 grid(params.batch, params.dim / kNRows);
|
||||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||||
@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
template<typename input_t, typename weight_t, typename state_t>
|
||||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
if (params.seqlen <= 128) {
|
if (params.seqlen <= 128) {
|
||||||
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
|
||||||
} else if (params.seqlen <= 256) {
|
} else if (params.seqlen <= 256) {
|
||||||
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
|
||||||
} else if (params.seqlen <= 512) {
|
} else if (params.seqlen <= 512) {
|
||||||
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
|
||||||
} else if (params.seqlen <= 1024) {
|
} else if (params.seqlen <= 1024) {
|
||||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
|
||||||
} else {
|
} else {
|
||||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (params.seqlen <= 256) {
|
if (params.seqlen <= 256) {
|
||||||
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
|
||||||
} else if (params.seqlen <= 512) {
|
} else if (params.seqlen <= 512) {
|
||||||
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
|
||||||
} else if (params.seqlen <= 1024) {
|
} else if (params.seqlen <= 1024) {
|
||||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
|
||||||
} else {
|
} else {
|
||||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||||
|
|
||||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
|
||||||
if (ITYPE == at::ScalarType::Half) { \
|
if (ITYPE == at::ScalarType::Half) { \
|
||||||
using input_t = at::Half; \
|
using input_t = at::Half; \
|
||||||
using weight_t = float; \
|
using weight_t = float; \
|
||||||
__VA_ARGS__(); \
|
if (STYPE == at::ScalarType::Half) { \
|
||||||
|
using state_t = at::Half; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else if (STYPE == at::ScalarType::Float) { \
|
||||||
|
using state_t = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
||||||
|
} \
|
||||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||||
using input_t = at::BFloat16; \
|
using input_t = at::BFloat16; \
|
||||||
using weight_t = float; \
|
using weight_t = float; \
|
||||||
__VA_ARGS__(); \
|
if (STYPE == at::ScalarType::BFloat16) { \
|
||||||
|
using state_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else if (STYPE == at::ScalarType::Float) { \
|
||||||
|
using state_t = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
||||||
|
} \
|
||||||
} else if (ITYPE == at::ScalarType::Float) { \
|
} else if (ITYPE == at::ScalarType::Float) { \
|
||||||
using input_t = float; \
|
using input_t = float; \
|
||||||
using weight_t = float; \
|
using weight_t = float; \
|
||||||
|
using state_t = float; \
|
||||||
__VA_ARGS__(); \
|
__VA_ARGS__(); \
|
||||||
} else { \
|
} else { \
|
||||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
template<typename input_t, typename weight_t, typename state_t>
|
||||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||||
@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
|
|
||||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||||
at::Tensor out = delta;
|
at::Tensor out = delta;
|
||||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
// ssm_states can now be either the same as input_type or float32
|
||||||
|
auto state_type = ssm_states.scalar_type();
|
||||||
|
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
|
||||||
TORCH_CHECK(ssm_states.is_cuda());
|
TORCH_CHECK(ssm_states.is_cuda());
|
||||||
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
||||||
|
|
||||||
@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
|
||||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -65,6 +65,11 @@ V0_UNSUPPORTED_MODELS = [
|
|||||||
"LiquidAI/LFM2-1.2B",
|
"LiquidAI/LFM2-1.2B",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
FP32_STATE_MODELS = [
|
||||||
|
"state-spaces/mamba-130m-hf",
|
||||||
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
# Avoid OOM
|
# Avoid OOM
|
||||||
MAX_NUM_SEQS = 4
|
MAX_NUM_SEQS = 4
|
||||||
|
|
||||||
@ -434,7 +439,7 @@ def test_full_cuda_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_fp32_state(
|
def test_fp32_state(
|
||||||
|
@ -30,12 +30,8 @@ class MambaStateDtypeCalculator:
|
|||||||
mamba_cache_dtype: MambaDType,
|
mamba_cache_dtype: MambaDType,
|
||||||
mamba_ssm_cache_dtype: MambaDType,
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
) -> tuple[torch.dtype, ...]:
|
) -> tuple[torch.dtype, ...]:
|
||||||
# TODO (tdoublep) requires kernel changes
|
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||||
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
|
mamba_ssm_cache_dtype)
|
||||||
raise ValueError("fp32 state for mamba1 is not yet supported")
|
|
||||||
else:
|
|
||||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
|
||||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def mamba2_state_dtype(
|
def mamba2_state_dtype(
|
||||||
@ -43,6 +39,16 @@ class MambaStateDtypeCalculator:
|
|||||||
model_dtype: Union[ModelDType, torch.dtype],
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
mamba_cache_dtype: MambaDType,
|
mamba_cache_dtype: MambaDType,
|
||||||
mamba_ssm_cache_dtype: MambaDType,
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||||
|
mamba_ssm_cache_dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _mamba_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
) -> tuple[torch.dtype, ...]:
|
) -> tuple[torch.dtype, ...]:
|
||||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||||
model_dtype)
|
model_dtype)
|
||||||
|
Reference in New Issue
Block a user