config: Add env_name_default and env_name_force to Config (#138956)

This allows Configs to handle setting their defaults (or overriding
themselves) via environment variables.

The environment variables are resolved at install time (which is usually
import time). This is done 1) to avoid any race conditions between
threads etc..., but 2) to help encourage people to just go modify the
configs directly, vs overriding environment variables to change
pytorch behaviour.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956
Approved by: https://github.com/ezyang
ghstack dependencies: #138766
This commit is contained in:
Colin L. Rice
2024-11-01 10:26:02 -06:00
committed by PyTorch MergeBot
parent 1270c78268
commit 2a857e940d
4 changed files with 88 additions and 8 deletions

View File

@ -1,6 +1,11 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import os
import pickle import pickle
os.environ["ENV_TRUE"] = "1"
os.environ["ENV_FALSE"] = "0"
from torch.testing._internal import fake_config_module as config from torch.testing._internal import fake_config_module as config
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL from torch.utils._config_module import _UNSET_SENTINEL
@ -70,6 +75,17 @@ class TestConfigModule(TestCase):
for k in config._config: for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL config._config[k].user_override = _UNSET_SENTINEL
def test_env_name_semantics(self):
self.assertTrue(config.e_env_default)
self.assertFalse(config.e_env_default_FALSE)
self.assertTrue(config.e_env_force)
config.e_env_default = False
self.assertFalse(config.e_env_default)
config.e_env_force = False
self.assertTrue(config.e_env_force)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_save_config(self): def test_save_config(self):
p = config.save_config() p = config.save_config()
self.assertEqual( self.assertEqual(
@ -93,6 +109,9 @@ class TestConfigModule(TestCase):
"e_config": True, "e_config": True,
"e_jk": True, "e_jk": True,
"e_jk_false": False, "e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
}, },
) )
config.e_bool = False config.e_bool = False
@ -123,6 +142,9 @@ class TestConfigModule(TestCase):
"e_config": True, "e_config": True,
"e_jk": True, "e_jk": True,
"e_jk_false": False, "e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
}, },
) )
config.e_bool = False config.e_bool = False
@ -152,35 +174,35 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
def test_get_hash(self): def test_get_hash(self):
self.assertEqual( self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
) )
# Test cached value # Test cached value
self.assertEqual( self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
) )
self.assertEqual( self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
) )
config._hash_digest = "fake" config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake") self.assertEqual(config.get_hash(), "fake")
config.e_bool = False config.e_bool = False
self.assertNotEqual( self.assertNotEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
) )
config.e_bool = True config.e_bool = True
# Test ignored values # Test ignored values
config.e_compile_ignored = False config.e_compile_ignored = False
self.assertEqual( self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05" config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
) )
for k in config._config: for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL config._config[k].user_override = _UNSET_SENTINEL
def test_dict_copy_semantics(self): def test_dict_copy_semantics(self):
p = config.shallow_copy_dict() p = config.shallow_copy_dict()
self.assertEqual( self.assertDictEqual(
p, p,
{ {
"e_bool": True, "e_bool": True,
@ -202,6 +224,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_config": True, "e_config": True,
"e_jk": True, "e_jk": True,
"e_jk_false": False, "e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
}, },
) )
p2 = config.to_dict() p2 = config.to_dict()
@ -227,6 +252,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_config": True, "e_config": True,
"e_jk": True, "e_jk": True,
"e_jk_false": False, "e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
}, },
) )
p3 = config.get_config_copy() p3 = config.get_config_copy()
@ -252,6 +280,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_config": True, "e_config": True,
"e_jk": True, "e_jk": True,
"e_jk_false": False, "e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
}, },
) )

View File

@ -21,6 +21,9 @@ e_compile_ignored = True
e_config = Config(default=True) e_config = Config(default=True)
e_jk = Config(justknob="does_not_exist") e_jk = Config(justknob="does_not_exist")
e_jk_false = Config(justknob="does_not_exist", default=False) e_jk_false = Config(justknob="does_not_exist", default=False)
e_env_default = Config(env_name_default="ENV_TRUE", default=False)
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True)
e_env_force = Config(env_name_force="ENV_TRUE", default=False)
class nested: class nested:

View File

@ -28,31 +28,60 @@ class Config:
This configs must be installed with install_config_module to be used This configs must be installed with install_config_module to be used
Precedence Order: Precedence Order:
env_name_force: If set, this environment variable overrides everything
user_override: If a user sets a value (i.e. foo.bar=True), that user_override: If a user sets a value (i.e. foo.bar=True), that
has the highest precendance and is always respected has precedence over everything after this.
env_name_default: If set, this environment variable will override everything
after this.
justknob: If this pytorch installation supports justknobs, that will justknob: If this pytorch installation supports justknobs, that will
override defaults, but will not override the user_override precendence. override defaults, but will not override the user_override precendence.
default: This value is the lowest precendance, and will be used if nothing is default: This value is the lowest precendance, and will be used if nothing is
set. set.
Environment Variables:
These are interpreted to be either "0" or "1" to represent true and false.
Arguments: Arguments:
justknob: the name of the feature / JK. In OSS this is unused. justknob: the name of the feature / JK. In OSS this is unused.
default: is the value to default this knob to in OSS. default: is the value to default this knob to in OSS.
env_name_force: The environment variable to read that is a FORCE
environment variable. I.e. it overrides everything
env_name_default: The environment variable to read that changes the
default behaviour. I.e. user overrides take preference.
""" """
default: Any = True default: Any = True
justknob: Optional[str] = None justknob: Optional[str] = None
env_name_default: Optional[str] = None
env_name_force: Optional[str] = None
def __init__(self, default: Any = True, justknob: Optional[str] = None): def __init__(
self,
default: Any = True,
justknob: Optional[str] = None,
env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None,
):
# python 3.9 does not support kw_only on the dataclass :(. # python 3.9 does not support kw_only on the dataclass :(.
self.default = default self.default = default
self.justknob = justknob self.justknob = justknob
self.env_name_default = env_name_default
self.env_name_force = env_name_force
# Types saved/loaded in configs # Types saved/loaded in configs
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
def _read_env_variable(name: str) -> Optional[bool]:
value = os.environ.get(name)
if value == "1":
return True
if value == "0":
return False
return None
def install_config_module(module: ModuleType) -> None: def install_config_module(module: ModuleType) -> None:
""" """
Converts a module-level config into a `ConfigModule()`. Converts a module-level config into a `ConfigModule()`.
@ -87,6 +116,7 @@ def install_config_module(module: ModuleType) -> None:
delattr(module, key) delattr(module, key)
elif isinstance(value, Config): elif isinstance(value, Config):
config[name] = _ConfigEntry(value) config[name] = _ConfigEntry(value)
if dest is module: if dest is module:
delattr(module, key) delattr(module, key)
elif isinstance(value, type): elif isinstance(value, type):
@ -167,10 +197,19 @@ class _ConfigEntry:
user_override: Any = _UNSET_SENTINEL user_override: Any = _UNSET_SENTINEL
# The justknob to check for this config # The justknob to check for this config
justknob: Optional[str] = None justknob: Optional[str] = None
# environment variables are read at install time
env_value_force: Any = _UNSET_SENTINEL
env_value_default: Any = _UNSET_SENTINEL
def __init__(self, config: Config): def __init__(self, config: Config):
self.default = config.default self.default = config.default
self.justknob = config.justknob self.justknob = config.justknob
if config.env_name_default is not None:
if (env_value := _read_env_variable(config.env_name_default)) is not None:
self.env_value_default = env_value
if config.env_name_force is not None:
if (env_value := _read_env_variable(config.env_name_force)) is not None:
self.env_value_force = env_value
class ConfigModule(ModuleType): class ConfigModule(ModuleType):
@ -202,9 +241,16 @@ class ConfigModule(ModuleType):
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
try: try:
config = self._config[name] config = self._config[name]
if config.env_value_force is not _UNSET_SENTINEL:
return config.env_value_force
if config.user_override is not _UNSET_SENTINEL: if config.user_override is not _UNSET_SENTINEL:
return config.user_override return config.user_override
if config.env_value_default is not _UNSET_SENTINEL:
return config.env_value_default
if config.justknob is not None: if config.justknob is not None:
# JK only supports bools and ints # JK only supports bools and ints
return justknobs_check(name=config.justknob, default=config.default) return justknobs_check(name=config.justknob, default=config.default)