mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[reland][ROCm] TunableOp for gemm_and_bias (#128919)
Reland of #128143 but added `alpha` and `bias` initialization to `launchTunableGemmAndBias` Thus far TunableOp was implemented for gemm, bgemm, and scaled_mm. gemm_and_bias was notably missing. This PR closes that gap. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128919 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
978c5a80a0
commit
0eb9c870fd
@ -155,6 +155,72 @@ private:
|
||||
bool duplicate_inputs_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GemmAndBiasParams : OpParams {
|
||||
std::string Signature() const override {
|
||||
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
||||
}
|
||||
|
||||
size_t GetSize(bool duplicate_inputs) const {
|
||||
size_t size = sizeof(T) * ldc * n;
|
||||
if (duplicate_inputs) {
|
||||
size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
|
||||
GemmAndBiasParams* copy = new GemmAndBiasParams;
|
||||
*copy = *this;
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
size_t c_size = ldc * n * sizeof(T);
|
||||
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
||||
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
||||
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
||||
if (duplicate_inputs) {
|
||||
size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
||||
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
||||
copy->duplicate_inputs_ = true;
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
// only call on object returned by DeepCopy
|
||||
void Delete() {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
||||
if (duplicate_inputs_) {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
||||
}
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa;
|
||||
char transb;
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
at::opmath_type<T> alpha;
|
||||
const T* a;
|
||||
int64_t lda;
|
||||
const T* b;
|
||||
int64_t ldb;
|
||||
T* c;
|
||||
int64_t ldc;
|
||||
const T* bias;
|
||||
at::cuda::blas::GEMMAndBiasActivationEpilogue activation;
|
||||
private:
|
||||
bool duplicate_inputs_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GemmStridedBatchedParams : OpParams {
|
||||
GemmStridedBatchedParams() {
|
||||
|
@ -25,35 +25,35 @@
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor();
|
||||
constexpr hipblasDatatype_t HipDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
|
||||
return HIPBLAS_R_32F;
|
||||
constexpr hipblasDatatype_t HipDataTypeFor<float>() {
|
||||
return HIP_R_32F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<Half>() {
|
||||
return HIPBLAS_R_16F;
|
||||
constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
|
||||
return HIP_R_16F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
||||
return HIPBLAS_R_16B;
|
||||
constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
|
||||
return HIP_R_16BF;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
|
||||
return HIPBLAS_R_64F;
|
||||
constexpr hipblasDatatype_t HipDataTypeFor<double>() {
|
||||
return HIP_R_64F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e4m3fnuz>() {
|
||||
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
|
||||
return HIP_R_8F_E4M3_FNUZ;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
|
||||
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
|
||||
return HIP_R_8F_E5M2_FNUZ;
|
||||
}
|
||||
|
||||
@ -62,6 +62,11 @@ int GetBatchFromParams(const GemmParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->batch;
|
||||
@ -77,6 +82,11 @@ int GetStrideAFromParams(const GemmParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->stride_a;
|
||||
@ -92,6 +102,11 @@ int GetStrideBFromParams(const GemmParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->stride_b;
|
||||
@ -107,6 +122,11 @@ int GetStrideCFromParams(const GemmParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->stride_c;
|
||||
@ -122,6 +142,11 @@ float GetAlphaFromParams(const GemmParams<T>* params) {
|
||||
return params->alpha;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return params->alpha;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->alpha;
|
||||
@ -137,6 +162,11 @@ float GetBetaFromParams(const GemmParams<T>* params) {
|
||||
return params->beta;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->beta;
|
||||
@ -152,6 +182,11 @@ const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return nullptr;
|
||||
@ -167,6 +202,11 @@ const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return nullptr;
|
||||
@ -182,6 +222,11 @@ const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return nullptr;
|
||||
@ -197,6 +242,11 @@ const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return params->bias;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return nullptr;
|
||||
@ -212,6 +262,11 @@ hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
|
||||
return HIP_R_32F;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return HipDataTypeFor<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return HIP_R_32F;
|
||||
@ -222,6 +277,26 @@ hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
|
||||
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
|
||||
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return params->activation;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
|
||||
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
||||
}
|
||||
|
||||
static hipblasOperation_t _hipblasOpFromChar(char op) {
|
||||
switch (op) {
|
||||
case 'n':
|
||||
@ -327,9 +402,9 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
||||
TuningStatus Call(const ParamsT* params) override {
|
||||
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
||||
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
||||
auto a_datatype = HipBlasDataTypeFor<AT>();
|
||||
auto b_datatype = HipBlasDataTypeFor<BT>();
|
||||
auto in_out_datatype = HipBlasDataTypeFor<CT>();
|
||||
auto a_datatype = HipDataTypeFor<AT>();
|
||||
auto b_datatype = HipDataTypeFor<BT>();
|
||||
auto in_out_datatype = HipDataTypeFor<CT>();
|
||||
auto opa = _hipblasOpFromChar(params->transa);
|
||||
auto opb = _hipblasOpFromChar(params->transb);
|
||||
|
||||
@ -381,17 +456,28 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
||||
const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
|
||||
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
|
||||
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
|
||||
if (mat1_scale_ptr && mat2_scale_ptr && result_scale_ptr) {
|
||||
if (mat1_scale_ptr && mat2_scale_ptr) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||
}
|
||||
if (result_scale_ptr) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
}
|
||||
|
||||
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
|
||||
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
|
||||
if (bias_ptr) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
|
||||
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
|
||||
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
|
||||
if (bias_ptr) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
|
||||
auto activation = GetActivationFromParams<CT>(params);
|
||||
if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
|
||||
}
|
||||
else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
|
||||
}
|
||||
else {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
|
||||
}
|
||||
}
|
||||
|
||||
@ -460,9 +546,9 @@ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout,
|
||||
auto GetHipBlasLtTypeStringAndOps() {
|
||||
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
||||
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
||||
auto a_datatype = HipBlasDataTypeFor<AT>();
|
||||
auto b_datatype = HipBlasDataTypeFor<BT>();
|
||||
auto in_out_datatype = HipBlasDataTypeFor<CT>();
|
||||
auto a_datatype = HipDataTypeFor<AT>();
|
||||
auto b_datatype = HipDataTypeFor<BT>();
|
||||
auto in_out_datatype = HipDataTypeFor<CT>();
|
||||
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
||||
|
||||
hipblasLtHandle_t handle;
|
||||
@ -505,6 +591,11 @@ auto GetHipBlasLtGemmTypeStringAndOps() {
|
||||
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
|
||||
}
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
|
||||
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
|
||||
}
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
|
||||
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
|
||||
|
@ -41,6 +41,28 @@ class DefaultGemmOp : public Callable<GemmParams<T>> {
|
||||
}
|
||||
};
|
||||
|
||||
static bool _transposeBoolFromChar(char op) {
|
||||
return op == 't' || op == 'T';
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
|
||||
public:
|
||||
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
|
||||
at::cuda::blas::gemm_and_bias<T>(
|
||||
_transposeBoolFromChar(params->transa),
|
||||
_transposeBoolFromChar(params->transb),
|
||||
params->m, params->n, params->k,
|
||||
params->alpha,
|
||||
params->a, params->lda,
|
||||
params->b, params->ldb,
|
||||
params->bias,
|
||||
params->c, params->ldc,
|
||||
params->activation);
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
||||
public:
|
||||
@ -201,6 +223,32 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
|
||||
public:
|
||||
GemmAndBiasTunableOp() {
|
||||
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
if constexpr (
|
||||
!std::is_same_v<T, c10::complex<float>> &&
|
||||
!std::is_same_v<T, c10::complex<double>>) {
|
||||
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string Signature() override {
|
||||
return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
|
||||
public:
|
||||
@ -240,7 +288,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
ScaledGemmTunableOp() {
|
||||
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#ifdef USE_ROCM
|
||||
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
|
@ -180,12 +180,6 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
||||
#ifdef USE_ROCM
|
||||
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
|
||||
// note the current tunable op is not the hipblaslt path (gemm_and_bias)
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
return true;
|
||||
}
|
||||
// allow both CUDA and HIP env var names for ROCm builds
|
||||
// also, current default for ROCm builds is disable by default
|
||||
if (env_value == nullptr) {
|
||||
@ -219,6 +213,46 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::GemmAndBiasParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.alpha = alpha.to<at::opmath_type<scalar_t>>();
|
||||
params.a = args.mata->const_data_ptr<scalar_t>();
|
||||
params.lda = args.lda;
|
||||
params.b = args.matb->const_data_ptr<scalar_t>();
|
||||
params.ldb = args.ldb;
|
||||
params.c = args.result->data_ptr<scalar_t>();
|
||||
params.ldc = args.result_ld;
|
||||
params.bias = bias;
|
||||
params.activation = activation;
|
||||
if (transa_ && transb_) {
|
||||
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
|
||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||
// preflights a check to try to avoid actually needing to call
|
||||
@ -346,6 +380,15 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
}
|
||||
else {
|
||||
at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
@ -364,7 +407,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
});
|
||||
}});
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
|
||||
@ -382,6 +425,15 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
activation_epilogue);
|
||||
}
|
||||
else {
|
||||
at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
@ -398,7 +450,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
});
|
||||
}});
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
|
@ -6037,6 +6037,33 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
def test_addmm_relu(self, device, dtype):
|
||||
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
|
||||
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
|
||||
@dtypesIfCUDA(*floating_types_and(
|
||||
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
|
||||
@dtypes(*floating_types_and(torch.bfloat16))
|
||||
@tf32_on_and_off(0.05)
|
||||
@bf32_on_and_off(0.05)
|
||||
def test_addmm_relu_tunableop_rocm(self, device, dtype):
|
||||
torch.cuda.tunable.enable(True)
|
||||
ordinal = torch.cuda.current_device()
|
||||
filename = f"tunableop_results{ordinal}.csv"
|
||||
torch.cuda.tunable.set_filename(filename)
|
||||
iterations = torch.cuda.tunable.get_max_tuning_iterations()
|
||||
torch.cuda.tunable.set_max_tuning_iterations(10)
|
||||
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
|
||||
# clean up, remove any file that was generated
|
||||
try:
|
||||
import os
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# reset back to prior settings
|
||||
torch.cuda.tunable.set_max_tuning_iterations(iterations)
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
|
||||
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
|
||||
@dtypesIfCUDA(*floating_types_and(
|
||||
|
Reference in New Issue
Block a user