mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
User-passed alpha to scaled_gemm (#165563)
Summary: Add optional user-passed `alpha` argument to `at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm calls (where the global de-scales are folded into the `alpha` argument. Global de-scales are naturally device tensors, but using cublas' device-pointer mode for `alpha`/`beta` has an interesting lifetime implication - the `alpha` tensor must be valid & correct until the end of the matmul call, *not* just the launch (as for host values). To enable this, I added device-constant memory for `one` and `zero`, along with a statically-held single-fp32-value tensor, which is valid from the first passed-`alpha` invocation of `scaled_gemm` to the end of the program. User-passed values are copied into this perpetual buffer to ensure lifetime requirements are met. Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563 Approved by: https://github.com/drisspg, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
202f83dc4e
commit
cb6e4d7d82
@ -1359,7 +1359,8 @@ _scaled_gemm(
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
@ -1410,7 +1411,8 @@ _scaled_gemm(
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
use_fast_accum,
|
||||
alpha);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user