[ROCm] Add missing gemm_a8w8_blockscale import (#28378)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-11-10 13:13:36 -10:00
committed by GitHub
parent 30700b1cd7
commit 021143561f

View File

@ -316,38 +316,39 @@ class W8A8BlockFp8LinearOp:
assert self.act_quant_group_shape == GroupShape(1, 128)
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
use_triton = (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
)
if use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
elif use_triton:
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
return gemm_a8w8_blockscale_op(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
output_dtype=input_2d.dtype,
)
def _run_triton(
self,