[inductor] Simplify mode options, only apply CompilerBisector changes once (#145232)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145232
Approved by: https://github.com/yanboliang
This commit is contained in:
Jason Ansel
2025-01-20 16:33:16 -08:00
committed by PyTorch MergeBot
parent 85811631d7
commit 505ade7471
2 changed files with 10 additions and 16 deletions

View File

@ -2279,10 +2279,13 @@ class _TorchCompileInductorWrapper:
compiler_name = "inductor"
def __init__(self, mode, options, dynamic):
from torch._inductor.compiler_bisector import CompilerBisector
self.config: _Dict[str, _Any] = {}
self.dynamic = dynamic
self.apply_mode(mode)
self.apply_options(options)
self.apply_options(CompilerBisector.get_config_change("inductor"))
if self.config.get("triton.cudagraphs", False):
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
@ -2300,26 +2303,12 @@ class _TorchCompileInductorWrapper:
)
def apply_mode(self, mode: _Optional[str]):
if mode is None or mode == "default":
pass
elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
if mode and mode != "default":
from torch._inductor import list_mode_options
self.apply_options(list_mode_options(mode, self.dynamic))
else:
raise RuntimeError(
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
)
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
from torch._inductor.compiler_bisector import CompilerBisector
if bisect_changes := CompilerBisector.get_config_change("inductor"):
options = {} if options is None else options
options = (
{**bisect_changes} if options is None else {**options, **bisect_changes} # type: ignore[dict-item]
)
if not options:
return

View File

@ -264,7 +264,12 @@ def list_mode_options(
"coordinate_descent_tuning": True,
},
}
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
try:
return mode_options[mode] if mode else mode_options
except KeyError as e:
raise RuntimeError(
f"Unrecognized mode={mode}, should be one of: {', '.join(mode_options.keys())}"
) from e
def list_options() -> list[str]: