[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
31b96d1c64
commit
47043eb678
@ -232,7 +232,6 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
|
||||
@ -1,656 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
int conv_state_len;
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
void *__restrict__ query_start_loc_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
int32_t *__restrict__ cache_seqlens;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
|
||||
void * conv_states_ptr;
|
||||
index_t conv_states_batch_stride;
|
||||
index_t conv_states_l_stride;
|
||||
index_t conv_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -1,28 +0,0 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
16
csrc/ops.h
16
csrc/ops.h
@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
||||
@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor!? conv_states,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
|
||||
@ -6,9 +6,8 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
||||
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
|
||||
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("has_initial_state", [True, False])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
||||
@pytest.mark.parametrize('dim', [64])
|
||||
@pytest.mark.parametrize('batch', [1])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
has_initial_state, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
x = torch.randn(batch, dim, seqlen, device=device,
|
||||
dtype=itype).contiguous()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
if has_initial_state:
|
||||
initial_states = torch.randn(batch,
|
||||
dim,
|
||||
width - 1,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
has_initial_state_tensor = torch.ones(batch,
|
||||
dtype=torch.bool,
|
||||
device=x.device)
|
||||
else:
|
||||
initial_states = None
|
||||
has_initial_state_tensor = None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
initial_states_ref = initial_states.clone(
|
||||
) if initial_states is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor)
|
||||
out_ref, final_states_ref = causal_conv1d_ref(
|
||||
x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=True,
|
||||
activation=activation)
|
||||
if has_initial_state:
|
||||
assert initial_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(initial_states,
|
||||
final_states_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, None, PAD_SLOT_ID))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||
@pytest.mark.parametrize("width", [3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
seqlen, has_bias,
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
||||
width, seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
# total_entries = number of cache line
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
||||
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
|
||||
dtype=itype).transpose(1, 2)
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
# conv_state will be (cache_lines, dim, state_len)
|
||||
# with contiguous along dim-axis
|
||||
conv_state = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
dtype=itype).transpose(1, 2)
|
||||
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.equal(conv_state[unused_states_bool],
|
||||
conv_state_for_padding_test[unused_states_bool])
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
|
||||
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
|
||||
@pytest.mark.parametrize('dim', [64, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize('with_padding', [True, False])
|
||||
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
silu_activation, itype):
|
||||
@pytest.mark.parametrize('batch', [4, 10])
|
||||
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
|
||||
has_bias, silu_activation, itype):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
batch_size = 1
|
||||
batch_size = batch
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat(
|
||||
@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0)
|
||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
|
||||
dtype=itype)[:, 4096:4096 + dim, :]
|
||||
x = rearrange(
|
||||
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||
"b s d -> b d s")[:, 4096:4096 + dim, :]
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
dim,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dtype=x.dtype).transpose(1, 2)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(0,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
out = causal_conv1d_fn(x.squeeze(0),
|
||||
weight,
|
||||
bias=bias,
|
||||
conv_states=final_states,
|
||||
query_start_loc=cumsum.cuda(),
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
|
||||
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation, PAD_SLOT_ID)
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation)
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
|
||||
@ -6,11 +6,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
|
||||
@ -963,17 +963,17 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
||||
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
|
||||
out_dtype: torch.dtype, device: torch.device):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
|
||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||
input and expert weights.
|
||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
sf_offset[E+1] - sf_offset[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
|
||||
|
||||
|
||||
# mamba
|
||||
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
conv_states: Optional[torch.Tensor],
|
||||
query_start_loc: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
silu_activation: bool, pad_slot_id: int):
|
||||
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
|
||||
query_start_loc, cache_indices,
|
||||
has_initial_state, silu_activation,
|
||||
pad_slot_id)
|
||||
|
||||
|
||||
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
pad_slot_id: int):
|
||||
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
||||
silu_activation, cache_seqlens,
|
||||
conv_state_indices, pad_slot_id)
|
||||
|
||||
|
||||
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||
B: torch.Tensor, C: torch.Tensor,
|
||||
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -21,6 +25,29 @@ class Mamba2Metadata:
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
"""
|
||||
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||
to handle a request in parallel, two supporting tensors are used
|
||||
(batch_ptr, token_chunk_offset_ptr)
|
||||
BLOCK_M = the # tokens to be handled by a Triton program
|
||||
(can be customized for different hardware)
|
||||
|
||||
nums_dict:
|
||||
tracks the data associated with a given value of BLOCK_M
|
||||
BLOCK_M = #tokens handled by a Triton program
|
||||
cu_seqlen: total tokens per batch
|
||||
(used as flag to update other data at each new input)
|
||||
batch_ptr: tracks batch-id handled by the Triton program
|
||||
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
|
||||
(Triton implementation of causal_conv1d handles parallelism in 3-axes
|
||||
- feature-axis
|
||||
- batch-axis
|
||||
- sequence-axis)
|
||||
"""
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
|
||||
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
total_seqlens: int):
|
||||
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||
> 0).sum()
|
||||
chunk_indices = torch.arange(N,
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
chunk_offsets = torch.zeros((N, ),
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
# if does not divide chunk_size, then there is one chunk insertion
|
||||
p += (s % chunk_size > 0)
|
||||
|
||||
# get the dimensions
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||
> 0)
|
||||
|
||||
# adjust inidces and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
chunk_offsets[_s] = s % chunk_size
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
|
||||
def prepare_mamba2_metadata(
|
||||
chunk_size: int,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba2_metadata=None,
|
||||
) -> Mamba2Metadata:
|
||||
|
||||
# compute number of prefill and decode requests
|
||||
@ -96,12 +88,12 @@ def prepare_mamba2_metadata(
|
||||
attn_metadata_instances = get_platform_metadata_classes()
|
||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
has_initial_states = \
|
||||
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
||||
# precompute flag to avoid device syncs in mamba2 layer forwards
|
||||
# precompute flag to avoid device syncs later in mamba2 layer
|
||||
# forwards
|
||||
# prep is only needed for mamba2 ssd prefill processing
|
||||
prep_initial_states = torch.any(has_initial_states).item()
|
||||
|
||||
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||
prep_initial_states = torch.any(
|
||||
has_initial_states[:num_prefills]).item()
|
||||
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
seq_idx = torch.repeat_interleave(torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
||||
@ -117,9 +109,78 @@ def prepare_mamba2_metadata(
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc, chunk_size, num_prefill_tokens)
|
||||
|
||||
if mamba2_metadata is not None:
|
||||
mamba2_metadata.has_initial_states = has_initial_states
|
||||
mamba2_metadata.prep_initial_states = prep_initial_states
|
||||
mamba2_metadata.chunk_size = chunk_size
|
||||
mamba2_metadata.seq_idx = seq_idx
|
||||
mamba2_metadata.chunk_indices = chunk_indices
|
||||
mamba2_metadata.chunk_offsets = chunk_offsets
|
||||
# We use 1 reset flag:
|
||||
# * mamba2_metadata.cu_seqlen is None
|
||||
# update config specific to (each input)
|
||||
# (become available at first layer, e.g. conv_weights)
|
||||
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
|
||||
|
||||
return mamba2_metadata
|
||||
return Mamba2Metadata(has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets)
|
||||
|
||||
|
||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||
mamba2_metadata: Union[Mamba2Metadata,
|
||||
Mamba2AttentionMetadata]):
|
||||
"""
|
||||
this is triggered upon handling a new input at the first layer
|
||||
"""
|
||||
dim, cu_seqlen = x.shape
|
||||
mamba2_metadata.cu_seqlen = cu_seqlen
|
||||
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||
nums_dict = {} # type: ignore
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
nums_dict[BLOCK_M]['nums'] = nums
|
||||
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||
offsetlist = [] # type: ignore
|
||||
for idx, num in enumerate(nums):
|
||||
offsetlist.extend(range(num))
|
||||
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||
|
||||
if mamba2_metadata.batch_ptr is None:
|
||||
# Update default value after class definition
|
||||
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
|
||||
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
mamba2_metadata.token_chunk_offset_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
else:
|
||||
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
|
||||
PAD_SLOT_ID)
|
||||
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
|
||||
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
|
||||
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
|
||||
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
|
||||
0:mlist_len].copy_(offsetlist)
|
||||
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
|
||||
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
|
||||
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
|
||||
mamba2_metadata.nums_dict = nums_dict
|
||||
return mamba2_metadata
|
||||
|
||||
@ -159,7 +159,7 @@ class MambaMixer(CustomOp):
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
|
||||
@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader(
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> LoaderFunction:
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures that all the groups corresponding to a head shard is placed
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures that all the groups corresponding to a head shard is placed
|
||||
together with it.
|
||||
"""
|
||||
|
||||
@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba2_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C,
|
||||
@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(
|
||||
x, attn_metadata.query_start_loc, mamba2_metadata)
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
hidden_states_B_C_p.transpose(0, 1),
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
# TODO: Why is this needed?
|
||||
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p)
|
||||
|
||||
@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (self.intermediate_size +
|
||||
2 * n_groups * self.ssm_state_size)
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.conv_kernel_size - 1,
|
||||
divide(conv_dim, world_size),
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache):
|
||||
# Initialize parent class
|
||||
super().__init__(max_batch_size)
|
||||
|
||||
# assume conv_state = (dim, state_len)
|
||||
assert conv_state_shape[0] > conv_state_shape[1]
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
(conv_state_shape[1], conv_state_shape[0]),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device="cuda").transpose(-1, -2)
|
||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
|
||||
return chunk_sizes.pop()
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
total_seqlens: int):
|
||||
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||
> 0).sum()
|
||||
chunk_indices = torch.arange(N,
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
chunk_offsets = torch.zeros((N, ),
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
# if does not divide chunk_size, then there is one chunk insertion
|
||||
p += (s % chunk_size > 0)
|
||||
|
||||
# get the dimensions
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||
> 0)
|
||||
|
||||
# adjust indices and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
chunk_offsets[_s] = s % chunk_size
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
@ -53,6 +88,10 @@ class Mamba2AttentionMetadata:
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
|
||||
Reference in New Issue
Block a user