From 17b71e5d6a8a45c33e01231e38056e7da5857c88 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Sat, 14 Dec 2024 09:24:12 -0800 Subject: [PATCH] Add config alias (#142088) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142088 Approved by: https://github.com/c00w --- test/test_utils_config_module.py | 42 ++++++++-- torch/testing/_internal/fake_config_module.py | 15 ++-- .../testing/_internal/fake_config_module2.py | 8 ++ torch/utils/_config_module.py | 78 ++++++++++++++++--- 4 files changed, 119 insertions(+), 24 deletions(-) create mode 100644 torch/testing/_internal/fake_config_module2.py diff --git a/test/test_utils_config_module.py b/test/test_utils_config_module.py index e027fc1934b2..85a8be600e50 100644 --- a/test/test_utils_config_module.py +++ b/test/test_utils_config_module.py @@ -9,7 +9,10 @@ os.environ["ENV_FALSE"] = "0" from typing import Optional -from torch.testing._internal import fake_config_module as config +from torch.testing._internal import ( + fake_config_module as config, + fake_config_module2 as config2, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.utils._config_module import _UNSET_SENTINEL, Config @@ -98,7 +101,7 @@ class TestConfigModule(TestCase): def test_save_config(self): p = config.save_config() - self.assertEqual( + self.assertDictEqual( pickle.loads(p), { "_cache_config_ignore_prefix": ["magic_cache_config"], @@ -133,7 +136,7 @@ class TestConfigModule(TestCase): def test_save_config_portable(self): p = config.save_config_portable() - self.assertEqual( + self.assertDictEqual( p, { "e_bool": True, @@ -174,22 +177,36 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" ) def test_get_hash(self): - self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") + hash_value = b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ" + self.assertEqual( + config.get_hash(), + hash_value, + ) # Test cached value - self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") - self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") + self.assertEqual( + config.get_hash(), + hash_value, + ) + self.assertEqual( + config.get_hash(), + hash_value, + ) config._hash_digest = "fake" self.assertEqual(config.get_hash(), "fake") config.e_bool = False self.assertNotEqual( - config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ" + config.get_hash(), + hash_value, ) config.e_bool = True # Test ignored values config.e_compile_ignored = False - self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") + self.assertEqual( + config.get_hash(), + hash_value, + ) def test_dict_copy_semantics(self): p = config.shallow_copy_dict() @@ -319,6 +336,15 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" ): Config(default="bad", justknob="fake_knob") + def test_alias(self): + self.assertFalse(config2.e_aliasing_bool) + self.assertFalse(config.e_aliased_bool) + with config2.patch(e_aliasing_bool=True): + self.assertTrue(config2.e_aliasing_bool) + self.assertTrue(config.e_aliased_bool) + with config.patch(e_aliased_bool=True): + self.assertTrue(config2.e_aliasing_bool) + if __name__ == "__main__": run_tests() diff --git a/torch/testing/_internal/fake_config_module.py b/torch/testing/_internal/fake_config_module.py index 84951650d071..cf891a37711b 100644 --- a/torch/testing/_internal/fake_config_module.py +++ b/torch/testing/_internal/fake_config_module.py @@ -19,12 +19,15 @@ _e_ignored = True magic_cache_config_ignored = True # [@compile_ignored: debug] e_compile_ignored = True -e_config = Config(default=True) -e_jk = Config(justknob="does_not_exist", default=True) -e_jk_false = Config(justknob="does_not_exist", default=False) -e_env_default = Config(env_name_default="ENV_TRUE", default=False) -e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True) -e_env_force = Config(env_name_force="ENV_TRUE", default=False) +e_config: bool = Config(default=True) +e_jk: bool = Config(justknob="does_not_exist", default=True) +e_jk_false: bool = Config(justknob="does_not_exist", default=False) +e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False) +e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True) +e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False) +e_aliased_bool: bool = Config( + alias="torch.testing._internal.fake_config_module2.e_aliasing_bool" +) class nested: diff --git a/torch/testing/_internal/fake_config_module2.py b/torch/testing/_internal/fake_config_module2.py new file mode 100644 index 000000000000..cf17b1b9d369 --- /dev/null +++ b/torch/testing/_internal/fake_config_module2.py @@ -0,0 +1,8 @@ +import sys + +from torch.utils._config_module import install_config_module + + +e_aliasing_bool = False + +install_config_module(sys.modules[__name__]) diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index cfbf39ad20b1..501b35a686c1 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -1,6 +1,7 @@ import contextlib import copy import hashlib +import importlib import inspect import io import os @@ -20,6 +21,7 @@ from typing import ( NoReturn, Optional, Set, + Tuple, TYPE_CHECKING, TypeVar, Union, @@ -38,6 +40,9 @@ CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict]) +_UNSET_SENTINEL = object() + + @dataclass class _Config(Generic[T]): """Represents a config with richer behaviour than just a default value. @@ -49,7 +54,9 @@ class _Config(Generic[T]): This configs must be installed with install_config_module to be used Precedence Order: - env_name_force: If set, this environment variable overrides everything + alias: If set, the directly use the value of the alias. + env_name_force: If set, this environment variable has precedence over + everything after this. user_override: If a user sets a value (i.e. foo.bar=True), that has precedence over everything after this. env_name_default: If set, this environment variable will override everything @@ -65,25 +72,28 @@ class _Config(Generic[T]): Arguments: justknob: the name of the feature / JK. In OSS this is unused. default: is the value to default this knob to in OSS. + alias: The alias config to read instead. env_name_force: The environment variable to read that is a FORCE - environment variable. I.e. it overrides everything + environment variable. I.e. it overrides everything except for alias. env_name_default: The environment variable to read that changes the default behaviour. I.e. user overrides take preference. """ - default: T + default: Union[T, object] justknob: Optional[str] = None env_name_default: Optional[str] = None env_name_force: Optional[str] = None value_type: Optional[type] = None + alias: Optional[str] = None def __init__( self, - default: T, + default: Union[T, object] = _UNSET_SENTINEL, justknob: Optional[str] = None, env_name_default: Optional[str] = None, env_name_force: Optional[str] = None, value_type: Optional[type] = None, + alias: Optional[str] = None, ): # python 3.9 does not support kw_only on the dataclass :(. self.default = default @@ -91,10 +101,18 @@ class _Config(Generic[T]): self.env_name_default = env_name_default self.env_name_force = env_name_force self.value_type = value_type + self.alias = alias if self.justknob is not None: assert isinstance( self.default, bool ), f"justknobs only support booleans, {self.default} is not a boolean" + if self.alias is not None: + assert ( + default is _UNSET_SENTINEL + and justknob is None + and env_name_default is None + and env_name_force is None + ), "if alias is set, default, justknob or env var cannot be set" # In runtime, we unbox the Config[T] to a T, but typechecker cannot see this, @@ -104,24 +122,28 @@ class _Config(Generic[T]): if TYPE_CHECKING: def Config( - default: T, + default: Union[T, object] = _UNSET_SENTINEL, justknob: Optional[str] = None, env_name_default: Optional[str] = None, env_name_force: Optional[str] = None, value_type: Optional[type] = None, + alias: Optional[str] = None, ) -> T: ... else: def Config( - default: T, + default: Union[T, object] = _UNSET_SENTINEL, justknob: Optional[str] = None, env_name_default: Optional[str] = None, env_name_force: Optional[str] = None, value_type: Optional[type] = None, + alias: Optional[str] = None, ) -> _Config[T]: - return _Config(default, justknob, env_name_default, env_name_force, value_type) + return _Config( + default, justknob, env_name_default, env_name_force, value_type, alias + ) def _read_env_variable(name: str) -> Optional[bool]: @@ -243,9 +265,6 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str return assignments -_UNSET_SENTINEL = object() - - @dataclass class _ConfigEntry: # The default value specified in the configuration @@ -272,6 +291,7 @@ class _ConfigEntry: # call so the final state is correct. It's just very unintuitive. # upstream bug - python/cpython#126886 hide: bool = False + alias: Optional[str] = None def __init__(self, config: _Config): self.default = config.default @@ -279,6 +299,7 @@ class _ConfigEntry: config.value_type if config.value_type is not None else type(self.default) ) self.justknob = config.justknob + self.alias = config.alias if config.env_name_default is not None: if (env_value := _read_env_variable(config.env_name_default)) is not None: self.env_value_default = env_value @@ -309,6 +330,8 @@ class ConfigModule(ModuleType): super().__setattr__(name, value) elif name not in self._config: raise AttributeError(f"{self.__name__}.{name} does not exist") + elif self._config[name].alias is not None: + self._set_alias_val(self._config[name], value) else: self._config[name].user_override = value self._is_dirty = True @@ -321,6 +344,10 @@ class ConfigModule(ModuleType): if config.hide: raise AttributeError(f"{self.__name__}.{name} does not exist") + alias_val = self._get_alias_val(config) + if alias_val is not _UNSET_SENTINEL: + return alias_val + if config.env_value_force is not _UNSET_SENTINEL: return config.env_value_force @@ -353,6 +380,33 @@ class ConfigModule(ModuleType): self._config[name].user_override = _UNSET_SENTINEL self._config[name].hide = True + def _get_alias_module_and_name( + self, entry: _ConfigEntry + ) -> Optional[Tuple[ModuleType, str]]: + alias = entry.alias + if alias is None: + return None + module_name, constant_name = alias.rsplit(".", 1) + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise AttributeError("config alias {alias} does not exist") from e + return module, constant_name + + def _get_alias_val(self, entry: _ConfigEntry) -> Any: + data = self._get_alias_module_and_name(entry) + if data is None: + return _UNSET_SENTINEL + module, constant_name = data + constant_value = getattr(module, constant_name) + return constant_value + + def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None: + data = self._get_alias_module_and_name(entry) + assert data is not None + module, constant_name = data + setattr(module, constant_name, val) + def _is_default(self, name: str) -> bool: return self._config[name].user_override is _UNSET_SENTINEL @@ -369,6 +423,7 @@ class ConfigModule(ModuleType): This is used by a number of different user facing export methods which all have slightly different semantics re: how and what to skip. + If a config is aliased, it skips this config. Arguments: ignored_keys are keys that should not be exported. @@ -391,7 +446,10 @@ class ConfigModule(ModuleType): continue if skip_default and self._is_default(key): continue + if self._config[key].alias is not None: + continue config[key] = copy.deepcopy(getattr(self, key)) + return config def get_type(self, config_name: str) -> type: