Add split-K control to cuBLAS reduced-precision settings (#164766)

## Summary
- add a CuBLASReductionOption enum so the CUDA context can track reduced-precision and split-K options
- extend the Python bindings, backend helpers, and docs to accept an optional allow_splitk argument for fp16/bf16 matmul controls
- update cuBLAS/cuBLASLt call sites plus dynamo guards and tests to respect the new combinations

## Testing
- python test/test_cuda.py TestCuda.test_cublas_allow_fp16_reduced_precision_reduction_get_set -v *(fails: ModuleNotFoundError: No module named 'psutil')*

------
https://chatgpt.com/codex/tasks/task_e_68e404623178832f8a3e1d34e1e175da

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164766
Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
Natalia Gimelshein
2025-10-08 18:48:42 +00:00
committed by PyTorch MergeBot
parent 0b85236477
commit 37c6087334
10 changed files with 300 additions and 72 deletions

View File

@ -587,20 +587,33 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
rocm_fa_preferred_backend = b;
}
bool Context::allowFP16ReductionCuBLAS() const {
CuBLASReductionOption Context::allowFP16ReductionCuBLAS() const {
return allow_fp16_reduction_cublas;
}
void Context::setAllowFP16ReductionCuBLAS(bool b) {
allow_fp16_reduction_cublas = b;
CuBLASReductionOption inline get_reduction_option(bool allow_reduced_precision, bool allow_splitk) {
TORCH_CHECK(
!(allow_reduced_precision && !allow_splitk),
"allow_splitk=False is not supported when reduced precision reductions are enabled");
if (allow_reduced_precision) {
return CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
} else if (allow_splitk) {
return CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK;
} else {
return CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK;
}
}
bool Context::allowBF16ReductionCuBLAS() const {
void Context::setAllowFP16ReductionCuBLAS(bool allow_reduced_precision, bool allow_splitk) {
allow_fp16_reduction_cublas = get_reduction_option(allow_reduced_precision, allow_splitk);
}
CuBLASReductionOption Context::allowBF16ReductionCuBLAS() const {
return allow_bf16_reduction_cublas;
}
void Context::setAllowBF16ReductionCuBLAS(bool b) {
allow_bf16_reduction_cublas = b;
void Context::setAllowBF16ReductionCuBLAS(bool allow_reduced_precision, bool allow_splitk) {
allow_bf16_reduction_cublas = get_reduction_option(allow_reduced_precision, allow_splitk);
}
bool Context::allowFP16AccumulationCuBLAS() const {

View File

@ -38,6 +38,12 @@ namespace at {
class Tensor;
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
enum class CuBLASReductionOption : uint8_t {
AllowReducedPrecisionWithSplitK = 0,
DisallowReducedPrecisionAllowSplitK = 1,
DisallowReducedPrecisionDisallowSplitK = 2,
};
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
@ -357,10 +363,14 @@ class TORCH_API Context {
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
void setAllowBF16ReductionCuBLAS(bool);
CuBLASReductionOption allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(
bool allow_reduced_precision,
bool allow_splitk = true);
CuBLASReductionOption allowBF16ReductionCuBLAS() const;
void setAllowBF16ReductionCuBLAS(
bool allow_reduced_precision,
bool allow_splitk = true);
bool allowFP16AccumulationCuBLAS() const;
void setAllowFP16AccumulationCuBLAS(bool);
@ -452,8 +462,10 @@ class TORCH_API Context {
: at::Float32MatmulPrecision::HIGHEST;
int benchmark_limit_cudnn = 10;
bool allow_tf32_cudnn = true;
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
CuBLASReductionOption allow_fp16_reduction_cublas =
CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
CuBLASReductionOption allow_bf16_reduction_cublas =
CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
bool allow_fp16_accumulation_cublas = false;
std::optional<int32_t> sm_carveout = std::nullopt;
bool enabled_mkldnn = true;

View File

@ -422,18 +422,34 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
abType = CUDA_R_16F;
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16F;
#ifndef USE_ROCM
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abType = CUDA_R_16BF;
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16BF;
#ifndef USE_ROCM
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
} else {
@ -1120,8 +1136,15 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(
}
if (prop->major >= 5) {
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
TORCH_CHECK(fp16_reduction !=
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
"torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
"..., allow_splitk=False) requires the cuBLASLt backend");
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
cublas_flags = static_cast<cublasMath_t>(
cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
@ -1180,8 +1203,15 @@ inline void gemm_internal_cublas_bfloat16_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DT
GEMM_CHECK_ARGVALUES(at::BFloat16);
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
TORCH_CHECK(bf16_reduction !=
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
"torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
"..., allow_splitk=False) requires the cuBLASLt backend");
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
cublas_flags = static_cast<cublasMath_t>(
cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
#if defined(USE_ROCM)
@ -1577,18 +1607,34 @@ bool gemm_and_bias(
abType = CUDA_R_16F;
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16F;
#ifndef USE_ROCM
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abType = CUDA_R_16BF;
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16BF;
#ifndef USE_ROCM
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
}

View File

@ -61,12 +61,16 @@ These backends include:
.. attribute:: allow_fp16_reduced_precision_reduction
A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs.
Assigning a tuple ``(allow_reduced_precision, allow_splitk)`` lets you also toggle whether
split-K heuristics may be used when dispatching to cuBLASLt. ``allow_splitk`` defaults to ``True``.
```
```{eval-rst}
.. attribute:: allow_bf16_reduced_precision_reduction
A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs.
Assigning a tuple ``(allow_reduced_precision, allow_splitk)`` lets you also toggle whether
split-K heuristics may be used when dispatching to cuBLASLt. ``allow_splitk`` defaults to ``True``.
```
```{eval-rst}

View File

@ -749,25 +749,67 @@ print(t.is_pinned())
def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
orig_splitk = (
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction_split_k
)
self.assertEqual(
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
(orig, orig_splitk),
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
self.assertEqual(
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
(not orig, True),
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
False,
False,
)
self.assertEqual(
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
(False, False),
)
with self.assertRaisesRegex(RuntimeError, "allow_splitk=False"):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
True,
False,
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
orig,
orig_splitk,
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
orig_splitk = (
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction_split_k
)
self.assertEqual(
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
(orig, orig_splitk),
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
self.assertEqual(
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
(not orig, True),
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
False,
False,
)
self.assertEqual(
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
(False, False),
)
with self.assertRaisesRegex(RuntimeError, "allow_splitk=False"):
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
True,
False,
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
orig,
orig_splitk,
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
def test_cublas_allow_fp16_accumulation_get_set(self):
orig = torch.backends.cuda.matmul.allow_fp16_accumulation

View File

@ -1254,17 +1254,19 @@ def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecis
def _set_float32_matmul_precision(
arg: str,
) -> None: ... # THPModule_setFloat32MatmulPrecision
def _get_cublas_allow_fp16_reduced_precision_reduction() -> (
_bool
): ... # THPModule_allowFP16ReductionCuBLAS
def _get_cublas_allow_fp16_reduced_precision_reduction() -> tuple[
_bool, _bool
]: ... # THPModule_allowFP16ReductionCuBLAS
def _set_cublas_allow_fp16_reduced_precision_reduction(
arg: _bool,
allow_splitk: _bool = ...,
) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS
def _get_cublas_allow_bf16_reduced_precision_reduction() -> (
_bool
): ... # THPModule_allowBF16ReductionCuBLAS
def _get_cublas_allow_bf16_reduced_precision_reduction() -> tuple[
_bool, _bool
]: ... # THPModule_allowBF16ReductionCuBLAS
def _set_cublas_allow_bf16_reduced_precision_reduction(
arg: _bool,
allow_splitk: _bool = ...,
) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS
def _get_cublas_allow_fp16_accumulation() -> (
_bool

View File

@ -44,7 +44,18 @@ if TYPE_CHECKING:
Node = torch.fx.Node
Region = list[Node]
IdenticalNodes = list[Node]
GlobalStateKey = tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
GlobalStateKey = tuple[
bool,
bool,
int,
tuple[bool, bool],
tuple[bool, bool],
torch.dtype,
bool,
bool,
bool,
bool,
]
log = logging.getLogger(__name__)
graph_expansion_log = torch._logging.getArtifactLogger(

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Union
from typing import Any, Union
from typing_extensions import deprecated
import torch
@ -126,13 +126,54 @@ class cuFFTPlanCacheManager:
class cuBLASModule:
@staticmethod
def _parse_reduction_setting(value: Any, attr_name: str) -> tuple[bool, bool]:
def _ensure_bool(obj: Any, which: str) -> bool:
if isinstance(obj, bool):
return obj
raise TypeError(
f"{attr_name} expects a bool for {which}, but got {type(obj)!r}"
)
if isinstance(value, bool):
return value, True
if isinstance(value, (list, tuple)):
if not value:
raise TypeError(f"{attr_name} expects at least one boolean argument")
if len(value) > 2:
raise TypeError(f"{attr_name} expects at most two boolean arguments")
allow_reduced_precision = _ensure_bool(value[0], "allow_reduced_precision")
if len(value) == 1:
return allow_reduced_precision, True
allow_splitk = _ensure_bool(value[1], "allow_splitk")
return allow_reduced_precision, allow_splitk
raise TypeError(
f"{attr_name} expects a bool or a tuple/list of bools, but got {type(value)!r}"
)
def __getattr__(self, name):
if name == "allow_tf32":
return torch._C._get_cublas_allow_tf32()
elif name == "allow_fp16_reduced_precision_reduction":
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
allow_reduced_precision, _ = (
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
)
return allow_reduced_precision
elif name == "allow_fp16_reduced_precision_reduction_split_k":
_, allow_splitk = (
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
)
return allow_splitk
elif name == "allow_bf16_reduced_precision_reduction":
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
allow_reduced_precision, _ = (
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
)
return allow_reduced_precision
elif name == "allow_bf16_reduced_precision_reduction_split_k":
_, allow_splitk = (
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
)
return allow_splitk
elif name == "allow_fp16_accumulation":
return torch._C._get_cublas_allow_fp16_accumulation()
elif name == "fp32_precision":
@ -143,9 +184,19 @@ class cuBLASModule:
if name == "allow_tf32":
return torch._C._set_cublas_allow_tf32(value)
elif name == "allow_fp16_reduced_precision_reduction":
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
value, "allow_fp16_reduced_precision_reduction"
)
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
allow_reduced_precision, allow_splitk
)
elif name == "allow_bf16_reduced_precision_reduction":
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
value, "allow_bf16_reduced_precision_reduction"
)
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
allow_reduced_precision, allow_splitk
)
elif name == "allow_fp16_accumulation":
return torch._C._set_cublas_allow_fp16_accumulation(value)
elif name == "fp32_precision":

View File

@ -1224,14 +1224,30 @@ static PyObject* THPModule_allowTF32CuBLAS(
static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
PyObject* _unused,
PyObject* arg) {
PyObject* args) {
HANDLE_TH_ERRORS
PyObject* allow_reduction_obj = nullptr;
PyObject* allow_splitk_obj = Py_None;
if (!PyArg_ParseTuple(args, "O|O", &allow_reduction_obj, &allow_splitk_obj)) {
return nullptr;
}
TORCH_CHECK(
PyBool_Check(arg),
"set_allow_fp16_reduction_cublas expects a bool, "
PyBool_Check(allow_reduction_obj),
"set_allow_fp16_reduction_cublas expects a bool for allow_reduced_precision, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
THPUtils_typename(allow_reduction_obj));
bool allow_reduction = allow_reduction_obj == Py_True;
bool allow_splitk = true;
if (allow_splitk_obj != Py_None) {
TORCH_CHECK(
PyBool_Check(allow_splitk_obj),
"set_allow_fp16_reduction_cublas expects a bool for allow_splitk, "
"but got ",
THPUtils_typename(allow_splitk_obj));
allow_splitk = allow_splitk_obj == Py_True;
}
at::globalContext().setAllowFP16ReductionCuBLAS(
allow_reduction, allow_splitk);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -1239,22 +1255,43 @@ static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
static PyObject* THPModule_allowFP16ReductionCuBLAS(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().allowFP16ReductionCuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
auto option = at::globalContext().allowFP16ReductionCuBLAS();
bool allow_reduced_precision =
option == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
bool allow_splitk = option !=
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK;
return PyTuple_Pack(
2,
allow_reduced_precision ? Py_True : Py_False,
allow_splitk ? Py_True : Py_False);
}
static PyObject* THPModule_setAllowBF16ReductionCuBLAS(
PyObject* _unused,
PyObject* arg) {
PyObject* args) {
HANDLE_TH_ERRORS
PyObject* allow_reduction_obj = nullptr;
PyObject* allow_splitk_obj = Py_None;
if (!PyArg_ParseTuple(args, "O|O", &allow_reduction_obj, &allow_splitk_obj)) {
return nullptr;
}
TORCH_CHECK(
PyBool_Check(arg),
"set_allow_bf16_reduction_cublas expects a bool, "
PyBool_Check(allow_reduction_obj),
"set_allow_bf16_reduction_cublas expects a bool for allow_reduced_precision, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
THPUtils_typename(allow_reduction_obj));
bool allow_reduction = allow_reduction_obj == Py_True;
bool allow_splitk = true;
if (allow_splitk_obj != Py_None) {
TORCH_CHECK(
PyBool_Check(allow_splitk_obj),
"set_allow_bf16_reduction_cublas expects a bool for allow_splitk, "
"but got ",
THPUtils_typename(allow_splitk_obj));
allow_splitk = allow_splitk_obj == Py_True;
}
at::globalContext().setAllowBF16ReductionCuBLAS(
allow_reduction, allow_splitk);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -1262,10 +1299,15 @@ static PyObject* THPModule_setAllowBF16ReductionCuBLAS(
static PyObject* THPModule_allowBF16ReductionCuBLAS(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().allowBF16ReductionCuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
auto option = at::globalContext().allowBF16ReductionCuBLAS();
bool allow_reduced_precision =
option == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
bool allow_splitk = option !=
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK;
return PyTuple_Pack(
2,
allow_reduced_precision ? Py_True : Py_False,
allow_splitk ? Py_True : Py_False);
}
static PyObject* THPModule_setAllowFP16AccumulationCuBLAS(
@ -1736,7 +1778,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
nullptr},
{"_set_cublas_allow_fp16_reduced_precision_reduction",
THPModule_setAllowFP16ReductionCuBLAS,
METH_O,
METH_VARARGS,
nullptr},
{"_get_cublas_allow_bf16_reduced_precision_reduction",
THPModule_allowBF16ReductionCuBLAS,
@ -1744,7 +1786,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
nullptr},
{"_set_cublas_allow_bf16_reduced_precision_reduction",
THPModule_setAllowBF16ReductionCuBLAS,
METH_O,
METH_VARARGS,
nullptr},
{"_get_cublas_allow_fp16_accumulation",
THPModule_allowFP16AccumulationCuBLAS,

View File

@ -19,6 +19,7 @@
#include <torch/csrc/utils/python_symnode.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <torch/extension.h>
#include <cstdint>
#include <torch/csrc/dynamo/debug_macros.h>
@ -703,8 +704,10 @@ struct GlobalStateGuard {
json_j["deterministic_algorithms_warn_only"] =
json_t._deterministic_algorithms_warn_only;
json_j["allow_tf32"] = json_t._allow_tf32;
json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce;
json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce;
json_j["allow_fp16_reduce"] =
static_cast<int64_t>(json_t._allow_fp16_reduce);
json_j["allow_bf16_reduce"] =
static_cast<int64_t>(json_t._allow_bf16_reduce);
json_j["num_threads"] = json_t._num_threads;
json_j["default_dtype"] = json_t._default_dtype.toScalarType();
}
@ -720,8 +723,10 @@ struct GlobalStateGuard {
json_t._deterministic_algorithms_warn_only =
json_j.at("deterministic_algorithms_warn_only");
json_t._allow_tf32 = json_j.at("allow_tf32");
json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce");
json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce");
json_t._allow_fp16_reduce = static_cast<at::CuBLASReductionOption>(
static_cast<int64_t>(json_j.at("allow_fp16_reduce")));
json_t._allow_bf16_reduce = static_cast<at::CuBLASReductionOption>(
static_cast<int64_t>(json_j.at("allow_bf16_reduce")));
json_t._num_threads = json_j.at("num_threads");
json_t._default_dtype =
caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype"));
@ -734,8 +739,8 @@ struct GlobalStateGuard {
bool _deterministic_algorithms;
bool _deterministic_algorithms_warn_only;
bool _allow_tf32;
bool _allow_fp16_reduce;
bool _allow_bf16_reduce;
at::CuBLASReductionOption _allow_fp16_reduce;
at::CuBLASReductionOption _allow_bf16_reduce;
int _num_threads;
caffe2::TypeMeta _default_dtype;
// TODO(jansel): we should guard on more state as inductor starts using it