mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Simplify mode options, only apply CompilerBisector changes once (#145232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145232 Approved by: https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
85811631d7
commit
505ade7471
@ -2279,10 +2279,13 @@ class _TorchCompileInductorWrapper:
|
||||
compiler_name = "inductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
self.config: _Dict[str, _Any] = {}
|
||||
self.dynamic = dynamic
|
||||
self.apply_mode(mode)
|
||||
self.apply_options(options)
|
||||
self.apply_options(CompilerBisector.get_config_change("inductor"))
|
||||
|
||||
if self.config.get("triton.cudagraphs", False):
|
||||
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
|
||||
@ -2300,26 +2303,12 @@ class _TorchCompileInductorWrapper:
|
||||
)
|
||||
|
||||
def apply_mode(self, mode: _Optional[str]):
|
||||
if mode is None or mode == "default":
|
||||
pass
|
||||
elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
|
||||
if mode and mode != "default":
|
||||
from torch._inductor import list_mode_options
|
||||
|
||||
self.apply_options(list_mode_options(mode, self.dynamic))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
|
||||
)
|
||||
|
||||
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if bisect_changes := CompilerBisector.get_config_change("inductor"):
|
||||
options = {} if options is None else options
|
||||
options = (
|
||||
{**bisect_changes} if options is None else {**options, **bisect_changes} # type: ignore[dict-item]
|
||||
)
|
||||
|
||||
if not options:
|
||||
return
|
||||
|
||||
|
@ -264,7 +264,12 @@ def list_mode_options(
|
||||
"coordinate_descent_tuning": True,
|
||||
},
|
||||
}
|
||||
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
|
||||
try:
|
||||
return mode_options[mode] if mode else mode_options
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Unrecognized mode={mode}, should be one of: {', '.join(mode_options.keys())}"
|
||||
) from e
|
||||
|
||||
|
||||
def list_options() -> list[str]:
|
||||
|
Reference in New Issue
Block a user