[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-10 12:21:52 -05:00
committed by GitHub
parent d0e186c16f
commit b039bfda8f
3 changed files with 16 additions and 7 deletions

View File

@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128;
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
using Idx_t = int64_t;
@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
Idx_t stride_counts_e = tokens_per_expert.stride(0);
static constexpr int GROUP_SIZE = 128;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
int const NUM_GROUPS = H / GROUP_SIZE;
if (!use_ue8m0) {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}

View File

@ -25,6 +25,7 @@ CASES = [
(8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(8, 64, 7168, fp8_dtype),
(8, 128, 128 * 33, fp8_dtype),
(8, 128, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype),
@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
)
# Run the SiLU V2 kernel
# TODO (varun): use_e8m0 is set to false as the reference impl does
# not handle that case.
y_q, y_s = persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, group_size=group_size
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
)
torch.cuda.synchronize()

View File

@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
device=y.device,
)
use_ue8m0 = is_deep_gemm_e8m0_used()
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index