mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
test_scaled_matmul_cuda: fix infer_scale_swizzle (#165788)
Extend #165747 fix to other cases. Add parentheses to clarify operator precedence. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165788 Approved by: https://github.com/jeffdaily, https://github.com/slayton58
This commit is contained in:
committed by
PyTorch MergeBot
parent
8139f33fa5
commit
8951df03de
@ -128,10 +128,10 @@ def infer_scale_swizzle(mat, scale):
|
||||
# deepgemm 1x128 / 128x1
|
||||
if len(scale.shape) > 1:
|
||||
if (
|
||||
scale.shape[0] == mat.shape[0]
|
||||
and scale.shape[1] == math.ceil(mat.shape[1] // 128)
|
||||
or scale.shape[1] == mat.shape[1]
|
||||
and scale.shape[0] == math.ceil(mat.shape[0] // 128)
|
||||
(scale.shape[0] == mat.shape[0]
|
||||
and scale.shape[1] == math.ceil(mat.shape[1] // 128))
|
||||
or (scale.shape[1] == mat.shape[1]
|
||||
and scale.shape[0] == math.ceil(mat.shape[0] // 128))
|
||||
):
|
||||
return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
@ -143,10 +143,10 @@ def infer_scale_swizzle(mat, scale):
|
||||
|
||||
# NVFP4
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
@ -164,10 +164,10 @@ def infer_scale_swizzle(mat, scale):
|
||||
if not torch.version.hip:
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
@ -175,8 +175,8 @@ def infer_scale_swizzle(mat, scale):
|
||||
else:
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
Reference in New Issue
Block a user