[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)

Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee
This commit is contained in:
eqy
2025-01-11 15:30:38 +00:00
committed by PyTorch MergeBot
parent 2e3b051154
commit 388b75edec
9 changed files with 188 additions and 13 deletions

View File

@ -1133,6 +1133,29 @@ static PyObject* THPModule_allowBF16ReductionCuBLAS(
Py_RETURN_FALSE;
}
static PyObject* THPModule_setAllowFP16AccumulationCuBLAS(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
PyBool_Check(arg),
"set_allow_fp16_accumulation_cublas expects a bool, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setAllowFP16AccumulationCuBLAS(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_allowFP16AccumulationCuBLAS(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().allowFP16AccumulationCuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
static PyObject* THPModule_setAllowFP16ReductionCPU(
PyObject* _unused,
PyObject* arg) {
@ -1574,6 +1597,14 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
THPModule_setAllowBF16ReductionCuBLAS,
METH_O,
nullptr},
{"_get_cublas_allow_fp16_accumulation",
THPModule_allowFP16AccumulationCuBLAS,
METH_NOARGS,
nullptr},
{"_set_cublas_allow_fp16_accumulation",
THPModule_setAllowFP16AccumulationCuBLAS,
METH_O,
nullptr},
{"_get_cpu_allow_fp16_reduced_precision_reduction",
THPModule_allowFP16ReductionCPU,
METH_NOARGS,