[Kernel] Support deep_gemm for linear methods (#19085)

Signed-off-by: artetaout <lulala341@gmail.com>
This commit is contained in:
artetaout
2025-06-11 15:14:45 +08:00
committed by GitHub
parent 5039ec2336
commit b8e809a057
3 changed files with 124 additions and 1 deletions

View File

@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
import logging
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
import deep_gemm
logger = logging.getLogger(__name__)
def prepare_block_fp8_matmul_inputs(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> tuple[int, int, int, torch.Tensor]:
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
return M, N, K, C
def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
return C
def w8a8_block_fp8_matmul_deepgemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
output_dtype)
return C
direct_register_custom_op(
op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_block_fp8_matmul_deepgemm,
mutates_args=[],
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -402,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase):
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,

View File

@ -3,12 +3,14 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import importlib.util
import json
import os
from typing import Any, Callable, Optional, Union
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -20,6 +22,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
return w8a8_block_fp8_matmul
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""
Check if DeepGEMM should be used based on the output dtype and weight shape.
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
divisible by 128.
"""
return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm(output_dtype, weight):
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
)
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=output_dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
if current_platform.is_cuda():
if current_platform.has_device_capability(100):
@ -134,7 +173,6 @@ def apply_w8a8_block_fp8_linear(
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)