Files
pytorch/torch/testing/_internal/logging_utils.py
Michael Lazos a1c46e5f8f component-level configurable logging for dynamo, inductor, aot (#94858)
Summary:

Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS`
Examples:
`TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled

`TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled

[More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce)

Implementation:
The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled.

Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized.

Adding a new log:
To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already).

Adding a new artifact:
To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well.
Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation.

[design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94858
Approved by: https://github.com/ezyang
2023-03-18 04:17:31 +00:00

145 lines
4.5 KiB
Python

import torch._dynamo.test_case
import unittest.mock
import os
import contextlib
import torch._logging
import torch._logging._internal
import logging
@contextlib.contextmanager
def preserve_log_state():
prev_state = torch._logging._internal._get_log_state()
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
try:
yield
finally:
torch._logging._internal._set_log_state(prev_state)
torch._logging._internal._init_logs()
def log_settings(settings):
exit_stack = contextlib.ExitStack()
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
exit_stack.enter_context(preserve_log_state())
exit_stack.enter_context(settings_patch)
torch._logging._internal._init_logs()
return exit_stack
def log_api(**kwargs):
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(preserve_log_state())
torch._logging.set_logs(**kwargs)
return exit_stack
def kwargs_to_settings(**kwargs):
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
settings = []
for name, val in kwargs.items():
if isinstance(val, bool):
settings.append(name)
elif isinstance(val, int):
if val in INT_TO_VERBOSITY:
settings.append(INT_TO_VERBOSITY[val] + name)
else:
raise ValueError("Invalid value for setting")
return ",".join(settings)
# Note on testing strategy:
# This class does two things:
# 1. Runs two versions of a test:
# 1a. patches the env var log settings to some specific value
# 1b. calls torch._logging.set_logs(..)
# 2. patches the emit method of each setup handler to gather records
# that are emitted to each console stream
# 3. passes a ref to the gathered records to each test case for checking
#
# The goal of this testing in general is to ensure that given some settings env var
# that the logs are setup correctly and capturing the correct records.
def make_logging_test(**kwargs):
def wrapper(fn):
def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
fn(self, records)
# run with API
torch._dynamo.reset()
records.clear()
with log_api(**kwargs), self._handler_watcher(records):
fn(self, records)
return test_fn
return wrapper
def make_settings_test(settings):
def wrapper(fn):
def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
with log_settings(settings), self._handler_watcher(records):
fn(self, records)
return test_fn
return wrapper
class LoggingTestCase(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
)
cls._exit_stack.enter_context(
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
)
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
torch._logging._internal.log_state.clear()
torch._logging._init_logs()
# This patches the emit method of each handler to gather records
# as they are emitted
def _handler_watcher(self, record_list):
exit_stack = contextlib.ExitStack()
def emit_post_hook(record):
nonlocal record_list
record_list.append(record)
# registered logs are the only ones with handlers, so patch those
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
logger = logging.getLogger(log_qname)
num_handlers = len(logger.handlers)
self.assertLessEqual(
num_handlers,
2,
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
)
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
for handler in logger.handlers:
old_emit = handler.emit
def new_emit(record):
old_emit(record)
emit_post_hook(record)
exit_stack.enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)
return exit_stack