mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add NVFP4 two-level scaling to scaled_mm (#165774)
Summary: * Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing * Add two-level tests Test Plan: ``` pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
f510d0dbc0
commit
d14cbb4476
@ -413,6 +413,42 @@ def data_to_nvfp4_scale(x, block_size):
|
||||
return scale
|
||||
|
||||
|
||||
def data_to_nvfp4_with_global_scale(x, block_size):
|
||||
# Simple (slow) reference implementation of NVFP4 two-level-scaling
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, block_size)
|
||||
|
||||
# Per-block-amax
|
||||
block_max = torch.amax(torch.abs(x), 1) + 1e-12
|
||||
|
||||
# Per-tensor max
|
||||
global_max = x.abs().max()
|
||||
|
||||
# Contants
|
||||
# Global encoding scale for block-scales
|
||||
S_enc = FP4_MAX_VAL * F8E4M3_MAX_VAL / global_max
|
||||
S_dec = 1. / S_enc
|
||||
|
||||
# Per-block decode-scale
|
||||
S_dec_b = block_max / FP4_MAX_VAL
|
||||
|
||||
# Stored scaled-e4m3 per-block decode scales
|
||||
S_dec_b_e4m3 = (S_dec_b * S_enc).to(torch.float8_e4m3fn)
|
||||
|
||||
# Actual per-block encoding scale
|
||||
S_enc_b = S_enc / S_dec_b_e4m3.float()
|
||||
|
||||
# scale & reshape input, reshape scales
|
||||
x = (S_enc_b.unsqueeze(1) * x).bfloat16().reshape(orig_shape)
|
||||
S_dec_b_e4m3 = S_dec_b_e4m3.reshape(orig_shape[0], -1)
|
||||
|
||||
# cast input
|
||||
x_fp4 = _bfloat16_to_float4_e2m1fn_x2(x)
|
||||
|
||||
# fp4x2, fp8_e4m3, float respectively
|
||||
return x_fp4, S_dec_b_e4m3, S_dec.float()
|
||||
|
||||
|
||||
def down_size(size):
|
||||
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
|
||||
return (*size[:-1], size[-1] // 2)
|
||||
@ -1254,6 +1290,59 @@ class TestFP8Matmul(TestCase):
|
||||
lp_data_expected = torch.tensor([0b10110010], dtype=torch.uint8)
|
||||
torch.testing.assert_close(lp_data_actual, lp_data_expected, atol=0, rtol=0)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
@parametrize("mkn", [
|
||||
# Nice shapes
|
||||
(128, 128, 128),
|
||||
(256, 256, 256),
|
||||
(128, 256, 512),
|
||||
(256, 512, 128),
|
||||
(512, 128, 256),
|
||||
|
||||
# Very unbalanced
|
||||
(1023, 64, 48),
|
||||
(31, 1024, 64),
|
||||
(45, 96, 1024),
|
||||
|
||||
# Mixed large and small
|
||||
(2, 1024, 128),
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None:
|
||||
device = 'cuda'
|
||||
M, K, N = mkn
|
||||
BLOCK_SIZE = 16
|
||||
# Note: SQNR target from `test_blockwise_mxfp8_nvfp4_mxfp4_numerics` test
|
||||
approx_match_sqnr_target = 15.8
|
||||
|
||||
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
|
||||
A, A_scale, A_global_scale = data_to_nvfp4_with_global_scale(A_ref, BLOCK_SIZE)
|
||||
B, B_scale, B_global_scale = data_to_nvfp4_with_global_scale(B_ref, BLOCK_SIZE)
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
C = scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
scale_a=[A_scale, A_global_scale],
|
||||
scale_recipe_a=[ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
scale_b=[B_scale, B_global_scale],
|
||||
scale_recipe_b=[ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
swizzle_a=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE],
|
||||
swizzle_b=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE],
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
sqnr = compute_error(C_ref, C)
|
||||
assert sqnr.item() > approx_match_sqnr_target
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
@parametrize("test_case_name", [
|
||||
"a_eye_b_eye",
|
||||
|
Reference in New Issue
Block a user