diff --git a/torch/__init__.py b/torch/__init__.py index da1e2374bb89..faedf38c6673 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2279,10 +2279,13 @@ class _TorchCompileInductorWrapper: compiler_name = "inductor" def __init__(self, mode, options, dynamic): + from torch._inductor.compiler_bisector import CompilerBisector + self.config: _Dict[str, _Any] = {} self.dynamic = dynamic self.apply_mode(mode) self.apply_options(options) + self.apply_options(CompilerBisector.get_config_change("inductor")) if self.config.get("triton.cudagraphs", False): os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" @@ -2300,26 +2303,12 @@ class _TorchCompileInductorWrapper: ) def apply_mode(self, mode: _Optional[str]): - if mode is None or mode == "default": - pass - elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}: + if mode and mode != "default": from torch._inductor import list_mode_options self.apply_options(list_mode_options(mode, self.dynamic)) - else: - raise RuntimeError( - f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs" - ) def apply_options(self, options: _Optional[_Dict[str, _Any]]): - from torch._inductor.compiler_bisector import CompilerBisector - - if bisect_changes := CompilerBisector.get_config_change("inductor"): - options = {} if options is None else options - options = ( - {**bisect_changes} if options is None else {**options, **bisect_changes} # type: ignore[dict-item] - ) - if not options: return diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index c84b689ed7a5..f41fde8177db 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -264,7 +264,12 @@ def list_mode_options( "coordinate_descent_tuning": True, }, } - return mode_options[mode] if mode else mode_options # type: ignore[return-value] + try: + return mode_options[mode] if mode else mode_options + except KeyError as e: + raise RuntimeError( + f"Unrecognized mode={mode}, should be one of: {', '.join(mode_options.keys())}" + ) from e def list_options() -> list[str]: