mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
config: Modify install_config_module to use a layered approach (#138758)
This modifies the config system, to use a single mapping of config -> ConfigEntry and to store the default and user values within them. We could have used multiple dicts (i.e. user_override and default), but as we add more fields (justknobs in this PR, perhaps testing and env variables later), it quickly becomes painful. There are a couple design decisions we could change. 1) All configs we save store the resolved value - not the default and user override seperately 2) All configs we load, apply the resolved value as a user override. This means that certain complexities of default behvaiour and deletion (as well as JK), will change if you save + load a config. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138758 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
46d0b635b9
commit
a0e095dd9f
@ -120,7 +120,7 @@ class TestInductorConfig(TestCase):
|
||||
def test_get_compiler_config(self):
|
||||
from torch._inductor import config as inductor_default_config
|
||||
|
||||
default_cudagraphs = inductor_default_config._default["triton.cudagraphs"]
|
||||
default_cudagraphs = inductor_default_config.triton.cudagraphs
|
||||
|
||||
# nn.Module: should update default config with a new value
|
||||
model = DummyModule()
|
||||
|
@ -3,6 +3,7 @@ import pickle
|
||||
|
||||
from torch.testing._internal import fake_config_module as config
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._config_module import _UNSET_SENTINEL
|
||||
|
||||
|
||||
class TestConfigModule(TestCase):
|
||||
@ -50,26 +51,24 @@ class TestConfigModule(TestCase):
|
||||
):
|
||||
config.does_not_exist = 0
|
||||
# Config changes get persisted between test cases
|
||||
config.e_bool = True
|
||||
config.nested.e_bool = True
|
||||
config.e_int = 1
|
||||
config.e_float = 1.0
|
||||
config.e_string = "string"
|
||||
config.e_list = [1]
|
||||
config.e_set = {1}
|
||||
config.e_tuple = (1,)
|
||||
config.e_dict = {1: 2}
|
||||
config.e_none = None
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_delete(self):
|
||||
self.assertTrue(config.e_bool)
|
||||
del config.e_bool
|
||||
with self.assertRaises(
|
||||
AttributeError, msg="fake_config_module.e_bool does not exist"
|
||||
):
|
||||
print(config.e_bool)
|
||||
# Config changes get persisted between test cases
|
||||
config.e_bool = True
|
||||
def test_none_override_semantics(self):
|
||||
config.e_bool = None
|
||||
self.assertIsNone(config.e_bool)
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_reference_semantics(self):
|
||||
config.e_list.append(2)
|
||||
self.assertEqual(config.e_list, [1, 2])
|
||||
config.e_set.add(2)
|
||||
self.assertEqual(config.e_set, {1, 2})
|
||||
config.e_dict[2] = 3
|
||||
self.assertEqual(config.e_dict, {1: 2, 2: 3})
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_save_config(self):
|
||||
p = config.save_config()
|
||||
@ -98,8 +97,8 @@ class TestConfigModule(TestCase):
|
||||
config.load_config(p)
|
||||
self.assertTrue(config.e_bool)
|
||||
self.assertFalse(config.e_ignored)
|
||||
# Config changes get persisted between test cases
|
||||
config.e_ignored = True
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_save_config_portable(self):
|
||||
p = config.save_config_portable()
|
||||
@ -126,18 +125,24 @@ class TestConfigModule(TestCase):
|
||||
self.assertTrue(config.e_bool)
|
||||
self.assertFalse(config._e_ignored)
|
||||
# Config changes get persisted between test cases
|
||||
config._e_ignored = True
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_codegen_config(self):
|
||||
config.e_bool = False
|
||||
config.e_ignored = False
|
||||
code = config.codegen_config()
|
||||
self.assertEqual(
|
||||
code, "torch.testing._internal.fake_config_module.e_bool = False"
|
||||
code,
|
||||
"""torch.testing._internal.fake_config_module.e_bool = False
|
||||
torch.testing._internal.fake_config_module.e_list = [1]
|
||||
torch.testing._internal.fake_config_module.e_set = {1}
|
||||
torch.testing._internal.fake_config_module.e_dict = {1: 2}
|
||||
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
|
||||
)
|
||||
# Config changes get persisted between test cases
|
||||
config.e_bool = True
|
||||
config.e_ignored = True
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_get_hash(self):
|
||||
self.assertEqual(
|
||||
@ -165,7 +170,8 @@ class TestConfigModule(TestCase):
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
|
||||
)
|
||||
config.e_compile_ignored = True
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_dict_copy_semantics(self):
|
||||
p = config.shallow_copy_dict()
|
||||
@ -240,9 +246,11 @@ class TestConfigModule(TestCase):
|
||||
self.assertEqual(p["e_dict"], {1: 2})
|
||||
self.assertEqual(p2["e_dict"], {1: 2})
|
||||
self.assertEqual(p3["e_dict"], {1: 2})
|
||||
config.e_dict = {1: 2}
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_patch(self):
|
||||
self.assertTrue(config.e_bool)
|
||||
with config.patch("e_bool", False):
|
||||
self.assertFalse(config.e_bool)
|
||||
self.assertTrue(config.e_bool)
|
||||
|
@ -7,6 +7,7 @@ import pickle
|
||||
import tokenize
|
||||
import unittest
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union
|
||||
from typing_extensions import deprecated
|
||||
@ -43,8 +44,7 @@ def install_config_module(module: ModuleType) -> None:
|
||||
|
||||
name = f"{prefix}{key}"
|
||||
if isinstance(value, CONFIG_TYPES):
|
||||
config[name] = value
|
||||
default[name] = value
|
||||
config[name] = _ConfigEntry(default=value)
|
||||
if dest is module:
|
||||
delattr(module, key)
|
||||
elif isinstance(value, type):
|
||||
@ -59,15 +59,12 @@ def install_config_module(module: ModuleType) -> None:
|
||||
else:
|
||||
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
|
||||
|
||||
config: Dict[str, Any] = {}
|
||||
default: Dict[str, Any] = {}
|
||||
config: Dict[str, _ConfigEntry] = {}
|
||||
|
||||
compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)
|
||||
|
||||
visit(module, module, "")
|
||||
module._config = config # type: ignore[attr-defined]
|
||||
module._default = default # type: ignore[attr-defined]
|
||||
module._allowed_keys = set(config.keys()) # type: ignore[attr-defined]
|
||||
module._compile_ignored_keys = compile_ignored_keys # type: ignore[attr-defined]
|
||||
module.__class__ = ConfigModuleInstance
|
||||
module._is_dirty = True # type: ignore[attr-defined]
|
||||
@ -116,17 +113,25 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str
|
||||
return assignments
|
||||
|
||||
|
||||
_UNSET_SENTINEL = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ConfigEntry:
|
||||
# The default value specified in the configuration
|
||||
default: Any
|
||||
# The value specified by the user when they overrode the configuration
|
||||
# _UNSET_SENTINEL indicates the value is not set.
|
||||
user_override: Any = _UNSET_SENTINEL
|
||||
|
||||
|
||||
class ConfigModule(ModuleType):
|
||||
# NOTE: This should be kept in sync with _config_typing.pyi.
|
||||
|
||||
# The default values of the configuration settings. This can be used to
|
||||
# determine if the config has been changed or not.
|
||||
_default: Dict[str, Any]
|
||||
# The actual configuration settings. E.g., torch._dynamo.config.debug
|
||||
# would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
|
||||
# maps as "triton.cudagraphs"
|
||||
_config: Dict[str, Any]
|
||||
_allowed_keys: Set[str]
|
||||
# maps as "triton.cudagraphs". See discussion on the class for meaning of various sub items
|
||||
_config: Dict[str, _ConfigEntry]
|
||||
_bypass_keys: Set[str]
|
||||
_compile_ignored_keys: Set[str]
|
||||
_is_dirty: bool
|
||||
@ -140,15 +145,25 @@ class ConfigModule(ModuleType):
|
||||
def __setattr__(self, name: str, value: object) -> None:
|
||||
if name in self._bypass_keys:
|
||||
super().__setattr__(name, value)
|
||||
elif name not in self._allowed_keys:
|
||||
elif name not in self._config:
|
||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||
else:
|
||||
self._config[name] = value
|
||||
self._config[name].user_override = value
|
||||
self._is_dirty = True
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._config[name]
|
||||
config = self._config[name]
|
||||
if config.user_override is not _UNSET_SENTINEL:
|
||||
return config.user_override
|
||||
|
||||
# Note that reference types can still be modified, so we
|
||||
# copy them to user_overrides in case the user overrides
|
||||
# them
|
||||
if isinstance(config.default, (list, set, dict)):
|
||||
config.user_override = copy.deepcopy(config.default)
|
||||
return config.user_override
|
||||
return config.default
|
||||
except KeyError as e:
|
||||
# make hasattr() work properly
|
||||
raise AttributeError(f"{self.__name__}.{name} does not exist") from e
|
||||
@ -157,7 +172,10 @@ class ConfigModule(ModuleType):
|
||||
self._is_dirty = True
|
||||
# must support delete because unittest.mock.patch deletes
|
||||
# then recreate things
|
||||
del self._config[name]
|
||||
self._config[name].user_override = _UNSET_SENTINEL
|
||||
|
||||
def _is_default(self, name: str) -> bool:
|
||||
return self._config[name].user_override is _UNSET_SENTINEL
|
||||
|
||||
def _get_dict(
|
||||
self,
|
||||
@ -184,7 +202,7 @@ class ConfigModule(ModuleType):
|
||||
config: Dict[str, Any] = {}
|
||||
for key in self._config:
|
||||
if ignored_keys and key in ignored_keys:
|
||||
if skip_default and self._config[key] != self._default[key]:
|
||||
if skip_default and not self._is_default(key):
|
||||
warnings.warn(
|
||||
f"Skipping serialization of {key} value {self._config[key]}"
|
||||
)
|
||||
@ -192,22 +210,23 @@ class ConfigModule(ModuleType):
|
||||
if ignored_prefixes:
|
||||
if any(key.startswith(prefix) for prefix in ignored_prefixes):
|
||||
continue
|
||||
if skip_default and self._config[key] == self._default[key]:
|
||||
if skip_default and self._is_default(key):
|
||||
continue
|
||||
config[key] = copy.deepcopy(self._config[key])
|
||||
config[key] = copy.deepcopy(getattr(self, key))
|
||||
return config
|
||||
|
||||
def save_config(self) -> bytes:
|
||||
"""Convert config to a pickled blob"""
|
||||
ignored_keys = getattr(self, "_save_config_ignore", [])
|
||||
return pickle.dumps(
|
||||
self._get_dict(ignored_keys=self._config.get("_save_config_ignore", ())),
|
||||
self._get_dict(ignored_keys=ignored_keys),
|
||||
protocol=2,
|
||||
)
|
||||
|
||||
def save_config_portable(self) -> Dict[str, Any]:
|
||||
"""Convert config to portable format"""
|
||||
prefixes = ["_"]
|
||||
prefixes.extend(self._config["_cache_config_ignore_prefix"])
|
||||
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
|
||||
return self._get_dict(ignored_prefixes=prefixes)
|
||||
|
||||
def codegen_config(self) -> str:
|
||||
@ -217,7 +236,7 @@ class ConfigModule(ModuleType):
|
||||
lines = []
|
||||
mod = self.__name__
|
||||
for k, v in self._get_dict(
|
||||
ignored_keys=self._config.get("_save_config_ignore"), skip_default=True
|
||||
ignored_keys=getattr(self, "_save_config_ignore", []), skip_default=True
|
||||
).items():
|
||||
lines.append(f"{mod}.{k} = {v!r}")
|
||||
return "\n".join(lines)
|
||||
@ -255,7 +274,13 @@ class ConfigModule(ModuleType):
|
||||
config = pickle.loads(maybe_pickled_config)
|
||||
else:
|
||||
config = maybe_pickled_config
|
||||
self._config.update(config)
|
||||
for k, v in config.items():
|
||||
if k in self._config:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"key {k} with value {v} is not understood by this config"
|
||||
)
|
||||
|
||||
def get_config_copy(self) -> Dict[str, Any]:
|
||||
return self._get_dict()
|
||||
@ -338,11 +363,13 @@ class ConfigModule(ModuleType):
|
||||
config = self._config
|
||||
|
||||
def change() -> Callable[[], None]:
|
||||
prior = {k: config[k] for k in changes}
|
||||
config.update(changes)
|
||||
prior = {k: config[k].user_override for k in changes}
|
||||
for k, v in changes.items():
|
||||
self._config[k].user_override = v
|
||||
|
||||
def revert() -> None:
|
||||
config.update(prior)
|
||||
for k, v in prior.items():
|
||||
self._config[k].user_override = v
|
||||
|
||||
return revert
|
||||
|
||||
|
Reference in New Issue
Block a user