Add NVFP4 two-level scaling to scaled_mm (#165774)

Summary:

* Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing
* Add two-level tests

Test Plan:

```
pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774
Approved by: https://github.com/drisspg
This commit is contained in:
Simon Layton
2025-10-17 23:29:10 +00:00
committed by PyTorch MergeBot
parent f510d0dbc0
commit d14cbb4476
2 changed files with 107 additions and 6 deletions

View File

@ -2322,12 +2322,23 @@ _scaled_nvfp4_nvfp4(
const Tensor& scale_b, const SwizzleType swizzle_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
const bool single_scale,
Tensor& out) {
Tensor& out,
const std::optional<Tensor>& global_scale_a = std::nullopt,
const std::optional<Tensor>& global_scale_b = std::nullopt) {
#ifdef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
#endif
TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
std::optional<Tensor> alpha = std::nullopt;
// Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise,
// if this is "And" we would silently do nothing in the case where one global scale is
// passed and not the other.
if (global_scale_a.has_value() || global_scale_b.has_value()) {
TORCH_CHECK_VALUE(global_scale_a.has_value(),
"For two-level-scaled NVFP4, global_scale_a must have a value");
TORCH_CHECK_VALUE(global_scale_b.has_value(),
"For two-level-scaled NVFP4, global_scale_b must have a value");
alpha = global_scale_a.value().mul(global_scale_b.value());
}
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
// Scales must be swizzled
@ -2349,7 +2360,7 @@ _scaled_nvfp4_nvfp4(
auto scaling_choice_a = ScalingType::BlockWise1x16;
auto scaling_choice_b = ScalingType::BlockWise1x16;
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha);
}
@ -2555,9 +2566,10 @@ _scaled_mm_cuda_v2_out(
} else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out,
scale_a[1], scale_b[1]);
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else {

View File

@ -413,6 +413,42 @@ def data_to_nvfp4_scale(x, block_size):
return scale
def data_to_nvfp4_with_global_scale(x, block_size):
# Simple (slow) reference implementation of NVFP4 two-level-scaling
orig_shape = x.shape
x = x.reshape(-1, block_size)
# Per-block-amax
block_max = torch.amax(torch.abs(x), 1) + 1e-12
# Per-tensor max
global_max = x.abs().max()
# Contants
# Global encoding scale for block-scales
S_enc = FP4_MAX_VAL * F8E4M3_MAX_VAL / global_max
S_dec = 1. / S_enc
# Per-block decode-scale
S_dec_b = block_max / FP4_MAX_VAL
# Stored scaled-e4m3 per-block decode scales
S_dec_b_e4m3 = (S_dec_b * S_enc).to(torch.float8_e4m3fn)
# Actual per-block encoding scale
S_enc_b = S_enc / S_dec_b_e4m3.float()
# scale & reshape input, reshape scales
x = (S_enc_b.unsqueeze(1) * x).bfloat16().reshape(orig_shape)
S_dec_b_e4m3 = S_dec_b_e4m3.reshape(orig_shape[0], -1)
# cast input
x_fp4 = _bfloat16_to_float4_e2m1fn_x2(x)
# fp4x2, fp8_e4m3, float respectively
return x_fp4, S_dec_b_e4m3, S_dec.float()
def down_size(size):
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
return (*size[:-1], size[-1] // 2)
@ -1254,6 +1290,59 @@ class TestFP8Matmul(TestCase):
lp_data_expected = torch.tensor([0b10110010], dtype=torch.uint8)
torch.testing.assert_close(lp_data_actual, lp_data_expected, atol=0, rtol=0)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
@parametrize("mkn", [
# Nice shapes
(128, 128, 128),
(256, 256, 256),
(128, 256, 512),
(256, 512, 128),
(512, 128, 256),
# Very unbalanced
(1023, 64, 48),
(31, 1024, 64),
(45, 96, 1024),
# Mixed large and small
(2, 1024, 128),
(127, 96, 1024),
(1025, 128, 96)
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None:
device = 'cuda'
M, K, N = mkn
BLOCK_SIZE = 16
# Note: SQNR target from `test_blockwise_mxfp8_nvfp4_mxfp4_numerics` test
approx_match_sqnr_target = 15.8
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
A, A_scale, A_global_scale = data_to_nvfp4_with_global_scale(A_ref, BLOCK_SIZE)
B, B_scale, B_global_scale = data_to_nvfp4_with_global_scale(B_ref, BLOCK_SIZE)
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
C_ref = A_ref @ B_ref.t()
C = scaled_mm(
A,
B.t(),
scale_a=[A_scale, A_global_scale],
scale_recipe_a=[ScalingType.BlockWise1x16, ScalingType.TensorWise],
scale_b=[B_scale, B_global_scale],
scale_recipe_b=[ScalingType.BlockWise1x16, ScalingType.TensorWise],
swizzle_a=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE],
swizzle_b=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE],
output_dtype=torch.bfloat16,
)
sqnr = compute_error(C_ref, C)
assert sqnr.item() > approx_match_sqnr_target
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
@parametrize("test_case_name", [
"a_eye_b_eye",