mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
User-passed alpha to scaled_gemm (#165563)
Summary: Add optional user-passed `alpha` argument to `at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm calls (where the global de-scales are folded into the `alpha` argument. Global de-scales are naturally device tensors, but using cublas' device-pointer mode for `alpha`/`beta` has an interesting lifetime implication - the `alpha` tensor must be valid & correct until the end of the matmul call, *not* just the launch (as for host values). To enable this, I added device-constant memory for `one` and `zero`, along with a statically-held single-fp32-value tensor, which is valid from the first passed-`alpha` invocation of `scaled_gemm` to the end of the program. User-passed values are copied into this perpetual buffer to ensure lifetime requirements are met. Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563 Approved by: https://github.com/drisspg, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
202f83dc4e
commit
cb6e4d7d82
@ -16,6 +16,8 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/detail/BLASConstants.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
@ -1954,13 +1956,15 @@ void scaled_gemm(
|
||||
const void *result_scale_ptr,
|
||||
int64_t result_ld,
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool use_fast_accum,
|
||||
const std::optional<Tensor>& alpha) {
|
||||
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
|
||||
// of input arguments to this function.
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
const float alpha_val = 1.0;
|
||||
const float beta_val = 0.0;
|
||||
// Note: alpha_val may change later depending on user-passed argument
|
||||
float alpha_val = 1.0;
|
||||
float beta_val = 0.0;
|
||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||
@ -2031,6 +2035,33 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
|
||||
// Handle user-passed alpha
|
||||
float *alpha_ptr = &alpha_val;
|
||||
float *beta_ptr = &beta_val;
|
||||
|
||||
if (alpha.has_value()) {
|
||||
auto& a = alpha.value();
|
||||
|
||||
// if device-tensor
|
||||
if (a.is_cuda()) {
|
||||
// NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
|
||||
// valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
|
||||
// we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
|
||||
// managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
|
||||
// for beta respectively.
|
||||
float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
|
||||
at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
|
||||
user_alpha.copy_(a);
|
||||
// Tell cublasLt we're using device-side pointers for alpha/beta
|
||||
auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
|
||||
alpha_ptr = user_alpha.data_ptr<float>();
|
||||
beta_ptr = at::cuda::detail::get_cublas_device_zero();
|
||||
} else {
|
||||
alpha_val = a.item<float>();
|
||||
}
|
||||
}
|
||||
// For other data types, use the get_scale_mode function based on scaling type
|
||||
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
|
||||
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
||||
@ -2048,6 +2079,7 @@ void scaled_gemm(
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
@ -2088,10 +2120,10 @@ void scaled_gemm(
|
||||
auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
&alpha_val,
|
||||
alpha_ptr,
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
&beta_val,
|
||||
beta_ptr,
|
||||
Cdesc.descriptor(),
|
||||
Ddesc.descriptor(),
|
||||
all_algos[i].algo,
|
||||
@ -2110,17 +2142,14 @@ void scaled_gemm(
|
||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
&alpha_val,
|
||||
alpha_ptr,
|
||||
mat1_ptr,
|
||||
Adesc.descriptor(),
|
||||
mat2_ptr,
|
||||
Bdesc.descriptor(),
|
||||
&beta_val,
|
||||
#ifdef USE_ROCM
|
||||
beta_ptr,
|
||||
// NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
|
||||
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
|
||||
#else
|
||||
nullptr,
|
||||
#endif // ifdef USE_ROCM
|
||||
Cdesc.descriptor(),
|
||||
result_ptr,
|
||||
Ddesc.descriptor(),
|
||||
|
@ -161,7 +161,8 @@ void scaled_gemm(
|
||||
const void* result_scale_ptr,
|
||||
int64_t result_ld,
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum);
|
||||
bool use_fast_accum,
|
||||
const std::optional<Tensor>& alpha);
|
||||
|
||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
|
||||
|
||||
|
54
aten/src/ATen/cuda/detail/BLASConstants.cu
Normal file
54
aten/src/ATen/cuda/detail/BLASConstants.cu
Normal file
@ -0,0 +1,54 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
namespace detail {
|
||||
|
||||
__device__ __constant__ float cublas_one_device;
|
||||
__device__ __constant__ float cublas_zero_device;
|
||||
|
||||
float *get_cublas_device_one() {
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
const float one = 1.f;
|
||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
|
||||
});
|
||||
|
||||
float *ptr;
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
float *get_cublas_device_zero() {
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
const float zero = 0.f;
|
||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
|
||||
});
|
||||
|
||||
float *ptr;
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
float *get_user_alpha_ptr() {
|
||||
static float *alpha_ptr;
|
||||
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
|
||||
});
|
||||
|
||||
return alpha_ptr;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace cuda
|
||||
} // namespace at
|
11
aten/src/ATen/cuda/detail/BLASConstants.h
Normal file
11
aten/src/ATen/cuda/detail/BLASConstants.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/TensorBase.h>
|
||||
|
||||
namespace at::cuda::detail {
|
||||
|
||||
float *get_cublas_device_one();
|
||||
float *get_cublas_device_zero();
|
||||
float *get_user_alpha_ptr();
|
||||
|
||||
} // namespace at::cuda::detail
|
@ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
params->c_scale_ptr,
|
||||
params->ldc,
|
||||
params->c_dtype,
|
||||
params->use_fast_accum);
|
||||
params->use_fast_accum,
|
||||
std::nullopt /* alpha */);
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
|
@ -1359,7 +1359,8 @@ _scaled_gemm(
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
@ -1410,7 +1411,8 @@ _scaled_gemm(
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
use_fast_accum,
|
||||
alpha);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user