mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
config: Add env_name_default and env_name_force to Config (#138956)
This allows Configs to handle setting their defaults (or overriding themselves) via environment variables. The environment variables are resolved at install time (which is usually import time). This is done 1) to avoid any race conditions between threads etc..., but 2) to help encourage people to just go modify the configs directly, vs overriding environment variables to change pytorch behaviour. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956 Approved by: https://github.com/ezyang ghstack dependencies: #138766
This commit is contained in:
committed by
PyTorch MergeBot
parent
1270c78268
commit
2a857e940d
@ -1,6 +1,11 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["ENV_TRUE"] = "1"
|
||||||
|
os.environ["ENV_FALSE"] = "0"
|
||||||
|
|
||||||
from torch.testing._internal import fake_config_module as config
|
from torch.testing._internal import fake_config_module as config
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
from torch.utils._config_module import _UNSET_SENTINEL
|
from torch.utils._config_module import _UNSET_SENTINEL
|
||||||
@ -70,6 +75,17 @@ class TestConfigModule(TestCase):
|
|||||||
for k in config._config:
|
for k in config._config:
|
||||||
config._config[k].user_override = _UNSET_SENTINEL
|
config._config[k].user_override = _UNSET_SENTINEL
|
||||||
|
|
||||||
|
def test_env_name_semantics(self):
|
||||||
|
self.assertTrue(config.e_env_default)
|
||||||
|
self.assertFalse(config.e_env_default_FALSE)
|
||||||
|
self.assertTrue(config.e_env_force)
|
||||||
|
config.e_env_default = False
|
||||||
|
self.assertFalse(config.e_env_default)
|
||||||
|
config.e_env_force = False
|
||||||
|
self.assertTrue(config.e_env_force)
|
||||||
|
for k in config._config:
|
||||||
|
config._config[k].user_override = _UNSET_SENTINEL
|
||||||
|
|
||||||
def test_save_config(self):
|
def test_save_config(self):
|
||||||
p = config.save_config()
|
p = config.save_config()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -93,6 +109,9 @@ class TestConfigModule(TestCase):
|
|||||||
"e_config": True,
|
"e_config": True,
|
||||||
"e_jk": True,
|
"e_jk": True,
|
||||||
"e_jk_false": False,
|
"e_jk_false": False,
|
||||||
|
"e_env_default": True,
|
||||||
|
"e_env_default_FALSE": False,
|
||||||
|
"e_env_force": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
@ -123,6 +142,9 @@ class TestConfigModule(TestCase):
|
|||||||
"e_config": True,
|
"e_config": True,
|
||||||
"e_jk": True,
|
"e_jk": True,
|
||||||
"e_jk_false": False,
|
"e_jk_false": False,
|
||||||
|
"e_env_default": True,
|
||||||
|
"e_env_default_FALSE": False,
|
||||||
|
"e_env_force": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
@ -152,35 +174,35 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||||||
|
|
||||||
def test_get_hash(self):
|
def test_get_hash(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||||
)
|
)
|
||||||
# Test cached value
|
# Test cached value
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||||
)
|
)
|
||||||
config._hash_digest = "fake"
|
config._hash_digest = "fake"
|
||||||
self.assertEqual(config.get_hash(), "fake")
|
self.assertEqual(config.get_hash(), "fake")
|
||||||
|
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||||
)
|
)
|
||||||
config.e_bool = True
|
config.e_bool = True
|
||||||
|
|
||||||
# Test ignored values
|
# Test ignored values
|
||||||
config.e_compile_ignored = False
|
config.e_compile_ignored = False
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||||
)
|
)
|
||||||
for k in config._config:
|
for k in config._config:
|
||||||
config._config[k].user_override = _UNSET_SENTINEL
|
config._config[k].user_override = _UNSET_SENTINEL
|
||||||
|
|
||||||
def test_dict_copy_semantics(self):
|
def test_dict_copy_semantics(self):
|
||||||
p = config.shallow_copy_dict()
|
p = config.shallow_copy_dict()
|
||||||
self.assertEqual(
|
self.assertDictEqual(
|
||||||
p,
|
p,
|
||||||
{
|
{
|
||||||
"e_bool": True,
|
"e_bool": True,
|
||||||
@ -202,6 +224,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||||||
"e_config": True,
|
"e_config": True,
|
||||||
"e_jk": True,
|
"e_jk": True,
|
||||||
"e_jk_false": False,
|
"e_jk_false": False,
|
||||||
|
"e_env_default": True,
|
||||||
|
"e_env_default_FALSE": False,
|
||||||
|
"e_env_force": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
p2 = config.to_dict()
|
p2 = config.to_dict()
|
||||||
@ -227,6 +252,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||||||
"e_config": True,
|
"e_config": True,
|
||||||
"e_jk": True,
|
"e_jk": True,
|
||||||
"e_jk_false": False,
|
"e_jk_false": False,
|
||||||
|
"e_env_default": True,
|
||||||
|
"e_env_default_FALSE": False,
|
||||||
|
"e_env_force": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
p3 = config.get_config_copy()
|
p3 = config.get_config_copy()
|
||||||
@ -252,6 +280,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||||||
"e_config": True,
|
"e_config": True,
|
||||||
"e_jk": True,
|
"e_jk": True,
|
||||||
"e_jk_false": False,
|
"e_jk_false": False,
|
||||||
|
"e_env_default": True,
|
||||||
|
"e_env_default_FALSE": False,
|
||||||
|
"e_env_force": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +21,9 @@ e_compile_ignored = True
|
|||||||
e_config = Config(default=True)
|
e_config = Config(default=True)
|
||||||
e_jk = Config(justknob="does_not_exist")
|
e_jk = Config(justknob="does_not_exist")
|
||||||
e_jk_false = Config(justknob="does_not_exist", default=False)
|
e_jk_false = Config(justknob="does_not_exist", default=False)
|
||||||
|
e_env_default = Config(env_name_default="ENV_TRUE", default=False)
|
||||||
|
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True)
|
||||||
|
e_env_force = Config(env_name_force="ENV_TRUE", default=False)
|
||||||
|
|
||||||
|
|
||||||
class nested:
|
class nested:
|
||||||
|
@ -28,31 +28,60 @@ class Config:
|
|||||||
This configs must be installed with install_config_module to be used
|
This configs must be installed with install_config_module to be used
|
||||||
|
|
||||||
Precedence Order:
|
Precedence Order:
|
||||||
|
env_name_force: If set, this environment variable overrides everything
|
||||||
user_override: If a user sets a value (i.e. foo.bar=True), that
|
user_override: If a user sets a value (i.e. foo.bar=True), that
|
||||||
has the highest precendance and is always respected
|
has precedence over everything after this.
|
||||||
|
env_name_default: If set, this environment variable will override everything
|
||||||
|
after this.
|
||||||
justknob: If this pytorch installation supports justknobs, that will
|
justknob: If this pytorch installation supports justknobs, that will
|
||||||
override defaults, but will not override the user_override precendence.
|
override defaults, but will not override the user_override precendence.
|
||||||
default: This value is the lowest precendance, and will be used if nothing is
|
default: This value is the lowest precendance, and will be used if nothing is
|
||||||
set.
|
set.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
These are interpreted to be either "0" or "1" to represent true and false.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
justknob: the name of the feature / JK. In OSS this is unused.
|
justknob: the name of the feature / JK. In OSS this is unused.
|
||||||
default: is the value to default this knob to in OSS.
|
default: is the value to default this knob to in OSS.
|
||||||
|
env_name_force: The environment variable to read that is a FORCE
|
||||||
|
environment variable. I.e. it overrides everything
|
||||||
|
env_name_default: The environment variable to read that changes the
|
||||||
|
default behaviour. I.e. user overrides take preference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default: Any = True
|
default: Any = True
|
||||||
justknob: Optional[str] = None
|
justknob: Optional[str] = None
|
||||||
|
env_name_default: Optional[str] = None
|
||||||
|
env_name_force: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, default: Any = True, justknob: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
default: Any = True,
|
||||||
|
justknob: Optional[str] = None,
|
||||||
|
env_name_default: Optional[str] = None,
|
||||||
|
env_name_force: Optional[str] = None,
|
||||||
|
):
|
||||||
# python 3.9 does not support kw_only on the dataclass :(.
|
# python 3.9 does not support kw_only on the dataclass :(.
|
||||||
self.default = default
|
self.default = default
|
||||||
self.justknob = justknob
|
self.justknob = justknob
|
||||||
|
self.env_name_default = env_name_default
|
||||||
|
self.env_name_force = env_name_force
|
||||||
|
|
||||||
|
|
||||||
# Types saved/loaded in configs
|
# Types saved/loaded in configs
|
||||||
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
|
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_env_variable(name: str) -> Optional[bool]:
|
||||||
|
value = os.environ.get(name)
|
||||||
|
if value == "1":
|
||||||
|
return True
|
||||||
|
if value == "0":
|
||||||
|
return False
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def install_config_module(module: ModuleType) -> None:
|
def install_config_module(module: ModuleType) -> None:
|
||||||
"""
|
"""
|
||||||
Converts a module-level config into a `ConfigModule()`.
|
Converts a module-level config into a `ConfigModule()`.
|
||||||
@ -87,6 +116,7 @@ def install_config_module(module: ModuleType) -> None:
|
|||||||
delattr(module, key)
|
delattr(module, key)
|
||||||
elif isinstance(value, Config):
|
elif isinstance(value, Config):
|
||||||
config[name] = _ConfigEntry(value)
|
config[name] = _ConfigEntry(value)
|
||||||
|
|
||||||
if dest is module:
|
if dest is module:
|
||||||
delattr(module, key)
|
delattr(module, key)
|
||||||
elif isinstance(value, type):
|
elif isinstance(value, type):
|
||||||
@ -167,10 +197,19 @@ class _ConfigEntry:
|
|||||||
user_override: Any = _UNSET_SENTINEL
|
user_override: Any = _UNSET_SENTINEL
|
||||||
# The justknob to check for this config
|
# The justknob to check for this config
|
||||||
justknob: Optional[str] = None
|
justknob: Optional[str] = None
|
||||||
|
# environment variables are read at install time
|
||||||
|
env_value_force: Any = _UNSET_SENTINEL
|
||||||
|
env_value_default: Any = _UNSET_SENTINEL
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
def __init__(self, config: Config):
|
||||||
self.default = config.default
|
self.default = config.default
|
||||||
self.justknob = config.justknob
|
self.justknob = config.justknob
|
||||||
|
if config.env_name_default is not None:
|
||||||
|
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
||||||
|
self.env_value_default = env_value
|
||||||
|
if config.env_name_force is not None:
|
||||||
|
if (env_value := _read_env_variable(config.env_name_force)) is not None:
|
||||||
|
self.env_value_force = env_value
|
||||||
|
|
||||||
|
|
||||||
class ConfigModule(ModuleType):
|
class ConfigModule(ModuleType):
|
||||||
@ -202,9 +241,16 @@ class ConfigModule(ModuleType):
|
|||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
try:
|
try:
|
||||||
config = self._config[name]
|
config = self._config[name]
|
||||||
|
|
||||||
|
if config.env_value_force is not _UNSET_SENTINEL:
|
||||||
|
return config.env_value_force
|
||||||
|
|
||||||
if config.user_override is not _UNSET_SENTINEL:
|
if config.user_override is not _UNSET_SENTINEL:
|
||||||
return config.user_override
|
return config.user_override
|
||||||
|
|
||||||
|
if config.env_value_default is not _UNSET_SENTINEL:
|
||||||
|
return config.env_value_default
|
||||||
|
|
||||||
if config.justknob is not None:
|
if config.justknob is not None:
|
||||||
# JK only supports bools and ints
|
# JK only supports bools and ints
|
||||||
return justknobs_check(name=config.justknob, default=config.default)
|
return justknobs_check(name=config.justknob, default=config.default)
|
||||||
|
Reference in New Issue
Block a user