config: Don't spam warnings about reference type configs (#145800)

Summary:
https://github.com/pytorch/pytorch/issues/145755

The is_dynamic check for reference types was subtly broken, causing log spam
after it was accessed

Added an explicit type for is_default for reference types to make sure this
behaviour is correct
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145800
Approved by: https://github.com/eellison
This commit is contained in:
clr
2025-01-30 10:39:27 -08:00
committed by PyTorch MergeBot
parent 5a527fa5ee
commit f746bb6311
2 changed files with 17 additions and 17 deletions

View File

@ -194,8 +194,7 @@ torch.testing._internal.fake_config_module.e_env_default = True
torch.testing._internal.fake_config_module.e_env_default_FALSE = False
torch.testing._internal.fake_config_module.e_env_default_str = '1234'
torch.testing._internal.fake_config_module.e_env_default_str_empty = ''
torch.testing._internal.fake_config_module.e_env_force = True
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
torch.testing._internal.fake_config_module.e_env_force = True""",
)
def test_codegen_config_function(self):
@ -390,6 +389,12 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""",
with config.patch(e_aliased_bool=True):
self.assertTrue(config2.e_aliasing_bool)
def test_reference_is_default(self):
t = config.e_dict
self.assertTrue(config._is_default("e_dict"))
t["a"] = "b"
self.assertFalse(config._is_default("e_dict"))
if __name__ == "__main__":
run_tests()

View File

@ -9,7 +9,6 @@ import pickle
import sys
import tokenize
import unittest
import warnings
from dataclasses import dataclass
from types import FunctionType, ModuleType
from typing import (
@ -444,11 +443,12 @@ class ConfigModule(ModuleType):
config_val.env_value_force is _UNSET_SENTINEL
or config_val.env_value_force == config_val.default
)
return (
config_val.user_override is _UNSET_SENTINEL
and not_set_env_default
and not_set_env_force
)
unset = config_val.user_override is _UNSET_SENTINEL
# Handle reference types specially to avoid spammy warnings
if isinstance(config_val.default, (list, set, dict)):
unset = unset or config_val.user_override == config_val.default
return unset and not_set_env_default and not_set_env_force
def _get_dict(
self,
@ -470,16 +470,11 @@ class ConfigModule(ModuleType):
ignored_prefixes are prefixes that if a key matches should
not be exported
skip_default does two things. One if a key has not been modified
it skips it. The other is it modified the logging behaviour
to match what codegen already did for modified skipped keys
it skips it.
"""
config: dict[str, Any] = {}
for key in self._config:
if ignored_keys and key in ignored_keys:
if skip_default and not self._is_default(key):
warnings.warn(
f"Skipping serialization of {key} value {getattr(self, key)}"
)
continue
if ignored_prefixes:
if any(key.startswith(prefix) for prefix in ignored_prefixes):
@ -611,9 +606,9 @@ class ConfigModule(ModuleType):
if k in self._config:
setattr(self, k, v)
else:
warnings.warn(
f"key {k} with value {v} is not understood by this config"
)
from torch._dynamo.utils import warn_once
warn_once(f"key {k} with value {v} is not understood by this config")
def get_config_copy(self) -> dict[str, Any]:
return self._get_dict()