mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)
## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - https://github.com/pytorch/FBGEMM/pull/4816 - https://github.com/pytorch/FBGEMM/pull/4820 - https://github.com/pytorch/FBGEMM/pull/4821 - https://github.com/pytorch/FBGEMM/pull/4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5985e28912
commit
b6d0a9ea90
@ -889,6 +889,12 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
set(USE_FBGEMM_GENAI off)
|
||||
endif()
|
||||
|
||||
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100
|
||||
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a")
|
||||
message(WARNING "Setting USE_FBGEMM_GENAI to ON for CUDA build on SM100")
|
||||
set(USE_FBGEMM_GENAI ON)
|
||||
endif()
|
||||
|
||||
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
|
||||
# Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
|
Reference in New Issue
Block a user