MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)

## Summary
- We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816
- This is needed for backward pass of mxfp8 MoE training with grouped gemms
- Changes:
    - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm`
    - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile
    - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs
    - Bump FBGEMM third party submodule to include:
          - https://github.com/pytorch/FBGEMM/pull/4816
          - https://github.com/pytorch/FBGEMM/pull/4820
          - https://github.com/pytorch/FBGEMM/pull/4821
          - https://github.com/pytorch/FBGEMM/pull/4823

#### How fbgemm dependency was bumped
Documenting this since I haven't found it documented elsewhere:
- `cd ~/pytorch/third_party/fbgemm`
- `git fetch`
- `git checkout <hash>`
- `cd ~/pytorch`
- `git add third_party/fbgemm`

## Test plan

#### Test build
```
USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e .
...
Successfully installed torch-2.9.0a0+gitf5070f3
```
[full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581)

#### Unit tests
```
pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_
...

test/test_matmul_cuda.py .........                                                                                                                        [100%]

============================================================== 9 passed, 1668 deselected in 5.34s ===============================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209
Approved by: https://github.com/ngimel
This commit is contained in:
Daniel Vega-Myhre
2025-09-06 15:25:30 +00:00
committed by PyTorch MergeBot
parent 5985e28912
commit b6d0a9ea90
9 changed files with 531 additions and 70 deletions

View File

@ -7424,17 +7424,17 @@ def _meta_grouped_mm_common(
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
torch._check(
mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
)
else:
torch._check(
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
)
torch._check(
mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}",
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950
)
mat_a_is_2d = mat_a.dim() == 2
@ -7458,11 +7458,11 @@ def _meta_grouped_mm_common(
torch._check(
is_row_major(mat_a),
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}",
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950
)
torch._check(
is_col_major(mat_b),
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}",
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950
)
def check_valid_strides(mat_name, mat):
@ -7474,7 +7474,7 @@ def _meta_grouped_mm_common(
):
torch._check(
mat_stride[end_dim] % alignment == 0,
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.",
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950
)
elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
1, mat.shape[end_dim]
@ -7494,41 +7494,81 @@ def _meta_grouped_mm_common(
if scale_a is not None and scale_b is not None:
torch._check(
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
(scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32)
or (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
),
lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
)
is_mxfp8 = (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
)
def round_up(x, y):
"""Rounds up x to nearest multiple of y"""
return ((x + y - 1) // y) * y
def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
if mat.dim() == 2:
torch._check(
scale.dim() == 1,
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.is_contiguous(),
lambda: f"Expected {scale_name} to be contiguous.",
)
torch._check(
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
)
# For MXFP8, 2d tensors have variable size groups represented as subtensors,
# that are converted to blocked padded format individually. At compile time we don't know
# the group sizes yet, so we don't know the expect size of the blocked format scale.
# This limits what we can check here.
if is_mxfp8:
torch._check(
scale.dim() == mat.dim(),
lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
)
else:
torch._check(
scale.dim() == 1,
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
)
else:
torch._check(
scale.dim() == 2,
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.stride(1) == 1,
scale.stride(-1) == 1,
lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
)
torch._check(
scale.shape[0] == mat.shape[0],
lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
)
torch._check(
scale.shape[1] == mat.shape[1 + scaled_dim],
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.",
)
# For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked
# scale sizes at compile time.
if is_mxfp8:
torch._check(
mat.ndim == scale.ndim,
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
)
# TODO: This logic only holds for RHS tensor in 2d-3d case.
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
G, K, N = scale.shape
block_size = 32
blocked_K = round_up(K / block_size, 4)
blocked_N = round_up(N, 128)
torch._check(
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
)
else:
torch._check(
scale.dim() == 2,
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.shape[1] == mat.shape[1 + scaled_dim],
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950
)
scale_multiplier = (
offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1