mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
1 Commits
d31f7844f8
...
mla_decode
Author | SHA1 | Date | |
---|---|---|---|
32e4481626 |
@ -563,7 +563,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
|
364
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
364
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
@ -0,0 +1,364 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
int max_splits = ceil_div(K, 128);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,203 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,165 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
273
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
273
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
@ -0,0 +1,273 @@
|
||||
/*
|
||||
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/kernel_hardware_info.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
// clang-format off
|
||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||
}
|
||||
#else
|
||||
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape,
|
||||
Element,
|
||||
ElementAcc,
|
||||
ElementOut,
|
||||
ElementAcc,
|
||||
TileScheduler,
|
||||
/*kIsCpAsync=*/!IsPaged128>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_nope = cute::make_tuple(
|
||||
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||
StrideQ stride_Q_pe = cute::make_tuple(
|
||||
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||
|
||||
StrideK stride_C = cute::make_tuple(
|
||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale,
|
||||
Q_nope_ptr,
|
||||
stride_Q_nope,
|
||||
Q_pe_ptr,
|
||||
stride_Q_pe,
|
||||
C_ptr,
|
||||
stride_C,
|
||||
C_ptr + D_latent,
|
||||
stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()),
|
||||
stride_PT,
|
||||
page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
// Maybe per batch split kv will fix this.
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
||||
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
||||
arguments.problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
}
|
||||
|
||||
#endif
|
||||
// clang-format on
|
13
csrc/ops.h
13
csrc/ops.h
@ -167,6 +167,19 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace, double sm_scale,
|
||||
int64_t num_kv_splits =
|
||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
|
||||
int64_t num_kv_splits =
|
||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
@ -514,6 +514,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
ops.impl("sm100_cutlass_mla_get_workspace_size",
|
||||
&sm100_cutlass_mla_get_workspace_size);
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
|
@ -1843,6 +1843,26 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
|
||||
return out
|
||||
|
||||
|
||||
def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor, page_table: torch.Tensor,
|
||||
workspace: torch.Tensor, scale: float,
|
||||
num_kv_splits: int) -> torch.Tensor:
|
||||
torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe,
|
||||
kv_c_and_k_pe_cache, seq_lens,
|
||||
page_table, workspace, scale,
|
||||
num_kv_splits)
|
||||
return out
|
||||
|
||||
|
||||
def sm100_cutlass_mla_get_workspace_size(max_seq_len: int, num_batches: int,
|
||||
sm_count: int,
|
||||
num_kv_splits: int) -> int:
|
||||
return torch.ops._C.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seq_len, num_batches, sm_count, num_kv_splits)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "weight_packed_linear"):
|
||||
|
||||
@register_fake("_C::weight_packed_linear")
|
||||
|
@ -1502,6 +1502,7 @@ class EngineArgs:
|
||||
"TRITON_ATTN_VLLM_V1",
|
||||
"TRITON_MLA",
|
||||
"CUTLASS_MLA_VLLM_V1",
|
||||
"SM100_CUTLASS_MLA_VLLM_V1",
|
||||
"FLASHMLA",
|
||||
"FLASHINFER",
|
||||
"FLASHINFER_VLLM_V1",
|
||||
|
@ -166,6 +166,13 @@ class CudaPlatformBase(Platform):
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
use_sm100_cutlass = (envs.VLLM_ATTENTION_BACKEND is not None \
|
||||
and envs.VLLM_ATTENTION_BACKEND == "SM100_CUTLASS_MLA_VLLM_V1")
|
||||
if use_sm100_cutlass and cache_config.block_size != 128:
|
||||
cache_config.block_size = 128
|
||||
logger.info("Forcing kv cache block size to 128 for "
|
||||
"SM100_CUTLASS_MLA_VLLM_V1 backend.")
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
@ -196,6 +203,17 @@ class CudaPlatformBase(Platform):
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
if use_mla:
|
||||
if selected_backend == _Backend.SM100_CUTLASS_MLA_VLLM_V1:
|
||||
if use_v1:
|
||||
logger.info_once(
|
||||
"Using Sm100Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"sm100_cutlass_mla.Sm100CutlassMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"Sm100 Cutlass MLA backend is only supported on "
|
||||
"V1 engine")
|
||||
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
|
||||
|
@ -54,6 +54,7 @@ class _Backend(enum.Enum):
|
||||
FLASHMLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
CUTLASS_MLA_VLLM_V1 = enum.auto()
|
||||
SM100_CUTLASS_MLA_VLLM_V1 = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
|
201
vllm/v1/attention/backends/mla/sm100_cutlass_mla.py
Normal file
201
vllm/v1/attention/backends/mla/sm100_cutlass_mla.py
Normal file
@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Sm100CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SM100_CUTLASS_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["Sm100CutlassMLAImpl"]:
|
||||
return Sm100CutlassMLAImpl
|
||||
|
||||
|
||||
workspace_size = 1024 * 1024 * 1024 # 1GB
|
||||
g_workspace = torch.zeros(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
|
||||
class Sm100CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
# TODO: Ensure the kernel does not hang.
|
||||
# If it happens, then set _num_kv_splits to 1
|
||||
self._num_kv_splits = 8
|
||||
|
||||
# This is a workaround to set a workspace for the sm100 cutlass mla.
|
||||
# We do need to execute sm100_cutlass_mla_get_workspace_size(..)
|
||||
# since it calls set_kv_split() internally, however, the workspace
|
||||
# size we use is large and shared over all executions.
|
||||
# TODO: Refactor this code to a place that can observe params
|
||||
# more fine-grained
|
||||
bs = 128 # dummy
|
||||
PAGE_SIZE = 128 # dummy
|
||||
max_seqlen_pad = 131072 # dummy
|
||||
sm_count = 0 # dummy
|
||||
_ = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seqlen_pad * PAGE_SIZE,
|
||||
bs,
|
||||
sm_count,
|
||||
num_kv_splits=self._num_kv_splits)
|
||||
|
||||
self._workspace = g_workspace
|
||||
|
||||
def sm100_cutlass_mla_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits:
|
||||
int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> torch.Tensor:
|
||||
assert (q_nope.ndim == 3
|
||||
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert (
|
||||
q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert (
|
||||
kv_c_and_k_pe_cache.ndim == 3
|
||||
), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
|
||||
kv_c_and_k_pe_cache.ndim)
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||
q_nope_padded[:, :H] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||
q_pe_padded[:, :H] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert (block_num
|
||||
> 0), f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert (
|
||||
seq_lens.dtype == torch.int32
|
||||
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
assert (
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out[:, :H].contiguous()
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
# TODO: Check if we really need it
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
|
||||
o = self.sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace, self.scale,
|
||||
self._num_kv_splits)
|
||||
|
||||
return self._v_up_proj(o)
|
Reference in New Issue
Block a user