mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Add optional user-passed `alpha` argument to `at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm calls (where the global de-scales are folded into the `alpha` argument. Global de-scales are naturally device tensors, but using cublas' device-pointer mode for `alpha`/`beta` has an interesting lifetime implication - the `alpha` tensor must be valid & correct until the end of the matmul call, *not* just the launch (as for host values). To enable this, I added device-constant memory for `one` and `zero`, along with a statically-held single-fp32-value tensor, which is valid from the first passed-`alpha` invocation of `scaled_gemm` to the end of the program. User-passed values are copied into this perpetual buffer to ensure lifetime requirements are met. Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563 Approved by: https://github.com/drisspg, https://github.com/eqy
55 lines
1.2 KiB
Plaintext
55 lines
1.2 KiB
Plaintext
#include <ATen/Functions.h>
|
|
#include <ATen/Tensor.h>
|
|
#include <ATen/cuda/Exceptions.h>
|
|
|
|
#include <mutex>
|
|
|
|
namespace at {
|
|
namespace cuda {
|
|
namespace detail {
|
|
|
|
__device__ __constant__ float cublas_one_device;
|
|
__device__ __constant__ float cublas_zero_device;
|
|
|
|
float *get_cublas_device_one() {
|
|
static c10::once_flag init_flag;
|
|
|
|
c10::call_once(init_flag, []() {
|
|
const float one = 1.f;
|
|
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
|
|
});
|
|
|
|
float *ptr;
|
|
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
|
return ptr;
|
|
}
|
|
|
|
float *get_cublas_device_zero() {
|
|
static c10::once_flag init_flag;
|
|
|
|
c10::call_once(init_flag, []() {
|
|
const float zero = 0.f;
|
|
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
|
|
});
|
|
|
|
float *ptr;
|
|
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
|
return ptr;
|
|
}
|
|
|
|
float *get_user_alpha_ptr() {
|
|
static float *alpha_ptr;
|
|
|
|
static c10::once_flag init_flag;
|
|
|
|
c10::call_once(init_flag, []() {
|
|
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
|
|
});
|
|
|
|
return alpha_ptr;
|
|
}
|
|
|
|
} // namespace detail
|
|
} // namespace cuda
|
|
} // namespace at
|