mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add NVFP4 two-level scaling to scaled_mm (#165774)
Summary: * Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing * Add two-level tests Test Plan: ``` pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
f510d0dbc0
commit
d14cbb4476
@ -2322,12 +2322,23 @@ _scaled_nvfp4_nvfp4(
|
||||
const Tensor& scale_b, const SwizzleType swizzle_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool single_scale,
|
||||
Tensor& out) {
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& global_scale_a = std::nullopt,
|
||||
const std::optional<Tensor>& global_scale_b = std::nullopt) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
|
||||
std::optional<Tensor> alpha = std::nullopt;
|
||||
// Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise,
|
||||
// if this is "And" we would silently do nothing in the case where one global scale is
|
||||
// passed and not the other.
|
||||
if (global_scale_a.has_value() || global_scale_b.has_value()) {
|
||||
TORCH_CHECK_VALUE(global_scale_a.has_value(),
|
||||
"For two-level-scaled NVFP4, global_scale_a must have a value");
|
||||
TORCH_CHECK_VALUE(global_scale_b.has_value(),
|
||||
"For two-level-scaled NVFP4, global_scale_b must have a value");
|
||||
alpha = global_scale_a.value().mul(global_scale_b.value());
|
||||
}
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
// Scales must be swizzled
|
||||
@ -2349,7 +2360,7 @@ _scaled_nvfp4_nvfp4(
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x16;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x16;
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha);
|
||||
}
|
||||
|
||||
|
||||
@ -2555,9 +2566,10 @@ _scaled_mm_cuda_v2_out(
|
||||
} else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
|
||||
return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out,
|
||||
scale_a[1], scale_b[1]);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
|
||||
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else {
|
||||
|
@ -413,6 +413,42 @@ def data_to_nvfp4_scale(x, block_size):
|
||||
return scale
|
||||
|
||||
|
||||
def data_to_nvfp4_with_global_scale(x, block_size):
|
||||
# Simple (slow) reference implementation of NVFP4 two-level-scaling
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, block_size)
|
||||
|
||||
# Per-block-amax
|
||||
block_max = torch.amax(torch.abs(x), 1) + 1e-12
|
||||
|
||||
# Per-tensor max
|
||||
global_max = x.abs().max()
|
||||
|
||||
# Contants
|
||||
# Global encoding scale for block-scales
|
||||
S_enc = FP4_MAX_VAL * F8E4M3_MAX_VAL / global_max
|
||||
S_dec = 1. / S_enc
|
||||
|
||||
# Per-block decode-scale
|
||||
S_dec_b = block_max / FP4_MAX_VAL
|
||||
|
||||
# Stored scaled-e4m3 per-block decode scales
|
||||
S_dec_b_e4m3 = (S_dec_b * S_enc).to(torch.float8_e4m3fn)
|
||||
|
||||
# Actual per-block encoding scale
|
||||
S_enc_b = S_enc / S_dec_b_e4m3.float()
|
||||
|
||||
# scale & reshape input, reshape scales
|
||||
x = (S_enc_b.unsqueeze(1) * x).bfloat16().reshape(orig_shape)
|
||||
S_dec_b_e4m3 = S_dec_b_e4m3.reshape(orig_shape[0], -1)
|
||||
|
||||
# cast input
|
||||
x_fp4 = _bfloat16_to_float4_e2m1fn_x2(x)
|
||||
|
||||
# fp4x2, fp8_e4m3, float respectively
|
||||
return x_fp4, S_dec_b_e4m3, S_dec.float()
|
||||
|
||||
|
||||
def down_size(size):
|
||||
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
|
||||
return (*size[:-1], size[-1] // 2)
|
||||
@ -1254,6 +1290,59 @@ class TestFP8Matmul(TestCase):
|
||||
lp_data_expected = torch.tensor([0b10110010], dtype=torch.uint8)
|
||||
torch.testing.assert_close(lp_data_actual, lp_data_expected, atol=0, rtol=0)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
@parametrize("mkn", [
|
||||
# Nice shapes
|
||||
(128, 128, 128),
|
||||
(256, 256, 256),
|
||||
(128, 256, 512),
|
||||
(256, 512, 128),
|
||||
(512, 128, 256),
|
||||
|
||||
# Very unbalanced
|
||||
(1023, 64, 48),
|
||||
(31, 1024, 64),
|
||||
(45, 96, 1024),
|
||||
|
||||
# Mixed large and small
|
||||
(2, 1024, 128),
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None:
|
||||
device = 'cuda'
|
||||
M, K, N = mkn
|
||||
BLOCK_SIZE = 16
|
||||
# Note: SQNR target from `test_blockwise_mxfp8_nvfp4_mxfp4_numerics` test
|
||||
approx_match_sqnr_target = 15.8
|
||||
|
||||
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
|
||||
A, A_scale, A_global_scale = data_to_nvfp4_with_global_scale(A_ref, BLOCK_SIZE)
|
||||
B, B_scale, B_global_scale = data_to_nvfp4_with_global_scale(B_ref, BLOCK_SIZE)
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
C = scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
scale_a=[A_scale, A_global_scale],
|
||||
scale_recipe_a=[ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
scale_b=[B_scale, B_global_scale],
|
||||
scale_recipe_b=[ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
swizzle_a=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE],
|
||||
swizzle_b=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE],
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
sqnr = compute_error(C_ref, C)
|
||||
assert sqnr.item() > approx_match_sqnr_target
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
@parametrize("test_case_name", [
|
||||
"a_eye_b_eye",
|
||||
|
Reference in New Issue
Block a user