mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add check that envvar configs are boolean (#145454)
So we don't get unexpected behavior when higher typed values are passed in Pull Request resolved: https://github.com/pytorch/pytorch/pull/145454 Approved by: https://github.com/c00w, https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
9091096d6c
commit
dd349207c5
@ -17,7 +17,7 @@ from torch.testing._internal import (
|
||||
fake_config_module3 as config3,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._config_module import _UNSET_SENTINEL, Config
|
||||
from torch.utils._config_module import _ConfigEntry, _UNSET_SENTINEL, Config
|
||||
|
||||
|
||||
class TestConfigModule(TestCase):
|
||||
@ -378,7 +378,7 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
|
||||
AssertionError,
|
||||
msg="AssertionError: justknobs only support booleans, thisisnotvalid is not a boolean",
|
||||
):
|
||||
Config(default="bad", justknob="fake_knob")
|
||||
_ConfigEntry(Config(default="bad", justknob="fake_knob"))
|
||||
|
||||
def test_alias(self):
|
||||
self.assertFalse(config2.e_aliasing_bool)
|
||||
@ -395,6 +395,18 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
|
||||
t["a"] = "b"
|
||||
self.assertFalse(config._is_default("e_dict"))
|
||||
|
||||
def test_invalid_config_int(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_ConfigEntry(
|
||||
Config(default=2, env_name_default="FAKE_DISABLE", value_type=int)
|
||||
)
|
||||
|
||||
def test_invalid_config_float(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_ConfigEntry(
|
||||
Config(default=2, env_name_force="FAKE_DISABLE", value_type=float)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -82,7 +82,6 @@ class _Config(Generic[T]):
|
||||
justknob: Optional[str] = None
|
||||
env_name_default: Optional[list[str]] = None
|
||||
env_name_force: Optional[list[str]] = None
|
||||
value_type: Optional[type] = None
|
||||
alias: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
@ -103,17 +102,13 @@ class _Config(Generic[T]):
|
||||
self.env_name_force = _Config.string_or_list_of_string_to_list(env_name_force)
|
||||
self.value_type = value_type
|
||||
self.alias = alias
|
||||
if self.justknob is not None:
|
||||
assert isinstance(
|
||||
self.default, bool
|
||||
), f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
if self.alias is not None:
|
||||
assert (
|
||||
default is _UNSET_SENTINEL
|
||||
and justknob is None
|
||||
and env_name_default is None
|
||||
and env_name_force is None
|
||||
), "if alias is set, default, justknob or env var cannot be set"
|
||||
), "if alias is set, none of {default, justknob and env var} can be set"
|
||||
|
||||
@staticmethod
|
||||
def string_or_list_of_string_to_list(
|
||||
@ -326,6 +321,21 @@ class _ConfigEntry:
|
||||
self.env_value_force = env_value
|
||||
break
|
||||
|
||||
# Ensure justknobs and envvars are allowlisted types
|
||||
if self.justknob is not None and self.default is not None:
|
||||
assert isinstance(
|
||||
self.default, bool
|
||||
), f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
if self.value_type is not None and (
|
||||
config.env_name_default is not None or config.env_name_force is not None
|
||||
):
|
||||
assert self.value_type in (
|
||||
bool,
|
||||
str,
|
||||
Optional[bool],
|
||||
Optional[str],
|
||||
), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
|
||||
|
||||
class ConfigModule(ModuleType):
|
||||
# NOTE: This should be kept in sync with _config_typing.pyi.
|
||||
|
Reference in New Issue
Block a user