[Kernel] cuda kernels for upcoming decode context parallel feature (#23791)
Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
17
csrc/cache.h
17
csrc/cache.h
@ -36,6 +36,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
torch::Tensor& cp_local_token_select_indices,
|
||||
torch::Tensor& kv_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
@ -47,4 +54,12 @@ void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
|
||||
void cp_gather_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
@ -395,6 +396,51 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void cp_fused_concat_and_cache_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
|
||||
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
|
||||
const int64_t slot_idx = slot_mapping[blockIdx.x];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
|
||||
int src_stride, int dst_stride, int size, int offset) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -508,6 +554,20 @@ void reshape_and_cache_flash(
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
cp_local_token_select_indices.data_ptr<int64_t>(), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
@ -546,6 +606,50 @@ void concat_and_cache_mla(
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
|
||||
// calls into one:
|
||||
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
|
||||
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
|
||||
// concat_and_cache_mla.
|
||||
void cp_fused_concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
|
||||
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
// both include padding.
|
||||
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
||||
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
||||
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
||||
// before padding.
|
||||
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
||||
// number of tokens.
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
@ -779,3 +883,146 @@ void gather_and_maybe_dequant_cache(
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
template <typename scalar_t>
|
||||
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
|
||||
// block_size.
|
||||
__global__ void cp_gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRY_SIZE]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const int32_t* __restrict__ seq_starts // Optional: starting offsets per
|
||||
// batch
|
||||
) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = cu_seq_lens[bid];
|
||||
const int32_t seq_end = cu_seq_lens[bid + 1];
|
||||
const int32_t seq_len = seq_end - seq_start;
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
const bool is_last_split = (split_end == tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch.
|
||||
// If seq_starts is provided, compute an offset based on it
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
if (seq_starts != nullptr) {
|
||||
offset += seq_starts[bid];
|
||||
}
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths.
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr = dst + pid * dst_entry_stride;
|
||||
copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr);
|
||||
offset += 1;
|
||||
// bump to next block
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting slot index by
|
||||
// seq_starts[bid]
|
||||
void cp_gather_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t entry_size = src_cache.flatten(2, -1).size(2);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
|
||||
"cu_seq_lens must be int32");
|
||||
if (seq_starts.has_value()) {
|
||||
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
||||
"seq_starts must be int32");
|
||||
}
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
|
||||
"src_cache and cu_seq_lens must be on the same device");
|
||||
if (seq_starts.has_value()) {
|
||||
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
|
||||
"src_cache and seq_starts must be on the same device");
|
||||
}
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size.
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_CP_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_CP_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_CP_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
@ -686,6 +686,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor cp_local_token_select_indices,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
|
||||
&cp_fused_concat_and_cache_mla);
|
||||
|
||||
// Convert the key and value cache to fp8 data type.
|
||||
cache_ops.def(
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
@ -702,6 +712,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||
&gather_and_maybe_dequant_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
@ -790,6 +790,78 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", [64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
max_seq_len + 1, (batch_size, ),
|
||||
device=device)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
block_table = torch.empty((batch_size, num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
for b in range(batch_size):
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
s = seq_len_tensor[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = tot_blocks_tensor[b]
|
||||
blocks = block_table[b, :tot].tolist()
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.cp_gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
|
@ -1625,6 +1625,20 @@ def concat_and_cache_mla(
|
||||
scale)
|
||||
|
||||
|
||||
def cp_fused_concat_and_cache_mla(
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
cp_local_token_select_indices: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
|
||||
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
@ -1662,6 +1676,16 @@ def gather_and_maybe_dequant_cache(
|
||||
scale, seq_starts)
|
||||
|
||||
|
||||
def cp_gather_cache(src_cache: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seq_lens: torch.Tensor,
|
||||
batch_size: int,
|
||||
seq_starts: Optional[torch.Tensor] = None) -> None:
|
||||
torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, seq_starts)
|
||||
|
||||
|
||||
def get_device_attribute(attribute: int, device: int) -> int:
|
||||
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
||||
|
||||
|
Reference in New Issue
Block a user