mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix B200 test fails in scaled_mm (#165747)
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the scale/swizzle inference code to prevent this. Fixes https://github.com/pytorch/pytorch/issues/165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165747 Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd3b48e85d
commit
39e0a832c9
@ -154,8 +154,8 @@ def infer_scale_swizzle(mat, scale):
|
||||
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
|
Reference in New Issue
Block a user