Files
pytorch/caffe2
Daniel Vega-Myhre b6d0a9ea90 MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)
## Summary
- We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816
- This is needed for backward pass of mxfp8 MoE training with grouped gemms
- Changes:
    - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm`
    - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile
    - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs
    - Bump FBGEMM third party submodule to include:
          - https://github.com/pytorch/FBGEMM/pull/4816
          - https://github.com/pytorch/FBGEMM/pull/4820
          - https://github.com/pytorch/FBGEMM/pull/4821
          - https://github.com/pytorch/FBGEMM/pull/4823

#### How fbgemm dependency was bumped
Documenting this since I haven't found it documented elsewhere:
- `cd ~/pytorch/third_party/fbgemm`
- `git fetch`
- `git checkout <hash>`
- `cd ~/pytorch`
- `git add third_party/fbgemm`

## Test plan

#### Test build
```
USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e .
...
Successfully installed torch-2.9.0a0+gitf5070f3
```
[full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581)

#### Unit tests
```
pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_
...

test/test_matmul_cuda.py .........                                                                                                                        [100%]

============================================================== 9 passed, 1668 deselected in 5.34s ===============================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209
Approved by: https://github.com/ngimel
2025-09-06 15:25:30 +00:00
..