MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)

## Summary
- We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816
- This is needed for backward pass of mxfp8 MoE training with grouped gemms
- Changes:
    - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm`
    - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile
    - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs
    - Bump FBGEMM third party submodule to include:
          - https://github.com/pytorch/FBGEMM/pull/4816
          - https://github.com/pytorch/FBGEMM/pull/4820
          - https://github.com/pytorch/FBGEMM/pull/4821
          - https://github.com/pytorch/FBGEMM/pull/4823

#### How fbgemm dependency was bumped
Documenting this since I haven't found it documented elsewhere:
- `cd ~/pytorch/third_party/fbgemm`
- `git fetch`
- `git checkout <hash>`
- `cd ~/pytorch`
- `git add third_party/fbgemm`

## Test plan

#### Test build
```
USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e .
...
Successfully installed torch-2.9.0a0+gitf5070f3
```
[full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581)

#### Unit tests
```
pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_
...

test/test_matmul_cuda.py .........                                                                                                                        [100%]

============================================================== 9 passed, 1668 deselected in 5.34s ===============================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209
Approved by: https://github.com/ngimel
This commit is contained in:
Daniel Vega-Myhre
2025-09-06 15:25:30 +00:00
committed by PyTorch MergeBot
parent 5985e28912
commit b6d0a9ea90
9 changed files with 531 additions and 70 deletions

View File

@ -889,6 +889,12 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
set(USE_FBGEMM_GENAI off)
endif()
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a")
message(WARNING "Setting USE_FBGEMM_GENAI to ON for CUDA build on SM100")
set(USE_FBGEMM_GENAI ON)
endif()
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't
cmake_dependent_option(