Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"

This reverts commit abf28982a8cb43342e7669d859de9543fd804cc9.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/ZainRizvi due to Sorry but this is failing internally. @Chillee can you please help change get remerged? See  D68720562 ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2616726406))
This commit is contained in:
PyTorch MergeBot
2025-01-27 19:38:24 +00:00
parent 9728e900dc
commit c986eba560
9 changed files with 17 additions and 196 deletions

View File

@ -394,6 +394,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
rocm_fa_preferred_backend = b;
}
bool Context::allowFP16ReductionCuBLAS() const {
return allow_fp16_reduction_cublas;
}
@ -410,14 +411,6 @@ void Context::setAllowBF16ReductionCuBLAS(bool b) {
allow_bf16_reduction_cublas = b;
}
bool Context::allowFP16AccumulationCuBLAS() const {
return allow_fp16_accumulation_cublas;
}
void Context::setAllowFP16AccumulationCuBLAS(bool b) {
allow_fp16_accumulation_cublas = b;
}
bool Context::hasMKL() {
#if AT_MKL_ENABLED()

View File

@ -337,8 +337,6 @@ class TORCH_API Context {
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
void setAllowBF16ReductionCuBLAS(bool);
bool allowFP16AccumulationCuBLAS() const;
void setAllowFP16AccumulationCuBLAS(bool);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
@ -420,7 +418,6 @@ class TORCH_API Context {
bool allow_tf32_cudnn = true;
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
bool allow_fp16_accumulation_cublas = false;
bool enabled_mkldnn = true;
bool enabled_nnpack = true;
at::LinalgBackend linalg_preferred_backend =

View File

@ -332,12 +332,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
#endif
void * alpha_ptr = &alpha;
void * beta_ptr = &beta;
if constexpr (std::is_same_v<Dtype, double>) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
@ -354,16 +348,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
abcType = CUDA_C_32F;
scaleType = CUDA_C_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
computeType = CUBLAS_COMPUTE_16F;
halpha = alpha;
hbeta = beta;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
#endif
abcType = CUDA_R_16F;
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abcType = CUDA_R_16BF;
@ -431,12 +415,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
alpha_ptr,
&alpha,
a,
Adesc.descriptor(),
b,
Bdesc.descriptor(),
beta_ptr,
&beta,
c,
Cdesc.descriptor(),
c,
@ -546,13 +530,6 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
#endif
void * alpha_ptr = &falpha;
void * beta_ptr = &fbeta;
auto compute_type = CUDA_R_32F;
#ifdef USE_ROCM
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
@ -561,28 +538,21 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
hipOperationToRocOperation(opa),
hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
(void*)alpha_ptr, a, rocblas_datatype_f16_r, (int)lda, stridea,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)beta_ptr, c, rocblas_datatype_f16_r, (int)ldc, stridec,
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
halpha = alpha;
hbeta = beta;
compute_type = CUDA_R_16F;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
if (prop->major >= 5){
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
handle, opa, opb, m, n, k,
alpha_ptr, a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, beta_ptr,
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
c, CUDA_R_16F, ldc, stridec,
num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
for (const auto i : c10::irange(num_batches)) {
at::cuda::blas::gemm<at::Half>(
@ -897,15 +867,8 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
#endif
void * alpha_ptr = &falpha;
void * beta_ptr = &fbeta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
auto compute_type = CUDA_R_32F;
#ifdef USE_ROCM
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
@ -918,14 +881,14 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
m,
n,
k,
alpha_ptr,
&falpha,
a,
rocblas_datatype_f16_r,
lda,
b,
rocblas_datatype_f16_r,
ldb,
beta_ptr,
&fbeta,
c,
rocblas_datatype_f16_r,
ldc,
@ -938,13 +901,6 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
compute_type = CUDA_R_16F;
halpha = alpha;
hbeta = beta;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
if (prop->major >= 5) {
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
@ -959,18 +915,18 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
m,
n,
k,
alpha_ptr,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
beta_ptr,
&fbeta,
c,
CUDA_R_16F,
ldc,
compute_type,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
@ -1272,12 +1228,6 @@ void gemm_and_bias(
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
void * alpha_ptr = &alpha_val;
void * beta_ptr = &beta_val;
#ifndef USE_ROCM
at::Half halpha_val;
at::Half hbeta_val;
#endif
if constexpr (std::is_same_v<Dtype, double>) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
@ -1288,17 +1238,6 @@ void gemm_and_bias(
}
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
computeType = CUBLAS_COMPUTE_16F;
scaleType = CUDA_R_16F;
halpha_val = alpha_val;
hbeta_val = beta_val;
alpha_ptr = &halpha_val;
beta_ptr = &hbeta_val;
}
#endif
abcType = CUDA_R_16F;
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abcType = CUDA_R_16BF;
@ -1367,12 +1306,12 @@ void gemm_and_bias(
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
alpha_ptr,
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
beta_ptr,
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,

View File

@ -148,9 +148,6 @@ For more information about TF32, see:
Reduced Precision Reduction in FP16 GEMMs
-----------------------------------------
(Distinct from full FP16 accumulation that is intended for hardware that has higher throughput
with FP16 accumulation than FP32 accumulation, see :ref:`Full FP16 accumulation<fp16accumulation>`)
fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.
Some example benchmark data on V100:
@ -209,28 +206,6 @@ To toggle the reduced precision reduction flags in C++, one can do
at::globalContext().setAllowBF16ReductionCuBLAS(true);
.. _fp16accumulation:
Full FP16 Accmumulation in FP16 GEMMs
-------------------------------------
Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation
in FP16, at the cost of numerical precision and greater likelihood of overflow.
Note that this setting only has an effect on GPUs of compute capability 7.0 (Volta)
or newer.
This behavior can be enabled via:
.. code:: python
torch.backends.cuda.matmul.allow_fp16_accumulation = True
To toggle the reduced precision reduction flags in C++, one can do
.. code:: C++
at::globalContext().setAllowFP16AccumulationCuBLAS(true);
Asynchronous execution
----------------------

View File

@ -576,13 +576,6 @@ class TestCuda(TestCase):
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
def test_cublas_allow_fp16_accumulation_get_set(self):
orig = torch.backends.cuda.matmul.allow_fp16_accumulation
self.assertEqual(torch._C._get_cublas_allow_fp16_accumulation(), orig)
torch.backends.cuda.matmul.allow_fp16_accumulation = not orig
self.assertEqual(torch._C._get_cublas_allow_fp16_accumulation(), not orig)
torch.backends.cuda.matmul.allow_fp16_accumulation = orig
def test_cudnn_allow_tf32_get_set(self):
with torch.backends.cudnn.flags(
enabled=None, benchmark=None, deterministic=None, allow_tf32=False

View File

@ -34,10 +34,10 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
skipIfRocm,
skipIfRocmVersionLessThan,
TEST_CUDA,
TEST_WITH_ROCM,
skipIfRocm,
TestCase,
)
@ -61,7 +61,7 @@ class TestMatmulCuda(TestCase):
torch.backends.cuda.matmul.allow_tf32 = True
super(self.__class__, self).tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False):
#
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
# results from the CUDA invocation of torch.addmm and the CPU invocation
@ -73,10 +73,8 @@ class TestMatmulCuda(TestCase):
# which fail the threshold check
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate
# Make random tensors on CPU (seed set on common_utils.py import)
# (Not using numpy because it does not support bfloat16)
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
@ -84,10 +82,6 @@ class TestMatmulCuda(TestCase):
m_input = make_arg((n, p))
m_1 = make_arg((n, m))
m_2 = make_arg((m, p))
# scale to abate overflows in fp16 accum
if fp16_accumulate:
m_1 = m_1 / 100
m_2 = m_2 / 100
# *(B)FLOAT16 Special Handling*
# Backend does not tensorize float16 on CPU,
# and bloat16 may present accuracy issues,
@ -121,7 +115,6 @@ class TestMatmulCuda(TestCase):
self.assertEqual(res_cpu, res_cuda)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
@ -144,36 +137,6 @@ class TestMatmulCuda(TestCase):
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False, True)
@onlyCUDA
@skipIfRocm
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True
x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half)
w = torch.rand(512, 512, device='cuda', dtype=torch.half)
b = torch.rand(512, device='cuda', dtype=torch.half)
out = torch.nn.functional.linear(x, w, b)
out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu())
self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3)
a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
out = torch.baddbmm(a, b, c)
out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu())
self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3)
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
@onlyCUDA
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
@dtypes(torch.float16)

View File

@ -1210,10 +1210,6 @@ def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... # THPMod
def _set_cublas_allow_bf16_reduced_precision_reduction(
arg: _bool,
) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS
def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16AccumulationCuBLAS
def _set_cublas_allow_fp16_accumulation(
arg: _bool,
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
def _set_conj(x: Tensor, conj: _bool) -> None: ...
def _set_neg(x: Tensor, neg: _bool) -> None: ...
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...

View File

@ -133,8 +133,6 @@ class cuBLASModule:
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
elif name == "allow_bf16_reduced_precision_reduction":
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
elif name == "allow_fp16_accumulation":
return torch._C._get_cublas_allow_fp16_accumulation()
raise AttributeError("Unknown attribute " + name)
def __setattr__(self, name, value):
@ -144,8 +142,6 @@ class cuBLASModule:
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
elif name == "allow_bf16_reduced_precision_reduction":
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
elif name == "allow_fp16_accumulation":
return torch._C._set_cublas_allow_fp16_accumulation(value)
raise AttributeError("Unknown attribute " + name)

View File

@ -1133,29 +1133,6 @@ static PyObject* THPModule_allowBF16ReductionCuBLAS(
Py_RETURN_FALSE;
}
static PyObject* THPModule_setAllowFP16AccumulationCuBLAS(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
PyBool_Check(arg),
"set_allow_fp16_accumulation_cublas expects a bool, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setAllowFP16AccumulationCuBLAS(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_allowFP16AccumulationCuBLAS(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().allowFP16AccumulationCuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
static PyObject* THPModule_setAllowFP16ReductionCPU(
PyObject* _unused,
PyObject* arg) {
@ -1597,14 +1574,6 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
THPModule_setAllowBF16ReductionCuBLAS,
METH_O,
nullptr},
{"_get_cublas_allow_fp16_accumulation",
THPModule_allowFP16AccumulationCuBLAS,
METH_NOARGS,
nullptr},
{"_set_cublas_allow_fp16_accumulation",
THPModule_setAllowFP16AccumulationCuBLAS,
METH_O,
nullptr},
{"_get_cpu_allow_fp16_reduced_precision_reduction",
THPModule_allowFP16ReductionCPU,
METH_NOARGS,