mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add split-K control to cuBLAS reduced-precision settings (#164766)
## Summary - add a CuBLASReductionOption enum so the CUDA context can track reduced-precision and split-K options - extend the Python bindings, backend helpers, and docs to accept an optional allow_splitk argument for fp16/bf16 matmul controls - update cuBLAS/cuBLASLt call sites plus dynamo guards and tests to respect the new combinations ## Testing - python test/test_cuda.py TestCuda.test_cublas_allow_fp16_reduced_precision_reduction_get_set -v *(fails: ModuleNotFoundError: No module named 'psutil')* ------ https://chatgpt.com/codex/tasks/task_e_68e404623178832f8a3e1d34e1e175da Pull Request resolved: https://github.com/pytorch/pytorch/pull/164766 Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b85236477
commit
37c6087334
@ -587,20 +587,33 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
|
||||
rocm_fa_preferred_backend = b;
|
||||
}
|
||||
|
||||
bool Context::allowFP16ReductionCuBLAS() const {
|
||||
CuBLASReductionOption Context::allowFP16ReductionCuBLAS() const {
|
||||
return allow_fp16_reduction_cublas;
|
||||
}
|
||||
|
||||
void Context::setAllowFP16ReductionCuBLAS(bool b) {
|
||||
allow_fp16_reduction_cublas = b;
|
||||
CuBLASReductionOption inline get_reduction_option(bool allow_reduced_precision, bool allow_splitk) {
|
||||
TORCH_CHECK(
|
||||
!(allow_reduced_precision && !allow_splitk),
|
||||
"allow_splitk=False is not supported when reduced precision reductions are enabled");
|
||||
if (allow_reduced_precision) {
|
||||
return CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
|
||||
} else if (allow_splitk) {
|
||||
return CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK;
|
||||
} else {
|
||||
return CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK;
|
||||
}
|
||||
}
|
||||
|
||||
bool Context::allowBF16ReductionCuBLAS() const {
|
||||
void Context::setAllowFP16ReductionCuBLAS(bool allow_reduced_precision, bool allow_splitk) {
|
||||
allow_fp16_reduction_cublas = get_reduction_option(allow_reduced_precision, allow_splitk);
|
||||
}
|
||||
|
||||
CuBLASReductionOption Context::allowBF16ReductionCuBLAS() const {
|
||||
return allow_bf16_reduction_cublas;
|
||||
}
|
||||
|
||||
void Context::setAllowBF16ReductionCuBLAS(bool b) {
|
||||
allow_bf16_reduction_cublas = b;
|
||||
void Context::setAllowBF16ReductionCuBLAS(bool allow_reduced_precision, bool allow_splitk) {
|
||||
allow_bf16_reduction_cublas = get_reduction_option(allow_reduced_precision, allow_splitk);
|
||||
}
|
||||
|
||||
bool Context::allowFP16AccumulationCuBLAS() const {
|
||||
|
@ -38,6 +38,12 @@ namespace at {
|
||||
class Tensor;
|
||||
|
||||
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
||||
|
||||
enum class CuBLASReductionOption : uint8_t {
|
||||
AllowReducedPrecisionWithSplitK = 0,
|
||||
DisallowReducedPrecisionAllowSplitK = 1,
|
||||
DisallowReducedPrecisionDisallowSplitK = 2,
|
||||
};
|
||||
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
|
||||
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
|
||||
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
|
||||
@ -357,10 +363,14 @@ class TORCH_API Context {
|
||||
void setAllowTF32CuBLAS(bool);
|
||||
Float32MatmulPrecision float32MatmulPrecision() const;
|
||||
Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
|
||||
bool allowFP16ReductionCuBLAS() const;
|
||||
void setAllowFP16ReductionCuBLAS(bool);
|
||||
bool allowBF16ReductionCuBLAS() const;
|
||||
void setAllowBF16ReductionCuBLAS(bool);
|
||||
CuBLASReductionOption allowFP16ReductionCuBLAS() const;
|
||||
void setAllowFP16ReductionCuBLAS(
|
||||
bool allow_reduced_precision,
|
||||
bool allow_splitk = true);
|
||||
CuBLASReductionOption allowBF16ReductionCuBLAS() const;
|
||||
void setAllowBF16ReductionCuBLAS(
|
||||
bool allow_reduced_precision,
|
||||
bool allow_splitk = true);
|
||||
bool allowFP16AccumulationCuBLAS() const;
|
||||
void setAllowFP16AccumulationCuBLAS(bool);
|
||||
|
||||
@ -452,8 +462,10 @@ class TORCH_API Context {
|
||||
: at::Float32MatmulPrecision::HIGHEST;
|
||||
int benchmark_limit_cudnn = 10;
|
||||
bool allow_tf32_cudnn = true;
|
||||
bool allow_fp16_reduction_cublas = true;
|
||||
bool allow_bf16_reduction_cublas = true;
|
||||
CuBLASReductionOption allow_fp16_reduction_cublas =
|
||||
CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
|
||||
CuBLASReductionOption allow_bf16_reduction_cublas =
|
||||
CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
|
||||
bool allow_fp16_accumulation_cublas = false;
|
||||
std::optional<int32_t> sm_carveout = std::nullopt;
|
||||
bool enabled_mkldnn = true;
|
||||
|
@ -422,18 +422,34 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
abType = CUDA_R_16F;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16F;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
CUBLASLT_REDUCTION_SCHEME_NONE)
|
||||
: CUBLASLT_REDUCTION_SCHEME_NONE;
|
||||
preference.setAttribute(
|
||||
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
|
||||
}
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abType = CUDA_R_16BF;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16BF;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
CUBLASLT_REDUCTION_SCHEME_NONE)
|
||||
: CUBLASLT_REDUCTION_SCHEME_NONE;
|
||||
preference.setAttribute(
|
||||
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
@ -1120,8 +1136,15 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(
|
||||
}
|
||||
if (prop->major >= 5) {
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
TORCH_CHECK(fp16_reduction !=
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
|
||||
"torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
|
||||
"..., allow_splitk=False) requires the cuBLASLt backend");
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
cublas_flags = static_cast<cublasMath_t>(
|
||||
cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
// Disallow fp16 reductions that could lead to unexpected overflow issues.
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
||||
@ -1180,8 +1203,15 @@ inline void gemm_internal_cublas_bfloat16_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DT
|
||||
GEMM_CHECK_ARGVALUES(at::BFloat16);
|
||||
#ifndef USE_ROCM
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
TORCH_CHECK(bf16_reduction !=
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
|
||||
"torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
|
||||
"..., allow_splitk=False) requires the cuBLASLt backend");
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
cublas_flags = static_cast<cublasMath_t>(
|
||||
cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_ROCM)
|
||||
@ -1577,18 +1607,34 @@ bool gemm_and_bias(
|
||||
abType = CUDA_R_16F;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16F;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
CUBLASLT_REDUCTION_SCHEME_NONE)
|
||||
: CUBLASLT_REDUCTION_SCHEME_NONE;
|
||||
preference.setAttribute(
|
||||
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
|
||||
}
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abType = CUDA_R_16BF;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16BF;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
CUBLASLT_REDUCTION_SCHEME_NONE)
|
||||
: CUBLASLT_REDUCTION_SCHEME_NONE;
|
||||
preference.setAttribute(
|
||||
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
@ -61,12 +61,16 @@ These backends include:
|
||||
.. attribute:: allow_fp16_reduced_precision_reduction
|
||||
|
||||
A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs.
|
||||
Assigning a tuple ``(allow_reduced_precision, allow_splitk)`` lets you also toggle whether
|
||||
split-K heuristics may be used when dispatching to cuBLASLt. ``allow_splitk`` defaults to ``True``.
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. attribute:: allow_bf16_reduced_precision_reduction
|
||||
|
||||
A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs.
|
||||
Assigning a tuple ``(allow_reduced_precision, allow_splitk)`` lets you also toggle whether
|
||||
split-K heuristics may be used when dispatching to cuBLASLt. ``allow_splitk`` defaults to ``True``.
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
|
@ -749,25 +749,67 @@ print(t.is_pinned())
|
||||
|
||||
def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
|
||||
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
orig_splitk = (
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction_split_k
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
|
||||
(orig, orig_splitk),
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
|
||||
self.assertEqual(
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
|
||||
(not orig, True),
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
False,
|
||||
False,
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
|
||||
(False, False),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "allow_splitk=False"):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
True,
|
||||
False,
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
orig,
|
||||
orig_splitk,
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
|
||||
|
||||
def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
|
||||
orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
orig_splitk = (
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction_split_k
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
|
||||
(orig, orig_splitk),
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
|
||||
self.assertEqual(
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
|
||||
(not orig, True),
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
False,
|
||||
False,
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
|
||||
(False, False),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "allow_splitk=False"):
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
True,
|
||||
False,
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
orig,
|
||||
orig_splitk,
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
|
||||
|
||||
def test_cublas_allow_fp16_accumulation_get_set(self):
|
||||
orig = torch.backends.cuda.matmul.allow_fp16_accumulation
|
||||
|
@ -1254,17 +1254,19 @@ def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecis
|
||||
def _set_float32_matmul_precision(
|
||||
arg: str,
|
||||
) -> None: ... # THPModule_setFloat32MatmulPrecision
|
||||
def _get_cublas_allow_fp16_reduced_precision_reduction() -> (
|
||||
_bool
|
||||
): ... # THPModule_allowFP16ReductionCuBLAS
|
||||
def _get_cublas_allow_fp16_reduced_precision_reduction() -> tuple[
|
||||
_bool, _bool
|
||||
]: ... # THPModule_allowFP16ReductionCuBLAS
|
||||
def _set_cublas_allow_fp16_reduced_precision_reduction(
|
||||
arg: _bool,
|
||||
allow_splitk: _bool = ...,
|
||||
) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS
|
||||
def _get_cublas_allow_bf16_reduced_precision_reduction() -> (
|
||||
_bool
|
||||
): ... # THPModule_allowBF16ReductionCuBLAS
|
||||
def _get_cublas_allow_bf16_reduced_precision_reduction() -> tuple[
|
||||
_bool, _bool
|
||||
]: ... # THPModule_allowBF16ReductionCuBLAS
|
||||
def _set_cublas_allow_bf16_reduced_precision_reduction(
|
||||
arg: _bool,
|
||||
allow_splitk: _bool = ...,
|
||||
) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS
|
||||
def _get_cublas_allow_fp16_accumulation() -> (
|
||||
_bool
|
||||
|
@ -44,7 +44,18 @@ if TYPE_CHECKING:
|
||||
Node = torch.fx.Node
|
||||
Region = list[Node]
|
||||
IdenticalNodes = list[Node]
|
||||
GlobalStateKey = tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
|
||||
GlobalStateKey = tuple[
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
tuple[bool, bool],
|
||||
tuple[bool, bool],
|
||||
torch.dtype,
|
||||
bool,
|
||||
bool,
|
||||
bool,
|
||||
bool,
|
||||
]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_expansion_log = torch._logging.getArtifactLogger(
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -126,13 +126,54 @@ class cuFFTPlanCacheManager:
|
||||
|
||||
|
||||
class cuBLASModule:
|
||||
@staticmethod
|
||||
def _parse_reduction_setting(value: Any, attr_name: str) -> tuple[bool, bool]:
|
||||
def _ensure_bool(obj: Any, which: str) -> bool:
|
||||
if isinstance(obj, bool):
|
||||
return obj
|
||||
raise TypeError(
|
||||
f"{attr_name} expects a bool for {which}, but got {type(obj)!r}"
|
||||
)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return value, True
|
||||
if isinstance(value, (list, tuple)):
|
||||
if not value:
|
||||
raise TypeError(f"{attr_name} expects at least one boolean argument")
|
||||
if len(value) > 2:
|
||||
raise TypeError(f"{attr_name} expects at most two boolean arguments")
|
||||
allow_reduced_precision = _ensure_bool(value[0], "allow_reduced_precision")
|
||||
if len(value) == 1:
|
||||
return allow_reduced_precision, True
|
||||
allow_splitk = _ensure_bool(value[1], "allow_splitk")
|
||||
return allow_reduced_precision, allow_splitk
|
||||
raise TypeError(
|
||||
f"{attr_name} expects a bool or a tuple/list of bools, but got {type(value)!r}"
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "allow_tf32":
|
||||
return torch._C._get_cublas_allow_tf32()
|
||||
elif name == "allow_fp16_reduced_precision_reduction":
|
||||
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||
allow_reduced_precision, _ = (
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_reduced_precision
|
||||
elif name == "allow_fp16_reduced_precision_reduction_split_k":
|
||||
_, allow_splitk = (
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_splitk
|
||||
elif name == "allow_bf16_reduced_precision_reduction":
|
||||
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
allow_reduced_precision, _ = (
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_reduced_precision
|
||||
elif name == "allow_bf16_reduced_precision_reduction_split_k":
|
||||
_, allow_splitk = (
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_splitk
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._get_cublas_allow_fp16_accumulation()
|
||||
elif name == "fp32_precision":
|
||||
@ -143,9 +184,19 @@ class cuBLASModule:
|
||||
if name == "allow_tf32":
|
||||
return torch._C._set_cublas_allow_tf32(value)
|
||||
elif name == "allow_fp16_reduced_precision_reduction":
|
||||
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
|
||||
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
|
||||
value, "allow_fp16_reduced_precision_reduction"
|
||||
)
|
||||
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
|
||||
allow_reduced_precision, allow_splitk
|
||||
)
|
||||
elif name == "allow_bf16_reduced_precision_reduction":
|
||||
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
|
||||
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
|
||||
value, "allow_bf16_reduced_precision_reduction"
|
||||
)
|
||||
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
|
||||
allow_reduced_precision, allow_splitk
|
||||
)
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._set_cublas_allow_fp16_accumulation(value)
|
||||
elif name == "fp32_precision":
|
||||
|
@ -1224,14 +1224,30 @@ static PyObject* THPModule_allowTF32CuBLAS(
|
||||
|
||||
static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* allow_reduction_obj = nullptr;
|
||||
PyObject* allow_splitk_obj = Py_None;
|
||||
if (!PyArg_ParseTuple(args, "O|O", &allow_reduction_obj, &allow_splitk_obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(arg),
|
||||
"set_allow_fp16_reduction_cublas expects a bool, "
|
||||
PyBool_Check(allow_reduction_obj),
|
||||
"set_allow_fp16_reduction_cublas expects a bool for allow_reduced_precision, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
|
||||
THPUtils_typename(allow_reduction_obj));
|
||||
bool allow_reduction = allow_reduction_obj == Py_True;
|
||||
bool allow_splitk = true;
|
||||
if (allow_splitk_obj != Py_None) {
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(allow_splitk_obj),
|
||||
"set_allow_fp16_reduction_cublas expects a bool for allow_splitk, "
|
||||
"but got ",
|
||||
THPUtils_typename(allow_splitk_obj));
|
||||
allow_splitk = allow_splitk_obj == Py_True;
|
||||
}
|
||||
at::globalContext().setAllowFP16ReductionCuBLAS(
|
||||
allow_reduction, allow_splitk);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -1239,22 +1255,43 @@ static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
|
||||
static PyObject* THPModule_allowFP16ReductionCuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
if (at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
auto option = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
bool allow_reduced_precision =
|
||||
option == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
|
||||
bool allow_splitk = option !=
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK;
|
||||
return PyTuple_Pack(
|
||||
2,
|
||||
allow_reduced_precision ? Py_True : Py_False,
|
||||
allow_splitk ? Py_True : Py_False);
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setAllowBF16ReductionCuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* allow_reduction_obj = nullptr;
|
||||
PyObject* allow_splitk_obj = Py_None;
|
||||
if (!PyArg_ParseTuple(args, "O|O", &allow_reduction_obj, &allow_splitk_obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(arg),
|
||||
"set_allow_bf16_reduction_cublas expects a bool, "
|
||||
PyBool_Check(allow_reduction_obj),
|
||||
"set_allow_bf16_reduction_cublas expects a bool for allow_reduced_precision, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
|
||||
THPUtils_typename(allow_reduction_obj));
|
||||
bool allow_reduction = allow_reduction_obj == Py_True;
|
||||
bool allow_splitk = true;
|
||||
if (allow_splitk_obj != Py_None) {
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(allow_splitk_obj),
|
||||
"set_allow_bf16_reduction_cublas expects a bool for allow_splitk, "
|
||||
"but got ",
|
||||
THPUtils_typename(allow_splitk_obj));
|
||||
allow_splitk = allow_splitk_obj == Py_True;
|
||||
}
|
||||
at::globalContext().setAllowBF16ReductionCuBLAS(
|
||||
allow_reduction, allow_splitk);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -1262,10 +1299,15 @@ static PyObject* THPModule_setAllowBF16ReductionCuBLAS(
|
||||
static PyObject* THPModule_allowBF16ReductionCuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
if (at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
auto option = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
bool allow_reduced_precision =
|
||||
option == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
|
||||
bool allow_splitk = option !=
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK;
|
||||
return PyTuple_Pack(
|
||||
2,
|
||||
allow_reduced_precision ? Py_True : Py_False,
|
||||
allow_splitk ? Py_True : Py_False);
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setAllowFP16AccumulationCuBLAS(
|
||||
@ -1736,7 +1778,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
nullptr},
|
||||
{"_set_cublas_allow_fp16_reduced_precision_reduction",
|
||||
THPModule_setAllowFP16ReductionCuBLAS,
|
||||
METH_O,
|
||||
METH_VARARGS,
|
||||
nullptr},
|
||||
{"_get_cublas_allow_bf16_reduced_precision_reduction",
|
||||
THPModule_allowBF16ReductionCuBLAS,
|
||||
@ -1744,7 +1786,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
nullptr},
|
||||
{"_set_cublas_allow_bf16_reduced_precision_reduction",
|
||||
THPModule_setAllowBF16ReductionCuBLAS,
|
||||
METH_O,
|
||||
METH_VARARGS,
|
||||
nullptr},
|
||||
{"_get_cublas_allow_fp16_accumulation",
|
||||
THPModule_allowFP16AccumulationCuBLAS,
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <torch/csrc/utils/python_symnode.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include <torch/csrc/dynamo/debug_macros.h>
|
||||
|
||||
@ -703,8 +704,10 @@ struct GlobalStateGuard {
|
||||
json_j["deterministic_algorithms_warn_only"] =
|
||||
json_t._deterministic_algorithms_warn_only;
|
||||
json_j["allow_tf32"] = json_t._allow_tf32;
|
||||
json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce;
|
||||
json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce;
|
||||
json_j["allow_fp16_reduce"] =
|
||||
static_cast<int64_t>(json_t._allow_fp16_reduce);
|
||||
json_j["allow_bf16_reduce"] =
|
||||
static_cast<int64_t>(json_t._allow_bf16_reduce);
|
||||
json_j["num_threads"] = json_t._num_threads;
|
||||
json_j["default_dtype"] = json_t._default_dtype.toScalarType();
|
||||
}
|
||||
@ -720,8 +723,10 @@ struct GlobalStateGuard {
|
||||
json_t._deterministic_algorithms_warn_only =
|
||||
json_j.at("deterministic_algorithms_warn_only");
|
||||
json_t._allow_tf32 = json_j.at("allow_tf32");
|
||||
json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce");
|
||||
json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce");
|
||||
json_t._allow_fp16_reduce = static_cast<at::CuBLASReductionOption>(
|
||||
static_cast<int64_t>(json_j.at("allow_fp16_reduce")));
|
||||
json_t._allow_bf16_reduce = static_cast<at::CuBLASReductionOption>(
|
||||
static_cast<int64_t>(json_j.at("allow_bf16_reduce")));
|
||||
json_t._num_threads = json_j.at("num_threads");
|
||||
json_t._default_dtype =
|
||||
caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype"));
|
||||
@ -734,8 +739,8 @@ struct GlobalStateGuard {
|
||||
bool _deterministic_algorithms;
|
||||
bool _deterministic_algorithms_warn_only;
|
||||
bool _allow_tf32;
|
||||
bool _allow_fp16_reduce;
|
||||
bool _allow_bf16_reduce;
|
||||
at::CuBLASReductionOption _allow_fp16_reduce;
|
||||
at::CuBLASReductionOption _allow_bf16_reduce;
|
||||
int _num_threads;
|
||||
caffe2::TypeMeta _default_dtype;
|
||||
// TODO(jansel): we should guard on more state as inductor starts using it
|
||||
|
Reference in New Issue
Block a user