From d14cbb44760e69b3f2871a1fc428a03ae16a9056 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Fri, 17 Oct 2025 23:29:10 +0000 Subject: [PATCH] Add NVFP4 two-level scaling to scaled_mm (#165774) Summary: * Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing * Add two-level tests Test Plan: ``` pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774 Approved by: https://github.com/drisspg --- aten/src/ATen/native/cuda/Blas.cpp | 24 ++++++-- test/test_scaled_matmul_cuda.py | 89 ++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 4ee35013ab77..68a9582a09c1 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -2322,12 +2322,23 @@ _scaled_nvfp4_nvfp4( const Tensor& scale_b, const SwizzleType swizzle_b, const std::optional& bias, const c10::ScalarType out_dtype, - const bool single_scale, - Tensor& out) { + Tensor& out, + const std::optional& global_scale_a = std::nullopt, + const std::optional& global_scale_b = std::nullopt) { #ifdef USE_ROCM TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM"); #endif - TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported"); + std::optional alpha = std::nullopt; + // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise, + // if this is "And" we would silently do nothing in the case where one global scale is + // passed and not the other. + if (global_scale_a.has_value() || global_scale_b.has_value()) { + TORCH_CHECK_VALUE(global_scale_a.has_value(), + "For two-level-scaled NVFP4, global_scale_a must have a value"); + TORCH_CHECK_VALUE(global_scale_b.has_value(), + "For two-level-scaled NVFP4, global_scale_b must have a value"); + alpha = global_scale_a.value().mul(global_scale_b.value()); + } // Restrictions: // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 // Scales must be swizzled @@ -2349,7 +2360,7 @@ _scaled_nvfp4_nvfp4( auto scaling_choice_a = ScalingType::BlockWise1x16; auto scaling_choice_b = ScalingType::BlockWise1x16; - return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); + return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha); } @@ -2555,9 +2566,10 @@ _scaled_mm_cuda_v2_out( } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) { return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported"); + return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out, + scale_a[1], scale_b[1]); } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) { - return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out); + return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) { return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); } else { diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index d57b1535d02f..7dd6f10d3a82 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -413,6 +413,42 @@ def data_to_nvfp4_scale(x, block_size): return scale +def data_to_nvfp4_with_global_scale(x, block_size): + # Simple (slow) reference implementation of NVFP4 two-level-scaling + orig_shape = x.shape + x = x.reshape(-1, block_size) + + # Per-block-amax + block_max = torch.amax(torch.abs(x), 1) + 1e-12 + + # Per-tensor max + global_max = x.abs().max() + + # Contants + # Global encoding scale for block-scales + S_enc = FP4_MAX_VAL * F8E4M3_MAX_VAL / global_max + S_dec = 1. / S_enc + + # Per-block decode-scale + S_dec_b = block_max / FP4_MAX_VAL + + # Stored scaled-e4m3 per-block decode scales + S_dec_b_e4m3 = (S_dec_b * S_enc).to(torch.float8_e4m3fn) + + # Actual per-block encoding scale + S_enc_b = S_enc / S_dec_b_e4m3.float() + + # scale & reshape input, reshape scales + x = (S_enc_b.unsqueeze(1) * x).bfloat16().reshape(orig_shape) + S_dec_b_e4m3 = S_dec_b_e4m3.reshape(orig_shape[0], -1) + + # cast input + x_fp4 = _bfloat16_to_float4_e2m1fn_x2(x) + + # fp4x2, fp8_e4m3, float respectively + return x_fp4, S_dec_b_e4m3, S_dec.float() + + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" return (*size[:-1], size[-1] // 2) @@ -1254,6 +1290,59 @@ class TestFP8Matmul(TestCase): lp_data_expected = torch.tensor([0b10110010], dtype=torch.uint8) torch.testing.assert_close(lp_data_actual, lp_data_expected, atol=0, rtol=0) + + @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) + @parametrize("mkn", [ + # Nice shapes + (128, 128, 128), + (256, 256, 256), + (128, 256, 512), + (256, 512, 128), + (512, 128, 256), + + # Very unbalanced + (1023, 64, 48), + (31, 1024, 64), + (45, 96, 1024), + + # Mixed large and small + (2, 1024, 128), + (127, 96, 1024), + (1025, 128, 96) + ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") + def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: + device = 'cuda' + M, K, N = mkn + BLOCK_SIZE = 16 + # Note: SQNR target from `test_blockwise_mxfp8_nvfp4_mxfp4_numerics` test + approx_match_sqnr_target = 15.8 + + A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000 + B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000 + + A, A_scale, A_global_scale = data_to_nvfp4_with_global_scale(A_ref, BLOCK_SIZE) + B, B_scale, B_global_scale = data_to_nvfp4_with_global_scale(B_ref, BLOCK_SIZE) + A_scale = to_blocked(A_scale) + B_scale = to_blocked(B_scale) + + C_ref = A_ref @ B_ref.t() + + C = scaled_mm( + A, + B.t(), + scale_a=[A_scale, A_global_scale], + scale_recipe_a=[ScalingType.BlockWise1x16, ScalingType.TensorWise], + scale_b=[B_scale, B_global_scale], + scale_recipe_b=[ScalingType.BlockWise1x16, ScalingType.TensorWise], + swizzle_a=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE], + swizzle_b=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE], + output_dtype=torch.bfloat16, + ) + + sqnr = compute_error(C_ref, C) + assert sqnr.item() > approx_match_sqnr_target + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) @parametrize("test_case_name", [ "a_eye_b_eye",