mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Discovered breakages by enabling codecache by default and doing a CI run. I'll commit these fixes first and eventually enabling caching by default will (hopefully) be a one-liner. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125258 Approved by: https://github.com/eellison
214 lines
6.8 KiB
Python
214 lines
6.8 KiB
Python
# mypy: ignore-errors
|
|
|
|
import torch._dynamo.test_case
|
|
import unittest.mock
|
|
import os
|
|
import contextlib
|
|
import torch._logging
|
|
import torch._logging._internal
|
|
from torch._dynamo.utils import LazyString
|
|
from torch._inductor import config as inductor_config
|
|
import logging
|
|
import io
|
|
|
|
@contextlib.contextmanager
|
|
def preserve_log_state():
|
|
prev_state = torch._logging._internal._get_log_state()
|
|
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._logging._internal._set_log_state(prev_state)
|
|
torch._logging._internal._init_logs()
|
|
|
|
def log_settings(settings):
|
|
exit_stack = contextlib.ExitStack()
|
|
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
|
|
exit_stack.enter_context(preserve_log_state())
|
|
exit_stack.enter_context(settings_patch)
|
|
torch._logging._internal._init_logs()
|
|
return exit_stack
|
|
|
|
def log_api(**kwargs):
|
|
exit_stack = contextlib.ExitStack()
|
|
exit_stack.enter_context(preserve_log_state())
|
|
torch._logging.set_logs(**kwargs)
|
|
return exit_stack
|
|
|
|
|
|
def kwargs_to_settings(**kwargs):
|
|
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
|
|
|
|
settings = []
|
|
|
|
def append_setting(name, level):
|
|
if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
|
|
settings.append(INT_TO_VERBOSITY[level] + name)
|
|
return
|
|
else:
|
|
raise ValueError("Invalid value for setting")
|
|
|
|
for name, val in kwargs.items():
|
|
if isinstance(val, bool):
|
|
settings.append(name)
|
|
elif isinstance(val, int):
|
|
append_setting(name, val)
|
|
elif isinstance(val, dict) and name == "modules":
|
|
for module_qname, level in val.items():
|
|
append_setting(module_qname, level)
|
|
else:
|
|
raise ValueError("Invalid value for setting")
|
|
|
|
return ",".join(settings)
|
|
|
|
|
|
# Note on testing strategy:
|
|
# This class does two things:
|
|
# 1. Runs two versions of a test:
|
|
# 1a. patches the env var log settings to some specific value
|
|
# 1b. calls torch._logging.set_logs(..)
|
|
# 2. patches the emit method of each setup handler to gather records
|
|
# that are emitted to each console stream
|
|
# 3. passes a ref to the gathered records to each test case for checking
|
|
#
|
|
# The goal of this testing in general is to ensure that given some settings env var
|
|
# that the logs are setup correctly and capturing the correct records.
|
|
def make_logging_test(**kwargs):
|
|
def wrapper(fn):
|
|
@inductor_config.patch({"fx_graph_cache": False})
|
|
def test_fn(self):
|
|
|
|
torch._dynamo.reset()
|
|
records = []
|
|
# run with env var
|
|
if len(kwargs) == 0:
|
|
with self._handler_watcher(records):
|
|
fn(self, records)
|
|
else:
|
|
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
# run with API
|
|
torch._dynamo.reset()
|
|
records.clear()
|
|
with log_api(**kwargs), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
|
|
return test_fn
|
|
|
|
return wrapper
|
|
|
|
def make_settings_test(settings):
|
|
def wrapper(fn):
|
|
def test_fn(self):
|
|
torch._dynamo.reset()
|
|
records = []
|
|
# run with env var
|
|
with log_settings(settings), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
return test_fn
|
|
|
|
return wrapper
|
|
|
|
class LoggingTestCase(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
|
|
)
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
|
|
)
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch("torch._dynamo.config.verbose", False)
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
torch._logging._internal.log_state.clear()
|
|
torch._logging._init_logs()
|
|
|
|
def hasRecord(self, records, m):
|
|
return any(m in r.getMessage() for r in records)
|
|
|
|
def getRecord(self, records, m):
|
|
record = None
|
|
for r in records:
|
|
# NB: not r.msg because it looks like 3.11 changed how they
|
|
# structure log records
|
|
if m in r.getMessage():
|
|
self.assertIsNone(
|
|
record,
|
|
msg=LazyString(
|
|
lambda: f"multiple matching records: {record} and {r} among {records}"
|
|
),
|
|
)
|
|
record = r
|
|
if record is None:
|
|
self.fail(f"did not find record with {m} among {records}")
|
|
return record
|
|
|
|
# This patches the emit method of each handler to gather records
|
|
# as they are emitted
|
|
def _handler_watcher(self, record_list):
|
|
exit_stack = contextlib.ExitStack()
|
|
|
|
def emit_post_hook(record):
|
|
nonlocal record_list
|
|
record_list.append(record)
|
|
|
|
# registered logs are the only ones with handlers, so patch those
|
|
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
|
|
logger = logging.getLogger(log_qname)
|
|
num_handlers = len(logger.handlers)
|
|
self.assertLessEqual(
|
|
num_handlers,
|
|
2,
|
|
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
|
|
)
|
|
|
|
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
|
|
|
|
for handler in logger.handlers:
|
|
old_emit = handler.emit
|
|
|
|
def new_emit(record):
|
|
old_emit(record)
|
|
emit_post_hook(record)
|
|
|
|
exit_stack.enter_context(
|
|
unittest.mock.patch.object(handler, "emit", new_emit)
|
|
)
|
|
|
|
return exit_stack
|
|
|
|
|
|
def logs_to_string(module, log_option):
|
|
"""Example:
|
|
logs_to_string("torch._inductor.compile_fx", "post_grad_graphs")
|
|
returns the output of TORCH_LOGS="post_grad_graphs" from the
|
|
torch._inductor.compile_fx module.
|
|
"""
|
|
log_stream = io.StringIO()
|
|
handler = logging.StreamHandler(stream=log_stream)
|
|
|
|
@contextlib.contextmanager
|
|
def tmp_redirect_logs():
|
|
try:
|
|
logger = torch._logging.getArtifactLogger(module, log_option)
|
|
logger.addHandler(handler)
|
|
yield
|
|
finally:
|
|
logger.removeHandler(handler)
|
|
|
|
def ctx_manager():
|
|
exit_stack = log_settings(log_option)
|
|
exit_stack.enter_context(tmp_redirect_logs())
|
|
return exit_stack
|
|
|
|
return log_stream, ctx_manager
|