Files
pytorch/aten/src/ATen/cuda/detail/BLASConstants.cu
Simon Layton cb6e4d7d82 User-passed alpha to scaled_gemm (#165563)
Summary:

Add optional user-passed `alpha` argument to
`at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm
calls (where the global de-scales are folded into the `alpha` argument.

Global de-scales are naturally device tensors, but using cublas'
device-pointer mode for `alpha`/`beta` has an interesting lifetime
implication - the `alpha` tensor must be valid & correct until the end
of the matmul call, *not* just the launch (as for host values). To
enable this, I added device-constant memory for `one` and `zero`, along
with a statically-held single-fp32-value tensor, which is valid from the
first passed-`alpha` invocation of `scaled_gemm` to the end of the
program. User-passed values are copied into this perpetual buffer to
ensure lifetime requirements are met.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563
Approved by: https://github.com/drisspg, https://github.com/eqy
2025-10-17 09:42:33 +00:00

55 lines
1.2 KiB
Plaintext

#include <ATen/Functions.h>
#include <ATen/Tensor.h>
#include <ATen/cuda/Exceptions.h>
#include <mutex>
namespace at {
namespace cuda {
namespace detail {
__device__ __constant__ float cublas_one_device;
__device__ __constant__ float cublas_zero_device;
float *get_cublas_device_one() {
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
const float one = 1.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
return ptr;
}
float *get_cublas_device_zero() {
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
const float zero = 0.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
return ptr;
}
float *get_user_alpha_ptr() {
static float *alpha_ptr;
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
});
return alpha_ptr;
}
} // namespace detail
} // namespace cuda
} // namespace at