mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add type annotations to Configs (#139833)
Summary: Adds types to Configs, and fixes a bug in options that was caused by the lack of types. fixes: https://github.com/pytorch/pytorch/issues/139822 Configs are used by many modules so not sure which label to put. Types also allow https://github.com/pytorch/pytorch/pull/139736 to fuzz configs Pull Request resolved: https://github.com/pytorch/pytorch/pull/139833 Approved by: https://github.com/c00w
This commit is contained in:
committed by
PyTorch MergeBot
parent
5203138483
commit
2037ea3e15
@ -6,6 +6,8 @@ import pickle
|
||||
os.environ["ENV_TRUE"] = "1"
|
||||
os.environ["ENV_FALSE"] = "0"
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch.testing._internal import fake_config_module as config
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._config_module import _UNSET_SENTINEL
|
||||
@ -15,6 +17,7 @@ class TestConfigModule(TestCase):
|
||||
def test_base_value_loading(self):
|
||||
self.assertTrue(config.e_bool)
|
||||
self.assertTrue(config.nested.e_bool)
|
||||
self.assertTrue(config.e_optional)
|
||||
self.assertEqual(config.e_int, 1)
|
||||
self.assertEqual(config.e_float, 1.0)
|
||||
self.assertEqual(config.e_string, "string")
|
||||
@ -28,6 +31,10 @@ class TestConfigModule(TestCase):
|
||||
):
|
||||
config.does_not_exist
|
||||
|
||||
def test_type_loading(self):
|
||||
self.assertEqual(config.get_type("e_optional"), Optional[bool])
|
||||
self.assertEqual(config.get_type("e_none"), Optional[bool])
|
||||
|
||||
def test_overrides(self):
|
||||
config.e_bool = False
|
||||
self.assertFalse(config.e_bool)
|
||||
@ -51,6 +58,10 @@ class TestConfigModule(TestCase):
|
||||
self.assertEqual(config.e_none, "not none")
|
||||
config.e_none = None
|
||||
self.assertEqual(config.e_none, None)
|
||||
config.e_optional = None
|
||||
self.assertEqual(config.e_optional, None)
|
||||
config.e_optional = False
|
||||
self.assertEqual(config.e_optional, False)
|
||||
with self.assertRaises(
|
||||
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
||||
):
|
||||
@ -112,6 +123,7 @@ class TestConfigModule(TestCase):
|
||||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
config.e_bool = False
|
||||
@ -145,6 +157,7 @@ class TestConfigModule(TestCase):
|
||||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
config.e_bool = False
|
||||
@ -173,30 +186,22 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_get_hash(self):
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
# Test cached value
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
config._hash_digest = "fake"
|
||||
self.assertEqual(config.get_hash(), "fake")
|
||||
|
||||
config.e_bool = False
|
||||
self.assertNotEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
||||
)
|
||||
config.e_bool = True
|
||||
|
||||
# Test ignored values
|
||||
config.e_compile_ignored = False
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
@ -227,6 +232,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
p2 = config.to_dict()
|
||||
@ -255,6 +261,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
p3 = config.get_config_copy()
|
||||
@ -283,6 +290,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user