[Refactor] Remove DeepGEMM OP Register (#25710)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-09-25 20:13:41 -04:00
committed by GitHub
parent 081b5594a2
commit 9fe4c2bdb9
2 changed files with 5 additions and 90 deletions

View File

@ -1,78 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import torch
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt
logger = logging.getLogger(__name__)
def prepare_block_fp8_matmul_inputs(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> tuple[int, int, int, torch.Tensor]:
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
return M, N, K, C
def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16
fp8_gemm_nt((A, As), (B, Bs), C)
return C
def w8a8_block_fp8_matmul_deepgemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
output_dtype)
return C
direct_register_custom_op(
op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_block_fp8_matmul_deepgemm,
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
)

View File

@ -23,7 +23,7 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@ -141,17 +141,10 @@ def apply_w8a8_block_fp8_linear(
block_size[1],
column_major_scales=True,
)
# ensure DeepGEMM-backed custom op is registered before use
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=output_dtype)
output = torch.empty((q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device)
fp8_gemm_nt((q_input, x_scale), (weight, weight_scale), output)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)