mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit abf28982a8cb43342e7669d859de9543fd804cc9. Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/ZainRizvi due to Sorry but this is failing internally. @Chillee can you please help change get remerged? See D68720562 ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2616726406))
This commit is contained in:
@ -394,6 +394,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
|
||||
rocm_fa_preferred_backend = b;
|
||||
}
|
||||
|
||||
|
||||
bool Context::allowFP16ReductionCuBLAS() const {
|
||||
return allow_fp16_reduction_cublas;
|
||||
}
|
||||
@ -410,14 +411,6 @@ void Context::setAllowBF16ReductionCuBLAS(bool b) {
|
||||
allow_bf16_reduction_cublas = b;
|
||||
}
|
||||
|
||||
bool Context::allowFP16AccumulationCuBLAS() const {
|
||||
return allow_fp16_accumulation_cublas;
|
||||
}
|
||||
|
||||
void Context::setAllowFP16AccumulationCuBLAS(bool b) {
|
||||
allow_fp16_accumulation_cublas = b;
|
||||
}
|
||||
|
||||
|
||||
bool Context::hasMKL() {
|
||||
#if AT_MKL_ENABLED()
|
||||
|
@ -337,8 +337,6 @@ class TORCH_API Context {
|
||||
void setAllowFP16ReductionCuBLAS(bool);
|
||||
bool allowBF16ReductionCuBLAS() const;
|
||||
void setAllowBF16ReductionCuBLAS(bool);
|
||||
bool allowFP16AccumulationCuBLAS() const;
|
||||
void setAllowFP16AccumulationCuBLAS(bool);
|
||||
at::QEngine qEngine() const;
|
||||
void setQEngine(at::QEngine e);
|
||||
static const std::vector<at::QEngine>& supportedQEngines();
|
||||
@ -420,7 +418,6 @@ class TORCH_API Context {
|
||||
bool allow_tf32_cudnn = true;
|
||||
bool allow_fp16_reduction_cublas = true;
|
||||
bool allow_bf16_reduction_cublas = true;
|
||||
bool allow_fp16_accumulation_cublas = false;
|
||||
bool enabled_mkldnn = true;
|
||||
bool enabled_nnpack = true;
|
||||
at::LinalgBackend linalg_preferred_backend =
|
||||
|
@ -332,12 +332,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
if constexpr (std::is_same_v<Dtype, double>) {
|
||||
abcType = CUDA_R_64F;
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
@ -354,16 +348,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
abcType = CUDA_C_32F;
|
||||
scaleType = CUDA_C_32F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||
#ifndef USE_ROCM
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
halpha = alpha;
|
||||
hbeta = beta;
|
||||
alpha_ptr = &halpha;
|
||||
beta_ptr = &hbeta;
|
||||
}
|
||||
#endif
|
||||
abcType = CUDA_R_16F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abcType = CUDA_R_16BF;
|
||||
@ -431,12 +415,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
alpha_ptr,
|
||||
&alpha,
|
||||
a,
|
||||
Adesc.descriptor(),
|
||||
b,
|
||||
Bdesc.descriptor(),
|
||||
beta_ptr,
|
||||
&beta,
|
||||
c,
|
||||
Cdesc.descriptor(),
|
||||
c,
|
||||
@ -546,13 +530,6 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
BGEMM_CHECK_ARGVALUES(at::Half);
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
#endif
|
||||
void * alpha_ptr = &falpha;
|
||||
void * beta_ptr = &fbeta;
|
||||
auto compute_type = CUDA_R_32F;
|
||||
#ifdef USE_ROCM
|
||||
int flag = 0;
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
@ -561,28 +538,21 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
|
||||
hipOperationToRocOperation(opa),
|
||||
hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
|
||||
(void*)alpha_ptr, a, rocblas_datatype_f16_r, (int)lda, stridea,
|
||||
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
|
||||
b, rocblas_datatype_f16_r, (int)ldb, strideb,
|
||||
(void*)beta_ptr, c, rocblas_datatype_f16_r, (int)ldc, stridec,
|
||||
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
|
||||
c, rocblas_datatype_f16_r, (int)ldc, stridec,
|
||||
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
|
||||
0, flag)));
|
||||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
halpha = alpha;
|
||||
hbeta = beta;
|
||||
compute_type = CUDA_R_16F;
|
||||
alpha_ptr = &halpha;
|
||||
beta_ptr = &hbeta;
|
||||
}
|
||||
if (prop->major >= 5){
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
|
||||
handle, opa, opb, m, n, k,
|
||||
alpha_ptr, a, CUDA_R_16F, lda, stridea,
|
||||
b, CUDA_R_16F, ldb, strideb, beta_ptr,
|
||||
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
|
||||
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
|
||||
c, CUDA_R_16F, ldc, stridec,
|
||||
num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
for (const auto i : c10::irange(num_batches)) {
|
||||
at::cuda::blas::gemm<at::Half>(
|
||||
@ -897,15 +867,8 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
#endif
|
||||
void * alpha_ptr = &falpha;
|
||||
void * beta_ptr = &fbeta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(at::Half);
|
||||
auto compute_type = CUDA_R_32F;
|
||||
#ifdef USE_ROCM
|
||||
int flag = 0;
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
@ -918,14 +881,14 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha_ptr,
|
||||
&falpha,
|
||||
a,
|
||||
rocblas_datatype_f16_r,
|
||||
lda,
|
||||
b,
|
||||
rocblas_datatype_f16_r,
|
||||
ldb,
|
||||
beta_ptr,
|
||||
&fbeta,
|
||||
c,
|
||||
rocblas_datatype_f16_r,
|
||||
ldc,
|
||||
@ -938,13 +901,6 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
flag)));
|
||||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
compute_type = CUDA_R_16F;
|
||||
halpha = alpha;
|
||||
hbeta = beta;
|
||||
alpha_ptr = &halpha;
|
||||
beta_ptr = &hbeta;
|
||||
}
|
||||
if (prop->major >= 5) {
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
@ -959,18 +915,18 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha_ptr,
|
||||
&falpha,
|
||||
a,
|
||||
CUDA_R_16F,
|
||||
lda,
|
||||
b,
|
||||
CUDA_R_16F,
|
||||
ldb,
|
||||
beta_ptr,
|
||||
&fbeta,
|
||||
c,
|
||||
CUDA_R_16F,
|
||||
ldc,
|
||||
compute_type,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
} else {
|
||||
@ -1272,12 +1228,6 @@ void gemm_and_bias(
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
void * alpha_ptr = &alpha_val;
|
||||
void * beta_ptr = &beta_val;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha_val;
|
||||
at::Half hbeta_val;
|
||||
#endif
|
||||
if constexpr (std::is_same_v<Dtype, double>) {
|
||||
abcType = CUDA_R_64F;
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
@ -1288,17 +1238,6 @@ void gemm_and_bias(
|
||||
}
|
||||
abcType = CUDA_R_32F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||
#ifndef USE_ROCM
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
scaleType = CUDA_R_16F;
|
||||
halpha_val = alpha_val;
|
||||
hbeta_val = beta_val;
|
||||
alpha_ptr = &halpha_val;
|
||||
beta_ptr = &hbeta_val;
|
||||
}
|
||||
#endif
|
||||
abcType = CUDA_R_16F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abcType = CUDA_R_16BF;
|
||||
@ -1367,12 +1306,12 @@ void gemm_and_bias(
|
||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
alpha_ptr,
|
||||
&alpha_val,
|
||||
mat1_ptr,
|
||||
Adesc.descriptor(),
|
||||
mat2_ptr,
|
||||
Bdesc.descriptor(),
|
||||
beta_ptr,
|
||||
&beta_val,
|
||||
result_ptr,
|
||||
Cdesc.descriptor(),
|
||||
result_ptr,
|
||||
|
@ -148,9 +148,6 @@ For more information about TF32, see:
|
||||
Reduced Precision Reduction in FP16 GEMMs
|
||||
-----------------------------------------
|
||||
|
||||
(Distinct from full FP16 accumulation that is intended for hardware that has higher throughput
|
||||
with FP16 accumulation than FP32 accumulation, see :ref:`Full FP16 accumulation<fp16accumulation>`)
|
||||
|
||||
fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.
|
||||
|
||||
Some example benchmark data on V100:
|
||||
@ -209,28 +206,6 @@ To toggle the reduced precision reduction flags in C++, one can do
|
||||
|
||||
at::globalContext().setAllowBF16ReductionCuBLAS(true);
|
||||
|
||||
.. _fp16accumulation:
|
||||
|
||||
Full FP16 Accmumulation in FP16 GEMMs
|
||||
-------------------------------------
|
||||
|
||||
Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation
|
||||
in FP16, at the cost of numerical precision and greater likelihood of overflow.
|
||||
Note that this setting only has an effect on GPUs of compute capability 7.0 (Volta)
|
||||
or newer.
|
||||
|
||||
This behavior can be enabled via:
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
|
||||
To toggle the reduced precision reduction flags in C++, one can do
|
||||
|
||||
.. code:: C++
|
||||
|
||||
at::globalContext().setAllowFP16AccumulationCuBLAS(true);
|
||||
|
||||
Asynchronous execution
|
||||
----------------------
|
||||
|
||||
|
@ -576,13 +576,6 @@ class TestCuda(TestCase):
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
|
||||
|
||||
def test_cublas_allow_fp16_accumulation_get_set(self):
|
||||
orig = torch.backends.cuda.matmul.allow_fp16_accumulation
|
||||
self.assertEqual(torch._C._get_cublas_allow_fp16_accumulation(), orig)
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = not orig
|
||||
self.assertEqual(torch._C._get_cublas_allow_fp16_accumulation(), not orig)
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = orig
|
||||
|
||||
def test_cudnn_allow_tf32_get_set(self):
|
||||
with torch.backends.cudnn.flags(
|
||||
enabled=None, benchmark=None, deterministic=None, allow_tf32=False
|
||||
|
@ -34,10 +34,10 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
skipIfRocmVersionLessThan,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_ROCM,
|
||||
skipIfRocm,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
@ -61,7 +61,7 @@ class TestMatmulCuda(TestCase):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
super(self.__class__, self).tearDown()
|
||||
|
||||
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
|
||||
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False):
|
||||
#
|
||||
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
|
||||
# results from the CUDA invocation of torch.addmm and the CPU invocation
|
||||
@ -73,10 +73,8 @@ class TestMatmulCuda(TestCase):
|
||||
# which fail the threshold check
|
||||
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate
|
||||
# Make random tensors on CPU (seed set on common_utils.py import)
|
||||
# (Not using numpy because it does not support bfloat16)
|
||||
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
|
||||
@ -84,10 +82,6 @@ class TestMatmulCuda(TestCase):
|
||||
m_input = make_arg((n, p))
|
||||
m_1 = make_arg((n, m))
|
||||
m_2 = make_arg((m, p))
|
||||
# scale to abate overflows in fp16 accum
|
||||
if fp16_accumulate:
|
||||
m_1 = m_1 / 100
|
||||
m_2 = m_2 / 100
|
||||
# *(B)FLOAT16 Special Handling*
|
||||
# Backend does not tensorize float16 on CPU,
|
||||
# and bloat16 may present accuracy issues,
|
||||
@ -121,7 +115,6 @@ class TestMatmulCuda(TestCase):
|
||||
self.assertEqual(res_cpu, res_cuda)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
@ -144,36 +137,6 @@ class TestMatmulCuda(TestCase):
|
||||
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
|
||||
self.cublas_addmm(size, dtype, True)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
||||
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
@parametrize("size", [100, 1000, 10000])
|
||||
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
|
||||
self.cublas_addmm(size, dtype, False, True)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
|
||||
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half)
|
||||
w = torch.rand(512, 512, device='cuda', dtype=torch.half)
|
||||
b = torch.rand(512, device='cuda', dtype=torch.half)
|
||||
out = torch.nn.functional.linear(x, w, b)
|
||||
out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu())
|
||||
self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3)
|
||||
|
||||
a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
||||
b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
||||
c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
||||
out = torch.baddbmm(a, b, c)
|
||||
out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu())
|
||||
self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3)
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
||||
|
||||
@onlyCUDA
|
||||
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
|
||||
@dtypes(torch.float16)
|
||||
|
@ -1210,10 +1210,6 @@ def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... # THPMod
|
||||
def _set_cublas_allow_bf16_reduced_precision_reduction(
|
||||
arg: _bool,
|
||||
) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS
|
||||
def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16AccumulationCuBLAS
|
||||
def _set_cublas_allow_fp16_accumulation(
|
||||
arg: _bool,
|
||||
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
|
||||
def _set_conj(x: Tensor, conj: _bool) -> None: ...
|
||||
def _set_neg(x: Tensor, neg: _bool) -> None: ...
|
||||
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
|
||||
|
@ -133,8 +133,6 @@ class cuBLASModule:
|
||||
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||
elif name == "allow_bf16_reduced_precision_reduction":
|
||||
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._get_cublas_allow_fp16_accumulation()
|
||||
raise AttributeError("Unknown attribute " + name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
@ -144,8 +142,6 @@ class cuBLASModule:
|
||||
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
|
||||
elif name == "allow_bf16_reduced_precision_reduction":
|
||||
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._set_cublas_allow_fp16_accumulation(value)
|
||||
raise AttributeError("Unknown attribute " + name)
|
||||
|
||||
|
||||
|
@ -1133,29 +1133,6 @@ static PyObject* THPModule_allowBF16ReductionCuBLAS(
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setAllowFP16AccumulationCuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(arg),
|
||||
"set_allow_fp16_accumulation_cublas expects a bool, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setAllowFP16AccumulationCuBLAS(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_allowFP16AccumulationCuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
if (at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setAllowFP16ReductionCPU(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
@ -1597,14 +1574,6 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
THPModule_setAllowBF16ReductionCuBLAS,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_cublas_allow_fp16_accumulation",
|
||||
THPModule_allowFP16AccumulationCuBLAS,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_cublas_allow_fp16_accumulation",
|
||||
THPModule_setAllowFP16AccumulationCuBLAS,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_cpu_allow_fp16_reduced_precision_reduction",
|
||||
THPModule_allowFP16ReductionCPU,
|
||||
METH_NOARGS,
|
||||
|
Reference in New Issue
Block a user