MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)

## Summary
- We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816
- This is needed for backward pass of mxfp8 MoE training with grouped gemms
- Changes:
    - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm`
    - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile
    - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs
    - Bump FBGEMM third party submodule to include:
          - https://github.com/pytorch/FBGEMM/pull/4816
          - https://github.com/pytorch/FBGEMM/pull/4820
          - https://github.com/pytorch/FBGEMM/pull/4821
          - https://github.com/pytorch/FBGEMM/pull/4823

#### How fbgemm dependency was bumped
Documenting this since I haven't found it documented elsewhere:
- `cd ~/pytorch/third_party/fbgemm`
- `git fetch`
- `git checkout <hash>`
- `cd ~/pytorch`
- `git add third_party/fbgemm`

## Test plan

#### Test build
```
USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e .
...
Successfully installed torch-2.9.0a0+gitf5070f3
```
[full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581)

#### Unit tests
```
pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_
...

test/test_matmul_cuda.py .........                                                                                                                        [100%]

============================================================== 9 passed, 1668 deselected in 5.34s ===============================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209
Approved by: https://github.com/ngimel
This commit is contained in:
Daniel Vega-Myhre
2025-09-06 15:25:30 +00:00
committed by PyTorch MergeBot
parent 5985e28912
commit b6d0a9ea90
9 changed files with 531 additions and 70 deletions

View File

@ -889,6 +889,12 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
set(USE_FBGEMM_GENAI off)
endif()
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a")
message(WARNING "Setting USE_FBGEMM_GENAI to ON for CUDA build on SM100")
set(USE_FBGEMM_GENAI ON)
endif()
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't
cmake_dependent_option(

View File

@ -255,48 +255,77 @@ endif()
# FBGEMM GenAI
IF(USE_FBGEMM_GENAI)
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
list(FILTER fbgemm_genai_native_cuda_cu INCLUDE REGEX ${FBGEMM_CUTLASS_KERNELS_REGEX})
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
file(GLOB_RECURSE fbgemm_genai_native_cuda_cpp
"${FBGEMM_GENAI_SRCS}/common/*.cpp"
)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
# Combine all source files into a single list
list(APPEND fbgemm_genai_all_sources
${fbgemm_genai_native_cuda_cu}
${fbgemm_genai_native_cuda_cpp}
)
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
# Now, create the library and provide the sources at the same time
add_library(fbgemm_genai OBJECT ${fbgemm_genai_all_sources})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
set(fbgemm_genai_mx8mx8bf16_grouped
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_GENAI_DIR}/include/
${FBGEMM_GENAI_DIR}/common/include/
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${fbgemm_genai_mx8mx8bf16_grouped}
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
else()
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
endif()
endif()
endif()
@ -639,6 +668,13 @@ if(USE_CUDA AND NOT USE_ROCM)
add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
# Add FBGEMM_GENAI include directories for torch_ops.h
if(USE_FBGEMM_GENAI)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
if($ENV{ATEN_STATIC_CUDA})
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS

View File

@ -1551,7 +1551,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
namespace {
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
TORCH_CHECK(
scale.dim() == 1,
@ -1585,9 +1586,66 @@ namespace {
"scale must have the same first dimension as mat for arg ",
arg_idx);
}
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
// that are converted to blocked padded format individually,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
// so we can check the exact expected scale sizes here without a d2h sync.
auto round_up = [](auto x, auto y) {
return ((x + y - 1) / y) * y;
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
}
}
Tensor
@ -1612,8 +1670,8 @@ const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/false);
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+");
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
@ -1646,10 +1704,12 @@ bool use_fast_accum) {
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
}
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
// FP8 per-tensor and per-row scaling expect fp32 scales.
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK(
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"Both scale_a and scale_b must be float (fp32) tensors.");
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
@ -1660,6 +1720,32 @@ bool use_fast_accum) {
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
// MXFP8 grouped GEMM dispatching
bool is_mx8mx8bf16 = (
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
);
TORCH_CHECK(out_dtype == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
if (is_mx8mx8bf16) {
bool b_is_3d = mat_b.dim() == 3;
bool is_2d_2d = a_is_2d && b_is_2d;
bool is_2d_3d = a_is_2d && b_is_3d;
TORCH_CHECK(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
fbgemm_gpu::mx8mx8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out);
return out;
}
#endif
#ifndef USE_ROCM
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
@ -1691,6 +1777,7 @@ bool use_fast_accum) {
#else
TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
#endif
#endif
}

View File

@ -1638,6 +1638,10 @@ if(USE_CUDA)
# order of the libraries in the linker call matters here when statically
# linking; libculibos and cublas must be last.
target_link_libraries(torch_cuda PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
if(USE_FBGEMM_GENAI)
# Link fbgemm_genai to torch_cuda (only for (1) CUDA build for SM100).
target_link_libraries(torch_cuda PRIVATE fbgemm_genai)
endif()
endif()
# ---[ XPU library.
@ -1759,9 +1763,10 @@ if(USE_ROCM)
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
if(USE_FBGEMM_GENAI)
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
if(USE_ROCM)
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
endif()
endif()
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})

View File

@ -30,6 +30,7 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
PLATFORM_SUPPORTS_MX_GEMM,
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
IS_SM90,
)
from torch.testing._internal.common_device_type import (
@ -55,7 +56,13 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.common_quantized import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ceil_div, to_blocked
from torch.testing._internal.common_quantized import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
ceil_div, to_blocked,
to_mxfp8,
generate_jagged_offs,
)
_IS_SM8X = False
if TEST_CUDA:
@ -771,6 +778,7 @@ class TestMatmulCuda(TestCase):
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
mxfp8_grouped_mm_skip_msg = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+"
# avoid division by zero when calculating scale
EPS = 1e-12
@ -901,6 +909,8 @@ def to_fp8_saturated(
return x.to(fp8_dtype)
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the error between two tensors in dB.
@ -1045,6 +1055,167 @@ class TestFP8Matmul(TestCase):
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [2048, 2049])
@parametrize("N", [8192])
@parametrize("K", [16640])
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K):
torch.manual_seed(42)
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
input_group_end_offsets = generate_jagged_offs(
G, total_K, multiple_of=32, device="cuda"
)
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
# Convert scales to blocked format.
x_list = []
w_list = []
x_blocked_scale_list = []
w_blocked_scale_list = []
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
)
curr_group_end_offset = input_group_end_offsets[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
x_slice = X[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
w_slice = W[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (N, K_group)
x_scale_slice, xq_slice = to_mxfp8(
x_slice
) # scale shape -> (M, K_group // 32)
w_scale_slice, wq_slice = to_mxfp8(
w_slice
) # scale shape -> (N, K_group // 32)
x_list.append(xq_slice)
w_list.append(wq_slice)
# Convert scales to blocked format.
x_scale_slice_blocked = to_blocked(
x_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
w_scale_slice_blocked = to_blocked(
w_scale_slice
) # (round_up(N, 128), round_up(K_group//32, 4))
x_blocked_scale_list.append(x_scale_slice_blocked)
w_blocked_scale_list.append(w_scale_slice_blocked)
# Assemble the full XQ and WQ
xq = torch.cat(x_list, dim=1).contiguous()
wq = torch.cat(w_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
M_rounded = round_up(M, 128)
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
# Combine all WQ groups blocked scales into one tensor.
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
N_rounded = round_up(N, 128)
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
# Compute mxfp8 grouped mm output
y_mxfp8 = torch._scaled_grouped_mm(
xq, # (M, total_K)
wq.transpose(-2, -1), # (total_K, N)
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
offs=input_group_end_offsets, # (G,)
out_dtype=torch.bfloat16,
)
# bf16 reference output
y_bf16 = torch._grouped_mm(
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
)
# Assert no NaNs
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
# Assert outputs are close
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [16640])
@parametrize("N", [8192])
@parametrize("K", [4096])
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K):
torch.manual_seed(42)
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
# 2D inputs with groups along M, 3D weights.
block_size = 32
total_M = M # Alias for clarity that M dim contains groups.
X = torch.randn((total_M, K), dtype=torch.bfloat16, device="cuda") * 0.1
W = torch.randn((G, N, K), dtype=torch.bfloat16, device="cuda") * 0.01
input_group_end_offsets = generate_jagged_offs(
G, total_M, multiple_of=32, device="cuda"
)
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(G):
w_scale, wq = to_mxfp8(W[i])
w_scale = to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
# Compute mxfp8 grouped gemm.
y_mxfp8 = torch._scaled_grouped_mm(
xq,
wq.transpose(-2, -1),
x_scale,
w_scale,
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
)
# Compute reference bf16 grouped gemm.
y_bf16 = torch._grouped_mm(
X,
W.transpose(-2, -1),
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
)
# Assert outputs are close.
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_vs_emulated(self, base_dtype):

View File

@ -7424,17 +7424,17 @@ def _meta_grouped_mm_common(
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
torch._check(
mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
)
else:
torch._check(
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
)
torch._check(
mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}",
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950
)
mat_a_is_2d = mat_a.dim() == 2
@ -7458,11 +7458,11 @@ def _meta_grouped_mm_common(
torch._check(
is_row_major(mat_a),
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}",
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950
)
torch._check(
is_col_major(mat_b),
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}",
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950
)
def check_valid_strides(mat_name, mat):
@ -7474,7 +7474,7 @@ def _meta_grouped_mm_common(
):
torch._check(
mat_stride[end_dim] % alignment == 0,
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.",
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950
)
elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
1, mat.shape[end_dim]
@ -7494,41 +7494,81 @@ def _meta_grouped_mm_common(
if scale_a is not None and scale_b is not None:
torch._check(
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
(scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32)
or (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
),
lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
)
is_mxfp8 = (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
)
def round_up(x, y):
"""Rounds up x to nearest multiple of y"""
return ((x + y - 1) // y) * y
def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
if mat.dim() == 2:
torch._check(
scale.dim() == 1,
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.is_contiguous(),
lambda: f"Expected {scale_name} to be contiguous.",
)
torch._check(
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
)
# For MXFP8, 2d tensors have variable size groups represented as subtensors,
# that are converted to blocked padded format individually. At compile time we don't know
# the group sizes yet, so we don't know the expect size of the blocked format scale.
# This limits what we can check here.
if is_mxfp8:
torch._check(
scale.dim() == mat.dim(),
lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
)
else:
torch._check(
scale.dim() == 1,
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
)
else:
torch._check(
scale.dim() == 2,
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.stride(1) == 1,
scale.stride(-1) == 1,
lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
)
torch._check(
scale.shape[0] == mat.shape[0],
lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
)
torch._check(
scale.shape[1] == mat.shape[1 + scaled_dim],
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.",
)
# For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked
# scale sizes at compile time.
if is_mxfp8:
torch._check(
mat.ndim == scale.ndim,
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
)
# TODO: This logic only holds for RHS tensor in 2d-3d case.
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
G, K, N = scale.shape
block_size = 32
blocked_K = round_up(K / block_size, 4)
blocked_N = round_up(N, 128)
torch._check(
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
)
else:
torch._check(
scale.dim() == 2,
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.shape[1] == mat.shape[1 + scaled_dim],
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950
)
scale_multiplier = (
offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1

View File

@ -41,6 +41,7 @@ IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_ca
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0))
IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0))
def evaluate_gfx_arch_within(arch_list):
if not torch.cuda.is_available():
@ -129,9 +130,17 @@ def evaluate_platform_supports_mx_gemm():
return SM100OrLater
return False
def evaluate_platform_supports_mxfp8_grouped_gemm():
if torch.cuda.is_available() and not torch.version.hip:
built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show()
return built_with_fbgemm_genai and IS_SM100
return False
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm())
if TEST_NUMBA:
try:

View File

@ -479,3 +479,110 @@ def to_blocked(input_matrix) -> torch.Tensor:
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()
# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
def to_mxfp8(
data_hp: torch.Tensor,
block_size: int = 32,
):
assert data_hp.dtype in (
torch.bfloat16,
torch.float,
), f"{data_hp.dtype} is not supported yet"
assert (
data_hp.shape[-1] % block_size == 0
), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
assert data_hp.is_contiguous(), "unsupported"
orig_shape = data_hp.shape
data_hp = data_hp.reshape(
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
)
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
data_hp = data_hp.to(torch.float32)
max_abs = max_abs.to(torch.float32)
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
max_pos = F8E4M3_MAX
# RCEIL
def _to_mx_rceil(
data_hp: torch.Tensor,
max_abs: torch.Tensor,
max_pos: float,
) -> tuple[torch.Tensor, torch.Tensor]:
E8M0_EXPONENT_BIAS = 127
descale = max_abs / max_pos
exponent = torch.where(
torch.isnan(descale),
0xFF, # Handle biased exponent for nan
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
(
torch.clamp(
torch.ceil(torch.log2(descale)),
min=-E8M0_EXPONENT_BIAS,
max=E8M0_EXPONENT_BIAS,
)
+ E8M0_EXPONENT_BIAS
).to(torch.uint8),
)
descale_fp = torch.where(
exponent == 0,
1.0,
torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
)
# scale and saturated cast the data elements to max of target dtype
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
return exponent, data_lp
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
# cast to target dtype
data_lp = data_lp.to(torch.float8_e4m3fn)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
return scale_e8m0_biased, data_lp
# Source: https://github.com/pytorch/ao/blob/568c1932a16ae9f30d48da214a88dc0013e98ed8/torchao/prototype/moe_training/utils.py#L310
def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):
"""
Utility function for tests and benchmarks.
Generates a tensor of length E, containing random values divisible by `multiple_of`,
from 0 to M, in sorted order, and where the final value in the tensor is always M.
Args:
E (int): The length of the tensor.
M (int): The maximum value in the tensor.
Returns:
torch.Tensor: A tensor of length E with the specified properties.
"""
import random
# Ensure M is divisible by 16
if M % multiple_of != 0:
raise ValueError(f"M must be divisible by {multiple_of}")
# Generate a list of possible values
possible_values = list(range(multiple_of, M + 1, multiple_of))
# If E is larger than the number of possible values, raise an error
if E > len(possible_values):
raise ValueError("E cannot be larger than the number of possible values")
# Randomly select E - 1 values from the possible values (excluding M)
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
# Append M to the selected values
selected_values = torch.cat((selected_values, torch.tensor([M])))
# Sort the selected values
selected_values, _ = torch.sort(selected_values)
return selected_values.to(dtype).to(device)