diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b1daeeed8..0129f85123 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,7 +259,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -615,6 +615,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "in CUDA target architectures.") endif() endif() + + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() # # Machete kernels diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index d922a3349e..ce7f47cf72 100644 --- a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -45,7 +45,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass_extensions/gemm/dispatch_policy.hpp" diff --git a/csrc/ops.h b/csrc/ops.h index 52c264d64c..56e51cc659 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -239,6 +239,11 @@ void cutlass_moe_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch); +void cutlass_blockwise_scaled_grouped_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets); + void cutlass_fp4_group_mm( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index 2387ec57e8..2d67da9876 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -51,7 +51,8 @@ struct cutlass_3x_gemm { // These are the minimum alignments needed for the kernels to compile static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentCD = 4; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu new file mode 100644 index 0000000000..ef57e503b2 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu @@ -0,0 +1,367 @@ +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include + +using namespace cute; + +template +__global__ void get_ggemm_starts( + int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scale_offsets, + ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementAccumulator* a_scale_base_as_int, + ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) { + int expert_id = threadIdx.x; + + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + + int m = problem_sizes[expert_id * 3]; + int n = problem_sizes[expert_id * 3 + 1]; + int k = problem_sizes[expert_id * 3 + 2]; + + int32_t expert_offset = expert_offsets[expert_id]; + int a_stride = expert_offset * k; + int b_stride = expert_id * k * n; + int a_scale_stride = expert_offset * k / 128; + int b_scale_stride = expert_id * k * n / 128 / 128; + + a_offsets[expert_id] = a_base_as_int + a_stride; + b_offsets[expert_id] = b_base_as_int + b_stride; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride; + b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = + ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + *layout_sfb_ptr = + ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \ + ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_ggemm_starts<<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(problem_sizes.data_ptr())); \ + } + +template +void run_get_ggemm_starts( + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& layout_sfa, + torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0); + TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0); + + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, + LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA, + LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Unsupported output tensor type"); + } +} + +template +void run_blockwise_scaled_group_mm( + torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a, + const torch::Tensor& stride_b, const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + // Types + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = LayoutD; + + // Alignments + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*, + AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, + cute::tuple, + AlignmentA, ElementB, + cute::tuple, + AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = (int)expert_offsets.size(0); + + Gemm gemm_op; + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(stride_a.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(stride_b.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast( + layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast( + layout_sfb.data_ptr())}; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a_ptrs.get_device(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(stride_c.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(stride_c.data_ptr())}; + + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info}; + + at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void blockwise_scaled_group_mm_dispatch_shape( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { + struct MmaConfig { + using ElementA = cutlass::float_e4m3_t; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + 1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + using LayoutC = cutlass::layout::RowMajor; + using MmaTileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + }; + + int num_experts = (int)expert_offsets.size(0); + + auto a_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto b_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto out_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto a_scales_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto b_scales_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + + auto layout_sfa = torch::empty( + {num_experts, 5}, + torch::TensorOptions().dtype(torch::kInt32).device(a.device())); + auto layout_sfb = torch::empty( + {num_experts, 5}, + torch::TensorOptions().dtype(torch::kInt32).device(a.device())); + + auto stride_a = torch::full( + {num_experts}, a.size(1), + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto stride_b = torch::full( + {num_experts}, a.size(1), + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto stride_c = torch::full( + {num_experts}, output.size(1), + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + + torch::TensorOptions options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + run_get_ggemm_starts( + expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a, + b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes); + + run_blockwise_scaled_group_mm( + out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a, + stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes, + expert_offsets); +} + +void cutlass_blockwise_scaled_grouped_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, + "a must be kFloat8_e4m3fn"); + TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, + "b must be kFloat8_e4m3fn"); + TORCH_CHECK(output.scalar_type() == torch::kBFloat16 || + output.scalar_type() == torch::kHalf, + "output must be bfloat16 or half"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, + "scales_a must be float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, + "scales_b must be float32"); + TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, + "expert_offsets must be int32"); + + TORCH_CHECK(output.dim() == 2, "output must be 2D tensor"); + TORCH_CHECK(a.dim() == 2, "a must be 2D tensor"); + TORCH_CHECK(b.dim() == 3, "b must be 3D tensor"); + TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor"); + TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor"); + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); + +#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100 + if (output.scalar_type() == torch::kBFloat16) { + blockwise_scaled_group_mm_dispatch_shape( + output, a, b, scales_a, scales_b, problem_sizes, expert_offsets); + } else if (output.scalar_type() == torch::kFloat16) { + blockwise_scaled_group_mm_dispatch_shape( + output, a, b, scales_a, scales_b, problem_sizes, expert_offsets); + } else { + TORCH_CHECK(false, "Unsupported output tensor type"); + } +#endif +} \ No newline at end of file diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index eca5d328b0..2f52a6b7a0 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -38,7 +38,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/atom/copy_traits_sm90_tma.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index c22523da4e..637bba1384 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm { // These are the minimum alignments needed for the kernels to compile static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentCD = 4; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8bb71cad29..04329e75db 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -393,6 +393,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // cutlass blockwise scaledgroup GEMM + ops.def( + "cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, " + "Tensor scales_a, Tensor scales_b, " + "Tensor problem_sizes, Tensor expert_offsets) -> ()", + {stride_tag}); + ops.impl("cutlass_blockwise_scaled_grouped_mm", torch::kCUDA, + &cutlass_blockwise_scaled_grouped_mm); + // cutlass nvfp4 block scaled group GEMM ops.def( "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py new file mode 100644 index 0000000000..bf228dcece --- /dev/null +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepGEMM Style Cutlass Grouped GEMM Test +# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py + +import random + +import pytest +import torch + +from tests.kernels.utils import baseline_scaled_mm +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + + +def cdiv(a, b): + return (a + b - 1) // b + + +def per_token_cast_to_fp8( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * + (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128), + device=x.device, + dtype=x.dtype) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ + (4, 8192, 7168, 4096), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + (32, 1024, 7168, 4096), + (32, 1024, 2048, 7168), +]) +@pytest.mark.parametrize("out_dtype", [torch.float16]) +@pytest.mark.skipif( + (lambda x: x is None or x.to_int() != 100)( + current_platform.get_device_capability()), + reason="Block Scaled Grouped GEMM is only supported on SM100.") +def test_cutlass_grouped_gemm( + num_groups: int, + expected_m_per_group: int, + k: int, + n: int, + out_dtype: torch.dtype, +): + device = "cuda" + alignment = 128 + group_ms = [ + int(expected_m_per_group * random.uniform(0.7, 1.3)) + for _ in range(num_groups) + ] + m = sum([cdiv(m, alignment) * alignment for m in group_ms]) + + x = torch.randn((m, k), device=device, dtype=out_dtype) + y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype) + out = torch.empty((m, n), device=device, dtype=out_dtype) + ref_out = torch.randn((m, n), device=device, dtype=out_dtype) + + ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m] + pb_size = [] + for i in range(num_groups): + pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k]) + problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32) + expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) + + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, cdiv(n, 128), k // 128), + device=device, + dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + for i in range(num_groups): + a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] + a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]] + b = y_fp8[0][i].t() + b_scale = y_fp8[1][i].t() + baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) + ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline + + ops.cutlass_blockwise_scaled_grouped_mm( + out, + x_fp8[0], + y_fp8[0], + x_fp8[1], + y_fp8[1], + problem_sizes, + expert_offsets[:-1], + ) + + torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6b1b3f787c..eb9d0b4058 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -646,6 +646,20 @@ def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) +def cutlass_blockwise_scaled_grouped_mm( + output: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, +): + torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a, + scales_b, problem_sizes, + expert_offsets) + + def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, alpha: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d889f740a0..431fb290b2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,12 +7,17 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, + _fp8_quantize, + _resize_cache) from vllm.scalar_type import scalar_types +logger = init_logger(__name__) + def run_cutlass_moe_fp8( output: torch.Tensor, @@ -508,3 +513,130 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, out = (c2.view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()).sum(dim=1) return out.to(dtype=out_dtype) + + +def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor) -> bool: + + def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): + return M >= 128 and N % 128 == 0 and K % 128 == 0 + + m = hidden_states.size(0) + _, K, N = w2.size() + if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") + return False + + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") + return False + + return True + + +def run_cutlass_block_scaled_fused_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + w1_q = w1.transpose(1, 2) + w2_q = w2.transpose(1, 2) + w1_scale = w1_scale.transpose(1, 2) + w2_scale = w2_scale.transpose(1, 2) + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert a.shape[0] == topk_ids.shape[ + 0], "a and topk_ids must have the same batch size" + assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" + assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2_scale expert number mismatch" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + out_dtype = a.dtype + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device="cuda") + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + + topk = topk_ids.size(1) + + a_q, a1_scale = _fp8_quantize(a, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) + device = a_q.device + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] + + c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device) + c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device) + + ops.cutlass_blockwise_scaled_grouped_mm( + c1, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + problem_sizes1, + expert_offsets[:-1], + ) + + intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) + torch.ops._C.silu_and_mul(intermediate, c1) + + intermediate_q, a2_scale = _fp8_quantize(intermediate, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) + + ops.cutlass_blockwise_scaled_grouped_mm( + c2, + intermediate_q, + w2_q, + a2_scale, + w2_scale, + problem_sizes2, + expert_offsets[:-1], + ) + + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 041819bb7b..fbbccbb34d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -15,6 +15,9 @@ from vllm.logger import init_logger # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, get_config_quant_dtype) +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + _valid_cutlass_block_scaled_grouped_gemm, + run_cutlass_block_scaled_fused_experts) # yapf: enable from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) @@ -1129,29 +1132,31 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. N = w1.size(1) @@ -1174,6 +1179,17 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 + and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + assert apply_router_weight_on_input is False + return run_cutlass_block_scaled_fused_experts( + a=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a7d221780b..5a1a427d7d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -473,12 +473,30 @@ class Fp8MoEMethod(FusedMoEMethodBase): logger.warning_once( "DeepGemm not supported on the current platform.") + # Check for CutlassBlockScaledGroupedGemm support. + self.allow_cutlass_block_scaled_grouped_gemm = False + if not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + "CutlassBlockScaledGroupedGemm kernels") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(100)): + logger.info_once( + "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." + ) + self.allow_cutlass_block_scaled_grouped_gemm = True + else: + logger.warning_once( + "CutlassBlockScaledGroupedGemm not supported on the current " + "platform.") + self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore fused_experts, use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm) + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm)) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int,