mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
component-level configurable logging for dynamo, inductor, aot (#94858)
Summary: Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS` Examples: `TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled `TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled [More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce) Implementation: The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled. Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized. Adding a new log: To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already). Adding a new artifact: To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well. Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation. [design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#) Pull Request resolved: https://github.com/pytorch/pytorch/pull/94858 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
086ce765a5
commit
a1c46e5f8f
@ -862,6 +862,7 @@ include_patterns = [
|
||||
'test/test_value_ranges.py',
|
||||
'torch/utils/_sympy/interp.py',
|
||||
'torch/utils/_sympy/reference.py',
|
||||
'torch/_logging/**/*.py',
|
||||
'torch/nn/parallel/distributed.py',
|
||||
]
|
||||
command = [
|
||||
|
153
test/dynamo/test_logging.py
Normal file
153
test/dynamo/test_logging.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import unittest.mock
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.logging_utils import (
|
||||
LoggingTestCase,
|
||||
make_logging_test,
|
||||
make_settings_test,
|
||||
)
|
||||
|
||||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
||||
|
||||
|
||||
def example_fn(a):
|
||||
output = a.mul(torch.ones(1000, 1000))
|
||||
output = output.add(torch.ones(1000, 1000))
|
||||
return output
|
||||
|
||||
|
||||
def dynamo_error_fn(a):
|
||||
output = a.mul(torch.ones(1000, 1000))
|
||||
output = output.add(torch.ones(10, 10))
|
||||
return output
|
||||
|
||||
|
||||
def inductor_error_fn(a):
|
||||
output = torch.round(a)
|
||||
return output
|
||||
|
||||
|
||||
def inductor_schedule_fn(a):
|
||||
output = a.add(torch.ones(1000, 1000, device="cuda"))
|
||||
return output
|
||||
|
||||
|
||||
ARGS = (torch.ones(1000, 1000, requires_grad=True),)
|
||||
|
||||
|
||||
def multi_record_test(num_records, **kwargs):
|
||||
@make_logging_test(**kwargs)
|
||||
def fn(self, records):
|
||||
fn_opt = torch._dynamo.optimize("inductor")(example_fn)
|
||||
fn_opt(*ARGS)
|
||||
self.assertEqual(len(records), num_records)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def within_range_record_test(num_records_lower, num_records_higher, **kwargs):
|
||||
@make_logging_test(**kwargs)
|
||||
def fn(self, records):
|
||||
fn_opt = torch._dynamo.optimize("inductor")(example_fn)
|
||||
fn_opt(*ARGS)
|
||||
self.assertGreaterEqual(len(records), num_records_lower)
|
||||
self.assertLessEqual(len(records), num_records_higher)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def single_record_test(**kwargs):
|
||||
return multi_record_test(1, **kwargs)
|
||||
|
||||
|
||||
class LoggingTests(LoggingTestCase):
|
||||
test_bytecode = multi_record_test(2, bytecode=True)
|
||||
test_output_code = multi_record_test(1, output_code=True)
|
||||
|
||||
@requires_cuda()
|
||||
@make_logging_test(schedule=True)
|
||||
def test_schedule(self, records):
|
||||
fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
|
||||
fn_opt(torch.ones(1000, 1000, device="cuda"))
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 5)
|
||||
|
||||
test_dynamo_debug = within_range_record_test(30, 50, dynamo=logging.DEBUG)
|
||||
test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO)
|
||||
|
||||
@make_logging_test(dynamo=logging.ERROR)
|
||||
def test_dynamo_error(self, records):
|
||||
try:
|
||||
fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
|
||||
fn_opt(*ARGS)
|
||||
except Exception:
|
||||
pass
|
||||
self.assertEqual(len(records), 1)
|
||||
|
||||
test_aot = within_range_record_test(2, 6, aot=logging.INFO)
|
||||
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
|
||||
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)
|
||||
|
||||
@make_logging_test(dynamo=logging.ERROR)
|
||||
def test_inductor_error(self, records):
|
||||
exitstack = contextlib.ExitStack()
|
||||
import torch._inductor.lowering
|
||||
|
||||
def throw(x):
|
||||
raise AssertionError()
|
||||
|
||||
# inject an error in the lowerings
|
||||
dict_entries = {}
|
||||
for x in list(torch._inductor.lowering.lowerings.keys()):
|
||||
if "round" in x.__name__:
|
||||
dict_entries[x] = throw
|
||||
|
||||
exitstack.enter_context(
|
||||
unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries)
|
||||
)
|
||||
|
||||
try:
|
||||
fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn)
|
||||
fn_opt(*ARGS)
|
||||
except Exception:
|
||||
pass
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertIsInstance(records[0].msg, str)
|
||||
|
||||
exitstack.close()
|
||||
|
||||
# check that logging to a child log of a registered logger
|
||||
# does not register it and result in duplicated records
|
||||
@make_settings_test("torch._dynamo.output_graph")
|
||||
def test_open_registration_with_registered_parent(self, records):
|
||||
logger = logging.getLogger("torch._dynamo.output_graph")
|
||||
logger.info("hi")
|
||||
self.assertEqual(len(records), 1)
|
||||
|
||||
# check logging to a random log that is not a child log of a registered
|
||||
# logger registers it and sets handlers properly
|
||||
@make_settings_test("torch.utils")
|
||||
def test_open_registration(self, records):
|
||||
logger = logging.getLogger("torch.utils")
|
||||
logger.info("hi")
|
||||
self.assertEqual(len(records), 1)
|
||||
|
||||
|
||||
# single record tests
|
||||
exclusions = {"bytecode", "output_code", "schedule"}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True}))
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
20
test/functorch/test_logging.py
Normal file
20
test/functorch/test_logging.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import torch
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
from torch._functorch.aot_autograd import aot_function
|
||||
from torch._functorch.compilers import nop
|
||||
import logging
|
||||
|
||||
class TestAOTLogging(LoggingTestCase):
|
||||
|
||||
@make_logging_test(aot=logging.DEBUG)
|
||||
def test_logging(self, records):
|
||||
def f(x):
|
||||
return torch.sin(x)
|
||||
compiled_f = aot_function(
|
||||
f,
|
||||
fw_compiler=nop,
|
||||
bw_compiler=nop
|
||||
)
|
||||
compiled_f(torch.randn(3))
|
||||
self.assertGreater(len(records), 0)
|
@ -1639,3 +1639,7 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
|
||||
'use torch.sparse_coo_tensor(..., check_invariants=False) instead.')
|
||||
kwargs['check_invariants'] = False
|
||||
return torch.sparse_coo_tensor(*args, **kwargs)
|
||||
|
||||
|
||||
from . import _logging
|
||||
_logging._init_logs()
|
||||
|
@ -9,28 +9,24 @@ from . import external_utils
|
||||
|
||||
from .logging import get_loggers_level, set_loggers_level
|
||||
|
||||
# log level (levels print what it says + all levels listed below it)
|
||||
# logging.DEBUG print full traces <-- lowest level + print tracing of every instruction
|
||||
# logging.INFO print the steps that dynamo is running and optionally, compiled functions + graphs
|
||||
# logging.WARN print warnings (including graph breaks)
|
||||
# logging.ERROR print exceptions (and what user code was being processed when it occurred)
|
||||
|
||||
# Note (mlazos): This is deprecated and will be removed very soon
|
||||
# to configure logging for dynamo, aot, and inductor
|
||||
# use the following API in the torch._logging module
|
||||
# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
|
||||
# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
|
||||
# see this design doc for more detailed info
|
||||
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
|
||||
log_level = property(
|
||||
lambda _: get_loggers_level(), lambda _, lvl: set_loggers_level(lvl)
|
||||
)
|
||||
|
||||
# log compiled function + graphs at level INFO
|
||||
output_code = False
|
||||
|
||||
# the name of a file to write the logs to
|
||||
log_file_name = None
|
||||
|
||||
# Verbose will print full stack traces on warnings and errors
|
||||
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
|
||||
|
||||
# If true, traced graph outputs will be outputted as Python GraphModule code.
|
||||
# If false, traced graph outputs will be outputted in tabular form.
|
||||
output_graph_code = False
|
||||
|
||||
# verify the correctness of optimized backend
|
||||
verify_correctness = False
|
||||
|
||||
@ -59,6 +55,9 @@ constant_functions = {
|
||||
torch._utils.is_compiling: True,
|
||||
}
|
||||
|
||||
# Here for bw compat, will be removed (mlazos)
|
||||
# see above notes for log_level on how to configure the new logging system
|
||||
output_code = None
|
||||
|
||||
# don't specialize on shapes and strides and put shape ops in graph
|
||||
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
|
||||
|
@ -7,6 +7,7 @@ import weakref
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
from torch._guards import tracing
|
||||
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
|
||||
|
||||
@ -38,15 +39,18 @@ from .utils import (
|
||||
gen_record_file_name,
|
||||
guard_failures,
|
||||
increment_frame,
|
||||
init_logging,
|
||||
is_namedtuple,
|
||||
istype,
|
||||
orig_code_map,
|
||||
reset_graph_break_dup_checker,
|
||||
setup_compile_debug,
|
||||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
|
||||
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
|
||||
|
||||
|
||||
class Tracker:
|
||||
@ -101,9 +105,11 @@ def wrap_convert_context(fn):
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||
cleanup = setup_compile_debug()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
cleanup.close()
|
||||
torch._C._set_grad_enabled(prior_grad_mode)
|
||||
torch.random.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available():
|
||||
@ -195,7 +201,7 @@ def convert_frame_assert(
|
||||
export: bool = False,
|
||||
):
|
||||
"""Fully convert a frame into an FX graph"""
|
||||
init_logging()
|
||||
reset_graph_break_dup_checker()
|
||||
|
||||
def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks):
|
||||
increment_frame()
|
||||
@ -339,25 +345,26 @@ def _compile(
|
||||
return None
|
||||
output_codes.add(out_code)
|
||||
|
||||
if config.output_code:
|
||||
log.info(
|
||||
format_bytecode(
|
||||
"ORIGINAL BYTECODE",
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
code,
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
format_bytecode(
|
||||
"MODIFIED BYTECODE",
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
out_code,
|
||||
),
|
||||
)
|
||||
def log_bytecode(prefix, name, filename, line_no, code):
|
||||
if bytecode_log.isEnabledFor(logging.DEBUG):
|
||||
bytecode_log.debug(
|
||||
format_bytecode(prefix, name, filename, line_no, code)
|
||||
)
|
||||
|
||||
log_bytecode(
|
||||
"ORIGINAL BYTECODE",
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
code,
|
||||
)
|
||||
log_bytecode(
|
||||
"MODIFIED BYTECODE",
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
out_code,
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
assert output.guards is not None
|
||||
@ -371,12 +378,12 @@ def _compile(
|
||||
|
||||
guarded_code = GuardedCode(out_code, check_fn.check_fn)
|
||||
|
||||
if config.output_code:
|
||||
if guards_log.isEnabledFor(logging.DEBUG):
|
||||
guard_str = "GUARDS:\n"
|
||||
guard_str += "\n".join(
|
||||
[f" - {str(guard)}" for guard in sorted(output.guards)]
|
||||
)
|
||||
log.info(guard_str)
|
||||
guards_log.debug(guard_str)
|
||||
|
||||
if hooks.guard_export_fn is not None:
|
||||
hooks.guard_export_fn(output.guards)
|
||||
@ -423,7 +430,6 @@ def replay(filename):
|
||||
|
||||
original_replay_val = config.replay_record_enabled
|
||||
config.replay_record_enabled = False
|
||||
init_logging()
|
||||
with open(filename, "rb") as in_file:
|
||||
record = ExecutionRecord.load(in_file)
|
||||
record.globals = {
|
||||
|
@ -159,8 +159,14 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
|
||||
msg = os.linesep * 2
|
||||
|
||||
if config.verbose:
|
||||
msg = format_bytecode(
|
||||
"WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
|
||||
msg = str(
|
||||
format_bytecode(
|
||||
"WON'T CONVERT",
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
code,
|
||||
)
|
||||
)
|
||||
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
|
||||
msg += format_exc()
|
||||
|
@ -1,13 +1,8 @@
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
|
||||
# logging level for dynamo generated graphs/bytecode/guards
|
||||
logging.CODE = 15
|
||||
logging.addLevelName(logging.CODE, "CODE")
|
||||
|
||||
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
|
||||
disable_progress = True
|
||||
|
||||
@ -15,9 +10,9 @@ disable_progress = True
|
||||
# Return all loggers that torchdynamo/torchinductor is responsible for
|
||||
def get_loggers():
|
||||
return [
|
||||
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
|
||||
logging.getLogger("torch._dynamo"),
|
||||
logging.getLogger("torch._inductor"),
|
||||
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
|
||||
]
|
||||
|
||||
|
||||
@ -33,68 +28,6 @@ def get_loggers_level():
|
||||
return get_loggers()[0].level
|
||||
|
||||
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"formatters": {
|
||||
"torchdynamo_format": {
|
||||
"format": "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"torchdynamo_console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "torchdynamo_format",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"torch._dynamo": {
|
||||
"level": "DEBUG",
|
||||
"handlers": ["torchdynamo_console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"torch._inductor": {
|
||||
"level": "DEBUG",
|
||||
"handlers": ["torchdynamo_console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"torch.fx.experimental.symbolic_shapes": {
|
||||
"level": "DEBUG",
|
||||
"handlers": ["torchdynamo_console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
"disable_existing_loggers": False,
|
||||
}
|
||||
|
||||
|
||||
# initialize torchdynamo loggers
|
||||
def init_logging(log_level, log_file_name=None):
|
||||
if "PYTEST_CURRENT_TEST" not in os.environ:
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
if log_file_name is not None:
|
||||
log_file = logging.FileHandler(log_file_name)
|
||||
log_file.setLevel(log_level)
|
||||
for logger in get_loggers():
|
||||
logger.addHandler(log_file)
|
||||
|
||||
if bool(os.environ.get("TORCH_COMPILE_DEBUG", False)):
|
||||
from .utils import get_debug_dir
|
||||
|
||||
log_level = logging.DEBUG
|
||||
log_path = os.path.join(get_debug_dir(), "torchdynamo")
|
||||
if not os.path.exists(log_path):
|
||||
os.makedirs(log_path)
|
||||
|
||||
log_file = logging.FileHandler(os.path.join(log_path, "debug.log"))
|
||||
log_file.setLevel(logging.DEBUG)
|
||||
logger = logging.getLogger("torch._dynamo")
|
||||
logger.addHandler(log_file)
|
||||
|
||||
set_loggers_level(log_level)
|
||||
|
||||
|
||||
# Creates a logging function that logs a message with a step # prepended.
|
||||
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
|
||||
# so that step numbers are initialized properly. e.g.:
|
||||
|
@ -9,6 +9,8 @@ import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
|
||||
|
||||
import torch._logging
|
||||
|
||||
import torch.nn
|
||||
from torch import fx
|
||||
from torch._guards import Checkpointable, Guard, GuardsCheckpointState, TracingContext
|
||||
@ -42,6 +44,7 @@ from .utils import (
|
||||
count_calls,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
format_graph_code,
|
||||
format_graph_tabular,
|
||||
same,
|
||||
)
|
||||
@ -55,6 +58,8 @@ from .variables.tensor import (
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
|
||||
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
||||
|
||||
|
||||
class OutputGraphState(NamedTuple):
|
||||
@ -618,24 +623,8 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
||||
counters["stats"]["unique_graphs"] += 1
|
||||
self.install_global(name, compiled_fn)
|
||||
|
||||
try:
|
||||
# the call to tabulate can cause a lot of memory to be allocated
|
||||
if config.log_level <= logging.INFO and config.output_code:
|
||||
graph_str = (
|
||||
gm.print_readable()
|
||||
if config.output_graph_code
|
||||
else format_graph_tabular(gm.graph)
|
||||
)
|
||||
log.log(
|
||||
logging.INFO,
|
||||
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n",
|
||||
)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
graph_code_log.debug(format_graph_code(name, gm))
|
||||
graph_tabular_log.debug(format_graph_tabular(name, gm))
|
||||
|
||||
cg = PyCodegen(tx)
|
||||
cg.make_call_generated_code(name)
|
||||
|
@ -10,7 +10,7 @@ import functools
|
||||
import gc
|
||||
import inspect
|
||||
import itertools
|
||||
import logging.config
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
@ -25,6 +25,9 @@ from contextlib import contextmanager
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch._logging
|
||||
from . import config
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
@ -43,8 +46,6 @@ from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
from . import config, logging as torchdynamo_logging
|
||||
|
||||
counters = collections.defaultdict(collections.Counter)
|
||||
troubleshooting_url = "https://pytorch.org/docs/master/dynamo/troubleshooting.html"
|
||||
|
||||
@ -255,21 +256,38 @@ class DuplicateWarningChecker:
|
||||
graph_break_dup_warning_checker = DuplicateWarningChecker()
|
||||
|
||||
|
||||
def init_logging():
|
||||
torchdynamo_logging.init_logging(
|
||||
config.log_level, log_file_name=config.log_file_name
|
||||
)
|
||||
def setup_compile_debug():
|
||||
compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
|
||||
exitstack = contextlib.ExitStack()
|
||||
|
||||
if compile_debug:
|
||||
torch._logging.set_logs(
|
||||
dynamo=logging.DEBUG,
|
||||
aot=logging.DEBUG,
|
||||
inductor=logging.DEBUG,
|
||||
output_code=True, # this is off by default
|
||||
)
|
||||
|
||||
debug_file_handler = add_file_handler()
|
||||
exitstack.callback(lambda: log.removeHandler(debug_file_handler))
|
||||
|
||||
return exitstack
|
||||
|
||||
|
||||
def reset_graph_break_dup_checker():
|
||||
graph_break_dup_warning_checker.reset()
|
||||
|
||||
|
||||
def format_graph_tabular(graph):
|
||||
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in graph.nodes]
|
||||
return tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])
|
||||
def add_file_handler():
|
||||
log_path = os.path.join(get_debug_dir(), "torchdynamo")
|
||||
if not os.path.exists(log_path):
|
||||
os.makedirs(log_path)
|
||||
|
||||
|
||||
def format_bytecode(prefix, name, filename, line_no, code):
|
||||
return f"{prefix} {name} {filename}\
|
||||
line {line_no} \n{dis.Bytecode(code).dis()}\n "
|
||||
log_file = logging.FileHandler(os.path.join(log_path, "debug.log"))
|
||||
log_file.setLevel(logging.DEBUG)
|
||||
logger = logging.getLogger("torch._dynamo")
|
||||
logger.addHandler(log_file)
|
||||
return log_file
|
||||
|
||||
|
||||
def gen_record_file_name(exc, code):
|
||||
@ -1379,3 +1397,33 @@ def tensor_always_has_static_shape(
|
||||
if not is_tensor:
|
||||
return True, TensorStaticReason.NOT_TENSOR
|
||||
return False, None
|
||||
|
||||
|
||||
def format_graph_code(name, gm):
|
||||
return _format_graph_code(
|
||||
name, gm.forward.__code__.co_filename, gm.print_readable(print_output=False)
|
||||
)
|
||||
|
||||
|
||||
def _format_graph_code(name, filename, graph_str):
|
||||
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
||||
|
||||
|
||||
def format_graph_tabular(fn_name, gm):
|
||||
try:
|
||||
from tabulate import tabulate # TODO: Check that this is installed
|
||||
except ImportError:
|
||||
return (
|
||||
"Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n"
|
||||
+ format_graph_code(fn_name, gm)
|
||||
)
|
||||
|
||||
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
|
||||
graph_str = tabulate(
|
||||
node_specs, headers=["opcode", "name", "target", "args", "kwargs"]
|
||||
)
|
||||
return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str)
|
||||
|
||||
|
||||
def format_bytecode(prefix, name, filename, line_no, code):
|
||||
return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"
|
||||
|
@ -18,7 +18,8 @@ import torch.utils._pytree as pytree
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._dynamo.utils import dynamo_timed, format_graph_code
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch._subclasses import CrossRefFakeMode, FakeTensor, FakeTensorMode
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||
@ -30,6 +31,9 @@ from .partitioners import default_partition
|
||||
from torch._guards import TracingContext, DuplicateInputs
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aot_forward_log = getArtifactLogger(__name__, "aot_forward_graph")
|
||||
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
|
||||
aot_backward_log = getArtifactLogger(__name__, "aot_backward_graph")
|
||||
|
||||
MutationType = Enum(
|
||||
"MutationType", ("none", "metadata_only", "data", "data_and_metadata")
|
||||
@ -1260,9 +1264,7 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
|
||||
fw_module.graph.eliminate_dead_code()
|
||||
fw_module.recompile()
|
||||
|
||||
if config.debug_graphs:
|
||||
log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======")
|
||||
log.debug(fw_module.print_readable(print_output=False))
|
||||
aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
@ -2269,9 +2271,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
|
||||
"Graph partitioning without functionalization is not sound, we may introduce errors"
|
||||
)
|
||||
|
||||
if config.debug_joint:
|
||||
log.debug(f"====== Joint graph {aot_config.aot_id} ======")
|
||||
log.debug(fx_g.print_readable(print_output=False))
|
||||
aot_joint_log.info(format_graph_code(f"====== Joint graph {aot_config.aot_id} =====\n", fx_g))
|
||||
|
||||
with torch.no_grad():
|
||||
with track_graph_compiling(aot_config, "joint"):
|
||||
@ -2288,11 +2288,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
|
||||
]
|
||||
_num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
|
||||
|
||||
if config.debug_graphs:
|
||||
log.debug(f"====== Forward graph {aot_config.aot_id} ======")
|
||||
log.debug(fw_module.print_readable(print_output=False))
|
||||
log.debug(f"====== Backward graph {aot_config.aot_id} ======")
|
||||
log.debug(bw_module.print_readable(print_output=False))
|
||||
aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
|
||||
aot_backward_log.info(format_graph_code(f"====== Backward graph {aot_config.aot_id} ======\n", bw_module))
|
||||
|
||||
with track_graph_compiling(aot_config, "forward"):
|
||||
compiled_fw_func = aot_config.fw_compiler(
|
||||
@ -2604,8 +2601,6 @@ def create_aot_dispatcher_function(
|
||||
**aot_config.decompositions,
|
||||
}
|
||||
|
||||
log.setLevel(config.log_level)
|
||||
|
||||
# NB: don't bother setting allow_fallback_kernels; this should not actually
|
||||
# be configurable in fake tensor, we should automatically do the right
|
||||
# thing
|
||||
|
@ -12,6 +12,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
|
||||
import torch._logging
|
||||
from ..._dynamo import config as dynamo_config
|
||||
from .. import config, ir, scheduler
|
||||
from ..codecache import get_code_path
|
||||
@ -42,6 +43,7 @@ from .common import (
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
|
||||
|
||||
def signature_of(arg):
|
||||
@ -1589,8 +1591,8 @@ class TritonScheduling:
|
||||
f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
|
||||
)
|
||||
|
||||
if dynamo_config.output_code:
|
||||
log.info("schedule: %s", node_schedule)
|
||||
if schedule_log.isEnabledFor(logging.DEBUG):
|
||||
schedule_log.debug(f"Schedule:\n {node_schedule}")
|
||||
return self.codegen_node_schedule(node_schedule, numel, rnumel)
|
||||
|
||||
@staticmethod
|
||||
|
@ -24,7 +24,7 @@ from torch import fx as fx
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
|
||||
from torch._dynamo.utils import get_debug_dir, init_logging
|
||||
from torch._dynamo.utils import get_debug_dir
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
from torch.fx.passes.tools_common import legalize_graph
|
||||
@ -291,8 +291,6 @@ class DebugContext:
|
||||
|
||||
def __enter__(self):
|
||||
log = logging.getLogger("torch._inductor")
|
||||
if not log.handlers:
|
||||
init_logging()
|
||||
|
||||
if config.debug:
|
||||
|
||||
|
@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Set
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
import torch.fx
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
@ -47,6 +48,7 @@ from .utils import (
|
||||
from .virtualized import V
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
|
||||
|
||||
|
||||
def supported_dtype_of_cpp_wrapper(dtype):
|
||||
@ -629,6 +631,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
for name, value in self.constants.items():
|
||||
setattr(mod, name, value)
|
||||
|
||||
log.debug(f"Output code written to: {mod.__file__}")
|
||||
output_code_log.debug(f"Output code: \n{code}")
|
||||
if config.benchmark_kernel:
|
||||
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
||||
V.debug.output_code(mod.__file__)
|
||||
|
8
torch/_logging/__init__.py
Normal file
8
torch/_logging/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Top level logging module for torch logging
|
||||
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
|
||||
# Simple setup for onboarding (see above doc for more detail):
|
||||
# 1. register any top-level log qualified name for your module in torch._logging._registrations (see there for examples)
|
||||
# 2. register any artifacts (<artifact_name> below) in torch._logging._registrations
|
||||
# a. call getArtifactLogger(__name__, <artifact_name>) at your logging site instead of the standard logger to log your artifact
|
||||
import torch._logging._registrations
|
||||
from ._internal import _init_logs, getArtifactLogger, set_logs
|
426
torch/_logging/_internal.py
Normal file
426
torch/_logging/_internal.py
Normal file
@ -0,0 +1,426 @@
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import __import__
|
||||
from typing import Dict, Set
|
||||
from weakref import WeakSet
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_LOG_LEVEL = logging.WARN
|
||||
DEFAULT_FORMATTER = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
|
||||
)
|
||||
LOG_ENV_VAR = "TORCH_LOGS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogRegistry:
|
||||
# shorthand name to log qualified name
|
||||
# Note: this only contains loggers registered
|
||||
# from register_log
|
||||
# e.g. "dynamo" -> "torch._dynamo"
|
||||
log_alias_to_log_qname: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# artifact logger qualified names,
|
||||
# this is populated lazily, as calls to getArtifactLogger
|
||||
# currently formatted as <module>.__<artifact_name>
|
||||
# e.g. "torch._dynamo.convert_frame.__guards"
|
||||
artifact_log_qnames: Set[str] = field(default_factory=set)
|
||||
|
||||
# child logs of registered logs if specified via open
|
||||
# registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
|
||||
# these need to be tracked so their levels can be reset properly
|
||||
# e.g. "torch._dynamo.output_graph"
|
||||
child_log_qnames: Set[str] = field(default_factory=set)
|
||||
|
||||
# artifact names, populated by register_artifact
|
||||
# e.g. "guards"
|
||||
artifact_names: Set[str] = field(default_factory=set)
|
||||
|
||||
# artifacts which are not displayed unless explicitly named in the
|
||||
# settings. Ex. output_code is NOT displayed even if the inductor
|
||||
# log level is set to DEBUG. It must be explicitly named in the settings
|
||||
off_by_default_artifact_names: Set[str] = field(default_factory=set)
|
||||
|
||||
def is_artifact(self, name):
|
||||
return name in self.artifact_names
|
||||
|
||||
def is_log(self, alias):
|
||||
return alias in self.log_alias_to_log_qname
|
||||
|
||||
# register a log with an alias
|
||||
def register_log(self, alias, log_qname):
|
||||
self.log_alias_to_log_qname[alias] = log_qname
|
||||
|
||||
# register an artifact name
|
||||
def register_artifact_name(self, name, off_by_default):
|
||||
self.artifact_names.add(name)
|
||||
|
||||
# if off by default, don't enable it
|
||||
# when log_name's log_level is set to DEBUG
|
||||
if off_by_default:
|
||||
self.off_by_default_artifact_names.add(name)
|
||||
|
||||
# register the qualified name of an artifact log
|
||||
# this is needed to know which logs need to be reset
|
||||
# whenever the log_state is changed
|
||||
def register_artifact_log(self, artifact_log_qname):
|
||||
self.artifact_log_qnames.add(artifact_log_qname)
|
||||
|
||||
def register_child_log(self, log_qname):
|
||||
self.child_log_qnames.add(log_qname)
|
||||
|
||||
def get_log_qnames(self):
|
||||
return set(self.log_alias_to_log_qname.values())
|
||||
|
||||
def get_artifact_log_qnames(self):
|
||||
return set(self.artifact_log_qnames)
|
||||
|
||||
def get_child_log_qnames(self):
|
||||
return set(self.child_log_qnames)
|
||||
|
||||
def is_off_by_default(self, artifact_qname):
|
||||
return artifact_qname in self.off_by_default_artifact_names
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogState:
|
||||
# qualified log names -> currently set log level
|
||||
log_qname_to_level: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# the set of currently enabled artifacts
|
||||
artifact_names: Set[str] = field(default_factory=set)
|
||||
|
||||
def enable_artifact(self, artifact_name):
|
||||
self.artifact_names.add(artifact_name)
|
||||
|
||||
def is_artifact_enabled(self, name):
|
||||
return name in self.artifact_names
|
||||
|
||||
def enable_log(self, log_qname, log_level):
|
||||
self.log_qname_to_level[log_qname] = log_level
|
||||
|
||||
def get_log_level_pairs(self):
|
||||
return self.log_qname_to_level.items()
|
||||
|
||||
def clear(self):
|
||||
self.log_qname_to_level.clear()
|
||||
self.artifact_names.clear()
|
||||
|
||||
|
||||
log_registry = LogRegistry()
|
||||
log_state = LogState()
|
||||
|
||||
|
||||
# User API for setting log properties
|
||||
# ex. format set_logs(LOG_NAME=LEVEL, ARTIFACT_NAME=bool)
|
||||
# ex. set_logs(dynamo=logging.DEBUG, graph_code=True)
|
||||
def set_logs(
|
||||
dynamo=DEFAULT_LOG_LEVEL,
|
||||
aot=DEFAULT_LOG_LEVEL,
|
||||
inductor=DEFAULT_LOG_LEVEL,
|
||||
bytecode=False,
|
||||
aot_forward_graph=False,
|
||||
aot_backward_graph=False,
|
||||
aot_joint_graph=False,
|
||||
graph=False,
|
||||
graph_code=False,
|
||||
guards=False,
|
||||
output_code=False,
|
||||
schedule=False,
|
||||
):
|
||||
"""
|
||||
Enable setting the log level of individual components through kwargs.
|
||||
Args are set using the following format:
|
||||
set_logs(<log_name>=<log_level>,...<artifact_name>=<True or False>)
|
||||
"""
|
||||
# ignore if env var is set
|
||||
if LOG_ENV_VAR in os.environ:
|
||||
log.warning(
|
||||
"Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
|
||||
)
|
||||
return
|
||||
|
||||
log_state.clear()
|
||||
|
||||
def _set_logs(**kwargs):
|
||||
for alias, val in kwargs.items():
|
||||
if log_registry.is_artifact(alias):
|
||||
if val:
|
||||
log_state.enable_artifact(alias)
|
||||
elif log_registry.is_log(alias):
|
||||
if val not in logging._levelToName:
|
||||
raise ValueError(
|
||||
f"Unrecognized log level for log {alias}: {val}, valid level values "
|
||||
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
|
||||
)
|
||||
if val != DEFAULT_LOG_LEVEL:
|
||||
log_state.enable_log(
|
||||
log_registry.log_alias_to_log_qname[alias], val
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized log or artifact name passed to set_logs: {alias}"
|
||||
)
|
||||
|
||||
_init_logs()
|
||||
|
||||
_set_logs(
|
||||
dynamo=dynamo,
|
||||
aot=aot,
|
||||
inductor=inductor,
|
||||
bytecode=bytecode,
|
||||
aot_forward_graph=aot_forward_graph,
|
||||
aot_backward_graph=aot_backward_graph,
|
||||
aot_joint_graph=aot_joint_graph,
|
||||
graph=graph,
|
||||
graph_code=graph_code,
|
||||
guards=guards,
|
||||
output_code=output_code,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
def register_log(setting_name, log_name):
|
||||
"""
|
||||
Enables a log to be controlled by the env var and user API with the setting_name
|
||||
Args:
|
||||
setting_name: the shorthand name used in the env var and user API
|
||||
log_name: the log name that the setting_name is associated with
|
||||
"""
|
||||
log_registry.register_log(setting_name, log_name)
|
||||
|
||||
|
||||
def register_artifact(setting_name, off_by_default=False):
|
||||
"""
|
||||
Enables an artifact to be controlled by the env var and user API with name
|
||||
Args:
|
||||
setting_name: the shorthand name used in the env var and user API
|
||||
off_by_default: whether this artifact should be logged when the ancestor loggers
|
||||
are enabled at level DEBUG
|
||||
"""
|
||||
log_registry.register_artifact_name(setting_name, off_by_default)
|
||||
|
||||
|
||||
def getArtifactLogger(module_qname, artifact_name):
|
||||
if artifact_name not in log_registry.artifact_names:
|
||||
raise ValueError(
|
||||
f"Artifact name: {repr(artifact_name)} not registered,"
|
||||
f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
|
||||
)
|
||||
qname = module_qname + f".__{artifact_name}"
|
||||
log = logging.getLogger(module_qname + f".__{artifact_name}")
|
||||
log.artifact_name = artifact_name # type: ignore[attr-defined]
|
||||
log_registry.register_artifact_log(qname)
|
||||
configure_artifact_log(log)
|
||||
return log
|
||||
|
||||
|
||||
INCR_VERBOSITY_CHAR = "+"
|
||||
DECR_VERBOSITY_CHAR = "-"
|
||||
VERBOSITY_REGEX = (
|
||||
"("
|
||||
+ "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
|
||||
+ "?)"
|
||||
)
|
||||
|
||||
|
||||
def configure_artifact_log(log):
|
||||
# if parent log is set to debug, but this artifact is off by default
|
||||
# set propagate to False so that this artifact is not propagated
|
||||
# to its ancestor logger
|
||||
# this artifact is only logged when explicitly enabled (occurs below)
|
||||
if (
|
||||
log_registry.is_off_by_default(log.artifact_name)
|
||||
and log.getEffectiveLevel() == logging.DEBUG
|
||||
):
|
||||
log.propagate = False
|
||||
|
||||
# enable artifact logging when explicitly enabled
|
||||
if log_state.is_artifact_enabled(log.artifact_name):
|
||||
log.setLevel(logging.DEBUG)
|
||||
log.propagate = True
|
||||
|
||||
|
||||
# match a comma separated list of loggable names (whitespace allowed after commas)
|
||||
def _gen_settings_regex():
|
||||
return re.compile(r"((\+|-)?[\w\.]+,\\s*)*(\+|-)?[\w\.]+?")
|
||||
|
||||
|
||||
def _validate_settings(settings):
|
||||
return re.fullmatch(_gen_settings_regex(), settings) is not None
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _parse_log_settings(settings):
|
||||
if settings == "":
|
||||
return dict()
|
||||
|
||||
if not _validate_settings(settings):
|
||||
raise ValueError(
|
||||
f"Invalid log settings: {settings}, must be a comma separated list of registerered log or artifact names."
|
||||
)
|
||||
|
||||
settings = re.sub(r"\s+", "", settings)
|
||||
log_names = settings.split(",")
|
||||
|
||||
def get_name_level_pair(name):
|
||||
clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
|
||||
clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
|
||||
|
||||
if name[0] == INCR_VERBOSITY_CHAR:
|
||||
level = logging.DEBUG
|
||||
elif name[0] == DECR_VERBOSITY_CHAR:
|
||||
level = logging.ERROR
|
||||
else:
|
||||
level = logging.INFO
|
||||
|
||||
return clean_name, level
|
||||
|
||||
log_state = LogState()
|
||||
for name in log_names:
|
||||
name, level = get_name_level_pair(name)
|
||||
if log_registry.is_log(name):
|
||||
assert level is not None
|
||||
log_qname = log_registry.log_alias_to_log_qname[name]
|
||||
log_state.enable_log(log_qname, level)
|
||||
elif log_registry.is_artifact(name):
|
||||
log_state.enable_artifact(name)
|
||||
elif _is_valid_module(name):
|
||||
if not _has_registered_parent(name):
|
||||
log_registry.register_log(name, name)
|
||||
else:
|
||||
log_registry.register_child_log(name)
|
||||
log_state.enable_log(name, level)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid log settings: '{settings}', must be a comma separated list of log or artifact names."
|
||||
)
|
||||
|
||||
return log_state
|
||||
|
||||
|
||||
def _is_valid_module(qname):
|
||||
try:
|
||||
__import__(qname)
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _update_log_state_from_env():
|
||||
global log_state
|
||||
log_setting = os.environ.get(LOG_ENV_VAR, None)
|
||||
if log_setting is not None:
|
||||
log_state = _parse_log_settings(log_setting)
|
||||
|
||||
|
||||
def _has_registered_parent(log_qname):
|
||||
cur_log = logging.getLogger(log_qname)
|
||||
|
||||
registered_log_qnames = log_registry.get_log_qnames()
|
||||
|
||||
while cur_log.parent:
|
||||
if cur_log.name in registered_log_qnames:
|
||||
return True
|
||||
cur_log = cur_log.parent
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _setup_handlers(create_handler_fn, log):
|
||||
debug_handler = _track_handler(create_handler_fn())
|
||||
debug_handler.setFormatter(DEFAULT_FORMATTER)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
log.addHandler(debug_handler)
|
||||
|
||||
|
||||
handlers = WeakSet() # type: ignore[var-annotated]
|
||||
|
||||
|
||||
# mark handlers that we've created
|
||||
# so we don't modify user handlers
|
||||
def _track_handler(handler):
|
||||
handlers.add(handler)
|
||||
return handler
|
||||
|
||||
|
||||
def _is_torch_handler(handler):
|
||||
return handler in handlers
|
||||
|
||||
|
||||
# clears all torch handlers on specified loggers
|
||||
def _clear_handlers(log):
|
||||
to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
|
||||
for handler in to_remove:
|
||||
log.removeHandler(handler)
|
||||
|
||||
|
||||
def _reset_logs():
|
||||
# reset all registered logs
|
||||
for log_qname in log_registry.get_log_qnames():
|
||||
log = logging.getLogger(log_qname)
|
||||
log.setLevel(logging.WARNING)
|
||||
log.propagate = False
|
||||
_clear_handlers(log)
|
||||
|
||||
# reset all artifact and child logs
|
||||
for artifact_log_qname in itertools.chain(
|
||||
log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
|
||||
):
|
||||
log = logging.getLogger(artifact_log_qname)
|
||||
log.setLevel(logging.NOTSET)
|
||||
log.propagate = True
|
||||
|
||||
|
||||
def _get_log_state():
|
||||
return log_state
|
||||
|
||||
|
||||
def _set_log_state(state):
|
||||
global log_state
|
||||
log_state = state
|
||||
|
||||
|
||||
def _init_logs(log_file_name=None):
|
||||
_reset_logs()
|
||||
_update_log_state_from_env()
|
||||
|
||||
for log_qname, level in log_state.get_log_level_pairs():
|
||||
log = logging.getLogger(log_qname)
|
||||
log.setLevel(level)
|
||||
|
||||
# setup handlers for all registered loggers
|
||||
for log_qname in log_registry.get_log_qnames():
|
||||
log = logging.getLogger(log_qname)
|
||||
_setup_handlers(
|
||||
logging.StreamHandler,
|
||||
log,
|
||||
)
|
||||
|
||||
if log_file_name is not None:
|
||||
_setup_handlers(
|
||||
lambda: logging.FileHandler(log_file_name),
|
||||
log,
|
||||
)
|
||||
|
||||
# configure artifact loggers, note: this must happen last
|
||||
# since the levels of ancestor loggers are taken into account
|
||||
for artifact_log_qname in log_registry.get_artifact_log_qnames():
|
||||
log = logging.getLogger(artifact_log_qname)
|
||||
configure_artifact_log(log)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def warning_once(logger_obj, *args, **kwargs):
|
||||
"""
|
||||
This function is similar to `logger.warning()`, but will emit the warning with the same message only once
|
||||
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
|
||||
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
|
||||
another type of cache that includes the caller frame information in the hashing function.
|
||||
"""
|
||||
logger_obj.warning(*args, **kwargs)
|
16
torch/_logging/_registrations.py
Normal file
16
torch/_logging/_registrations.py
Normal file
@ -0,0 +1,16 @@
|
||||
from ._internal import register_artifact, register_log
|
||||
|
||||
register_log("dynamo", "torch._dynamo")
|
||||
register_log("aot", "torch._functorch.aot_autograd")
|
||||
register_log("inductor", "torch._inductor")
|
||||
register_log("sym_shapes", "torch.fx.experimental.symbolic_shapes")
|
||||
|
||||
register_artifact("guards")
|
||||
register_artifact("bytecode")
|
||||
register_artifact("graph")
|
||||
register_artifact("graph_code")
|
||||
register_artifact("aot_forward_graph")
|
||||
register_artifact("aot_backward_graph")
|
||||
register_artifact("aot_joint_graph")
|
||||
register_artifact("output_code", off_by_default=True)
|
||||
register_artifact("schedule", off_by_default=True)
|
144
torch/testing/_internal/logging_utils.py
Normal file
144
torch/testing/_internal/logging_utils.py
Normal file
@ -0,0 +1,144 @@
|
||||
import torch._dynamo.test_case
|
||||
import unittest.mock
|
||||
import os
|
||||
import contextlib
|
||||
import torch._logging
|
||||
import torch._logging._internal
|
||||
import logging
|
||||
|
||||
@contextlib.contextmanager
|
||||
def preserve_log_state():
|
||||
prev_state = torch._logging._internal._get_log_state()
|
||||
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._logging._internal._set_log_state(prev_state)
|
||||
torch._logging._internal._init_logs()
|
||||
|
||||
def log_settings(settings):
|
||||
exit_stack = contextlib.ExitStack()
|
||||
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
|
||||
exit_stack.enter_context(preserve_log_state())
|
||||
exit_stack.enter_context(settings_patch)
|
||||
torch._logging._internal._init_logs()
|
||||
return exit_stack
|
||||
|
||||
def log_api(**kwargs):
|
||||
exit_stack = contextlib.ExitStack()
|
||||
exit_stack.enter_context(preserve_log_state())
|
||||
torch._logging.set_logs(**kwargs)
|
||||
return exit_stack
|
||||
|
||||
|
||||
def kwargs_to_settings(**kwargs):
|
||||
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
|
||||
|
||||
settings = []
|
||||
for name, val in kwargs.items():
|
||||
if isinstance(val, bool):
|
||||
settings.append(name)
|
||||
elif isinstance(val, int):
|
||||
if val in INT_TO_VERBOSITY:
|
||||
settings.append(INT_TO_VERBOSITY[val] + name)
|
||||
else:
|
||||
raise ValueError("Invalid value for setting")
|
||||
|
||||
return ",".join(settings)
|
||||
|
||||
|
||||
# Note on testing strategy:
|
||||
# This class does two things:
|
||||
# 1. Runs two versions of a test:
|
||||
# 1a. patches the env var log settings to some specific value
|
||||
# 1b. calls torch._logging.set_logs(..)
|
||||
# 2. patches the emit method of each setup handler to gather records
|
||||
# that are emitted to each console stream
|
||||
# 3. passes a ref to the gathered records to each test case for checking
|
||||
#
|
||||
# The goal of this testing in general is to ensure that given some settings env var
|
||||
# that the logs are setup correctly and capturing the correct records.
|
||||
def make_logging_test(**kwargs):
|
||||
def wrapper(fn):
|
||||
def test_fn(self):
|
||||
|
||||
torch._dynamo.reset()
|
||||
records = []
|
||||
# run with env var
|
||||
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
|
||||
fn(self, records)
|
||||
|
||||
# run with API
|
||||
torch._dynamo.reset()
|
||||
records.clear()
|
||||
with log_api(**kwargs), self._handler_watcher(records):
|
||||
fn(self, records)
|
||||
|
||||
|
||||
return test_fn
|
||||
|
||||
return wrapper
|
||||
|
||||
def make_settings_test(settings):
|
||||
def wrapper(fn):
|
||||
def test_fn(self):
|
||||
torch._dynamo.reset()
|
||||
records = []
|
||||
# run with env var
|
||||
with log_settings(settings), self._handler_watcher(records):
|
||||
fn(self, records)
|
||||
|
||||
return test_fn
|
||||
|
||||
return wrapper
|
||||
|
||||
class LoggingTestCase(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack.enter_context(
|
||||
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
|
||||
)
|
||||
cls._exit_stack.enter_context(
|
||||
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._exit_stack.close()
|
||||
torch._logging._internal.log_state.clear()
|
||||
torch._logging._init_logs()
|
||||
|
||||
# This patches the emit method of each handler to gather records
|
||||
# as they are emitted
|
||||
def _handler_watcher(self, record_list):
|
||||
exit_stack = contextlib.ExitStack()
|
||||
|
||||
def emit_post_hook(record):
|
||||
nonlocal record_list
|
||||
record_list.append(record)
|
||||
|
||||
# registered logs are the only ones with handlers, so patch those
|
||||
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
|
||||
logger = logging.getLogger(log_qname)
|
||||
num_handlers = len(logger.handlers)
|
||||
self.assertLessEqual(
|
||||
num_handlers,
|
||||
2,
|
||||
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
|
||||
)
|
||||
|
||||
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
|
||||
|
||||
for handler in logger.handlers:
|
||||
old_emit = handler.emit
|
||||
|
||||
def new_emit(record):
|
||||
old_emit(record)
|
||||
emit_post_hook(record)
|
||||
|
||||
exit_stack.enter_context(
|
||||
unittest.mock.patch.object(handler, "emit", new_emit)
|
||||
)
|
||||
|
||||
return exit_stack
|
Reference in New Issue
Block a user