Compare commits

...

1 Commits

13 changed files with 3301 additions and 1 deletions

View File

@ -563,7 +563,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
set(SRCS
"csrc/attention/mla/cutlass_mla_kernels.cu")
"csrc/attention/mla/cutlass_mla_kernels.cu"
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${MLA_ARCHS}")

View File

@ -0,0 +1,364 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
using namespace cute;
using namespace cutlass::fmha::kernel;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class Kernel_
>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
/// Argument structure: User API
using KernelArguments = typename Kernel::Arguments;
using ReductionArguments = typename ReductionKernel::Arguments;
using Arguments = KernelArguments;
/// Argument structure: Kernel API
using KernelParams = typename Kernel::Params;
using ReductionParams = typename ReductionKernel::Params;
struct Params {
KernelParams fmha_params;
ReductionParams reduction_params;
};
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
static ReductionArguments to_reduction_args(Arguments const& args) {
auto [H, K, D, B] = args.problem_shape;
return ReductionArguments{
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
args.ptr_split_kv, Kernel::TileShapeS::value
};
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
static void set_split_kv (KernelArguments& args) {
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
int sm_count = args.hw_info.sm_count;
int max_splits = ceil_div(K, 128);
int sms_per_batch = max(1, sm_count / B);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (! Kernel::can_implement(args)) {
return Status::kInvalid;
}
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
return Status::kInvalid;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
return workspace_bytes;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
if (status != Status::kSuccess) {
return status;
}
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {kernel_params, reduction_params};
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {fmha_params, reduction_params};
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params.fmha_params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess != result or Status::kSuccess != launch_result) {
//return Status::kSuccess;
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
else {
return Status::kSuccess;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,203 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
using Params = Arguments;
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
args.ptr_split_kv, args.tile_shape_s};
}
static size_t get_workspace_size(Arguments const& /*args*/) {
return 0;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return dim3(kNumHeads, 1, params.num_batches);
}
static dim3 get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
static bool can_implement(Arguments const& args) {
if (args.num_batches <= 0) return false;
if (args.split_kv <= 0) return false;
return true;
}
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
ElementAcc local_lse[kNLsePerThread];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
}
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
if (split < local_split_kv) {
sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
}
__syncthreads();
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
ElementAcc local_val[Elements] = {0};
for (int split = 0; split < local_split_kv; ++split) {
ElementAcc lse_scale = sLseScale[split];
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
}
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
}
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,165 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaIndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler(Params const&) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = size<0>(cluster_shape);
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
num_blocks *= split_kv; /* Maximum Split KV*/
return Params {
num_blocks,
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,273 @@
/*
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
}
#else
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using namespace cute;
using namespace cutlass::fmha::kernel;
template <bool v>
struct IsPersistent {
static const bool value = v;
};
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
using ElementOut = T;
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;
// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B
using TileScheduler =
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape,
Element,
ElementAcc,
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};
template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
double sm_scale,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope.sizes()[0];
int page_count_per_seq = page_table.sizes()[1];
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
int page_size = kv_c_and_k_pe_cache.sizes()[1];
int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD;
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
float scale = float(sm_scale);
using StrideQ = typename T::StrideQ;
using StrideK = typename T::StrideK;
using StrideO = typename T::StrideO;
using StrideLSE = typename T::StrideLSE;
StrideQ stride_Q_nope = cute::make_tuple(
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
StrideQ stride_Q_pe = cute::make_tuple(
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
StrideK stride_C = cute::make_tuple(
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
using Element = typename T::Element;
using ElementOut = typename T::ElementOut;
using ElementAcc = typename T::ElementAcc;
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
typename T::Fmha::Arguments arguments{
problem_shape,
{scale,
Q_nope_ptr,
stride_Q_nope,
Q_pe_ptr,
stride_Q_pe,
C_ptr,
stride_C,
C_ptr + D_latent,
stride_C,
static_cast<int*>(seq_lens.data_ptr()),
static_cast<int*>(page_table.data_ptr()),
stride_PT,
page_count_total,
page_size},
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info,
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
num_kv_splits, // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T::Fmha::set_split_kv(arguments);
return arguments;
}
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
return true;
});
return true;
});
}
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
using TileShapeH = typename MlaSm100Type::TileShapeH;
using TileShapeD = typename MlaSm100Type::TileShapeD;
arguments.problem_shape =
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
arguments.split_kv = num_kv_splits;
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
}
#endif
// clang-format on

View File

@ -167,6 +167,19 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table, double scale);
void sm100_cutlass_mla_decode(
torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
torch::Tensor const& workspace, double sm_scale,
int64_t num_kv_splits =
1 /* Set to 1 to avoid cuda_graph issue by default. */);
int64_t sm100_cutlass_mla_get_workspace_size(
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
int64_t num_kv_splits =
1 /* Set to 1 to avoid cuda_graph issue by default. */);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
#ifndef USE_ROCM

View File

@ -514,6 +514,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, Tensor workspace, float "
"scale,"
" int num_kv_splits) -> ()");
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
// SM100 CUTLASS MLA workspace
ops.def(
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
" int sm_count, int num_kv_splits) "
"-> int");
ops.impl("sm100_cutlass_mla_get_workspace_size",
&sm100_cutlass_mla_get_workspace_size);
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input,"

View File

@ -1843,6 +1843,26 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
return out
def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
workspace: torch.Tensor, scale: float,
num_kv_splits: int) -> torch.Tensor:
torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe,
kv_c_and_k_pe_cache, seq_lens,
page_table, workspace, scale,
num_kv_splits)
return out
def sm100_cutlass_mla_get_workspace_size(max_seq_len: int, num_batches: int,
sm_count: int,
num_kv_splits: int) -> int:
return torch.ops._C.sm100_cutlass_mla_get_workspace_size(
max_seq_len, num_batches, sm_count, num_kv_splits)
if hasattr(torch.ops._C, "weight_packed_linear"):
@register_fake("_C::weight_packed_linear")

View File

@ -1502,6 +1502,7 @@ class EngineArgs:
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"CUTLASS_MLA_VLLM_V1",
"SM100_CUTLASS_MLA_VLLM_V1",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",

View File

@ -166,6 +166,13 @@ class CudaPlatformBase(Platform):
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
use_sm100_cutlass = (envs.VLLM_ATTENTION_BACKEND is not None \
and envs.VLLM_ATTENTION_BACKEND == "SM100_CUTLASS_MLA_VLLM_V1")
if use_sm100_cutlass and cache_config.block_size != 128:
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"SM100_CUTLASS_MLA_VLLM_V1 backend.")
compilation_config = vllm_config.compilation_config
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
@ -196,6 +203,17 @@ class CudaPlatformBase(Platform):
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_mla:
if selected_backend == _Backend.SM100_CUTLASS_MLA_VLLM_V1:
if use_v1:
logger.info_once(
"Using Sm100Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"sm100_cutlass_mla.Sm100CutlassMLABackend")
else:
logger.warning(
"Sm100 Cutlass MLA backend is only supported on "
"V1 engine")
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:

View File

@ -54,6 +54,7 @@ class _Backend(enum.Enum):
FLASHMLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
CUTLASS_MLA_VLLM_V1 = enum.auto()
SM100_CUTLASS_MLA_VLLM_V1 = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()

View File

@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class Sm100CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "SM100_CUTLASS_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["Sm100CutlassMLAImpl"]:
return Sm100CutlassMLAImpl
workspace_size = 1024 * 1024 * 1024 # 1GB
g_workspace = torch.zeros(workspace_size, device="cuda", dtype=torch.uint8)
class Sm100CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")
# TODO: Ensure the kernel does not hang.
# If it happens, then set _num_kv_splits to 1
self._num_kv_splits = 8
# This is a workaround to set a workspace for the sm100 cutlass mla.
# We do need to execute sm100_cutlass_mla_get_workspace_size(..)
# since it calls set_kv_split() internally, however, the workspace
# size we use is large and shared over all executions.
# TODO: Refactor this code to a place that can observe params
# more fine-grained
bs = 128 # dummy
PAGE_SIZE = 128 # dummy
max_seqlen_pad = 131072 # dummy
sm_count = 0 # dummy
_ = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE,
bs,
sm_count,
num_kv_splits=self._num_kv_splits)
self._workspace = g_workspace
def sm100_cutlass_mla_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits:
int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> torch.Tensor:
assert (q_nope.ndim == 3
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert (
q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert (
kv_c_and_k_pe_cache.ndim == 3
), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
kv_c_and_k_pe_cache.ndim)
B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
D_latent = 512
D_rope = 64
assert D_q_nope == D_latent
assert D_q_pe == D_rope
assert D_ckv == D_latent + D_rope
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert (block_num
> 0), f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8
assert q_nope.dtype in (
torch.float16,
torch.bfloat16,
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
assert (
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
ops.sm100_cutlass_mla_decode(
out,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
return out[:, :H].contiguous()
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
# TODO: Check if we really need it
q_nope = q_nope.clone()
q_pe = q_pe.clone()
o = self.sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table,
self._workspace, self.scale,
self._num_kv_splits)
return self._v_up_proj(o)