Add type annotations to Configs (#139833)

Summary:
Adds types to Configs, and fixes a bug in options that was caused by the lack of types.

fixes: https://github.com/pytorch/pytorch/issues/139822

Configs are used by many modules so not sure which label to put.

Types also allow https://github.com/pytorch/pytorch/pull/139736 to fuzz configs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139833
Approved by: https://github.com/c00w
This commit is contained in:
Gabriel Ferns
2024-11-07 03:49:07 +00:00
committed by PyTorch MergeBot
parent 5203138483
commit 2037ea3e15
4 changed files with 55 additions and 21 deletions

View File

@ -6,6 +6,8 @@ import pickle
os.environ["ENV_TRUE"] = "1"
os.environ["ENV_FALSE"] = "0"
from typing import Optional
from torch.testing._internal import fake_config_module as config
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL
@ -15,6 +17,7 @@ class TestConfigModule(TestCase):
def test_base_value_loading(self):
self.assertTrue(config.e_bool)
self.assertTrue(config.nested.e_bool)
self.assertTrue(config.e_optional)
self.assertEqual(config.e_int, 1)
self.assertEqual(config.e_float, 1.0)
self.assertEqual(config.e_string, "string")
@ -28,6 +31,10 @@ class TestConfigModule(TestCase):
):
config.does_not_exist
def test_type_loading(self):
self.assertEqual(config.get_type("e_optional"), Optional[bool])
self.assertEqual(config.get_type("e_none"), Optional[bool])
def test_overrides(self):
config.e_bool = False
self.assertFalse(config.e_bool)
@ -51,6 +58,10 @@ class TestConfigModule(TestCase):
self.assertEqual(config.e_none, "not none")
config.e_none = None
self.assertEqual(config.e_none, None)
config.e_optional = None
self.assertEqual(config.e_optional, None)
config.e_optional = False
self.assertEqual(config.e_optional, False)
with self.assertRaises(
AttributeError, msg="fake_config_module.does_not_exist does not exist"
):
@ -112,6 +123,7 @@ class TestConfigModule(TestCase):
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
"e_optional": True,
},
)
config.e_bool = False
@ -145,6 +157,7 @@ class TestConfigModule(TestCase):
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
"e_optional": True,
},
)
config.e_bool = False
@ -173,30 +186,22 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
config._config[k].user_override = _UNSET_SENTINEL
def test_get_hash(self):
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
# Test cached value
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake")
config.e_bool = False
self.assertNotEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
)
config.e_bool = True
# Test ignored values
config.e_compile_ignored = False
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
@ -227,6 +232,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
"e_optional": True,
},
)
p2 = config.to_dict()
@ -255,6 +261,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
"e_optional": True,
},
)
p3 = config.get_config_copy()
@ -283,6 +290,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
"e_optional": True,
},
)