mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The MIOpen integration has changed over the years. In the past, the MIOpen default for benchmark was True and if it were set to False it would use MIOpen Immediate Mode. But with #145294 the MIOpen benchmark default changed to False and to activate immediate mode you would set the deterministic flag to True. This has proved too restrictive because benchmark and deterministic flags are independent from immediate mode. Thus, immediate mode needs its own flag. Though MIOpen still masquerades behind torch.backends.cudnn and its flags, it seemed inappropriate to add an miopen-exclusive flag to the set of cudnn flags. This PR adds the first miopen-only flag to control its immediate mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158951 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
54 lines
1.2 KiB
Python
54 lines
1.2 KiB
Python
# mypy: allow-untyped-defs
|
|
import sys
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
|
|
|
|
|
|
def set_flags(
|
|
_immediate=None,
|
|
):
|
|
orig_flags = (torch._C._get_miopen_immediate(),)
|
|
if _immediate is not None:
|
|
torch._C._set_miopen_immediate(_immediate)
|
|
return orig_flags
|
|
|
|
|
|
@contextmanager
|
|
def flags(
|
|
immediate=False,
|
|
):
|
|
with __allow_nonbracketed_mutation():
|
|
orig_flags = set_flags(
|
|
immediate,
|
|
)
|
|
try:
|
|
yield
|
|
finally:
|
|
# recover the previous values
|
|
with __allow_nonbracketed_mutation():
|
|
set_flags(*orig_flags)
|
|
|
|
|
|
# The magic here is to allow us to intercept code like this:
|
|
#
|
|
# torch.backends.<miopen|mkldnn>.immediate = True
|
|
|
|
|
|
class MiopenModule(PropModule):
|
|
def __init__(self, m, name):
|
|
super().__init__(m, name)
|
|
|
|
immediate = ContextProp(
|
|
torch._C._get_miopen_immediate, torch._C._set_miopen_immediate
|
|
)
|
|
|
|
|
|
# This is the sys.modules replacement trick, see
|
|
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
|
|
sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__)
|
|
|
|
# Add type annotation for the replaced module
|
|
immediate: bool
|