[reland][ROCm] TunableOp for gemm_and_bias (#128919)

Reland of #128143 but added `alpha` and `bias` initialization to `launchTunableGemmAndBias`

Thus far TunableOp was implemented for gemm, bgemm, and scaled_mm. gemm_and_bias was notably missing. This PR closes that gap.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128919
Approved by: https://github.com/malfet
This commit is contained in:
Jeff Daily
2024-08-22 18:27:50 +00:00
committed by PyTorch MergeBot
parent 978c5a80a0
commit 0eb9c870fd
5 changed files with 316 additions and 32 deletions

View File

@ -155,6 +155,72 @@ private:
bool duplicate_inputs_;
};
template <typename T>
struct GemmAndBiasParams : OpParams {
std::string Signature() const override {
return c10::str(transa, transb, "_", m, "_", n, "_", k);
}
size_t GetSize(bool duplicate_inputs) const {
size_t size = sizeof(T) * ldc * n;
if (duplicate_inputs) {
size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
}
return size;
}
GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
GemmAndBiasParams* copy = new GemmAndBiasParams;
*copy = *this;
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = ldc * n * sizeof(T);
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
if (duplicate_inputs) {
size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
copy->duplicate_inputs_ = true;
}
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
}
}
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
T* c;
int64_t ldc;
const T* bias;
at::cuda::blas::GEMMAndBiasActivationEpilogue activation;
private:
bool duplicate_inputs_;
};
template <typename T>
struct GemmStridedBatchedParams : OpParams {
GemmStridedBatchedParams() {

View File

@ -25,35 +25,35 @@
namespace at::cuda::tunable {
template <typename T>
constexpr hipblasDatatype_t HipBlasDataTypeFor();
constexpr hipblasDatatype_t HipDataTypeFor();
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLAS_R_32F;
constexpr hipblasDatatype_t HipDataTypeFor<float>() {
return HIP_R_32F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<Half>() {
return HIPBLAS_R_16F;
constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
return HIP_R_16F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLAS_R_16B;
constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
return HIP_R_16BF;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLAS_R_64F;
constexpr hipblasDatatype_t HipDataTypeFor<double>() {
return HIP_R_64F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e4m3fnuz>() {
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
return HIP_R_8F_E4M3_FNUZ;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
return HIP_R_8F_E5M2_FNUZ;
}
@ -62,6 +62,11 @@ int GetBatchFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
return params->batch;
@ -77,6 +82,11 @@ int GetStrideAFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_a;
@ -92,6 +102,11 @@ int GetStrideBFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_b;
@ -107,6 +122,11 @@ int GetStrideCFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_c;
@ -122,6 +142,11 @@ float GetAlphaFromParams(const GemmParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
return params->alpha;
@ -137,6 +162,11 @@ float GetBetaFromParams(const GemmParams<T>* params) {
return params->beta;
}
template <typename T>
float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
return 0.0;
}
template <typename T>
float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
return params->beta;
@ -152,6 +182,11 @@ const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
@ -167,6 +202,11 @@ const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
@ -182,6 +222,11 @@ const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
@ -197,6 +242,11 @@ const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
return params->bias;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
@ -212,6 +262,11 @@ hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
return HIP_R_32F;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
return HipDataTypeFor<T>();
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return HIP_R_32F;
@ -222,6 +277,26 @@ hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
return params->activation;
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
static hipblasOperation_t _hipblasOpFromChar(char op) {
switch (op) {
case 'n':
@ -327,9 +402,9 @@ class HipblasltGemmOp : public Callable<ParamsT> {
TuningStatus Call(const ParamsT* params) override {
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
auto a_datatype = HipBlasDataTypeFor<AT>();
auto b_datatype = HipBlasDataTypeFor<BT>();
auto in_out_datatype = HipBlasDataTypeFor<CT>();
auto a_datatype = HipDataTypeFor<AT>();
auto b_datatype = HipDataTypeFor<BT>();
auto in_out_datatype = HipDataTypeFor<CT>();
auto opa = _hipblasOpFromChar(params->transa);
auto opb = _hipblasOpFromChar(params->transb);
@ -381,17 +456,28 @@ class HipblasltGemmOp : public Callable<ParamsT> {
const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
if (mat1_scale_ptr && mat2_scale_ptr && result_scale_ptr) {
if (mat1_scale_ptr && mat2_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
}
if (result_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
if (bias_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
if (bias_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
auto activation = GetActivationFromParams<CT>(params);
if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
}
else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
}
else {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
}
}
@ -460,9 +546,9 @@ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout,
auto GetHipBlasLtTypeStringAndOps() {
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
auto a_datatype = HipBlasDataTypeFor<AT>();
auto b_datatype = HipBlasDataTypeFor<BT>();
auto in_out_datatype = HipBlasDataTypeFor<CT>();
auto a_datatype = HipDataTypeFor<AT>();
auto b_datatype = HipDataTypeFor<BT>();
auto in_out_datatype = HipDataTypeFor<CT>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
hipblasLtHandle_t handle;
@ -505,6 +591,11 @@ auto GetHipBlasLtGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();

View File

@ -41,6 +41,28 @@ class DefaultGemmOp : public Callable<GemmParams<T>> {
}
};
static bool _transposeBoolFromChar(char op) {
return op == 't' || op == 'T';
}
template <typename T>
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
public:
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
at::cuda::blas::gemm_and_bias<T>(
_transposeBoolFromChar(params->transa),
_transposeBoolFromChar(params->transb),
params->m, params->n, params->k,
params->alpha,
params->a, params->lda,
params->b, params->ldb,
params->bias,
params->c, params->ldc,
params->activation);
return OK;
}
};
template <typename T>
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
public:
@ -201,6 +223,32 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
};
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
public:
GemmAndBiasTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
#ifdef USE_ROCM
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
}
std::string Signature() override {
return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
public:
@ -240,7 +288,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
ScaledGemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
#if defined(USE_ROCM)
#ifdef USE_ROCM
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}

View File

@ -180,12 +180,6 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
static bool getDisableAddmmCudaLt() {
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
#ifdef USE_ROCM
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
// note the current tunable op is not the hipblaslt path (gemm_and_bias)
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
return true;
}
// allow both CUDA and HIP env var names for ROCm builds
// also, current default for ROCm builds is disable by default
if (env_value == nullptr) {
@ -219,6 +213,46 @@ static bool isSupportedHipLtROCmArch(int index) {
}
#endif
template <typename scalar_t>
static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::GemmAndBiasParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.alpha = alpha.to<at::opmath_type<scalar_t>>();
params.a = args.mata->const_data_ptr<scalar_t>();
params.lda = args.lda;
params.b = args.matb->const_data_ptr<scalar_t>();
params.ldb = args.ldb;
params.c = args.result->data_ptr<scalar_t>();
params.ldc = args.result_ld;
params.bias = bias;
params.activation = activation;
if (transa_ && transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T> gemm{};
gemm(&params);
}
else if (transa_ && !transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N> gemm{};
gemm(&params);
}
else if (!transa_ && transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T> gemm{};
gemm(&params);
}
else if (!transa_ && !transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N> gemm{};
gemm(&params);
}
else {
TORCH_CHECK(false, "unreachable");
}
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
@ -346,6 +380,15 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
}
else {
at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
@ -364,7 +407,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
});
}});
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
@ -382,6 +425,15 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
}
else {
at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
@ -398,7 +450,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.result_ld,
activation_epilogue
);
});
}});
#endif
} else
{

View File

@ -6037,6 +6037,33 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
def test_addmm_relu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
@onlyCUDA
@skipCUDAIfNotRocm
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_types_and(
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_addmm_relu_tunableop_rocm(self, device, dtype):
torch.cuda.tunable.enable(True)
ordinal = torch.cuda.current_device()
filename = f"tunableop_results{ordinal}.csv"
torch.cuda.tunable.set_filename(filename)
iterations = torch.cuda.tunable.get_max_tuning_iterations()
torch.cuda.tunable.set_max_tuning_iterations(10)
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
# clean up, remove any file that was generated
try:
import os
os.remove(filename)
except FileNotFoundError:
pass
# reset back to prior settings
torch.cuda.tunable.set_max_tuning_iterations(iterations)
torch.cuda.tunable.enable(False)
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_types_and(