Files
pytorch/torch/backends/miopen/__init__.py
Jeff Daily 9b29166f57 [ROCm] add flag torch.backends.miopen.immediate (#158951)
The MIOpen integration has changed over the years.  In the past, the MIOpen default for benchmark was True and if it were set to False it would use MIOpen Immediate Mode.  But with #145294 the MIOpen benchmark default changed to False and to activate immediate mode you would set the deterministic flag to True.  This has proved too restrictive because benchmark and deterministic flags are independent from immediate mode.  Thus, immediate mode needs its own flag.  Though MIOpen still masquerades behind torch.backends.cudnn and its flags, it seemed inappropriate to add an miopen-exclusive flag to the set of cudnn flags.  This PR adds the first miopen-only flag to control its immediate mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158951
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-07-25 04:01:51 +00:00

54 lines
1.2 KiB
Python

# mypy: allow-untyped-defs
import sys
from contextlib import contextmanager
import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
def set_flags(
_immediate=None,
):
orig_flags = (torch._C._get_miopen_immediate(),)
if _immediate is not None:
torch._C._set_miopen_immediate(_immediate)
return orig_flags
@contextmanager
def flags(
immediate=False,
):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(
immediate,
)
try:
yield
finally:
# recover the previous values
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
# The magic here is to allow us to intercept code like this:
#
# torch.backends.<miopen|mkldnn>.immediate = True
class MiopenModule(PropModule):
def __init__(self, m, name):
super().__init__(m, name)
immediate = ContextProp(
torch._C._get_miopen_immediate, torch._C._set_miopen_immediate
)
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__)
# Add type annotation for the replaced module
immediate: bool