Add check that envvar configs are boolean (#145454)

So we don't get unexpected behavior when higher typed values are passed in
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145454
Approved by: https://github.com/c00w, https://github.com/jamesjwu
This commit is contained in:
Raymond Li
2025-02-05 19:40:10 +00:00
committed by PyTorch MergeBot
parent 9091096d6c
commit dd349207c5
2 changed files with 30 additions and 8 deletions

View File

@ -17,7 +17,7 @@ from torch.testing._internal import (
fake_config_module3 as config3,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL, Config
from torch.utils._config_module import _ConfigEntry, _UNSET_SENTINEL, Config
class TestConfigModule(TestCase):
@ -378,7 +378,7 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
AssertionError,
msg="AssertionError: justknobs only support booleans, thisisnotvalid is not a boolean",
):
Config(default="bad", justknob="fake_knob")
_ConfigEntry(Config(default="bad", justknob="fake_knob"))
def test_alias(self):
self.assertFalse(config2.e_aliasing_bool)
@ -395,6 +395,18 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
t["a"] = "b"
self.assertFalse(config._is_default("e_dict"))
def test_invalid_config_int(self):
with self.assertRaises(AssertionError):
_ConfigEntry(
Config(default=2, env_name_default="FAKE_DISABLE", value_type=int)
)
def test_invalid_config_float(self):
with self.assertRaises(AssertionError):
_ConfigEntry(
Config(default=2, env_name_force="FAKE_DISABLE", value_type=float)
)
if __name__ == "__main__":
run_tests()

View File

@ -82,7 +82,6 @@ class _Config(Generic[T]):
justknob: Optional[str] = None
env_name_default: Optional[list[str]] = None
env_name_force: Optional[list[str]] = None
value_type: Optional[type] = None
alias: Optional[str] = None
def __init__(
@ -103,17 +102,13 @@ class _Config(Generic[T]):
self.env_name_force = _Config.string_or_list_of_string_to_list(env_name_force)
self.value_type = value_type
self.alias = alias
if self.justknob is not None:
assert isinstance(
self.default, bool
), f"justknobs only support booleans, {self.default} is not a boolean"
if self.alias is not None:
assert (
default is _UNSET_SENTINEL
and justknob is None
and env_name_default is None
and env_name_force is None
), "if alias is set, default, justknob or env var cannot be set"
), "if alias is set, none of {default, justknob and env var} can be set"
@staticmethod
def string_or_list_of_string_to_list(
@ -326,6 +321,21 @@ class _ConfigEntry:
self.env_value_force = env_value
break
# Ensure justknobs and envvars are allowlisted types
if self.justknob is not None and self.default is not None:
assert isinstance(
self.default, bool
), f"justknobs only support booleans, {self.default} is not a boolean"
if self.value_type is not None and (
config.env_name_default is not None or config.env_name_force is not None
):
assert self.value_type in (
bool,
str,
Optional[bool],
Optional[str],
), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
class ConfigModule(ModuleType):
# NOTE: This should be kept in sync with _config_typing.pyi.