test_scaled_matmul_cuda: fix infer_scale_swizzle (#165788)

Extend #165747 fix to other cases.
Add parentheses to clarify operator precedence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165788
Approved by: https://github.com/jeffdaily, https://github.com/slayton58
This commit is contained in:
Jagadish Krishnamoorthy
2025-10-19 21:41:58 +00:00
committed by PyTorch MergeBot
parent 8139f33fa5
commit 8951df03de

View File

@ -128,10 +128,10 @@ def infer_scale_swizzle(mat, scale):
# deepgemm 1x128 / 128x1
if len(scale.shape) > 1:
if (
scale.shape[0] == mat.shape[0]
and scale.shape[1] == math.ceil(mat.shape[1] // 128)
or scale.shape[1] == mat.shape[1]
and scale.shape[0] == math.ceil(mat.shape[0] // 128)
(scale.shape[0] == mat.shape[0]
and scale.shape[1] == math.ceil(mat.shape[1] // 128))
or (scale.shape[1] == mat.shape[1]
and scale.shape[0] == math.ceil(mat.shape[0] // 128))
):
return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
@ -143,10 +143,10 @@ def infer_scale_swizzle(mat, scale):
# NVFP4
if (
scale.numel()
(scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e4m3fn
):
@ -164,10 +164,10 @@ def infer_scale_swizzle(mat, scale):
if not torch.version.hip:
# MXFP8 w/ swizzle
if (
scale.numel()
(scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
@ -175,8 +175,8 @@ def infer_scale_swizzle(mat, scale):
else:
# MXFP8 w/o swizzle
if (
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE