mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)
## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - https://github.com/pytorch/FBGEMM/pull/4816 - https://github.com/pytorch/FBGEMM/pull/4820 - https://github.com/pytorch/FBGEMM/pull/4821 - https://github.com/pytorch/FBGEMM/pull/4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5985e28912
commit
b6d0a9ea90
@ -7424,17 +7424,17 @@ def _meta_grouped_mm_common(
|
||||
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
|
||||
torch._check(
|
||||
mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
|
||||
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
|
||||
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
|
||||
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
|
||||
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
|
||||
)
|
||||
|
||||
torch._check(
|
||||
mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
|
||||
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}",
|
||||
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950
|
||||
)
|
||||
|
||||
mat_a_is_2d = mat_a.dim() == 2
|
||||
@ -7458,11 +7458,11 @@ def _meta_grouped_mm_common(
|
||||
|
||||
torch._check(
|
||||
is_row_major(mat_a),
|
||||
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}",
|
||||
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950
|
||||
)
|
||||
torch._check(
|
||||
is_col_major(mat_b),
|
||||
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}",
|
||||
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950
|
||||
)
|
||||
|
||||
def check_valid_strides(mat_name, mat):
|
||||
@ -7474,7 +7474,7 @@ def _meta_grouped_mm_common(
|
||||
):
|
||||
torch._check(
|
||||
mat_stride[end_dim] % alignment == 0,
|
||||
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.",
|
||||
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950
|
||||
)
|
||||
elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
|
||||
1, mat.shape[end_dim]
|
||||
@ -7494,41 +7494,81 @@ def _meta_grouped_mm_common(
|
||||
|
||||
if scale_a is not None and scale_b is not None:
|
||||
torch._check(
|
||||
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
|
||||
lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
|
||||
(scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32)
|
||||
or (
|
||||
scale_a.dtype == torch.float8_e8m0fnu
|
||||
and scale_b.dtype == torch.float8_e8m0fnu
|
||||
),
|
||||
lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
|
||||
)
|
||||
is_mxfp8 = (
|
||||
scale_a.dtype == torch.float8_e8m0fnu
|
||||
and scale_b.dtype == torch.float8_e8m0fnu
|
||||
)
|
||||
|
||||
def round_up(x, y):
|
||||
"""Rounds up x to nearest multiple of y"""
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
|
||||
if mat.dim() == 2:
|
||||
torch._check(
|
||||
scale.dim() == 1,
|
||||
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.is_contiguous(),
|
||||
lambda: f"Expected {scale_name} to be contiguous.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
|
||||
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
|
||||
)
|
||||
# For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
# that are converted to blocked padded format individually. At compile time we don't know
|
||||
# the group sizes yet, so we don't know the expect size of the blocked format scale.
|
||||
# This limits what we can check here.
|
||||
if is_mxfp8:
|
||||
torch._check(
|
||||
scale.dim() == mat.dim(),
|
||||
lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 1,
|
||||
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
|
||||
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 2,
|
||||
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.stride(1) == 1,
|
||||
scale.stride(-1) == 1,
|
||||
lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[0],
|
||||
lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[1] == mat.shape[1 + scaled_dim],
|
||||
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.",
|
||||
)
|
||||
# For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked
|
||||
# scale sizes at compile time.
|
||||
if is_mxfp8:
|
||||
torch._check(
|
||||
mat.ndim == scale.ndim,
|
||||
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
|
||||
)
|
||||
# TODO: This logic only holds for RHS tensor in 2d-3d case.
|
||||
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
|
||||
G, K, N = scale.shape
|
||||
block_size = 32
|
||||
blocked_K = round_up(K / block_size, 4)
|
||||
blocked_N = round_up(N, 128)
|
||||
torch._check(
|
||||
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
|
||||
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 2,
|
||||
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[1] == mat.shape[1 + scaled_dim],
|
||||
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950
|
||||
)
|
||||
|
||||
scale_multiplier = (
|
||||
offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1
|
||||
|
Reference in New Issue
Block a user