mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Add "all" option to logging (#100664)
Adds the long-promised "all" option to logging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100664 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
f1b2e00700
commit
850556ed6e
@ -39,6 +39,9 @@ The following components and artifacts are configurable through the ``TORCH_LOGS
|
||||
variable (see torch._logging.set_logs for the python API):
|
||||
|
||||
Components:
|
||||
``all``
|
||||
Special component which configures the default log level of all components. Default: ``logging.WARN``
|
||||
|
||||
``dynamo``
|
||||
The log level for the TorchDynamo component. Default: ``logging.WARN``
|
||||
|
||||
|
||||
@ -200,6 +200,20 @@ class LoggingTests(LoggingTestCase):
|
||||
logger.info("hi")
|
||||
self.assertEqual(len(records), 1)
|
||||
|
||||
@make_logging_test(all=logging.DEBUG, dynamo=logging.INFO)
|
||||
def test_all(self, _):
|
||||
registry = torch._logging._internal.log_registry
|
||||
state = torch._logging._internal.log_state
|
||||
|
||||
dynamo_qname = registry.log_alias_to_log_qname["dynamo"]
|
||||
for logger_qname in torch._logging._internal.log_registry.get_log_qnames():
|
||||
logger = logging.getLogger(logger_qname)
|
||||
|
||||
if logger_qname == dynamo_qname:
|
||||
self.assertEqual(logger.level, logging.INFO)
|
||||
else:
|
||||
self.assertEqual(logger.level, logging.DEBUG)
|
||||
|
||||
|
||||
# single record tests
|
||||
exclusions = {
|
||||
|
||||
@ -118,9 +118,11 @@ log_state = LogState()
|
||||
|
||||
def set_logs(
|
||||
*,
|
||||
dynamo: int = DEFAULT_LOG_LEVEL,
|
||||
aot: int = DEFAULT_LOG_LEVEL,
|
||||
inductor: int = DEFAULT_LOG_LEVEL,
|
||||
all: Optional[int] = None,
|
||||
dynamo: Optional[int] = None,
|
||||
aot: Optional[int] = None,
|
||||
dynamic: int = None,
|
||||
inductor: int = None,
|
||||
bytecode: bool = False,
|
||||
aot_graphs: bool = False,
|
||||
aot_joint_graph: bool = False,
|
||||
@ -169,15 +171,21 @@ def set_logs(
|
||||
is set to a log level less than or equal to the log level of the artifact.
|
||||
|
||||
Keyword args:
|
||||
dynamo (:class:`int`):
|
||||
all (:class:`Optional[int]`):
|
||||
The default log level for all components. Default: ``logging.WARN``
|
||||
|
||||
dynamo (:class:`Optional[int]`):
|
||||
The log level for the TorchDynamo component. Default: ``logging.WARN``
|
||||
|
||||
aot (:class:`int`):
|
||||
aot (:class:`Optional[int]`):
|
||||
The log level for the AOTAutograd component. Default: ``logging.WARN``
|
||||
|
||||
inductor (:class:`int`):
|
||||
inductor (:class:`Optional[int]`):
|
||||
The log level for the TorchInductor component. Default: ``logging.WARN``
|
||||
|
||||
dynamic (:class:`Optional[int]`):
|
||||
The log level for dynamic shapes. Default: ``logging.WARN``
|
||||
|
||||
bytecode (:class:`bool`):
|
||||
Whether to emit the original and generated bytecode from TorchDynamo.
|
||||
Default: ``False``
|
||||
@ -250,7 +258,25 @@ def set_logs(
|
||||
modules = modules or {}
|
||||
|
||||
def _set_logs(**kwargs):
|
||||
default_level = kwargs.pop("all", None)
|
||||
if default_level:
|
||||
if default_level not in logging._levelToName:
|
||||
raise ValueError(
|
||||
f"Unrecognized log level for kwarg all: {default_level}, valid level values "
|
||||
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
|
||||
)
|
||||
|
||||
# add any missing aliases to kwargs
|
||||
for alias in log_registry.log_alias_to_log_qname.keys():
|
||||
if alias not in kwargs:
|
||||
kwargs[alias] = default_level
|
||||
else:
|
||||
default_level = DEFAULT_LOG_LEVEL
|
||||
|
||||
for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
|
||||
if val is None:
|
||||
val = default_level
|
||||
|
||||
if log_registry.is_artifact(alias):
|
||||
if val:
|
||||
log_state.enable_artifact(alias)
|
||||
@ -260,10 +286,10 @@ def set_logs(
|
||||
f"Unrecognized log level for log {alias}: {val}, valid level values "
|
||||
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
|
||||
)
|
||||
if val != DEFAULT_LOG_LEVEL:
|
||||
log_state.enable_log(
|
||||
log_registry.log_alias_to_log_qname[alias], val
|
||||
)
|
||||
|
||||
log_state.enable_log(log_registry.log_alias_to_log_qname[alias], val)
|
||||
elif alias == "all":
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized log or artifact name passed to set_logs: {alias}"
|
||||
@ -272,9 +298,11 @@ def set_logs(
|
||||
_init_logs()
|
||||
|
||||
_set_logs(
|
||||
all=all,
|
||||
dynamo=dynamo,
|
||||
aot=aot,
|
||||
inductor=inductor,
|
||||
dynamic=dynamic,
|
||||
bytecode=bytecode,
|
||||
aot_graphs=aot_graphs,
|
||||
aot_joint_graph=aot_joint_graph,
|
||||
@ -357,7 +385,9 @@ def _validate_settings(settings):
|
||||
def _invalid_settings_err_msg(settings):
|
||||
entities = "\n " + "\n ".join(
|
||||
itertools.chain(
|
||||
log_registry.log_alias_to_log_qname.keys(), log_registry.artifact_names
|
||||
["all"],
|
||||
log_registry.log_alias_to_log_qname.keys(),
|
||||
log_registry.artifact_names,
|
||||
)
|
||||
)
|
||||
msg = (
|
||||
@ -392,14 +422,24 @@ def _parse_log_settings(settings):
|
||||
return clean_name, level
|
||||
|
||||
log_state = LogState()
|
||||
|
||||
for name in log_names:
|
||||
name, level = get_name_level_pair(name)
|
||||
if name == "all":
|
||||
for log_qname in log_registry.get_log_qnames():
|
||||
log_state.enable_log(log_qname, level)
|
||||
|
||||
for name in log_names:
|
||||
name, level = get_name_level_pair(name)
|
||||
|
||||
if log_registry.is_log(name):
|
||||
assert level is not None
|
||||
log_qname = log_registry.log_alias_to_log_qname[name]
|
||||
log_state.enable_log(log_qname, level)
|
||||
elif log_registry.is_artifact(name):
|
||||
log_state.enable_artifact(name)
|
||||
elif name == "all":
|
||||
continue
|
||||
elif _is_valid_module(name):
|
||||
if not _has_registered_parent(name):
|
||||
log_registry.register_log(name, name)
|
||||
|
||||
Reference in New Issue
Block a user