mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: This allows us to use environment variables to set string values. We've added tests for the specific functionality implemented here. Note that we already accidentally started setting up configs to use this, so we're just adding the feature. Additionally, we're not fully validating the underlying type when we set the value (and in general, it's more difficult than we would like to do this). Let me know if people feel strongly, and we can add a PR to do this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145980 Approved by: https://github.com/yushangdi, https://github.com/oulgen
396 lines
14 KiB
Python
396 lines
14 KiB
Python
# Owner(s): ["module: unknown"]
|
|
import os
|
|
import pickle
|
|
from unittest.mock import patch
|
|
|
|
|
|
os.environ["ENV_TRUE"] = "1"
|
|
os.environ["ENV_FALSE"] = "0"
|
|
os.environ["ENV_STR"] = "1234"
|
|
os.environ["ENV_STR_EMPTY"] = ""
|
|
|
|
from typing import Optional
|
|
|
|
from torch.testing._internal import (
|
|
fake_config_module as config,
|
|
fake_config_module2 as config2,
|
|
fake_config_module3 as config3,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.utils._config_module import _UNSET_SENTINEL, Config
|
|
|
|
|
|
class TestConfigModule(TestCase):
|
|
def tearDown(self):
|
|
# Config changes get persisted between test cases
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_base_value_loading(self):
|
|
self.assertTrue(config.e_bool)
|
|
self.assertTrue(config.nested.e_bool)
|
|
self.assertTrue(config.e_optional)
|
|
self.assertEqual(config.e_int, 1)
|
|
self.assertEqual(config.e_float, 1.0)
|
|
self.assertEqual(config.e_string, "string")
|
|
self.assertEqual(config.e_list, [1])
|
|
self.assertEqual(config.e_set, {1})
|
|
self.assertEqual(config.e_tuple, (1,))
|
|
self.assertEqual(config.e_dict, {1: 2})
|
|
self.assertEqual(config.e_none, None)
|
|
with self.assertRaises(
|
|
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
|
):
|
|
config.does_not_exist
|
|
|
|
def test_type_loading(self):
|
|
self.assertEqual(config.get_type("e_optional"), Optional[bool])
|
|
self.assertEqual(config.get_type("e_none"), Optional[bool])
|
|
|
|
def test_overrides(self):
|
|
config.e_bool = False
|
|
self.assertFalse(config.e_bool)
|
|
config.nested.e_bool = False
|
|
self.assertFalse(config.nested.e_bool)
|
|
config.e_int = 2
|
|
self.assertEqual(config.e_int, 2)
|
|
config.e_float = 2.0
|
|
self.assertEqual(config.e_float, 2.0)
|
|
config.e_string = "string2"
|
|
self.assertEqual(config.e_string, "string2")
|
|
config.e_list = [2]
|
|
self.assertEqual(config.e_list, [2])
|
|
config.e_set = {2}
|
|
self.assertEqual(config.e_set, {2})
|
|
config.e_tuple = (2,)
|
|
self.assertEqual(config.e_tuple, (2,))
|
|
config.e_dict = {2: 3}
|
|
self.assertEqual(config.e_dict, {2: 3})
|
|
config.e_none = "not none"
|
|
self.assertEqual(config.e_none, "not none")
|
|
config.e_none = None
|
|
self.assertEqual(config.e_none, None)
|
|
config.e_optional = None
|
|
self.assertEqual(config.e_optional, None)
|
|
config.e_optional = False
|
|
self.assertEqual(config.e_optional, False)
|
|
with self.assertRaises(
|
|
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
|
):
|
|
config.does_not_exist = 0
|
|
|
|
def test_none_override_semantics(self):
|
|
config.e_bool = None
|
|
self.assertIsNone(config.e_bool)
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_reference_semantics(self):
|
|
config.e_list.append(2)
|
|
self.assertEqual(config.e_list, [1, 2])
|
|
config.e_set.add(2)
|
|
self.assertEqual(config.e_set, {1, 2})
|
|
config.e_dict[2] = 3
|
|
self.assertEqual(config.e_dict, {1: 2, 2: 3})
|
|
|
|
def test_env_name_semantics(self):
|
|
self.assertTrue(config.e_env_default)
|
|
self.assertFalse(config.e_env_default_FALSE)
|
|
self.assertTrue(config.e_env_force)
|
|
config.e_env_default = False
|
|
self.assertFalse(config.e_env_default)
|
|
config.e_env_force = False
|
|
self.assertTrue(config.e_env_force)
|
|
|
|
def test_env_name_string_semantics(self):
|
|
self.assertEqual(config.e_env_default_str, "1234")
|
|
self.assertEqual(config.e_env_default_str_empty, "")
|
|
config.e_env_default_str = "override"
|
|
self.assertEqual(config.e_env_default_str, "override")
|
|
|
|
def test_multi_env(self):
|
|
self.assertTrue(config2.e_env_default_multi)
|
|
self.assertTrue(config2.e_env_force_multi)
|
|
|
|
def test_save_config(self):
|
|
p = config.save_config()
|
|
self.assertDictEqual(
|
|
pickle.loads(p),
|
|
{
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"magic_cache_config_ignored": True,
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_default_str": "1234",
|
|
"e_env_default_str_empty": "",
|
|
"e_env_force": True,
|
|
"e_optional": True,
|
|
},
|
|
)
|
|
config.e_bool = False
|
|
config.e_ignored = False
|
|
config.load_config(p)
|
|
self.assertTrue(config.e_bool)
|
|
self.assertFalse(config.e_ignored)
|
|
|
|
def test_save_config_portable(self):
|
|
p = config.save_config_portable()
|
|
self.assertDictEqual(
|
|
p,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_default_str": "1234",
|
|
"e_env_default_str_empty": "",
|
|
"e_env_force": True,
|
|
"e_optional": True,
|
|
},
|
|
)
|
|
config.e_bool = False
|
|
config._e_ignored = False
|
|
config.load_config(p)
|
|
self.assertTrue(config.e_bool)
|
|
self.assertFalse(config._e_ignored)
|
|
|
|
def test_codegen_config(self):
|
|
config.e_bool = False
|
|
config.e_ignored = False
|
|
code = config.codegen_config()
|
|
self.assertEqual(
|
|
code,
|
|
"""torch.testing._internal.fake_config_module.e_bool = False
|
|
torch.testing._internal.fake_config_module.e_env_default = True
|
|
torch.testing._internal.fake_config_module.e_env_default_FALSE = False
|
|
torch.testing._internal.fake_config_module.e_env_default_str = '1234'
|
|
torch.testing._internal.fake_config_module.e_env_default_str_empty = ''
|
|
torch.testing._internal.fake_config_module.e_env_force = True
|
|
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
|
|
)
|
|
|
|
def test_codegen_config_function(self):
|
|
import logging
|
|
import warnings
|
|
|
|
config3.e_list = [print, warnings.warn, logging.warn]
|
|
config3.e_set = {print}
|
|
config3.e_func = warnings.warn
|
|
code = config3.codegen_config()
|
|
self.assertIn("import _warnings", code)
|
|
self.assertIn("import logging", code)
|
|
self.assertIn(
|
|
"""torch.testing._internal.fake_config_module3.e_list = ['print', '_warnings.warn', 'logging.warn']
|
|
torch.testing._internal.fake_config_module3.e_set = { print }
|
|
torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
|
|
code,
|
|
)
|
|
|
|
def test_get_hash(self):
|
|
hash_value = b"\x87\xf7\xc6\x1di\x7f\x96-\x85\xdc\x04\xd5\xd0\xf6\x1c\x87"
|
|
self.assertEqual(
|
|
config.get_hash(),
|
|
hash_value,
|
|
)
|
|
# Test cached value
|
|
self.assertEqual(
|
|
config.get_hash(),
|
|
hash_value,
|
|
)
|
|
self.assertEqual(
|
|
config.get_hash(),
|
|
hash_value,
|
|
)
|
|
config._hash_digest = "fake"
|
|
self.assertEqual(config.get_hash(), "fake")
|
|
|
|
config.e_bool = False
|
|
self.assertNotEqual(
|
|
config.get_hash(),
|
|
hash_value,
|
|
)
|
|
config.e_bool = True
|
|
|
|
# Test ignored values
|
|
config.e_compile_ignored = False
|
|
self.assertEqual(
|
|
config.get_hash(),
|
|
hash_value,
|
|
)
|
|
|
|
def test_dict_copy_semantics(self):
|
|
p = config.shallow_copy_dict()
|
|
self.assertDictEqual(
|
|
p,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"magic_cache_config_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_default_str": "1234",
|
|
"e_env_default_str_empty": "",
|
|
"e_env_force": True,
|
|
"e_optional": True,
|
|
},
|
|
)
|
|
p2 = config.to_dict()
|
|
self.assertEqual(
|
|
p2,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"magic_cache_config_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_default_str": "1234",
|
|
"e_env_default_str_empty": "",
|
|
"e_env_force": True,
|
|
"e_optional": True,
|
|
},
|
|
)
|
|
p3 = config.get_config_copy()
|
|
self.assertEqual(
|
|
p3,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"magic_cache_config_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_default_str": "1234",
|
|
"e_env_default_str_empty": "",
|
|
"e_env_force": True,
|
|
"e_optional": True,
|
|
},
|
|
)
|
|
|
|
# Shallow + deep copy semantics
|
|
config.e_dict[2] = 3
|
|
self.assertEqual(p["e_dict"], {1: 2})
|
|
self.assertEqual(p2["e_dict"], {1: 2})
|
|
self.assertEqual(p3["e_dict"], {1: 2})
|
|
|
|
def test_patch(self):
|
|
self.assertTrue(config.e_bool)
|
|
with config.patch("e_bool", False):
|
|
self.assertFalse(config.e_bool)
|
|
self.assertTrue(config.e_bool)
|
|
with config.patch(e_bool=False):
|
|
self.assertFalse(config.e_bool)
|
|
self.assertTrue(config.e_bool)
|
|
with self.assertRaises(AssertionError):
|
|
with config.patch("does_not_exist"):
|
|
pass
|
|
|
|
def test_make_closur_patcher(self):
|
|
revert = config._make_closure_patcher(e_bool=False)()
|
|
self.assertFalse(config.e_bool)
|
|
revert()
|
|
self.assertTrue(config.e_bool)
|
|
|
|
def test_unittest_patch(self):
|
|
with patch("torch.testing._internal.fake_config_module.e_bool", False):
|
|
with patch("torch.testing._internal.fake_config_module.e_bool", False):
|
|
self.assertFalse(config.e_bool)
|
|
# unittest.mock has some very weird semantics around deletion of attributes when undoing patches
|
|
self.assertFalse(config.e_bool)
|
|
self.assertTrue(config.e_bool)
|
|
|
|
def test_bad_jk_type(self):
|
|
with self.assertRaises(
|
|
AssertionError,
|
|
msg="AssertionError: justknobs only support booleans, thisisnotvalid is not a boolean",
|
|
):
|
|
Config(default="bad", justknob="fake_knob")
|
|
|
|
def test_alias(self):
|
|
self.assertFalse(config2.e_aliasing_bool)
|
|
self.assertFalse(config.e_aliased_bool)
|
|
with config2.patch(e_aliasing_bool=True):
|
|
self.assertTrue(config2.e_aliasing_bool)
|
|
self.assertTrue(config.e_aliased_bool)
|
|
with config.patch(e_aliased_bool=True):
|
|
self.assertTrue(config2.e_aliasing_bool)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|