[Hardware/NVIDIA/Kernel] Enable nvidia/DeepSeek-R1-FP4 Model (#16362)

This commit is contained in:
Pavani Majety
2025-05-09 16:24:41 -07:00
committed by GitHub
parent 3b602cdea7
commit 0c0fdae84f
16 changed files with 1994 additions and 112 deletions

View File

@ -288,6 +288,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu")
@ -495,7 +496,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
@ -533,7 +536,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")

View File

@ -0,0 +1,408 @@
# SPDX-License-Identifier: Apache-2.0
"""
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import nvtx
import torch
import torch.utils.benchmark as benchmark
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.scalar_type import scalar_types
from vllm.utils import FlexibleArgumentParser
WEIGHT_SHAPES_MOE = {
"nvidia/DeepSeek-R1-FP4": [
[256, 8, 2048, 7168],
],
}
DEFAULT_MODELS = [
"nvidia/DeepSeek-R1-FP4",
]
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
DEFAULT_TP_SIZES = [1]
PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def bench_run(results: list[benchmark.Measurement], model: str,
num_experts: int, topk: int, per_act_token: bool,
per_out_ch: bool, mkn: tuple[int, int, int]):
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label = (
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
mkn))
print(f"Testing: {sub_label}")
(m, k, n) = mkn
dtype = torch.half
device = "cuda"
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
_, a_fp8_scale = ops.scaled_fp8_quant(a)
w1_fp8q = torch.empty((num_experts, 2 * n, k),
device=device,
dtype=torch.float8_e4m3fn)
w2_fp8q = torch.empty((num_experts, k, n),
device=device,
dtype=torch.float8_e4m3fn)
w1_fp8scale = torch.empty((num_experts, 1, 1),
device=device,
dtype=torch.float32)
w2_fp8scale = torch.empty((num_experts, 1, 1),
device=device,
dtype=torch.float32)
for expert in range(num_experts):
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_fp8q_notransp = w1_fp8q.clone()
w2_fp8q_notransp = w2_fp8q.clone()
w1_fp8q = w1_fp8q.transpose(1, 2)
w2_fp8q = w2_fp8q.transpose(1, 2)
score = torch.randn((m, num_experts), device=device, dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
quant_blocksize = 16
w1_blockscale = torch.empty((num_experts, 2 * n, k // quant_blocksize),
device=device,
dtype=torch.float8_e4m3fn)
w2_blockscale = torch.empty((num_experts, k, n // quant_blocksize),
device=device,
dtype=torch.float8_e4m3fn)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2),
device=device,
dtype=torch.uint8)
w2_fp4 = torch.empty((num_experts, k, n // 2),
device=device,
dtype=torch.uint8)
w1_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32)
w2_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32)
a1_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32)
a2_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32)
for expert in range(num_experts):
w1_e = w1[expert]
w2_e = w2[expert]
w1_amax = torch.abs(w1_e).max().to(torch.float32)
w2_amax = torch.abs(w2_e).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1_e, w1_gs[expert])
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2_e, w2_gs[expert])
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor, num_repeats: int):
for _ in range(num_repeats):
fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale)
def run_cutlass_moe_fp4(a: torch.Tensor, w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor, w1_gs: torch.Tensor,
w2_gs: torch.Tensor, a1_gs: torch.Tensor,
a2_gs: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, m: int, n: int, k: int,
e: int, device: torch.device, num_repeats: int):
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device)
def run_cutlass_from_graph(
a: torch.Tensor, a1_gscale: torch.Tensor, w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor, w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor, w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int,
k: int, e: int, device: torch.device):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe_fp4(a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_alphas,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device)
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale)
def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
graph.replay()
torch.cuda.synchronize()
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
run_cutlass_from_graph(a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device)
torch.cuda.synchronize()
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_fp8q_notransp, w2_fp8q_notransp,
topk_weights, topk_ids, w1_fp8scale, w2_fp8scale,
a_fp8_scale)
torch.cuda.synchronize()
min_run_time = 5
num_warmup = 5
num_runs = 25
globals = {
# Baseline params
"w1": w1,
"w2": w2,
"score": score,
"topk": topk,
"w1_fp8q_notransp": w1_fp8q_notransp,
"w2_fp8q_notransp": w2_fp8q_notransp,
"w1_fp8scale": w1_fp8scale,
"w2_fp8scale": w2_fp8scale,
"a_fp8_scale": a_fp8_scale,
# Cutlass params
"a": a,
"a1_gscale": a1_gs,
"w1_fp4": w1_fp4,
"w1_blockscale": w1_blockscale,
"w1_alphas": w1_gs,
"a2_gscale": a2_gs,
"w2_fp4": w2_fp4,
"w2_blockscale": w2_blockscale,
"w2_alphas": w2_gs,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"m": m,
"n": n,
"k": k,
"e": num_experts,
"device": device,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
# Gen params
"num_runs": num_runs,
# Kernels
"run_triton_moe": run_triton_moe,
"run_cutlass_moe_fp4": run_cutlass_moe_fp4,
"replay_graph": replay_graph,
}
# Warmup
run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights,
topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_warmup)
results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))
# Warmup
replay_graph(triton_graph, num_warmup)
results.append(
benchmark.Timer(
stmt="replay_graph(triton_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
# Warmup
run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_gs,
w2_gs, a1_gs, a2_gs, topk_weights, topk_ids, m, n, k,
num_experts, device, num_warmup)
results.append(
benchmark.Timer(
stmt=
"run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="cutlass_moe_fp4",
).blocked_autorange(min_run_time=min_run_time))
# Warmup
replay_graph(cutlass_graph, num_warmup)
results.append(
benchmark.Timer(
stmt="replay_graph(cutlass_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="cutlass_moe_fp4_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results: list[benchmark.Measurement] = []
for model in args.models:
for tp in args.tp_sizes:
for layer in WEIGHT_SHAPES_MOE[model]:
num_experts = layer[0]
topk = layer[1]
size_k = layer[2]
size_n = layer[3] // tp
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for per_act_token in PER_ACT_TOKEN_OPTS:
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in args.batch_sizes:
mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk,
per_act_token, per_out_ch, mkn)
compare = benchmark.Compare(results)
compare.print()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark NVFP4 CUTLASS MOE across specified "
"models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token",
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)

View File

@ -208,6 +208,12 @@ void cutlass_moe_mm(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@ -235,6 +241,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale,
torch::Tensor const& input_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@ -37,12 +37,6 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
@ -53,6 +47,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& bias);
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@ -224,7 +227,8 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k);

View File

@ -0,0 +1,402 @@
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using namespace cute;
template <typename ElementAB, typename ElementC, typename ElementSF,
typename ElementAccumulator, typename LayoutSFA, typename LayoutSFB,
typename ScaleConfig>
__global__ void __get_group_gemm_starts(
ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets,
ElementSF** a_scales_offsets, ElementSF** b_scales_offsets,
ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int,
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
const int K, const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
// size for block in block scale.
int64_t group_size = 16;
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
assert((m >= 0 && n == N && k == K && k % 2 == 0) &&
"unexpected problem sizes");
int64_t half_k = static_cast<int64_t>(k / 2);
int64_t group_k = static_cast<int64_t>(k / group_size);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
// Shape of C = [M, N]
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using EpilogueOperatorClass =
cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass =
cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size
using ClusterShape = Shape<_1, _1, _1>;
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule = cutlass::gemm::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
};
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape,
ClusterShape, Shape<_128, _64>, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
LayoutC*, AlignmentD,
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA,
ElementB, LayoutB*, AlignmentB, ElementAccumulator,
typename MMA1SMConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm100GroupParams<
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides1.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides1.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides1.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides1.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array =
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
} else {
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above.");
#endif
}

View File

@ -0,0 +1,404 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts.
int rowIdx_in_expert = 0;
int expert_idx = 0;
for (int i = 0; i < n_experts; i++) {
if (rowIdx >= input_offset_by_experts[i] &&
rowIdx < input_offset_by_experts[i + 1]) {
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
expert_idx = i;
break;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
}
#endif
}
template <typename T>
void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
// TODO: this multiProcessorCount should be cached.
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
device);
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts,
"input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts,
"output_scale_offset_by_experts must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k = input.size(1);
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(output.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k,
n_experts, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk,
k, n_experts, stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}

View File

@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch::Tensor const& input_sf);
#endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
}

View File

@ -363,6 +363,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
{stride_tag});
ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
@ -492,6 +500,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_q = torch.empty((e, 2 * n, k // 2),
device="cuda",
dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
w1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=e,
device=a.device,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
torch.testing.assert_close(torch_output,
cutlass_output,
atol=1e-1,
rtol=1e-1)
if __name__ == "__main__":
test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)

View File

@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.scalar_type import scalar_types
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloatArray = [
0.,
0.5,
1.,
1.5,
2.,
3.,
4.,
6.,
]
def e2m1_to_fp32(int4_value):
signBit = (int4_value & 0x8)
int4_absValue = int4_value & 0x7
float_result = kE2M1ToFloatArray[int4_absValue]
if (signBit):
float_result = -float_result
return float_result
def break_fp4_bytes(a, dtype):
assert (a.dtype == torch.uint8)
m, n = a.shape
a = a.flatten()
# Get upper 4 bits
highHalfByte = (a & 0xF0) >> 4
# Get lower 4 bits
lowHalfByte = a & 0x0F
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
return out
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
sf_m, sf_k = a_sf_swizzled.shape
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
m, n, dtype, block_size, device):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert (m_k == n_k)
a_in_dtype = dequantize_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
return torch.matmul(a_in_dtype, b_in_dtype.t())

View File

@ -745,10 +745,11 @@ def get_cutlass_moe_mm_data(
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation, output_permutation,
num_experts, n, k)
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation,
output_permutation,
num_experts, n, k)
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
@ -767,9 +768,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
MMs used in the fused MoE operation.
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
"""
torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales,
b_scales, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides)
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides,
c_strides)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
alphas: torch.Tensor, problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
out_dtype: torch.dtype, device: torch.device):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_tensors.shape[0]
n = b_tensors.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
b_scales, alphas, problem_sizes,
expert_offsets, sf_offsets)
return c.to(out_dtype)
# aqlm
@ -960,6 +993,57 @@ def scaled_fp4_quant(
return output, output_scale
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
MAX_TOKENS_PER_EXPERT: int = 163840,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2, (
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')
input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}")
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(m_numtopk,
k // 2,
device=input_tensor.device,
dtype=torch.uint8)
output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device)
torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor,
input_global_scale, expert_offsets,
blockscale_offsets)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# fp8
def scaled_fp8_quant(
input: torch.Tensor,

View File

@ -36,7 +36,7 @@ if HAS_TRITON:
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
@ -48,4 +48,5 @@ if HAS_TRITON:
"get_config_file_name",
"grouped_topk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
]

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
""" CUTLASS based Fused MoE kernels."""
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.scalar_type import scalar_types
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
@ -178,3 +179,126 @@ def cutlass_moe_fp8(
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
MAX_TOKENS_PER_EXPERT = 65536
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
device: torch.device):
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3
and w2_blockscale.ndim
== 3), ("All Weights must be of rank 3 for cutlass_moe_fp4")
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.")
assert (k_a // 2 == half_k_w1
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
"expected `n`")
assert (m == m_a), "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
== m), ("topk must be provided for each row of a")
assert (m <= MAX_TOKENS_PER_EXPERT), (
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m = {m}")
out_dtype = a.dtype
num_topk = topk_ids.shape[1]
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, e, n, k)
tokens_per_expert = problem_sizes1[:, 0]
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a,
a1_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
expert_map=a_map,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
expert_offsets[:-1], blockscale_offsets[:-1],
out_dtype, device)
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
device=device,
dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
intermediate,
a2_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device)
del int_fp4, int_blockscale
out = (c2[c_map].view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
return out.to(dtype=out_dtype)

View File

@ -643,7 +643,7 @@ class FusedMoE(torch.nn.Module):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
@ -697,8 +697,9 @@ class FusedMoE(torch.nn.Module):
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
if ("compressed" in quant_method_name.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
@ -718,6 +719,22 @@ class FusedMoE(torch.nn.Module):
tp_rank=self.tp_rank)
return
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return
# Case weight scales, zero_points and offset
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.nn import Module
@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -210,25 +212,37 @@ class ModelOptNvFp4Config(QuantizationConfig):
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
if ("group_size" and "kv_cache_quant_algo"
and "exclude_modules") not in quant_config:
raise ValueError("NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json")
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: List):
import re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules):
if (is_layer_skipped(prefix, self.exclude_modules)
or self.is_layer_excluded(prefix, self.exclude_modules)):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self)
return None
@ -409,3 +423,235 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
if bias is not None:
out = out + bias
return out.view(*output_shape)
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
Args:
quant_config: NVFP4 Quant Config
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w13_weight", w13_weight)
# GEMM 2
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=weight_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w2_weight", w2_weight)
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition //
self.quant_config.group_size,
dtype=weight_scale_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
w13_input_scale = PerTensorScaleParameter(data=torch.empty(
num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(data=torch.empty(
num_experts, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
requires_grad=False)
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False)
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = self.swizzle_blockscale(
layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
requires_grad=False)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False)
# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE.")
assert expert_map is None, ("Expert Parallelism /expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device).to(x.dtype)