mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add NVFP4 two-level scaling to scaled_mm (#165774)
Summary: * Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing * Add two-level tests Test Plan: ``` pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
f510d0dbc0
commit
d14cbb4476
@ -2322,12 +2322,23 @@ _scaled_nvfp4_nvfp4(
|
||||
const Tensor& scale_b, const SwizzleType swizzle_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool single_scale,
|
||||
Tensor& out) {
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& global_scale_a = std::nullopt,
|
||||
const std::optional<Tensor>& global_scale_b = std::nullopt) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
|
||||
std::optional<Tensor> alpha = std::nullopt;
|
||||
// Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise,
|
||||
// if this is "And" we would silently do nothing in the case where one global scale is
|
||||
// passed and not the other.
|
||||
if (global_scale_a.has_value() || global_scale_b.has_value()) {
|
||||
TORCH_CHECK_VALUE(global_scale_a.has_value(),
|
||||
"For two-level-scaled NVFP4, global_scale_a must have a value");
|
||||
TORCH_CHECK_VALUE(global_scale_b.has_value(),
|
||||
"For two-level-scaled NVFP4, global_scale_b must have a value");
|
||||
alpha = global_scale_a.value().mul(global_scale_b.value());
|
||||
}
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
// Scales must be swizzled
|
||||
@ -2349,7 +2360,7 @@ _scaled_nvfp4_nvfp4(
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x16;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x16;
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha);
|
||||
}
|
||||
|
||||
|
||||
@ -2555,9 +2566,10 @@ _scaled_mm_cuda_v2_out(
|
||||
} else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
|
||||
return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out,
|
||||
scale_a[1], scale_b[1]);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
|
||||
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user