mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Add optional user-passed `alpha` argument to `at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm calls (where the global de-scales are folded into the `alpha` argument. Global de-scales are naturally device tensors, but using cublas' device-pointer mode for `alpha`/`beta` has an interesting lifetime implication - the `alpha` tensor must be valid & correct until the end of the matmul call, *not* just the launch (as for host values). To enable this, I added device-constant memory for `one` and `zero`, along with a statically-held single-fp32-value tensor, which is valid from the first passed-`alpha` invocation of `scaled_gemm` to the end of the program. User-passed values are copied into this perpetual buffer to ensure lifetime requirements are met. Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563 Approved by: https://github.com/drisspg, https://github.com/eqy