mirror of
https://github.com/vllm-project/vllm.git
synced 2025-11-12 00:54:47 +08:00
[ROCm] Add missing gemm_a8w8_blockscale import (#28378)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@ -316,38 +316,39 @@ class W8A8BlockFp8LinearOp:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
if (
|
||||
use_triton = (
|
||||
not current_platform.is_fp8_fnuz()
|
||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||
):
|
||||
)
|
||||
|
||||
if use_triton:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||
else:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
|
||||
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
# MI350 case uses triton kernel
|
||||
elif use_triton:
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
self.act_quant_group_shape.col,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
||||
return rocm_aiter_ops.gemm_w8a8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
return gemm_a8w8_blockscale_op(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
output_dtype=input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user