[Kernel] Add support for block FP8 on SM120 (NVIDIA 5090 and RTX PRO 6000) (#22131)

Signed-off-by: Junhao Li <junhao@ubicloud.com>
This commit is contained in:
Junhao Li
2025-08-07 22:18:28 -04:00
committed by GitHub
parent b2c8ce57c6
commit 3303f134e0
6 changed files with 229 additions and 18 deletions

View File

@ -427,6 +427,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"

View File

@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
#endif
}
};
template <typename Kernel>
struct enable_sm120_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};

View File

@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@ -0,0 +1,183 @@
#pragma once
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace vllm {
using namespace cute;
// clang-format off
template <class OutType, int ScaleGranularityM,
int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise {
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementAB;
// ColumnMajor is used for B to match the CUTLASS convention.
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementC = void; // TODO: support bias
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm120;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
auto mainloop_args = [&](){
return typename GemmKernel::MainloopArguments{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
};
}();
auto prob_shape = cute::make_shape(m, n, k, 1);
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
// TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
cutlass::gemm::collective::KernelScheduleAuto>>(
out, a, b, a_scales, b_scales);
}
} // namespace vllm

View File

@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm

View File

@ -1,11 +1,9 @@
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm120 (Blackwell Geforce).
NVIDIA GPUs with sm120 (Blackwell).
*/
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm120_fp8,
nullptr, // int8 not supported on SM120
vllm::cutlass_scaled_mm_blockwise_sm120_fp8);
}
#endif