Compare commits

...

3 Commits

Author SHA1 Message Date
ff1d7b5a79 Update
[ghstack-poisoned]
2025-11-14 11:17:20 -08:00
babc4d67dc Update
[ghstack-poisoned]
2025-11-14 11:08:00 -08:00
8f83f099d4 Update (base update)
[ghstack-poisoned]
2025-11-14 11:08:00 -08:00
2 changed files with 27 additions and 0 deletions

View File

@ -1096,6 +1096,18 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
void
_check_mxfp4_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
// Only on B200 GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
dprops->major == 10 && dprops->minor == 0,
"MXFP4 scaling only supported in CUDA for B200"
);
#endif
}
Tensor&
_scaled_mxfp4_mxfp4(
@ -1108,6 +1120,7 @@ _scaled_mxfp4_mxfp4(
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
_check_mxfp4_support();
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",

View File

@ -14,6 +14,7 @@ import torch
from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
from torch.testing._internal.common_cuda import (
IS_SM90,
IS_SM100,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
@ -53,6 +54,7 @@ from torch.testing._internal.common_quantized import (
_IS_SM8X = False
if TEST_CUDA:
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
@ -736,6 +738,10 @@ class TestFP8Matmul(TestCase):
@parametrize("format", ["mxfp8"] + (["nvfp4", "mxfp4"] if torch.version.cuda else []))
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
torch.manual_seed(42)
if format == "mxfp4" and not IS_SM100:
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200")
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
input_group_end_offsets = generate_jagged_offs(
G, total_K, multiple_of=32, device="cuda"
@ -799,6 +805,10 @@ class TestFP8Matmul(TestCase):
@parametrize("format", ["mxfp8"] + (["nvfp4", "mxfp4"] if torch.version.cuda else []))
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
torch.manual_seed(42)
if format == "mxfp4" and not IS_SM100:
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200")
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
# 2D inputs with groups along M, 3D weights.
block_size = 32
@ -1870,6 +1880,8 @@ class TestFP8Matmul(TestCase):
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
if recipe == "mxfp4" and not IS_SM100:
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200")
device = "cuda"
M, K, N = mkn
@ -2090,6 +2102,8 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
if recipe == "mxfp4" and not IS_SM100:
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200")
M, K, N = (1024, 512, 2048)
BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32
BLOCK_SIZE_MN = 128