mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Summary: Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS` Examples: `TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled `TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled [More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce) Implementation: The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled. Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized. Adding a new log: To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already). Adding a new artifact: To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well. Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation. [design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#) Pull Request resolved: https://github.com/pytorch/pytorch/pull/94858 Approved by: https://github.com/ezyang
145 lines
4.5 KiB
Python
145 lines
4.5 KiB
Python
import torch._dynamo.test_case
|
|
import unittest.mock
|
|
import os
|
|
import contextlib
|
|
import torch._logging
|
|
import torch._logging._internal
|
|
import logging
|
|
|
|
@contextlib.contextmanager
|
|
def preserve_log_state():
|
|
prev_state = torch._logging._internal._get_log_state()
|
|
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._logging._internal._set_log_state(prev_state)
|
|
torch._logging._internal._init_logs()
|
|
|
|
def log_settings(settings):
|
|
exit_stack = contextlib.ExitStack()
|
|
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
|
|
exit_stack.enter_context(preserve_log_state())
|
|
exit_stack.enter_context(settings_patch)
|
|
torch._logging._internal._init_logs()
|
|
return exit_stack
|
|
|
|
def log_api(**kwargs):
|
|
exit_stack = contextlib.ExitStack()
|
|
exit_stack.enter_context(preserve_log_state())
|
|
torch._logging.set_logs(**kwargs)
|
|
return exit_stack
|
|
|
|
|
|
def kwargs_to_settings(**kwargs):
|
|
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
|
|
|
|
settings = []
|
|
for name, val in kwargs.items():
|
|
if isinstance(val, bool):
|
|
settings.append(name)
|
|
elif isinstance(val, int):
|
|
if val in INT_TO_VERBOSITY:
|
|
settings.append(INT_TO_VERBOSITY[val] + name)
|
|
else:
|
|
raise ValueError("Invalid value for setting")
|
|
|
|
return ",".join(settings)
|
|
|
|
|
|
# Note on testing strategy:
|
|
# This class does two things:
|
|
# 1. Runs two versions of a test:
|
|
# 1a. patches the env var log settings to some specific value
|
|
# 1b. calls torch._logging.set_logs(..)
|
|
# 2. patches the emit method of each setup handler to gather records
|
|
# that are emitted to each console stream
|
|
# 3. passes a ref to the gathered records to each test case for checking
|
|
#
|
|
# The goal of this testing in general is to ensure that given some settings env var
|
|
# that the logs are setup correctly and capturing the correct records.
|
|
def make_logging_test(**kwargs):
|
|
def wrapper(fn):
|
|
def test_fn(self):
|
|
|
|
torch._dynamo.reset()
|
|
records = []
|
|
# run with env var
|
|
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
# run with API
|
|
torch._dynamo.reset()
|
|
records.clear()
|
|
with log_api(**kwargs), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
|
|
return test_fn
|
|
|
|
return wrapper
|
|
|
|
def make_settings_test(settings):
|
|
def wrapper(fn):
|
|
def test_fn(self):
|
|
torch._dynamo.reset()
|
|
records = []
|
|
# run with env var
|
|
with log_settings(settings), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
return test_fn
|
|
|
|
return wrapper
|
|
|
|
class LoggingTestCase(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
|
|
)
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
torch._logging._internal.log_state.clear()
|
|
torch._logging._init_logs()
|
|
|
|
# This patches the emit method of each handler to gather records
|
|
# as they are emitted
|
|
def _handler_watcher(self, record_list):
|
|
exit_stack = contextlib.ExitStack()
|
|
|
|
def emit_post_hook(record):
|
|
nonlocal record_list
|
|
record_list.append(record)
|
|
|
|
# registered logs are the only ones with handlers, so patch those
|
|
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
|
|
logger = logging.getLogger(log_qname)
|
|
num_handlers = len(logger.handlers)
|
|
self.assertLessEqual(
|
|
num_handlers,
|
|
2,
|
|
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
|
|
)
|
|
|
|
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
|
|
|
|
for handler in logger.handlers:
|
|
old_emit = handler.emit
|
|
|
|
def new_emit(record):
|
|
old_emit(record)
|
|
emit_post_hook(record)
|
|
|
|
exit_stack.enter_context(
|
|
unittest.mock.patch.object(handler, "emit", new_emit)
|
|
)
|
|
|
|
return exit_stack
|