[feat]: CUTLASS block scaled group gemm for SM100 (#19757)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Duncan Moss <dmoss@nvidia.com>
This commit is contained in:
Duncan Moss
2025-07-04 11:58:04 -07:00
committed by GitHub
parent 2f35a022e6
commit 3d184b95b8
13 changed files with 726 additions and 30 deletions

View File

@ -259,7 +259,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@ -615,6 +615,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"in CUDA target architectures.")
endif()
endif()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# Machete kernels

View File

@ -45,7 +45,6 @@
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"

View File

@ -239,6 +239,11 @@ void cutlass_moe_mm(
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_blockwise_scaled_grouped_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,

View File

@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
static constexpr int AlignmentCD =
128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<

View File

@ -0,0 +1,367 @@
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using namespace cute;
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
__global__ void get_ggemm_starts(
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int,
ElementAccumulator* a_scale_base_as_int,
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
int expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
int m = problem_sizes[expert_id * 3];
int n = problem_sizes[expert_id * 3 + 1];
int k = problem_sizes[expert_id * 3 + 2];
int32_t expert_offset = expert_offsets[expert_id];
int a_stride = expert_offset * k;
int b_stride = expert_id * k * n;
int a_scale_stride = expert_offset * k / 128;
int b_scale_stride = expert_id * k * n / 128 / 128;
a_offsets[expert_id] = a_base_as_int + a_stride;
b_offsets[expert_id] = b_base_as_int + b_stride;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr =
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
*layout_sfb_ptr =
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<int*>(problem_sizes.data_ptr())); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_ggemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Unsupported output tensor type");
}
}
template <typename OutType, typename ScheduleConfig, typename LayoutD>
void run_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Types
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
// Alignments
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA,
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
AlignmentA, ElementB,
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = (int)expert_offsets.size(0);
Gemm gemm_op;
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementA**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(stride_a.data_ptr()),
static_cast<const ElementB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(stride_b.data_ptr()),
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
layout_sfa.data_ptr()),
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
layout_sfb.data_ptr())};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_ptrs.get_device();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(stride_c.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())};
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void blockwise_scaled_group_mm_dispatch_shape(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
struct MmaConfig {
using ElementA = cutlass::float_e4m3_t;
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using LayoutC = cutlass::layout::RowMajor;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
};
int num_experts = (int)expert_offsets.size(0);
auto a_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto b_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto out_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto a_scales_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto b_scales_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto layout_sfa = torch::empty(
{num_experts, 5},
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
auto layout_sfb = torch::empty(
{num_experts, 5},
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
auto stride_a = torch::full(
{num_experts}, a.size(1),
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto stride_b = torch::full(
{num_experts}, a.size(1),
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto stride_c = torch::full(
{num_experts}, output.size(1),
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
torch::TensorOptions options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
typename MmaConfig::LayoutSFB,
typename MmaConfig::ScaleConfig>(
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
run_blockwise_scaled_group_mm<OutType, MmaConfig,
typename MmaConfig::LayoutC>(
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
expert_offsets);
}
void cutlass_blockwise_scaled_grouped_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
"a must be kFloat8_e4m3fn");
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
"b must be kFloat8_e4m3fn");
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
output.scalar_type() == torch::kHalf,
"output must be bfloat16 or half");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
"scales_a must be float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
"scales_b must be float32");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
if (output.scalar_type() == torch::kBFloat16) {
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
} else if (output.scalar_type() == torch::kFloat16) {
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
} else {
TORCH_CHECK(false, "Unsupported output tensor type");
}
#endif
}

View File

@ -38,7 +38,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"

View File

@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm {
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
static constexpr int AlignmentCD =
128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<

View File

@ -393,6 +393,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass blockwise scaledgroup GEMM
ops.def(
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
"Tensor scales_a, Tensor scales_b, "
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
{stride_tag});
ops.impl("cutlass_blockwise_scaled_grouped_mm", torch::kCUDA,
&cutlass_blockwise_scaled_grouped_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"

View File

@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# DeepGEMM Style Cutlass Grouped GEMM Test
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
import random
import pytest
import torch
from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
def cdiv(a, b):
return (a + b - 1) // b
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view *
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128),
device=x.device,
dtype=x.dtype)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
])
@pytest.mark.parametrize("out_dtype", [torch.float16])
@pytest.mark.skipif(
(lambda x: x is None or x.to_int() != 100)(
current_platform.get_device_capability()),
reason="Block Scaled Grouped GEMM is only supported on SM100.")
def test_cutlass_grouped_gemm(
num_groups: int,
expected_m_per_group: int,
k: int,
n: int,
out_dtype: torch.dtype,
):
device = "cuda"
alignment = 128
group_ms = [
int(expected_m_per_group * random.uniform(0.7, 1.3))
for _ in range(num_groups)
]
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
x = torch.randn((m, k), device=device, dtype=out_dtype)
y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype)
out = torch.empty((m, n), device=device, dtype=out_dtype)
ref_out = torch.randn((m, n), device=device, dtype=out_dtype)
ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m]
pb_size = []
for i in range(num_groups):
pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k])
problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32)
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, cdiv(n, 128), k // 128),
device=device,
dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]]
b = y_fp8[0][i].t()
b_scale = y_fp8[1][i].t()
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline
ops.cutlass_blockwise_scaled_grouped_mm(
out,
x_fp8[0],
y_fp8[0],
x_fp8[1],
y_fp8[1],
problem_sizes,
expert_offsets[:-1],
)
torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3)

View File

@ -646,6 +646,20 @@ def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)
def cutlass_blockwise_scaled_grouped_mm(
output: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
scales_a: torch.Tensor,
scales_b: torch.Tensor,
problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor,
):
torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a,
scales_b, problem_sizes,
expert_offsets)
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor, alpha: torch.Tensor,

View File

@ -7,12 +7,17 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
def run_cutlass_moe_fp8(
output: torch.Tensor,
@ -508,3 +513,130 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
return out.to(dtype=out_dtype)
def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int):
return M >= 128 and N % 128 == 0 and K % 128 == 0
m = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K):
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
return False
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).")
return False
return True
def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
w1_q = w1.transpose(1, 2)
w2_q = w2.transpose(1, 2)
w1_scale = w1_scale.transpose(1, 2)
w2_scale = w2_scale.transpose(1, 2)
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert a.shape[0] == topk_ids.shape[
0], "a and topk_ids must have the same batch size"
assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn"
assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn"
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1_scale expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2_scale expert number mismatch"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
out_dtype = a.dtype
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device="cuda")
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(a,
A_scale=None,
per_act_token=False,
block_shape=[128, 128])
device = a_q.device
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
ops.get_cutlass_moe_mm_data(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map]
c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device)
c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device)
ops.cutlass_blockwise_scaled_grouped_mm(
c1,
rep_a_q,
w1_q,
rep_a1_scales,
w1_scale,
problem_sizes1,
expert_offsets[:-1],
)
intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device)
torch.ops._C.silu_and_mul(intermediate, c1)
intermediate_q, a2_scale = _fp8_quantize(intermediate,
A_scale=None,
per_act_token=False,
block_shape=[128, 128])
ops.cutlass_blockwise_scaled_grouped_mm(
c2,
intermediate_q,
w2_q,
a2_scale,
w2_scale,
problem_sizes2,
expert_offsets[:-1],
)
return (c2[c_map].view(m, topk, k) *
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)

View File

@ -15,6 +15,9 @@ from vllm.logger import init_logger
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
_valid_cutlass_block_scaled_grouped_gemm,
run_cutlass_block_scaled_fused_experts)
# yapf: enable
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
@ -1129,29 +1132,31 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.size(1)
@ -1174,6 +1179,17 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
assert apply_router_weight_on_input is False
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids)
else:
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,

View File

@ -473,12 +473,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
# Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,