From 37c6087334cce3ad4bc9838ea2ef63aba89f2253 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 8 Oct 2025 18:48:42 +0000 Subject: [PATCH] Add split-K control to cuBLAS reduced-precision settings (#164766) ## Summary - add a CuBLASReductionOption enum so the CUDA context can track reduced-precision and split-K options - extend the Python bindings, backend helpers, and docs to accept an optional allow_splitk argument for fp16/bf16 matmul controls - update cuBLAS/cuBLASLt call sites plus dynamo guards and tests to respect the new combinations ## Testing - python test/test_cuda.py TestCuda.test_cublas_allow_fp16_reduced_precision_reduction_get_set -v *(fails: ModuleNotFoundError: No module named 'psutil')* ------ https://chatgpt.com/codex/tasks/task_e_68e404623178832f8a3e1d34e1e175da Pull Request resolved: https://github.com/pytorch/pytorch/pull/164766 Approved by: https://github.com/malfet, https://github.com/albanD --- aten/src/ATen/Context.cpp | 25 ++++++-- aten/src/ATen/Context.h | 24 ++++++-- aten/src/ATen/cuda/CUDABlas.cpp | 78 +++++++++++++++++++------ docs/source/backends.md | 4 ++ test/test_cuda.py | 54 ++++++++++++++++-- torch/_C/__init__.pyi.in | 14 +++-- torch/_dynamo/graph_region_tracker.py | 13 ++++- torch/backends/cuda/__init__.py | 61 ++++++++++++++++++-- torch/csrc/Module.cpp | 82 ++++++++++++++++++++------- torch/csrc/dynamo/guards.cpp | 17 ++++-- 10 files changed, 300 insertions(+), 72 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 4418bb0d67d7..3310abfb41d5 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -587,20 +587,33 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { rocm_fa_preferred_backend = b; } -bool Context::allowFP16ReductionCuBLAS() const { +CuBLASReductionOption Context::allowFP16ReductionCuBLAS() const { return allow_fp16_reduction_cublas; } -void Context::setAllowFP16ReductionCuBLAS(bool b) { - allow_fp16_reduction_cublas = b; +CuBLASReductionOption inline get_reduction_option(bool allow_reduced_precision, bool allow_splitk) { + TORCH_CHECK( + !(allow_reduced_precision && !allow_splitk), + "allow_splitk=False is not supported when reduced precision reductions are enabled"); + if (allow_reduced_precision) { + return CuBLASReductionOption::AllowReducedPrecisionWithSplitK; + } else if (allow_splitk) { + return CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK; + } else { + return CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK; + } } -bool Context::allowBF16ReductionCuBLAS() const { +void Context::setAllowFP16ReductionCuBLAS(bool allow_reduced_precision, bool allow_splitk) { + allow_fp16_reduction_cublas = get_reduction_option(allow_reduced_precision, allow_splitk); +} + +CuBLASReductionOption Context::allowBF16ReductionCuBLAS() const { return allow_bf16_reduction_cublas; } -void Context::setAllowBF16ReductionCuBLAS(bool b) { - allow_bf16_reduction_cublas = b; +void Context::setAllowBF16ReductionCuBLAS(bool allow_reduced_precision, bool allow_splitk) { + allow_bf16_reduction_cublas = get_reduction_option(allow_reduced_precision, allow_splitk); } bool Context::allowFP16AccumulationCuBLAS() const { diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 42e92ab7284a..4055083cfcb2 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -38,6 +38,12 @@ namespace at { class Tensor; enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM }; + +enum class CuBLASReductionOption : uint8_t { + AllowReducedPrecisionWithSplitK = 0, + DisallowReducedPrecisionAllowSplitK = 1, + DisallowReducedPrecisionDisallowSplitK = 2, +}; enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN }; enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL }; enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 }; @@ -357,10 +363,14 @@ class TORCH_API Context { void setAllowTF32CuBLAS(bool); Float32MatmulPrecision float32MatmulPrecision() const; Float32Precision float32Precision(Float32Backend backend, Float32Op op) const; - bool allowFP16ReductionCuBLAS() const; - void setAllowFP16ReductionCuBLAS(bool); - bool allowBF16ReductionCuBLAS() const; - void setAllowBF16ReductionCuBLAS(bool); + CuBLASReductionOption allowFP16ReductionCuBLAS() const; + void setAllowFP16ReductionCuBLAS( + bool allow_reduced_precision, + bool allow_splitk = true); + CuBLASReductionOption allowBF16ReductionCuBLAS() const; + void setAllowBF16ReductionCuBLAS( + bool allow_reduced_precision, + bool allow_splitk = true); bool allowFP16AccumulationCuBLAS() const; void setAllowFP16AccumulationCuBLAS(bool); @@ -452,8 +462,10 @@ class TORCH_API Context { : at::Float32MatmulPrecision::HIGHEST; int benchmark_limit_cudnn = 10; bool allow_tf32_cudnn = true; - bool allow_fp16_reduction_cublas = true; - bool allow_bf16_reduction_cublas = true; + CuBLASReductionOption allow_fp16_reduction_cublas = + CuBLASReductionOption::AllowReducedPrecisionWithSplitK; + CuBLASReductionOption allow_bf16_reduction_cublas = + CuBLASReductionOption::AllowReducedPrecisionWithSplitK; bool allow_fp16_accumulation_cublas = false; std::optional sm_carveout = std::nullopt; bool enabled_mkldnn = true; diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 2d3fad27cd90..7484ec0f8863 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -422,18 +422,34 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D abType = CUDA_R_16F; cType = (std::is_same_v) ? CUDA_R_32F : CUDA_R_16F; #ifndef USE_ROCM - if (!at::globalContext().allowFP16ReductionCuBLAS()) { - preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, - CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE); + auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); + if (fp16_reduction != + at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { + uint32_t mask = + fp16_reduction == + at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK + ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | + CUBLASLT_REDUCTION_SCHEME_NONE) + : CUBLASLT_REDUCTION_SCHEME_NONE; + preference.setAttribute( + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask); } #endif } else if constexpr (std::is_same_v) { abType = CUDA_R_16BF; cType = (std::is_same_v) ? CUDA_R_32F : CUDA_R_16BF; #ifndef USE_ROCM - if (!at::globalContext().allowBF16ReductionCuBLAS()) { - preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, - CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE); + auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); + if (bf16_reduction != + at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { + uint32_t mask = + bf16_reduction == + at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK + ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | + CUBLASLT_REDUCTION_SCHEME_NONE) + : CUBLASLT_REDUCTION_SCHEME_NONE; + preference.setAttribute( + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask); } #endif } else { @@ -1120,8 +1136,15 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE( } if (prop->major >= 5) { cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH; - if (!at::globalContext().allowFP16ReductionCuBLAS()) { - cublas_flags = static_cast(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); + auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); + TORCH_CHECK(fp16_reduction != + at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK, + "torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction(" + "..., allow_splitk=False) requires the cuBLASLt backend"); + if (fp16_reduction != + at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { + cublas_flags = static_cast( + cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); } // Disallow fp16 reductions that could lead to unexpected overflow issues. TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags)); @@ -1180,8 +1203,15 @@ inline void gemm_internal_cublas_bfloat16_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DT GEMM_CHECK_ARGVALUES(at::BFloat16); #ifndef USE_ROCM cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH; - if (!at::globalContext().allowBF16ReductionCuBLAS()) { - cublas_flags = static_cast(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); + auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); + TORCH_CHECK(bf16_reduction != + at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK, + "torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction(" + "..., allow_splitk=False) requires the cuBLASLt backend"); + if (bf16_reduction != + at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { + cublas_flags = static_cast( + cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); } #endif #if defined(USE_ROCM) @@ -1577,18 +1607,34 @@ bool gemm_and_bias( abType = CUDA_R_16F; cType = (std::is_same_v) ? CUDA_R_32F : CUDA_R_16F; #ifndef USE_ROCM - if (!at::globalContext().allowFP16ReductionCuBLAS()) { - preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, - CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE); + auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); + if (fp16_reduction != + at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { + uint32_t mask = + fp16_reduction == + at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK + ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | + CUBLASLT_REDUCTION_SCHEME_NONE) + : CUBLASLT_REDUCTION_SCHEME_NONE; + preference.setAttribute( + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask); } #endif } else if constexpr (std::is_same_v) { abType = CUDA_R_16BF; cType = (std::is_same_v) ? CUDA_R_32F : CUDA_R_16BF; #ifndef USE_ROCM - if (!at::globalContext().allowBF16ReductionCuBLAS()) { - preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, - CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE); + auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); + if (bf16_reduction != + at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { + uint32_t mask = + bf16_reduction == + at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK + ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | + CUBLASLT_REDUCTION_SCHEME_NONE) + : CUBLASLT_REDUCTION_SCHEME_NONE; + preference.setAttribute( + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask); } #endif } diff --git a/docs/source/backends.md b/docs/source/backends.md index 3e6cdc9697bf..6f8791d9a608 100644 --- a/docs/source/backends.md +++ b/docs/source/backends.md @@ -61,12 +61,16 @@ These backends include: .. attribute:: allow_fp16_reduced_precision_reduction A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs. + Assigning a tuple ``(allow_reduced_precision, allow_splitk)`` lets you also toggle whether + split-K heuristics may be used when dispatching to cuBLASLt. ``allow_splitk`` defaults to ``True``. ``` ```{eval-rst} .. attribute:: allow_bf16_reduced_precision_reduction A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs. + Assigning a tuple ``(allow_reduced_precision, allow_splitk)`` lets you also toggle whether + split-K heuristics may be used when dispatching to cuBLASLt. ``allow_splitk`` defaults to ``True``. ``` ```{eval-rst} diff --git a/test/test_cuda.py b/test/test_cuda.py index 74cfdec2e904..8effa6ca43ef 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -749,25 +749,67 @@ print(t.is_pinned()) def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self): orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + orig_splitk = ( + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction_split_k + ) self.assertEqual( - torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig + torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), + (orig, orig_splitk), ) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig self.assertEqual( - torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig + torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), + (not orig, True), + ) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + False, + False, + ) + self.assertEqual( + torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), + (False, False), + ) + with self.assertRaisesRegex(RuntimeError, "allow_splitk=False"): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + True, + False, + ) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + orig, + orig_splitk, ) - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self): orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + orig_splitk = ( + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction_split_k + ) self.assertEqual( - torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig + torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), + (orig, orig_splitk), ) torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig self.assertEqual( - torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig + torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), + (not orig, True), + ) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + False, + False, + ) + self.assertEqual( + torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), + (False, False), + ) + with self.assertRaisesRegex(RuntimeError, "allow_splitk=False"): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + True, + False, + ) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + orig, + orig_splitk, ) - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig def test_cublas_allow_fp16_accumulation_get_set(self): orig = torch.backends.cuda.matmul.allow_fp16_accumulation diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7e29cf9fa218..a6885945e55e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1254,17 +1254,19 @@ def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecis def _set_float32_matmul_precision( arg: str, ) -> None: ... # THPModule_setFloat32MatmulPrecision -def _get_cublas_allow_fp16_reduced_precision_reduction() -> ( - _bool -): ... # THPModule_allowFP16ReductionCuBLAS +def _get_cublas_allow_fp16_reduced_precision_reduction() -> tuple[ + _bool, _bool +]: ... # THPModule_allowFP16ReductionCuBLAS def _set_cublas_allow_fp16_reduced_precision_reduction( arg: _bool, + allow_splitk: _bool = ..., ) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS -def _get_cublas_allow_bf16_reduced_precision_reduction() -> ( - _bool -): ... # THPModule_allowBF16ReductionCuBLAS +def _get_cublas_allow_bf16_reduced_precision_reduction() -> tuple[ + _bool, _bool +]: ... # THPModule_allowBF16ReductionCuBLAS def _set_cublas_allow_bf16_reduced_precision_reduction( arg: _bool, + allow_splitk: _bool = ..., ) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS def _get_cublas_allow_fp16_accumulation() -> ( _bool diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index c16ce22a1ded..19211bd4491b 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -44,7 +44,18 @@ if TYPE_CHECKING: Node = torch.fx.Node Region = list[Node] IdenticalNodes = list[Node] -GlobalStateKey = tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool] +GlobalStateKey = tuple[ + bool, + bool, + int, + tuple[bool, bool], + tuple[bool, bool], + torch.dtype, + bool, + bool, + bool, + bool, +] log = logging.getLogger(__name__) graph_expansion_log = torch._logging.getArtifactLogger( diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 87327428461a..5f70c28bf2d2 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import contextlib -from typing import Union +from typing import Any, Union from typing_extensions import deprecated import torch @@ -126,13 +126,54 @@ class cuFFTPlanCacheManager: class cuBLASModule: + @staticmethod + def _parse_reduction_setting(value: Any, attr_name: str) -> tuple[bool, bool]: + def _ensure_bool(obj: Any, which: str) -> bool: + if isinstance(obj, bool): + return obj + raise TypeError( + f"{attr_name} expects a bool for {which}, but got {type(obj)!r}" + ) + + if isinstance(value, bool): + return value, True + if isinstance(value, (list, tuple)): + if not value: + raise TypeError(f"{attr_name} expects at least one boolean argument") + if len(value) > 2: + raise TypeError(f"{attr_name} expects at most two boolean arguments") + allow_reduced_precision = _ensure_bool(value[0], "allow_reduced_precision") + if len(value) == 1: + return allow_reduced_precision, True + allow_splitk = _ensure_bool(value[1], "allow_splitk") + return allow_reduced_precision, allow_splitk + raise TypeError( + f"{attr_name} expects a bool or a tuple/list of bools, but got {type(value)!r}" + ) + def __getattr__(self, name): if name == "allow_tf32": return torch._C._get_cublas_allow_tf32() elif name == "allow_fp16_reduced_precision_reduction": - return torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + allow_reduced_precision, _ = ( + torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + ) + return allow_reduced_precision + elif name == "allow_fp16_reduced_precision_reduction_split_k": + _, allow_splitk = ( + torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + ) + return allow_splitk elif name == "allow_bf16_reduced_precision_reduction": - return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + allow_reduced_precision, _ = ( + torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + ) + return allow_reduced_precision + elif name == "allow_bf16_reduced_precision_reduction_split_k": + _, allow_splitk = ( + torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + ) + return allow_splitk elif name == "allow_fp16_accumulation": return torch._C._get_cublas_allow_fp16_accumulation() elif name == "fp32_precision": @@ -143,9 +184,19 @@ class cuBLASModule: if name == "allow_tf32": return torch._C._set_cublas_allow_tf32(value) elif name == "allow_fp16_reduced_precision_reduction": - return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) + allow_reduced_precision, allow_splitk = self._parse_reduction_setting( + value, "allow_fp16_reduced_precision_reduction" + ) + return torch._C._set_cublas_allow_fp16_reduced_precision_reduction( + allow_reduced_precision, allow_splitk + ) elif name == "allow_bf16_reduced_precision_reduction": - return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) + allow_reduced_precision, allow_splitk = self._parse_reduction_setting( + value, "allow_bf16_reduced_precision_reduction" + ) + return torch._C._set_cublas_allow_bf16_reduced_precision_reduction( + allow_reduced_precision, allow_splitk + ) elif name == "allow_fp16_accumulation": return torch._C._set_cublas_allow_fp16_accumulation(value) elif name == "fp32_precision": diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 567c1264fa14..e2e5da301f2b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1224,14 +1224,30 @@ static PyObject* THPModule_allowTF32CuBLAS( static PyObject* THPModule_setAllowFP16ReductionCuBLAS( PyObject* _unused, - PyObject* arg) { + PyObject* args) { HANDLE_TH_ERRORS + PyObject* allow_reduction_obj = nullptr; + PyObject* allow_splitk_obj = Py_None; + if (!PyArg_ParseTuple(args, "O|O", &allow_reduction_obj, &allow_splitk_obj)) { + return nullptr; + } TORCH_CHECK( - PyBool_Check(arg), - "set_allow_fp16_reduction_cublas expects a bool, " + PyBool_Check(allow_reduction_obj), + "set_allow_fp16_reduction_cublas expects a bool for allow_reduced_precision, " "but got ", - THPUtils_typename(arg)); - at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True); + THPUtils_typename(allow_reduction_obj)); + bool allow_reduction = allow_reduction_obj == Py_True; + bool allow_splitk = true; + if (allow_splitk_obj != Py_None) { + TORCH_CHECK( + PyBool_Check(allow_splitk_obj), + "set_allow_fp16_reduction_cublas expects a bool for allow_splitk, " + "but got ", + THPUtils_typename(allow_splitk_obj)); + allow_splitk = allow_splitk_obj == Py_True; + } + at::globalContext().setAllowFP16ReductionCuBLAS( + allow_reduction, allow_splitk); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1239,22 +1255,43 @@ static PyObject* THPModule_setAllowFP16ReductionCuBLAS( static PyObject* THPModule_allowFP16ReductionCuBLAS( PyObject* _unused, PyObject* noargs) { - if (at::globalContext().allowFP16ReductionCuBLAS()) { - Py_RETURN_TRUE; - } - Py_RETURN_FALSE; + auto option = at::globalContext().allowFP16ReductionCuBLAS(); + bool allow_reduced_precision = + option == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK; + bool allow_splitk = option != + at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK; + return PyTuple_Pack( + 2, + allow_reduced_precision ? Py_True : Py_False, + allow_splitk ? Py_True : Py_False); } static PyObject* THPModule_setAllowBF16ReductionCuBLAS( PyObject* _unused, - PyObject* arg) { + PyObject* args) { HANDLE_TH_ERRORS + PyObject* allow_reduction_obj = nullptr; + PyObject* allow_splitk_obj = Py_None; + if (!PyArg_ParseTuple(args, "O|O", &allow_reduction_obj, &allow_splitk_obj)) { + return nullptr; + } TORCH_CHECK( - PyBool_Check(arg), - "set_allow_bf16_reduction_cublas expects a bool, " + PyBool_Check(allow_reduction_obj), + "set_allow_bf16_reduction_cublas expects a bool for allow_reduced_precision, " "but got ", - THPUtils_typename(arg)); - at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True); + THPUtils_typename(allow_reduction_obj)); + bool allow_reduction = allow_reduction_obj == Py_True; + bool allow_splitk = true; + if (allow_splitk_obj != Py_None) { + TORCH_CHECK( + PyBool_Check(allow_splitk_obj), + "set_allow_bf16_reduction_cublas expects a bool for allow_splitk, " + "but got ", + THPUtils_typename(allow_splitk_obj)); + allow_splitk = allow_splitk_obj == Py_True; + } + at::globalContext().setAllowBF16ReductionCuBLAS( + allow_reduction, allow_splitk); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1262,10 +1299,15 @@ static PyObject* THPModule_setAllowBF16ReductionCuBLAS( static PyObject* THPModule_allowBF16ReductionCuBLAS( PyObject* _unused, PyObject* noargs) { - if (at::globalContext().allowBF16ReductionCuBLAS()) { - Py_RETURN_TRUE; - } - Py_RETURN_FALSE; + auto option = at::globalContext().allowBF16ReductionCuBLAS(); + bool allow_reduced_precision = + option == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK; + bool allow_splitk = option != + at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK; + return PyTuple_Pack( + 2, + allow_reduced_precision ? Py_True : Py_False, + allow_splitk ? Py_True : Py_False); } static PyObject* THPModule_setAllowFP16AccumulationCuBLAS( @@ -1736,7 +1778,7 @@ static std::initializer_list TorchMethods = { nullptr}, {"_set_cublas_allow_fp16_reduced_precision_reduction", THPModule_setAllowFP16ReductionCuBLAS, - METH_O, + METH_VARARGS, nullptr}, {"_get_cublas_allow_bf16_reduced_precision_reduction", THPModule_allowBF16ReductionCuBLAS, @@ -1744,7 +1786,7 @@ static std::initializer_list TorchMethods = { nullptr}, {"_set_cublas_allow_bf16_reduced_precision_reduction", THPModule_setAllowBF16ReductionCuBLAS, - METH_O, + METH_VARARGS, nullptr}, {"_get_cublas_allow_fp16_accumulation", THPModule_allowFP16AccumulationCuBLAS, diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 1b38c0b0acb8..bdcaf71c05d5 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include @@ -703,8 +704,10 @@ struct GlobalStateGuard { json_j["deterministic_algorithms_warn_only"] = json_t._deterministic_algorithms_warn_only; json_j["allow_tf32"] = json_t._allow_tf32; - json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce; - json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce; + json_j["allow_fp16_reduce"] = + static_cast(json_t._allow_fp16_reduce); + json_j["allow_bf16_reduce"] = + static_cast(json_t._allow_bf16_reduce); json_j["num_threads"] = json_t._num_threads; json_j["default_dtype"] = json_t._default_dtype.toScalarType(); } @@ -720,8 +723,10 @@ struct GlobalStateGuard { json_t._deterministic_algorithms_warn_only = json_j.at("deterministic_algorithms_warn_only"); json_t._allow_tf32 = json_j.at("allow_tf32"); - json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce"); - json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce"); + json_t._allow_fp16_reduce = static_cast( + static_cast(json_j.at("allow_fp16_reduce"))); + json_t._allow_bf16_reduce = static_cast( + static_cast(json_j.at("allow_bf16_reduce"))); json_t._num_threads = json_j.at("num_threads"); json_t._default_dtype = caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype")); @@ -734,8 +739,8 @@ struct GlobalStateGuard { bool _deterministic_algorithms; bool _deterministic_algorithms_warn_only; bool _allow_tf32; - bool _allow_fp16_reduce; - bool _allow_bf16_reduce; + at::CuBLASReductionOption _allow_fp16_reduce; + at::CuBLASReductionOption _allow_bf16_reduce; int _num_threads; caffe2::TypeMeta _default_dtype; // TODO(jansel): we should guard on more state as inductor starts using it