Compare commits

...

2 Commits

Author SHA1 Message Date
ab40c4fadc cleanup 2025-06-03 18:06:41 +00:00
9d57568b10 check in 2025-06-03 18:06:38 +00:00
5 changed files with 63 additions and 29 deletions

View File

@ -1545,6 +1545,7 @@ bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
at::opmath_type<Dtype> beta_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
@ -1552,7 +1553,8 @@ bool gemm_and_bias(
const Dtype* bias,
C_Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation) {
GEMMAndBiasActivationEpilogue activation,
bool bias2d) {
if (std::is_same_v<C_Dtype, float> && std::is_same_v<Dtype, at::BFloat16>) {
#ifdef USE_ROCM
@ -1567,7 +1569,6 @@ bool gemm_and_bias(
}
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = 0; // bias is added in epilogue
cudaDataType_t abType = CUDA_R_32F;
cudaDataType_t cType = CUDA_R_32F;
@ -1633,23 +1634,23 @@ bool gemm_and_bias(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
cublasLtEpilogue_t epilogue;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
epilogue = CUBLASLT_EPILOGUE_RELU;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040 || defined(USE_ROCM)
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
epilogue = CUBLASLT_EPILOGUE_GELU;
#endif
}
if (bias != nullptr) {
if (activation != GEMMAndBiasActivationEpilogue::None) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
}
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
CuBlasLtMatrixLayout Cdesc(cType, m, n, bias2d ? result_ld : 0);
CuBlasLtMatrixLayout Ddesc(cType, m, n, result_ld);
auto ltworkspace = CublasLtWorkspace();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
@ -1657,8 +1658,8 @@ bool gemm_and_bias(
#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
@ -1674,7 +1675,7 @@ bool gemm_and_bias(
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
Ddesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
@ -1693,10 +1694,10 @@ bool gemm_and_bias(
mat2_ptr,
Bdesc.descriptor(),
beta_ptr,
result_ptr,
bias,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
Ddesc.descriptor(),
&heuristicResult.algo,
ltworkspace.ptr,
ltworkspace.size,
@ -1743,6 +1744,7 @@ template bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<double> alpha_val,
at::opmath_type<double> beta_val,
const double* mat1_ptr,
int64_t mat1_ld,
const double* mat2_ptr,
@ -1750,7 +1752,8 @@ template bool gemm_and_bias(
const double* bias,
double* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
GEMMAndBiasActivationEpilogue activation,
bool bias2d);
template bool gemm_and_bias(
bool transpose_mat1,
@ -1759,6 +1762,7 @@ template bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<float> alpha_val,
at::opmath_type<float> beta_val,
const float* mat1_ptr,
int64_t mat1_ld,
const float* mat2_ptr,
@ -1766,7 +1770,8 @@ template bool gemm_and_bias(
const float* bias,
float* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
GEMMAndBiasActivationEpilogue activation,
bool bias2d);
template bool gemm_and_bias(
bool transpose_mat1,
@ -1775,6 +1780,7 @@ template bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<at::Half> alpha_val,
at::opmath_type<at::Half> beta_val,
const at::Half* mat1_ptr,
int64_t mat1_ld,
const at::Half* mat2_ptr,
@ -1782,7 +1788,8 @@ template bool gemm_and_bias(
const at::Half* bias,
at::Half* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
GEMMAndBiasActivationEpilogue activation,
bool bias2d);
template bool gemm_and_bias(
bool transpose_mat1,
@ -1791,6 +1798,7 @@ template bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<at::Half> alpha_val,
at::opmath_type<at::Half> beta_val,
const at::Half* mat1_ptr,
int64_t mat1_ld,
const at::Half* mat2_ptr,
@ -1798,7 +1806,8 @@ template bool gemm_and_bias(
const at::Half* bias,
float* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
GEMMAndBiasActivationEpilogue activation,
bool bias2d);
template bool gemm_and_bias(
bool transpose_mat1,
@ -1807,6 +1816,7 @@ template bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<at::BFloat16> alpha_val,
at::opmath_type<at::BFloat16> beta_val,
const at::BFloat16* mat1_ptr,
int64_t mat1_ld,
const at::BFloat16* mat2_ptr,
@ -1814,7 +1824,8 @@ template bool gemm_and_bias(
const at::BFloat16* bias,
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
GEMMAndBiasActivationEpilogue activation,
bool bias2d);
template bool gemm_and_bias(
bool transpose_mat1,
@ -1823,6 +1834,7 @@ template bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<at::BFloat16> alpha_val,
at::opmath_type<at::BFloat16> beta_val,
const at::BFloat16* mat1_ptr,
int64_t mat1_ld,
const at::BFloat16* mat2_ptr,
@ -1830,7 +1842,8 @@ template bool gemm_and_bias(
const at::BFloat16* bias,
float* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
GEMMAndBiasActivationEpilogue activation,
bool bias2d);
void scaled_gemm(
char transa,

View File

@ -114,6 +114,7 @@ bool gemm_and_bias(
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
at::opmath_type<Dtype> beta_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
@ -121,7 +122,8 @@ bool gemm_and_bias(
const Dtype* bias,
C_Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None,
bool bias2d = false);
void int8_gemm(
bool transpose_mat1,

View File

@ -457,6 +457,9 @@ struct GemmAndBiasParams : OpParams {
int64_t n{};
int64_t k{};
at::opmath_type<T> alpha{};
#if !defined(USE_ROCM)
at::opmath_type<T> beta{};
#endif
const T* a{};
int64_t lda{};
const T* b{};
@ -465,6 +468,7 @@ struct GemmAndBiasParams : OpParams {
int64_t ldc{};
const T* bias{};
at::cuda::blas::GEMMAndBiasActivationEpilogue activation{};
bool bias2d{};
private:
bool duplicate_inputs_{false};
};

View File

@ -55,11 +55,13 @@ class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
_transposeBoolFromChar(params->transb),
params->m, params->n, params->k,
params->alpha,
params->beta,
params->a, params->lda,
params->b, params->ldb,
params->bias,
params->c, params->ldc,
params->activation);
params->activation,
params->bias2d);
return OK;
}
};

View File

@ -280,7 +280,7 @@ static bool isSupportedHipLtROCmArch(int index) {
#endif
template <typename scalar_t>
static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const Scalar& beta, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation, bool bias2d) {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::GemmAndBiasParams<scalar_t> params;
@ -298,6 +298,7 @@ static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha
params.ldc = args.result_ld;
params.bias = bias;
params.activation = activation;
params.bias2d = bias2d;
if (transa_ && transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T> gemm{};
gemm(&params);
@ -369,9 +370,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
useLtInterface = ((self.dim() == 1 && // row broadcast case
result.dim() == 2 && self.size(0) == mat2_sizes[1]) ||
(self.dim() == 2 && self.size(0) == mat1_sizes[0] && // 2d case
self.size(1) == mat2_sizes[1])) &&
self.is_contiguous() && result.is_contiguous() &&
self.scalar_type() == result.scalar_type() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
@ -426,6 +430,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
}
cublasCommonArgs args(mat1, mat2, result);
bool bias2d = self.dim() == 2;
if (mat1.numel() == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
@ -475,6 +480,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
beta.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
@ -484,7 +490,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
activation_to_gemm_and_blas_arg(activation),
bias2d
);
}
});
@ -513,6 +520,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
beta.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
@ -520,7 +528,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
activation_epilogue,
bias2d
);
}});
} else {
@ -535,8 +544,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
beta,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
activation_epilogue,
bias2d);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
@ -546,6 +557,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
beta.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
@ -553,7 +565,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
activation_epilogue,
bias2d
);
}});
}