From 2a857e940de5b605068c86c0cfd4171e8a1651a6 Mon Sep 17 00:00:00 2001 From: "Colin L. Rice" Date: Fri, 1 Nov 2024 10:26:02 -0600 Subject: [PATCH] config: Add env_name_default and env_name_force to Config (#138956) This allows Configs to handle setting their defaults (or overriding themselves) via environment variables. The environment variables are resolved at install time (which is usually import time). This is done 1) to avoid any race conditions between threads etc..., but 2) to help encourage people to just go modify the configs directly, vs overriding environment variables to change pytorch behaviour. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956 Approved by: https://github.com/ezyang ghstack dependencies: #138766 --- .../TestConfigModule.test_env_name_semantics | 0 test/test_utils_config_module.py | 43 +++++++++++++--- torch/testing/_internal/fake_config_module.py | 3 ++ torch/utils/_config_module.py | 50 ++++++++++++++++++- 4 files changed, 88 insertions(+), 8 deletions(-) create mode 100644 test/dynamo_skips/TestConfigModule.test_env_name_semantics diff --git a/test/dynamo_skips/TestConfigModule.test_env_name_semantics b/test/dynamo_skips/TestConfigModule.test_env_name_semantics new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/test_utils_config_module.py b/test/test_utils_config_module.py index 84f6f2a2b0e2..c8e56381abdb 100644 --- a/test/test_utils_config_module.py +++ b/test/test_utils_config_module.py @@ -1,6 +1,11 @@ # Owner(s): ["module: unknown"] +import os import pickle + +os.environ["ENV_TRUE"] = "1" +os.environ["ENV_FALSE"] = "0" + from torch.testing._internal import fake_config_module as config from torch.testing._internal.common_utils import run_tests, TestCase from torch.utils._config_module import _UNSET_SENTINEL @@ -70,6 +75,17 @@ class TestConfigModule(TestCase): for k in config._config: config._config[k].user_override = _UNSET_SENTINEL + def test_env_name_semantics(self): + self.assertTrue(config.e_env_default) + self.assertFalse(config.e_env_default_FALSE) + self.assertTrue(config.e_env_force) + config.e_env_default = False + self.assertFalse(config.e_env_default) + config.e_env_force = False + self.assertTrue(config.e_env_force) + for k in config._config: + config._config[k].user_override = _UNSET_SENTINEL + def test_save_config(self): p = config.save_config() self.assertEqual( @@ -93,6 +109,9 @@ class TestConfigModule(TestCase): "e_config": True, "e_jk": True, "e_jk_false": False, + "e_env_default": True, + "e_env_default_FALSE": False, + "e_env_force": True, }, ) config.e_bool = False @@ -123,6 +142,9 @@ class TestConfigModule(TestCase): "e_config": True, "e_jk": True, "e_jk_false": False, + "e_env_default": True, + "e_env_default_FALSE": False, + "e_env_force": True, }, ) config.e_bool = False @@ -152,35 +174,35 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" def test_get_hash(self): self.assertEqual( - config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" + config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" ) # Test cached value self.assertEqual( - config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" + config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" ) self.assertEqual( - config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" + config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" ) config._hash_digest = "fake" self.assertEqual(config.get_hash(), "fake") config.e_bool = False self.assertNotEqual( - config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" + config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" ) config.e_bool = True # Test ignored values config.e_compile_ignored = False self.assertEqual( - config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" + config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" ) for k in config._config: config._config[k].user_override = _UNSET_SENTINEL def test_dict_copy_semantics(self): p = config.shallow_copy_dict() - self.assertEqual( + self.assertDictEqual( p, { "e_bool": True, @@ -202,6 +224,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" "e_config": True, "e_jk": True, "e_jk_false": False, + "e_env_default": True, + "e_env_default_FALSE": False, + "e_env_force": True, }, ) p2 = config.to_dict() @@ -227,6 +252,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" "e_config": True, "e_jk": True, "e_jk_false": False, + "e_env_default": True, + "e_env_default_FALSE": False, + "e_env_force": True, }, ) p3 = config.get_config_copy() @@ -252,6 +280,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" "e_config": True, "e_jk": True, "e_jk_false": False, + "e_env_default": True, + "e_env_default_FALSE": False, + "e_env_force": True, }, ) diff --git a/torch/testing/_internal/fake_config_module.py b/torch/testing/_internal/fake_config_module.py index 999839f3f38d..1d5bed8fe0ee 100644 --- a/torch/testing/_internal/fake_config_module.py +++ b/torch/testing/_internal/fake_config_module.py @@ -21,6 +21,9 @@ e_compile_ignored = True e_config = Config(default=True) e_jk = Config(justknob="does_not_exist") e_jk_false = Config(justknob="does_not_exist", default=False) +e_env_default = Config(env_name_default="ENV_TRUE", default=False) +e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True) +e_env_force = Config(env_name_force="ENV_TRUE", default=False) class nested: diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index c2b33017f6b2..41569465d252 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -28,31 +28,60 @@ class Config: This configs must be installed with install_config_module to be used Precedence Order: + env_name_force: If set, this environment variable overrides everything user_override: If a user sets a value (i.e. foo.bar=True), that - has the highest precendance and is always respected + has precedence over everything after this. + env_name_default: If set, this environment variable will override everything + after this. justknob: If this pytorch installation supports justknobs, that will override defaults, but will not override the user_override precendence. default: This value is the lowest precendance, and will be used if nothing is set. + Environment Variables: + These are interpreted to be either "0" or "1" to represent true and false. + Arguments: justknob: the name of the feature / JK. In OSS this is unused. default: is the value to default this knob to in OSS. + env_name_force: The environment variable to read that is a FORCE + environment variable. I.e. it overrides everything + env_name_default: The environment variable to read that changes the + default behaviour. I.e. user overrides take preference. """ default: Any = True justknob: Optional[str] = None + env_name_default: Optional[str] = None + env_name_force: Optional[str] = None - def __init__(self, default: Any = True, justknob: Optional[str] = None): + def __init__( + self, + default: Any = True, + justknob: Optional[str] = None, + env_name_default: Optional[str] = None, + env_name_force: Optional[str] = None, + ): # python 3.9 does not support kw_only on the dataclass :(. self.default = default self.justknob = justknob + self.env_name_default = env_name_default + self.env_name_force = env_name_force # Types saved/loaded in configs CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) +def _read_env_variable(name: str) -> Optional[bool]: + value = os.environ.get(name) + if value == "1": + return True + if value == "0": + return False + return None + + def install_config_module(module: ModuleType) -> None: """ Converts a module-level config into a `ConfigModule()`. @@ -87,6 +116,7 @@ def install_config_module(module: ModuleType) -> None: delattr(module, key) elif isinstance(value, Config): config[name] = _ConfigEntry(value) + if dest is module: delattr(module, key) elif isinstance(value, type): @@ -167,10 +197,19 @@ class _ConfigEntry: user_override: Any = _UNSET_SENTINEL # The justknob to check for this config justknob: Optional[str] = None + # environment variables are read at install time + env_value_force: Any = _UNSET_SENTINEL + env_value_default: Any = _UNSET_SENTINEL def __init__(self, config: Config): self.default = config.default self.justknob = config.justknob + if config.env_name_default is not None: + if (env_value := _read_env_variable(config.env_name_default)) is not None: + self.env_value_default = env_value + if config.env_name_force is not None: + if (env_value := _read_env_variable(config.env_name_force)) is not None: + self.env_value_force = env_value class ConfigModule(ModuleType): @@ -202,9 +241,16 @@ class ConfigModule(ModuleType): def __getattr__(self, name: str) -> Any: try: config = self._config[name] + + if config.env_value_force is not _UNSET_SENTINEL: + return config.env_value_force + if config.user_override is not _UNSET_SENTINEL: return config.user_override + if config.env_value_default is not _UNSET_SENTINEL: + return config.env_value_default + if config.justknob is not None: # JK only supports bools and ints return justknobs_check(name=config.justknob, default=config.default)