config: Add env_name_default and env_name_force to Config (#138956)

This allows Configs to handle setting their defaults (or overriding
themselves) via environment variables.

The environment variables are resolved at install time (which is usually
import time). This is done 1) to avoid any race conditions between
threads etc..., but 2) to help encourage people to just go modify the
configs directly, vs overriding environment variables to change
pytorch behaviour.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956
Approved by: https://github.com/ezyang
ghstack dependencies: #138766
This commit is contained in:
Colin L. Rice
2024-11-01 10:26:02 -06:00
committed by PyTorch MergeBot
parent 1270c78268
commit 2a857e940d
4 changed files with 88 additions and 8 deletions

View File

@ -1,6 +1,11 @@
# Owner(s): ["module: unknown"]
import os
import pickle
os.environ["ENV_TRUE"] = "1"
os.environ["ENV_FALSE"] = "0"
from torch.testing._internal import fake_config_module as config
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL
@ -70,6 +75,17 @@ class TestConfigModule(TestCase):
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_env_name_semantics(self):
self.assertTrue(config.e_env_default)
self.assertFalse(config.e_env_default_FALSE)
self.assertTrue(config.e_env_force)
config.e_env_default = False
self.assertFalse(config.e_env_default)
config.e_env_force = False
self.assertTrue(config.e_env_force)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_save_config(self):
p = config.save_config()
self.assertEqual(
@ -93,6 +109,9 @@ class TestConfigModule(TestCase):
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
config.e_bool = False
@ -123,6 +142,9 @@ class TestConfigModule(TestCase):
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
config.e_bool = False
@ -152,35 +174,35 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
def test_get_hash(self):
self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
# Test cached value
self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake")
config.e_bool = False
self.assertNotEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
config.e_bool = True
# Test ignored values
config.e_compile_ignored = False
self.assertEqual(
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_dict_copy_semantics(self):
p = config.shallow_copy_dict()
self.assertEqual(
self.assertDictEqual(
p,
{
"e_bool": True,
@ -202,6 +224,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
p2 = config.to_dict()
@ -227,6 +252,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
p3 = config.get_config_copy()
@ -252,6 +280,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)

View File

@ -21,6 +21,9 @@ e_compile_ignored = True
e_config = Config(default=True)
e_jk = Config(justknob="does_not_exist")
e_jk_false = Config(justknob="does_not_exist", default=False)
e_env_default = Config(env_name_default="ENV_TRUE", default=False)
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True)
e_env_force = Config(env_name_force="ENV_TRUE", default=False)
class nested:

View File

@ -28,31 +28,60 @@ class Config:
This configs must be installed with install_config_module to be used
Precedence Order:
env_name_force: If set, this environment variable overrides everything
user_override: If a user sets a value (i.e. foo.bar=True), that
has the highest precendance and is always respected
has precedence over everything after this.
env_name_default: If set, this environment variable will override everything
after this.
justknob: If this pytorch installation supports justknobs, that will
override defaults, but will not override the user_override precendence.
default: This value is the lowest precendance, and will be used if nothing is
set.
Environment Variables:
These are interpreted to be either "0" or "1" to represent true and false.
Arguments:
justknob: the name of the feature / JK. In OSS this is unused.
default: is the value to default this knob to in OSS.
env_name_force: The environment variable to read that is a FORCE
environment variable. I.e. it overrides everything
env_name_default: The environment variable to read that changes the
default behaviour. I.e. user overrides take preference.
"""
default: Any = True
justknob: Optional[str] = None
env_name_default: Optional[str] = None
env_name_force: Optional[str] = None
def __init__(self, default: Any = True, justknob: Optional[str] = None):
def __init__(
self,
default: Any = True,
justknob: Optional[str] = None,
env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None,
):
# python 3.9 does not support kw_only on the dataclass :(.
self.default = default
self.justknob = justknob
self.env_name_default = env_name_default
self.env_name_force = env_name_force
# Types saved/loaded in configs
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
def _read_env_variable(name: str) -> Optional[bool]:
value = os.environ.get(name)
if value == "1":
return True
if value == "0":
return False
return None
def install_config_module(module: ModuleType) -> None:
"""
Converts a module-level config into a `ConfigModule()`.
@ -87,6 +116,7 @@ def install_config_module(module: ModuleType) -> None:
delattr(module, key)
elif isinstance(value, Config):
config[name] = _ConfigEntry(value)
if dest is module:
delattr(module, key)
elif isinstance(value, type):
@ -167,10 +197,19 @@ class _ConfigEntry:
user_override: Any = _UNSET_SENTINEL
# The justknob to check for this config
justknob: Optional[str] = None
# environment variables are read at install time
env_value_force: Any = _UNSET_SENTINEL
env_value_default: Any = _UNSET_SENTINEL
def __init__(self, config: Config):
self.default = config.default
self.justknob = config.justknob
if config.env_name_default is not None:
if (env_value := _read_env_variable(config.env_name_default)) is not None:
self.env_value_default = env_value
if config.env_name_force is not None:
if (env_value := _read_env_variable(config.env_name_force)) is not None:
self.env_value_force = env_value
class ConfigModule(ModuleType):
@ -202,9 +241,16 @@ class ConfigModule(ModuleType):
def __getattr__(self, name: str) -> Any:
try:
config = self._config[name]
if config.env_value_force is not _UNSET_SENTINEL:
return config.env_value_force
if config.user_override is not _UNSET_SENTINEL:
return config.user_override
if config.env_value_default is not _UNSET_SENTINEL:
return config.env_value_default
if config.justknob is not None:
# JK only supports bools and ints
return justknobs_check(name=config.justknob, default=config.default)