config: Modify install_config_module to use a layered approach (#138758)

This modifies the config system, to use a single mapping of config ->
ConfigEntry and to store the default and user values within them.

We could have used multiple dicts (i.e. user_override and default), but
as we add more fields (justknobs in this PR, perhaps testing and env
variables later), it quickly becomes painful.

There are a couple design decisions we could change.
1) All configs we save store the resolved value - not the default and
   user override seperately
2) All configs we load, apply the resolved value as a user override.

This means that certain complexities of default behvaiour and deletion
(as well as JK), will change if you save + load a config.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138758
Approved by: https://github.com/ezyang
This commit is contained in:
Colin L. Rice
2024-10-29 14:38:06 -06:00
committed by PyTorch MergeBot
parent 46d0b635b9
commit a0e095dd9f
3 changed files with 89 additions and 54 deletions

View File

@ -120,7 +120,7 @@ class TestInductorConfig(TestCase):
def test_get_compiler_config(self):
from torch._inductor import config as inductor_default_config
default_cudagraphs = inductor_default_config._default["triton.cudagraphs"]
default_cudagraphs = inductor_default_config.triton.cudagraphs
# nn.Module: should update default config with a new value
model = DummyModule()

View File

@ -3,6 +3,7 @@ import pickle
from torch.testing._internal import fake_config_module as config
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL
class TestConfigModule(TestCase):
@ -50,26 +51,24 @@ class TestConfigModule(TestCase):
):
config.does_not_exist = 0
# Config changes get persisted between test cases
config.e_bool = True
config.nested.e_bool = True
config.e_int = 1
config.e_float = 1.0
config.e_string = "string"
config.e_list = [1]
config.e_set = {1}
config.e_tuple = (1,)
config.e_dict = {1: 2}
config.e_none = None
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_delete(self):
self.assertTrue(config.e_bool)
del config.e_bool
with self.assertRaises(
AttributeError, msg="fake_config_module.e_bool does not exist"
):
print(config.e_bool)
# Config changes get persisted between test cases
config.e_bool = True
def test_none_override_semantics(self):
config.e_bool = None
self.assertIsNone(config.e_bool)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_reference_semantics(self):
config.e_list.append(2)
self.assertEqual(config.e_list, [1, 2])
config.e_set.add(2)
self.assertEqual(config.e_set, {1, 2})
config.e_dict[2] = 3
self.assertEqual(config.e_dict, {1: 2, 2: 3})
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_save_config(self):
p = config.save_config()
@ -98,8 +97,8 @@ class TestConfigModule(TestCase):
config.load_config(p)
self.assertTrue(config.e_bool)
self.assertFalse(config.e_ignored)
# Config changes get persisted between test cases
config.e_ignored = True
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_save_config_portable(self):
p = config.save_config_portable()
@ -126,18 +125,24 @@ class TestConfigModule(TestCase):
self.assertTrue(config.e_bool)
self.assertFalse(config._e_ignored)
# Config changes get persisted between test cases
config._e_ignored = True
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_codegen_config(self):
config.e_bool = False
config.e_ignored = False
code = config.codegen_config()
self.assertEqual(
code, "torch.testing._internal.fake_config_module.e_bool = False"
code,
"""torch.testing._internal.fake_config_module.e_bool = False
torch.testing._internal.fake_config_module.e_list = [1]
torch.testing._internal.fake_config_module.e_set = {1}
torch.testing._internal.fake_config_module.e_dict = {1: 2}
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
)
# Config changes get persisted between test cases
config.e_bool = True
config.e_ignored = True
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_get_hash(self):
self.assertEqual(
@ -165,7 +170,8 @@ class TestConfigModule(TestCase):
self.assertEqual(
config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
)
config.e_compile_ignored = True
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_dict_copy_semantics(self):
p = config.shallow_copy_dict()
@ -240,9 +246,11 @@ class TestConfigModule(TestCase):
self.assertEqual(p["e_dict"], {1: 2})
self.assertEqual(p2["e_dict"], {1: 2})
self.assertEqual(p3["e_dict"], {1: 2})
config.e_dict = {1: 2}
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_patch(self):
self.assertTrue(config.e_bool)
with config.patch("e_bool", False):
self.assertFalse(config.e_bool)
self.assertTrue(config.e_bool)

View File

@ -7,6 +7,7 @@ import pickle
import tokenize
import unittest
import warnings
from dataclasses import dataclass
from types import FunctionType, ModuleType
from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union
from typing_extensions import deprecated
@ -43,8 +44,7 @@ def install_config_module(module: ModuleType) -> None:
name = f"{prefix}{key}"
if isinstance(value, CONFIG_TYPES):
config[name] = value
default[name] = value
config[name] = _ConfigEntry(default=value)
if dest is module:
delattr(module, key)
elif isinstance(value, type):
@ -59,15 +59,12 @@ def install_config_module(module: ModuleType) -> None:
else:
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
config: Dict[str, Any] = {}
default: Dict[str, Any] = {}
config: Dict[str, _ConfigEntry] = {}
compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)
visit(module, module, "")
module._config = config # type: ignore[attr-defined]
module._default = default # type: ignore[attr-defined]
module._allowed_keys = set(config.keys()) # type: ignore[attr-defined]
module._compile_ignored_keys = compile_ignored_keys # type: ignore[attr-defined]
module.__class__ = ConfigModuleInstance
module._is_dirty = True # type: ignore[attr-defined]
@ -116,17 +113,25 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str
return assignments
_UNSET_SENTINEL = object()
@dataclass
class _ConfigEntry:
# The default value specified in the configuration
default: Any
# The value specified by the user when they overrode the configuration
# _UNSET_SENTINEL indicates the value is not set.
user_override: Any = _UNSET_SENTINEL
class ConfigModule(ModuleType):
# NOTE: This should be kept in sync with _config_typing.pyi.
# The default values of the configuration settings. This can be used to
# determine if the config has been changed or not.
_default: Dict[str, Any]
# The actual configuration settings. E.g., torch._dynamo.config.debug
# would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
# maps as "triton.cudagraphs"
_config: Dict[str, Any]
_allowed_keys: Set[str]
# maps as "triton.cudagraphs". See discussion on the class for meaning of various sub items
_config: Dict[str, _ConfigEntry]
_bypass_keys: Set[str]
_compile_ignored_keys: Set[str]
_is_dirty: bool
@ -140,15 +145,25 @@ class ConfigModule(ModuleType):
def __setattr__(self, name: str, value: object) -> None:
if name in self._bypass_keys:
super().__setattr__(name, value)
elif name not in self._allowed_keys:
elif name not in self._config:
raise AttributeError(f"{self.__name__}.{name} does not exist")
else:
self._config[name] = value
self._config[name].user_override = value
self._is_dirty = True
def __getattr__(self, name: str) -> Any:
try:
return self._config[name]
config = self._config[name]
if config.user_override is not _UNSET_SENTINEL:
return config.user_override
# Note that reference types can still be modified, so we
# copy them to user_overrides in case the user overrides
# them
if isinstance(config.default, (list, set, dict)):
config.user_override = copy.deepcopy(config.default)
return config.user_override
return config.default
except KeyError as e:
# make hasattr() work properly
raise AttributeError(f"{self.__name__}.{name} does not exist") from e
@ -157,7 +172,10 @@ class ConfigModule(ModuleType):
self._is_dirty = True
# must support delete because unittest.mock.patch deletes
# then recreate things
del self._config[name]
self._config[name].user_override = _UNSET_SENTINEL
def _is_default(self, name: str) -> bool:
return self._config[name].user_override is _UNSET_SENTINEL
def _get_dict(
self,
@ -184,7 +202,7 @@ class ConfigModule(ModuleType):
config: Dict[str, Any] = {}
for key in self._config:
if ignored_keys and key in ignored_keys:
if skip_default and self._config[key] != self._default[key]:
if skip_default and not self._is_default(key):
warnings.warn(
f"Skipping serialization of {key} value {self._config[key]}"
)
@ -192,22 +210,23 @@ class ConfigModule(ModuleType):
if ignored_prefixes:
if any(key.startswith(prefix) for prefix in ignored_prefixes):
continue
if skip_default and self._config[key] == self._default[key]:
if skip_default and self._is_default(key):
continue
config[key] = copy.deepcopy(self._config[key])
config[key] = copy.deepcopy(getattr(self, key))
return config
def save_config(self) -> bytes:
"""Convert config to a pickled blob"""
ignored_keys = getattr(self, "_save_config_ignore", [])
return pickle.dumps(
self._get_dict(ignored_keys=self._config.get("_save_config_ignore", ())),
self._get_dict(ignored_keys=ignored_keys),
protocol=2,
)
def save_config_portable(self) -> Dict[str, Any]:
"""Convert config to portable format"""
prefixes = ["_"]
prefixes.extend(self._config["_cache_config_ignore_prefix"])
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
return self._get_dict(ignored_prefixes=prefixes)
def codegen_config(self) -> str:
@ -217,7 +236,7 @@ class ConfigModule(ModuleType):
lines = []
mod = self.__name__
for k, v in self._get_dict(
ignored_keys=self._config.get("_save_config_ignore"), skip_default=True
ignored_keys=getattr(self, "_save_config_ignore", []), skip_default=True
).items():
lines.append(f"{mod}.{k} = {v!r}")
return "\n".join(lines)
@ -255,7 +274,13 @@ class ConfigModule(ModuleType):
config = pickle.loads(maybe_pickled_config)
else:
config = maybe_pickled_config
self._config.update(config)
for k, v in config.items():
if k in self._config:
setattr(self, k, v)
else:
warnings.warn(
f"key {k} with value {v} is not understood by this config"
)
def get_config_copy(self) -> Dict[str, Any]:
return self._get_dict()
@ -338,11 +363,13 @@ class ConfigModule(ModuleType):
config = self._config
def change() -> Callable[[], None]:
prior = {k: config[k] for k in changes}
config.update(changes)
prior = {k: config[k].user_override for k in changes}
for k, v in changes.items():
self._config[k].user_override = v
def revert() -> None:
config.update(prior)
for k, v in prior.items():
self._config[k].user_override = v
return revert