Files
pytorch/test/test_utils_config_module.py
Raymond Li dd349207c5 Add check that envvar configs are boolean (#145454)
So we don't get unexpected behavior when higher typed values are passed in
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145454
Approved by: https://github.com/c00w, https://github.com/jamesjwu
2025-02-05 19:40:10 +00:00

413 lines
14 KiB
Python

# Owner(s): ["module: unknown"]
import os
import pickle
from unittest.mock import patch
os.environ["ENV_TRUE"] = "1"
os.environ["ENV_FALSE"] = "0"
os.environ["ENV_STR"] = "1234"
os.environ["ENV_STR_EMPTY"] = ""
from typing import Optional
from torch.testing._internal import (
fake_config_module as config,
fake_config_module2 as config2,
fake_config_module3 as config3,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _ConfigEntry, _UNSET_SENTINEL, Config
class TestConfigModule(TestCase):
def tearDown(self):
# Config changes get persisted between test cases
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_base_value_loading(self):
self.assertTrue(config.e_bool)
self.assertTrue(config.nested.e_bool)
self.assertTrue(config.e_optional)
self.assertEqual(config.e_int, 1)
self.assertEqual(config.e_float, 1.0)
self.assertEqual(config.e_string, "string")
self.assertEqual(config.e_list, [1])
self.assertEqual(config.e_set, {1})
self.assertEqual(config.e_tuple, (1,))
self.assertEqual(config.e_dict, {1: 2})
self.assertEqual(config.e_none, None)
with self.assertRaises(
AttributeError, msg="fake_config_module.does_not_exist does not exist"
):
config.does_not_exist
def test_type_loading(self):
self.assertEqual(config.get_type("e_optional"), Optional[bool])
self.assertEqual(config.get_type("e_none"), Optional[bool])
def test_overrides(self):
config.e_bool = False
self.assertFalse(config.e_bool)
config.nested.e_bool = False
self.assertFalse(config.nested.e_bool)
config.e_int = 2
self.assertEqual(config.e_int, 2)
config.e_float = 2.0
self.assertEqual(config.e_float, 2.0)
config.e_string = "string2"
self.assertEqual(config.e_string, "string2")
config.e_list = [2]
self.assertEqual(config.e_list, [2])
config.e_set = {2}
self.assertEqual(config.e_set, {2})
config.e_tuple = (2,)
self.assertEqual(config.e_tuple, (2,))
config.e_dict = {2: 3}
self.assertEqual(config.e_dict, {2: 3})
config.e_none = "not none"
self.assertEqual(config.e_none, "not none")
config.e_none = None
self.assertEqual(config.e_none, None)
config.e_optional = None
self.assertEqual(config.e_optional, None)
config.e_optional = False
self.assertEqual(config.e_optional, False)
with self.assertRaises(
AttributeError, msg="fake_config_module.does_not_exist does not exist"
):
config.does_not_exist = 0
def test_none_override_semantics(self):
config.e_bool = None
self.assertIsNone(config.e_bool)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_reference_semantics(self):
config.e_list.append(2)
self.assertEqual(config.e_list, [1, 2])
config.e_set.add(2)
self.assertEqual(config.e_set, {1, 2})
config.e_dict[2] = 3
self.assertEqual(config.e_dict, {1: 2, 2: 3})
def test_env_name_semantics(self):
self.assertTrue(config.e_env_default)
self.assertFalse(config.e_env_default_FALSE)
self.assertTrue(config.e_env_force)
config.e_env_default = False
self.assertFalse(config.e_env_default)
config.e_env_force = False
self.assertTrue(config.e_env_force)
def test_env_name_string_semantics(self):
self.assertEqual(config.e_env_default_str, "1234")
self.assertEqual(config.e_env_default_str_empty, "")
config.e_env_default_str = "override"
self.assertEqual(config.e_env_default_str, "override")
def test_multi_env(self):
self.assertTrue(config2.e_env_default_multi)
self.assertTrue(config2.e_env_force_multi)
def test_save_config(self):
p = config.save_config()
self.assertDictEqual(
pickle.loads(p),
{
"_cache_config_ignore_prefix": ["magic_cache_config"],
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"_e_ignored": True,
"e_compile_ignored": True,
"magic_cache_config_ignored": True,
"_save_config_ignore": ["e_ignored"],
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_default_str": "1234",
"e_env_default_str_empty": "",
"e_env_force": True,
"e_optional": True,
},
)
config.e_bool = False
config.e_ignored = False
config.load_config(p)
self.assertTrue(config.e_bool)
self.assertFalse(config.e_ignored)
def test_save_config_portable(self):
p = config.save_config_portable()
self.assertDictEqual(
p,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"e_compile_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_default_str": "1234",
"e_env_default_str_empty": "",
"e_env_force": True,
"e_optional": True,
},
)
config.e_bool = False
config._e_ignored = False
config.load_config(p)
self.assertTrue(config.e_bool)
self.assertFalse(config._e_ignored)
def test_codegen_config(self):
config.e_bool = False
config.e_ignored = False
code = config.codegen_config()
self.assertEqual(
code,
"""torch.testing._internal.fake_config_module.e_bool = False
torch.testing._internal.fake_config_module.e_env_default = True
torch.testing._internal.fake_config_module.e_env_default_FALSE = False
torch.testing._internal.fake_config_module.e_env_default_str = '1234'
torch.testing._internal.fake_config_module.e_env_default_str_empty = ''
torch.testing._internal.fake_config_module.e_env_force = True""",
)
def test_codegen_config_function(self):
import logging
import warnings
config3.e_list = [print, warnings.warn, logging.warn]
config3.e_set = {print}
config3.e_func = warnings.warn
code = config3.codegen_config()
self.assertIn("import _warnings", code)
self.assertIn("import logging", code)
self.assertIn(
"""torch.testing._internal.fake_config_module3.e_list = ['print', '_warnings.warn', 'logging.warn']
torch.testing._internal.fake_config_module3.e_set = { print }
torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
code,
)
def test_get_hash(self):
hash_value = b"\x87\xf7\xc6\x1di\x7f\x96-\x85\xdc\x04\xd5\xd0\xf6\x1c\x87"
self.assertEqual(
config.get_hash(),
hash_value,
)
# Test cached value
self.assertEqual(
config.get_hash(),
hash_value,
)
self.assertEqual(
config.get_hash(),
hash_value,
)
config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake")
config.e_bool = False
self.assertNotEqual(
config.get_hash(),
hash_value,
)
config.e_bool = True
# Test ignored values
config.e_compile_ignored = False
self.assertEqual(
config.get_hash(),
hash_value,
)
def test_dict_copy_semantics(self):
p = config.shallow_copy_dict()
self.assertDictEqual(
p,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"_e_ignored": True,
"e_compile_ignored": True,
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_default_str": "1234",
"e_env_default_str_empty": "",
"e_env_force": True,
"e_optional": True,
},
)
p2 = config.to_dict()
self.assertEqual(
p2,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"_e_ignored": True,
"e_compile_ignored": True,
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_default_str": "1234",
"e_env_default_str_empty": "",
"e_env_force": True,
"e_optional": True,
},
)
p3 = config.get_config_copy()
self.assertEqual(
p3,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"_e_ignored": True,
"e_compile_ignored": True,
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_default_str": "1234",
"e_env_default_str_empty": "",
"e_env_force": True,
"e_optional": True,
},
)
# Shallow + deep copy semantics
config.e_dict[2] = 3
self.assertEqual(p["e_dict"], {1: 2})
self.assertEqual(p2["e_dict"], {1: 2})
self.assertEqual(p3["e_dict"], {1: 2})
def test_patch(self):
self.assertTrue(config.e_bool)
with config.patch("e_bool", False):
self.assertFalse(config.e_bool)
self.assertTrue(config.e_bool)
with config.patch(e_bool=False):
self.assertFalse(config.e_bool)
self.assertTrue(config.e_bool)
with self.assertRaises(AssertionError):
with config.patch("does_not_exist"):
pass
def test_make_closur_patcher(self):
revert = config._make_closure_patcher(e_bool=False)()
self.assertFalse(config.e_bool)
revert()
self.assertTrue(config.e_bool)
def test_unittest_patch(self):
with patch("torch.testing._internal.fake_config_module.e_bool", False):
with patch("torch.testing._internal.fake_config_module.e_bool", False):
self.assertFalse(config.e_bool)
# unittest.mock has some very weird semantics around deletion of attributes when undoing patches
self.assertFalse(config.e_bool)
self.assertTrue(config.e_bool)
def test_bad_jk_type(self):
with self.assertRaises(
AssertionError,
msg="AssertionError: justknobs only support booleans, thisisnotvalid is not a boolean",
):
_ConfigEntry(Config(default="bad", justknob="fake_knob"))
def test_alias(self):
self.assertFalse(config2.e_aliasing_bool)
self.assertFalse(config.e_aliased_bool)
with config2.patch(e_aliasing_bool=True):
self.assertTrue(config2.e_aliasing_bool)
self.assertTrue(config.e_aliased_bool)
with config.patch(e_aliased_bool=True):
self.assertTrue(config2.e_aliasing_bool)
def test_reference_is_default(self):
t = config.e_dict
self.assertTrue(config._is_default("e_dict"))
t["a"] = "b"
self.assertFalse(config._is_default("e_dict"))
def test_invalid_config_int(self):
with self.assertRaises(AssertionError):
_ConfigEntry(
Config(default=2, env_name_default="FAKE_DISABLE", value_type=int)
)
def test_invalid_config_float(self):
with self.assertRaises(AssertionError):
_ConfigEntry(
Config(default=2, env_name_force="FAKE_DISABLE", value_type=float)
)
if __name__ == "__main__":
run_tests()