[ROCm] add flag torch.backends.miopen.immediate (#158951)

The MIOpen integration has changed over the years.  In the past, the MIOpen default for benchmark was True and if it were set to False it would use MIOpen Immediate Mode.  But with #145294 the MIOpen benchmark default changed to False and to activate immediate mode you would set the deterministic flag to True.  This has proved too restrictive because benchmark and deterministic flags are independent from immediate mode.  Thus, immediate mode needs its own flag.  Though MIOpen still masquerades behind torch.backends.cudnn and its flags, it seemed inappropriate to add an miopen-exclusive flag to the set of cudnn flags.  This PR adds the first miopen-only flag to control its immediate mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158951
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-07-25 04:01:51 +00:00
committed by PyTorch MergeBot
parent 1fced0c7d5
commit 9b29166f57
9 changed files with 112 additions and 12 deletions

View File

@ -334,6 +334,14 @@ void Context::setBenchmarkLimitCuDNN(int b) {
benchmark_limit_cudnn = b;
}
bool Context::immediateMiopen() const {
return immediate_miopen;
}
void Context::setImmediateMiopen(bool b) {
immediate_miopen = b;
}
bool Context::allowTF32CuBLAS() const {
#ifdef USE_ROCM
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);

View File

@ -205,6 +205,8 @@ class TORCH_API Context {
void setBenchmarkCuDNN(bool);
int benchmarkLimitCuDNN() const;
void setBenchmarkLimitCuDNN(int);
bool immediateMiopen() const;
void setImmediateMiopen(bool);
bool deterministicCuDNN() const;
void setDeterministicCuDNN(bool);
bool deterministicMkldnn() const;
@ -440,6 +442,7 @@ class TORCH_API Context {
bool enabled_overrideable = true;
bool allow_fp16_bf16_reduction_mathSDP = false;
bool benchmark_cudnn = false;
bool immediate_miopen = false;
Float32MatmulPrecision float32_matmul_precision =
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
? at::Float32MatmulPrecision::HIGH

View File

@ -724,8 +724,7 @@ void raw_miopen_convolution_forward_out(
args.odesc.set(output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
if (deterministic && !benchmark) {
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
if (at::globalContext().immediateMiopen()) {
uint64_t solution_id;
Workspace workspace = chooseSolution<miopenConvFwdAlgorithm_t>(args, &solution_id);
@ -833,8 +832,7 @@ void raw_miopen_depthwise_convolution_forward_out(
args.odesc.set(output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
if (deterministic && !benchmark) {
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
if (at::globalContext().immediateMiopen()) {
uint64_t solution_id;
Workspace workspace = chooseSolution<miopenConvFwdAlgorithm_t>(args, &solution_id);
@ -989,8 +987,7 @@ void raw_miopen_convolution_backward_weight_out(
args.odesc.set(grad_output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
if (deterministic && !benchmark) {
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
if (at::globalContext().immediateMiopen()) {
uint64_t solution_id;
Workspace workspace = chooseSolution<miopenConvBwdWeightsAlgorithm_t>(args, &solution_id);
@ -1034,8 +1031,7 @@ void raw_miopen_depthwise_convolution_backward_weight_out(
args.odesc.set(grad_output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
if (deterministic && !benchmark) {
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
if (at::globalContext().immediateMiopen()) {
uint64_t solution_id;
Workspace workspace = chooseSolution<miopenConvBwdWeightsAlgorithm_t>(args, &solution_id);
@ -1240,8 +1236,7 @@ void raw_miopen_convolution_backward_input_out(
args.odesc.set(grad_output);
args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
if (deterministic && !benchmark) {
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
if (at::globalContext().immediateMiopen()) {
uint64_t solution_id;
Workspace workspace = chooseSolution<miopenConvBwdDataAlgorithm_t>(args, &solution_id);
@ -1350,8 +1345,7 @@ void raw_miopen_depthwise_convolution_backward_input_out(
args.odesc.set(grad_output);
args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
if (deterministic && !benchmark) {
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
if (at::globalContext().immediateMiopen()) {
uint64_t solution_id;
Workspace workspace = chooseSolution<miopenConvBwdDataAlgorithm_t>(args, &solution_id);

View File

@ -253,6 +253,19 @@ These backends include:
```
## torch.backends.miopen
```{eval-rst}
.. automodule:: torch.backends.miopen
```
```{eval-rst}
.. attribute:: immediate
A :class:`bool` that, if True, causes MIOpen to use Immediate Mode
(https://rocm.docs.amd.com/projects/MIOpen/en/latest/how-to/find-and-immediate.html).
```
## torch.backends.mps
```{eval-rst}

View File

@ -1213,6 +1213,8 @@ def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn
def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
def _get_miopen_immediate() -> _bool: ... # THPModule_userImmediateMiopen
def _set_miopen_immediate(arg: _bool) -> None: ... # THPModule_setUserImmediateMiopen
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn

View File

@ -659,6 +659,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._get_cublas_allow_tf32",
"torch._C._get_cudnn_allow_tf32",
"torch._C._get_cudnn_benchmark",
"torch._C._get_miopen_immediate",
"torch._C._get_cudnn_deterministic",
"torch._C._get_cudnn_enabled",
"torch._C._get_custom_class_python_wrapper",

View File

@ -131,6 +131,7 @@ from torch.backends import (
cusparselt as cusparselt,
kleidiai as kleidiai,
mha as mha,
miopen as miopen,
mkl as mkl,
mkldnn as mkldnn,
mps as mps,

View File

@ -0,0 +1,53 @@
# mypy: allow-untyped-defs
import sys
from contextlib import contextmanager
import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
def set_flags(
_immediate=None,
):
orig_flags = (torch._C._get_miopen_immediate(),)
if _immediate is not None:
torch._C._set_miopen_immediate(_immediate)
return orig_flags
@contextmanager
def flags(
immediate=False,
):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(
immediate,
)
try:
yield
finally:
# recover the previous values
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
# The magic here is to allow us to intercept code like this:
#
# torch.backends.<miopen|mkldnn>.immediate = True
class MiopenModule(PropModule):
def __init__(self, m, name):
super().__init__(m, name)
immediate = ContextProp(
torch._C._get_miopen_immediate, torch._C._set_miopen_immediate
)
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__)
# Add type annotation for the replaced module
immediate: bool

View File

@ -1172,6 +1172,29 @@ static PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) {
Py_RETURN_FALSE;
}
static PyObject* THPModule_setImmediateMiopen(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
PyBool_Check(arg),
"set_immediate_miopen expects a bool, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setImmediateMiopen(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_immediateMiopen(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().immediateMiopen()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
static PyObject* THPModule_setAllowTF32CuBLAS(
PyObject* _unused,
PyObject* arg) {
@ -1642,6 +1665,8 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
{"_set_onednn_allow_tf32", THPModule_setAllowTF32OneDNN, METH_O, nullptr},
{"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
{"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
{"_get_miopen_immediate", THPModule_immediateMiopen, METH_NOARGS, nullptr},
{"_set_miopen_immediate", THPModule_setImmediateMiopen, METH_O, nullptr},
{"_get_cudnn_deterministic",
THPModule_deterministicCuDNN,
METH_NOARGS,