component-level configurable logging for dynamo, inductor, aot (#94858)

Summary:

Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS`
Examples:
`TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled

`TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled

[More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce)

Implementation:
The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled.

Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized.

Adding a new log:
To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already).

Adding a new artifact:
To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well.
Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation.

[design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94858
Approved by: https://github.com/ezyang
This commit is contained in:
Michael Lazos
2023-03-18 04:17:31 +00:00
committed by PyTorch MergeBot
parent 086ce765a5
commit a1c46e5f8f
18 changed files with 909 additions and 157 deletions

View File

@ -862,6 +862,7 @@ include_patterns = [
'test/test_value_ranges.py',
'torch/utils/_sympy/interp.py',
'torch/utils/_sympy/reference.py',
'torch/_logging/**/*.py',
'torch/nn/parallel/distributed.py',
]
command = [

153
test/dynamo/test_logging.py Normal file
View File

@ -0,0 +1,153 @@
# Owner(s): ["module: dynamo"]
import contextlib
import functools
import logging
import unittest.mock
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.logging_utils import (
LoggingTestCase,
make_logging_test,
make_settings_test,
)
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
def example_fn(a):
output = a.mul(torch.ones(1000, 1000))
output = output.add(torch.ones(1000, 1000))
return output
def dynamo_error_fn(a):
output = a.mul(torch.ones(1000, 1000))
output = output.add(torch.ones(10, 10))
return output
def inductor_error_fn(a):
output = torch.round(a)
return output
def inductor_schedule_fn(a):
output = a.add(torch.ones(1000, 1000, device="cuda"))
return output
ARGS = (torch.ones(1000, 1000, requires_grad=True),)
def multi_record_test(num_records, **kwargs):
@make_logging_test(**kwargs)
def fn(self, records):
fn_opt = torch._dynamo.optimize("inductor")(example_fn)
fn_opt(*ARGS)
self.assertEqual(len(records), num_records)
return fn
def within_range_record_test(num_records_lower, num_records_higher, **kwargs):
@make_logging_test(**kwargs)
def fn(self, records):
fn_opt = torch._dynamo.optimize("inductor")(example_fn)
fn_opt(*ARGS)
self.assertGreaterEqual(len(records), num_records_lower)
self.assertLessEqual(len(records), num_records_higher)
return fn
def single_record_test(**kwargs):
return multi_record_test(1, **kwargs)
class LoggingTests(LoggingTestCase):
test_bytecode = multi_record_test(2, bytecode=True)
test_output_code = multi_record_test(1, output_code=True)
@requires_cuda()
@make_logging_test(schedule=True)
def test_schedule(self, records):
fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
fn_opt(torch.ones(1000, 1000, device="cuda"))
self.assertGreater(len(records), 0)
self.assertLess(len(records), 5)
test_dynamo_debug = within_range_record_test(30, 50, dynamo=logging.DEBUG)
test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO)
@make_logging_test(dynamo=logging.ERROR)
def test_dynamo_error(self, records):
try:
fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
fn_opt(*ARGS)
except Exception:
pass
self.assertEqual(len(records), 1)
test_aot = within_range_record_test(2, 6, aot=logging.INFO)
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)
@make_logging_test(dynamo=logging.ERROR)
def test_inductor_error(self, records):
exitstack = contextlib.ExitStack()
import torch._inductor.lowering
def throw(x):
raise AssertionError()
# inject an error in the lowerings
dict_entries = {}
for x in list(torch._inductor.lowering.lowerings.keys()):
if "round" in x.__name__:
dict_entries[x] = throw
exitstack.enter_context(
unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries)
)
try:
fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn)
fn_opt(*ARGS)
except Exception:
pass
self.assertEqual(len(records), 1)
self.assertIsInstance(records[0].msg, str)
exitstack.close()
# check that logging to a child log of a registered logger
# does not register it and result in duplicated records
@make_settings_test("torch._dynamo.output_graph")
def test_open_registration_with_registered_parent(self, records):
logger = logging.getLogger("torch._dynamo.output_graph")
logger.info("hi")
self.assertEqual(len(records), 1)
# check logging to a random log that is not a child log of a registered
# logger registers it and sets handlers properly
@make_settings_test("torch.utils")
def test_open_registration(self, records):
logger = logging.getLogger("torch.utils")
logger.info("hi")
self.assertEqual(len(records), 1)
# single record tests
exclusions = {"bytecode", "output_code", "schedule"}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:
setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True}))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -0,0 +1,20 @@
# Owner(s): ["module: dynamo"]
import torch
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
from torch._functorch.aot_autograd import aot_function
from torch._functorch.compilers import nop
import logging
class TestAOTLogging(LoggingTestCase):
@make_logging_test(aot=logging.DEBUG)
def test_logging(self, records):
def f(x):
return torch.sin(x)
compiled_f = aot_function(
f,
fw_compiler=nop,
bw_compiler=nop
)
compiled_f(torch.randn(3))
self.assertGreater(len(records), 0)

View File

@ -1639,3 +1639,7 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
'use torch.sparse_coo_tensor(..., check_invariants=False) instead.')
kwargs['check_invariants'] = False
return torch.sparse_coo_tensor(*args, **kwargs)
from . import _logging
_logging._init_logs()

View File

@ -9,28 +9,24 @@ from . import external_utils
from .logging import get_loggers_level, set_loggers_level
# log level (levels print what it says + all levels listed below it)
# logging.DEBUG print full traces <-- lowest level + print tracing of every instruction
# logging.INFO print the steps that dynamo is running and optionally, compiled functions + graphs
# logging.WARN print warnings (including graph breaks)
# logging.ERROR print exceptions (and what user code was being processed when it occurred)
# Note (mlazos): This is deprecated and will be removed very soon
# to configure logging for dynamo, aot, and inductor
# use the following API in the torch._logging module
# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
# see this design doc for more detailed info
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
log_level = property(
lambda _: get_loggers_level(), lambda _, lvl: set_loggers_level(lvl)
)
# log compiled function + graphs at level INFO
output_code = False
# the name of a file to write the logs to
log_file_name = None
# Verbose will print full stack traces on warnings and errors
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
# If true, traced graph outputs will be outputted as Python GraphModule code.
# If false, traced graph outputs will be outputted in tabular form.
output_graph_code = False
# verify the correctness of optimized backend
verify_correctness = False
@ -59,6 +55,9 @@ constant_functions = {
torch._utils.is_compiling: True,
}
# Here for bw compat, will be removed (mlazos)
# see above notes for log_level on how to configure the new logging system
output_code = None
# don't specialize on shapes and strides and put shape ops in graph
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"

View File

@ -7,6 +7,7 @@ import weakref
from typing import Dict, Optional, Set
import torch
import torch._logging
from torch._guards import tracing
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
@ -38,15 +39,18 @@ from .utils import (
gen_record_file_name,
guard_failures,
increment_frame,
init_logging,
is_namedtuple,
istype,
orig_code_map,
reset_graph_break_dup_checker,
setup_compile_debug,
troubleshooting_url,
write_record_to_file,
)
log = logging.getLogger(__name__)
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
class Tracker:
@ -101,9 +105,11 @@ def wrap_convert_context(fn):
cuda_rng_state = torch.cuda.get_rng_state()
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
@ -195,7 +201,7 @@ def convert_frame_assert(
export: bool = False,
):
"""Fully convert a frame into an FX graph"""
init_logging()
reset_graph_break_dup_checker()
def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks):
increment_frame()
@ -339,25 +345,26 @@ def _compile(
return None
output_codes.add(out_code)
if config.output_code:
log.info(
format_bytecode(
"ORIGINAL BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
),
)
log.info(
format_bytecode(
"MODIFIED BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
out_code,
),
)
def log_bytecode(prefix, name, filename, line_no, code):
if bytecode_log.isEnabledFor(logging.DEBUG):
bytecode_log.debug(
format_bytecode(prefix, name, filename, line_no, code)
)
log_bytecode(
"ORIGINAL BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
)
log_bytecode(
"MODIFIED BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
out_code,
)
assert output is not None
assert output.guards is not None
@ -371,12 +378,12 @@ def _compile(
guarded_code = GuardedCode(out_code, check_fn.check_fn)
if config.output_code:
if guards_log.isEnabledFor(logging.DEBUG):
guard_str = "GUARDS:\n"
guard_str += "\n".join(
[f" - {str(guard)}" for guard in sorted(output.guards)]
)
log.info(guard_str)
guards_log.debug(guard_str)
if hooks.guard_export_fn is not None:
hooks.guard_export_fn(output.guards)
@ -423,7 +430,6 @@ def replay(filename):
original_replay_val = config.replay_record_enabled
config.replay_record_enabled = False
init_logging()
with open(filename, "rb") as in_file:
record = ExecutionRecord.load(in_file)
record.globals = {

View File

@ -159,8 +159,14 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2
if config.verbose:
msg = format_bytecode(
"WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
msg = str(
format_bytecode(
"WON'T CONVERT",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
)
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()

View File

@ -1,13 +1,8 @@
import itertools
import logging
import os
from torch.hub import _Faketqdm, tqdm
# logging level for dynamo generated graphs/bytecode/guards
logging.CODE = 15
logging.addLevelName(logging.CODE, "CODE")
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True
@ -15,9 +10,9 @@ disable_progress = True
# Return all loggers that torchdynamo/torchinductor is responsible for
def get_loggers():
return [
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
logging.getLogger("torch._dynamo"),
logging.getLogger("torch._inductor"),
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
]
@ -33,68 +28,6 @@ def get_loggers_level():
return get_loggers()[0].level
LOGGING_CONFIG = {
"version": 1,
"formatters": {
"torchdynamo_format": {
"format": "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
},
},
"handlers": {
"torchdynamo_console": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "torchdynamo_format",
"stream": "ext://sys.stderr",
},
},
"loggers": {
"torch._dynamo": {
"level": "DEBUG",
"handlers": ["torchdynamo_console"],
"propagate": False,
},
"torch._inductor": {
"level": "DEBUG",
"handlers": ["torchdynamo_console"],
"propagate": False,
},
"torch.fx.experimental.symbolic_shapes": {
"level": "DEBUG",
"handlers": ["torchdynamo_console"],
"propagate": False,
},
},
"disable_existing_loggers": False,
}
# initialize torchdynamo loggers
def init_logging(log_level, log_file_name=None):
if "PYTEST_CURRENT_TEST" not in os.environ:
logging.config.dictConfig(LOGGING_CONFIG)
if log_file_name is not None:
log_file = logging.FileHandler(log_file_name)
log_file.setLevel(log_level)
for logger in get_loggers():
logger.addHandler(log_file)
if bool(os.environ.get("TORCH_COMPILE_DEBUG", False)):
from .utils import get_debug_dir
log_level = logging.DEBUG
log_path = os.path.join(get_debug_dir(), "torchdynamo")
if not os.path.exists(log_path):
os.makedirs(log_path)
log_file = logging.FileHandler(os.path.join(log_path, "debug.log"))
log_file.setLevel(logging.DEBUG)
logger = logging.getLogger("torch._dynamo")
logger.addHandler(log_file)
set_loggers_level(log_level)
# Creates a logging function that logs a message with a step # prepended.
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
# so that step numbers are initialized properly. e.g.:

View File

@ -9,6 +9,8 @@ import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
import torch._logging
import torch.nn
from torch import fx
from torch._guards import Checkpointable, Guard, GuardsCheckpointState, TracingContext
@ -42,6 +44,7 @@ from .utils import (
count_calls,
counters,
dynamo_timed,
format_graph_code,
format_graph_tabular,
same,
)
@ -55,6 +58,8 @@ from .variables.tensor import (
)
log = logging.getLogger(__name__)
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
class OutputGraphState(NamedTuple):
@ -618,24 +623,8 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
counters["stats"]["unique_graphs"] += 1
self.install_global(name, compiled_fn)
try:
# the call to tabulate can cause a lot of memory to be allocated
if config.log_level <= logging.INFO and config.output_code:
graph_str = (
gm.print_readable()
if config.output_graph_code
else format_graph_tabular(gm.graph)
)
log.log(
logging.INFO,
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n",
)
except ImportError:
log.warning(
"Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
graph_code_log.debug(format_graph_code(name, gm))
graph_tabular_log.debug(format_graph_tabular(name, gm))
cg = PyCodegen(tx)
cg.make_call_generated_code(name)

View File

@ -10,7 +10,7 @@ import functools
import gc
import inspect
import itertools
import logging.config
import logging
import math
import operator
import os
@ -25,6 +25,9 @@ from contextlib import contextmanager
from functools import lru_cache, wraps
from typing import Any, Dict, List, Optional, Tuple, Union
import torch._logging
from . import config
try:
import numpy as np
@ -43,8 +46,6 @@ from torch._subclasses.fake_tensor import FakeTensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._pytree import tree_flatten, tree_map
from . import config, logging as torchdynamo_logging
counters = collections.defaultdict(collections.Counter)
troubleshooting_url = "https://pytorch.org/docs/master/dynamo/troubleshooting.html"
@ -255,21 +256,38 @@ class DuplicateWarningChecker:
graph_break_dup_warning_checker = DuplicateWarningChecker()
def init_logging():
torchdynamo_logging.init_logging(
config.log_level, log_file_name=config.log_file_name
)
def setup_compile_debug():
compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
exitstack = contextlib.ExitStack()
if compile_debug:
torch._logging.set_logs(
dynamo=logging.DEBUG,
aot=logging.DEBUG,
inductor=logging.DEBUG,
output_code=True, # this is off by default
)
debug_file_handler = add_file_handler()
exitstack.callback(lambda: log.removeHandler(debug_file_handler))
return exitstack
def reset_graph_break_dup_checker():
graph_break_dup_warning_checker.reset()
def format_graph_tabular(graph):
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in graph.nodes]
return tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])
def add_file_handler():
log_path = os.path.join(get_debug_dir(), "torchdynamo")
if not os.path.exists(log_path):
os.makedirs(log_path)
def format_bytecode(prefix, name, filename, line_no, code):
return f"{prefix} {name} {filename}\
line {line_no} \n{dis.Bytecode(code).dis()}\n "
log_file = logging.FileHandler(os.path.join(log_path, "debug.log"))
log_file.setLevel(logging.DEBUG)
logger = logging.getLogger("torch._dynamo")
logger.addHandler(log_file)
return log_file
def gen_record_file_name(exc, code):
@ -1379,3 +1397,33 @@ def tensor_always_has_static_shape(
if not is_tensor:
return True, TensorStaticReason.NOT_TENSOR
return False, None
def format_graph_code(name, gm):
return _format_graph_code(
name, gm.forward.__code__.co_filename, gm.print_readable(print_output=False)
)
def _format_graph_code(name, filename, graph_str):
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
def format_graph_tabular(fn_name, gm):
try:
from tabulate import tabulate # TODO: Check that this is installed
except ImportError:
return (
"Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n"
+ format_graph_code(fn_name, gm)
)
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
graph_str = tabulate(
node_specs, headers=["opcode", "name", "target", "args", "kwargs"]
)
return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str)
def format_bytecode(prefix, name, filename, line_no, code):
return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"

View File

@ -18,7 +18,8 @@ import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import dynamo_timed
from torch._dynamo.utils import dynamo_timed, format_graph_code
from torch._logging import getArtifactLogger
from torch._subclasses import CrossRefFakeMode, FakeTensor, FakeTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
@ -30,6 +31,9 @@ from .partitioners import default_partition
from torch._guards import TracingContext, DuplicateInputs
log = logging.getLogger(__name__)
aot_forward_log = getArtifactLogger(__name__, "aot_forward_graph")
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
aot_backward_log = getArtifactLogger(__name__, "aot_backward_graph")
MutationType = Enum(
"MutationType", ("none", "metadata_only", "data", "data_and_metadata")
@ -1260,9 +1264,7 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
fw_module.graph.eliminate_dead_code()
fw_module.recompile()
if config.debug_graphs:
log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======")
log.debug(fw_module.print_readable(print_output=False))
aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
disable_amp = torch._C._is_any_autocast_enabled()
context = disable_autocast_manager if disable_amp else nullcontext
@ -2269,9 +2271,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
"Graph partitioning without functionalization is not sound, we may introduce errors"
)
if config.debug_joint:
log.debug(f"====== Joint graph {aot_config.aot_id} ======")
log.debug(fx_g.print_readable(print_output=False))
aot_joint_log.info(format_graph_code(f"====== Joint graph {aot_config.aot_id} =====\n", fx_g))
with torch.no_grad():
with track_graph_compiling(aot_config, "joint"):
@ -2288,11 +2288,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
]
_num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
if config.debug_graphs:
log.debug(f"====== Forward graph {aot_config.aot_id} ======")
log.debug(fw_module.print_readable(print_output=False))
log.debug(f"====== Backward graph {aot_config.aot_id} ======")
log.debug(bw_module.print_readable(print_output=False))
aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
aot_backward_log.info(format_graph_code(f"====== Backward graph {aot_config.aot_id} ======\n", bw_module))
with track_graph_compiling(aot_config, "forward"):
compiled_fw_func = aot_config.fw_compiler(
@ -2604,8 +2601,6 @@ def create_aot_dispatcher_function(
**aot_config.decompositions,
}
log.setLevel(config.log_level)
# NB: don't bother setting allow_fallback_kernels; this should not actually
# be configurable in fake tensor, we should automatically do the right
# thing

View File

@ -12,6 +12,7 @@ import sympy
import torch
import torch._logging
from ..._dynamo import config as dynamo_config
from .. import config, ir, scheduler
from ..codecache import get_code_path
@ -42,6 +43,7 @@ from .common import (
)
log = logging.getLogger(__name__)
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
def signature_of(arg):
@ -1589,8 +1591,8 @@ class TritonScheduling:
f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
)
if dynamo_config.output_code:
log.info("schedule: %s", node_schedule)
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug(f"Schedule:\n {node_schedule}")
return self.codegen_node_schedule(node_schedule, numel, rnumel)
@staticmethod

View File

@ -24,7 +24,7 @@ from torch import fx as fx
from torch._dynamo import config as dynamo_config
from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir, init_logging
from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
@ -291,8 +291,6 @@ class DebugContext:
def __enter__(self):
log = logging.getLogger("torch._inductor")
if not log.handlers:
init_logging()
if config.debug:

View File

@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Set
import sympy
import torch
import torch._logging
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import dynamo_timed
@ -47,6 +48,7 @@ from .utils import (
from .virtualized import V
log = logging.getLogger(__name__)
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
def supported_dtype_of_cpp_wrapper(dtype):
@ -629,6 +631,8 @@ class GraphLowering(torch.fx.Interpreter):
for name, value in self.constants.items():
setattr(mod, name, value)
log.debug(f"Output code written to: {mod.__file__}")
output_code_log.debug(f"Output code: \n{code}")
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)

View File

@ -0,0 +1,8 @@
# Top level logging module for torch logging
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
# Simple setup for onboarding (see above doc for more detail):
# 1. register any top-level log qualified name for your module in torch._logging._registrations (see there for examples)
# 2. register any artifacts (<artifact_name> below) in torch._logging._registrations
# a. call getArtifactLogger(__name__, <artifact_name>) at your logging site instead of the standard logger to log your artifact
import torch._logging._registrations
from ._internal import _init_logs, getArtifactLogger, set_logs

426
torch/_logging/_internal.py Normal file
View File

@ -0,0 +1,426 @@
import functools
import itertools
import logging
import os
import re
from dataclasses import dataclass, field
from importlib import __import__
from typing import Dict, Set
from weakref import WeakSet
log = logging.getLogger(__name__)
DEFAULT_LOG_LEVEL = logging.WARN
DEFAULT_FORMATTER = logging.Formatter(
"[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
)
LOG_ENV_VAR = "TORCH_LOGS"
@dataclass
class LogRegistry:
# shorthand name to log qualified name
# Note: this only contains loggers registered
# from register_log
# e.g. "dynamo" -> "torch._dynamo"
log_alias_to_log_qname: Dict[str, str] = field(default_factory=dict)
# artifact logger qualified names,
# this is populated lazily, as calls to getArtifactLogger
# currently formatted as <module>.__<artifact_name>
# e.g. "torch._dynamo.convert_frame.__guards"
artifact_log_qnames: Set[str] = field(default_factory=set)
# child logs of registered logs if specified via open
# registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
# these need to be tracked so their levels can be reset properly
# e.g. "torch._dynamo.output_graph"
child_log_qnames: Set[str] = field(default_factory=set)
# artifact names, populated by register_artifact
# e.g. "guards"
artifact_names: Set[str] = field(default_factory=set)
# artifacts which are not displayed unless explicitly named in the
# settings. Ex. output_code is NOT displayed even if the inductor
# log level is set to DEBUG. It must be explicitly named in the settings
off_by_default_artifact_names: Set[str] = field(default_factory=set)
def is_artifact(self, name):
return name in self.artifact_names
def is_log(self, alias):
return alias in self.log_alias_to_log_qname
# register a log with an alias
def register_log(self, alias, log_qname):
self.log_alias_to_log_qname[alias] = log_qname
# register an artifact name
def register_artifact_name(self, name, off_by_default):
self.artifact_names.add(name)
# if off by default, don't enable it
# when log_name's log_level is set to DEBUG
if off_by_default:
self.off_by_default_artifact_names.add(name)
# register the qualified name of an artifact log
# this is needed to know which logs need to be reset
# whenever the log_state is changed
def register_artifact_log(self, artifact_log_qname):
self.artifact_log_qnames.add(artifact_log_qname)
def register_child_log(self, log_qname):
self.child_log_qnames.add(log_qname)
def get_log_qnames(self):
return set(self.log_alias_to_log_qname.values())
def get_artifact_log_qnames(self):
return set(self.artifact_log_qnames)
def get_child_log_qnames(self):
return set(self.child_log_qnames)
def is_off_by_default(self, artifact_qname):
return artifact_qname in self.off_by_default_artifact_names
@dataclass
class LogState:
# qualified log names -> currently set log level
log_qname_to_level: Dict[str, str] = field(default_factory=dict)
# the set of currently enabled artifacts
artifact_names: Set[str] = field(default_factory=set)
def enable_artifact(self, artifact_name):
self.artifact_names.add(artifact_name)
def is_artifact_enabled(self, name):
return name in self.artifact_names
def enable_log(self, log_qname, log_level):
self.log_qname_to_level[log_qname] = log_level
def get_log_level_pairs(self):
return self.log_qname_to_level.items()
def clear(self):
self.log_qname_to_level.clear()
self.artifact_names.clear()
log_registry = LogRegistry()
log_state = LogState()
# User API for setting log properties
# ex. format set_logs(LOG_NAME=LEVEL, ARTIFACT_NAME=bool)
# ex. set_logs(dynamo=logging.DEBUG, graph_code=True)
def set_logs(
dynamo=DEFAULT_LOG_LEVEL,
aot=DEFAULT_LOG_LEVEL,
inductor=DEFAULT_LOG_LEVEL,
bytecode=False,
aot_forward_graph=False,
aot_backward_graph=False,
aot_joint_graph=False,
graph=False,
graph_code=False,
guards=False,
output_code=False,
schedule=False,
):
"""
Enable setting the log level of individual components through kwargs.
Args are set using the following format:
set_logs(<log_name>=<log_level>,...<artifact_name>=<True or False>)
"""
# ignore if env var is set
if LOG_ENV_VAR in os.environ:
log.warning(
"Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
)
return
log_state.clear()
def _set_logs(**kwargs):
for alias, val in kwargs.items():
if log_registry.is_artifact(alias):
if val:
log_state.enable_artifact(alias)
elif log_registry.is_log(alias):
if val not in logging._levelToName:
raise ValueError(
f"Unrecognized log level for log {alias}: {val}, valid level values "
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
)
if val != DEFAULT_LOG_LEVEL:
log_state.enable_log(
log_registry.log_alias_to_log_qname[alias], val
)
else:
raise ValueError(
f"Unrecognized log or artifact name passed to set_logs: {alias}"
)
_init_logs()
_set_logs(
dynamo=dynamo,
aot=aot,
inductor=inductor,
bytecode=bytecode,
aot_forward_graph=aot_forward_graph,
aot_backward_graph=aot_backward_graph,
aot_joint_graph=aot_joint_graph,
graph=graph,
graph_code=graph_code,
guards=guards,
output_code=output_code,
schedule=schedule,
)
def register_log(setting_name, log_name):
"""
Enables a log to be controlled by the env var and user API with the setting_name
Args:
setting_name: the shorthand name used in the env var and user API
log_name: the log name that the setting_name is associated with
"""
log_registry.register_log(setting_name, log_name)
def register_artifact(setting_name, off_by_default=False):
"""
Enables an artifact to be controlled by the env var and user API with name
Args:
setting_name: the shorthand name used in the env var and user API
off_by_default: whether this artifact should be logged when the ancestor loggers
are enabled at level DEBUG
"""
log_registry.register_artifact_name(setting_name, off_by_default)
def getArtifactLogger(module_qname, artifact_name):
if artifact_name not in log_registry.artifact_names:
raise ValueError(
f"Artifact name: {repr(artifact_name)} not registered,"
f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
)
qname = module_qname + f".__{artifact_name}"
log = logging.getLogger(module_qname + f".__{artifact_name}")
log.artifact_name = artifact_name # type: ignore[attr-defined]
log_registry.register_artifact_log(qname)
configure_artifact_log(log)
return log
INCR_VERBOSITY_CHAR = "+"
DECR_VERBOSITY_CHAR = "-"
VERBOSITY_REGEX = (
"("
+ "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
+ "?)"
)
def configure_artifact_log(log):
# if parent log is set to debug, but this artifact is off by default
# set propagate to False so that this artifact is not propagated
# to its ancestor logger
# this artifact is only logged when explicitly enabled (occurs below)
if (
log_registry.is_off_by_default(log.artifact_name)
and log.getEffectiveLevel() == logging.DEBUG
):
log.propagate = False
# enable artifact logging when explicitly enabled
if log_state.is_artifact_enabled(log.artifact_name):
log.setLevel(logging.DEBUG)
log.propagate = True
# match a comma separated list of loggable names (whitespace allowed after commas)
def _gen_settings_regex():
return re.compile(r"((\+|-)?[\w\.]+,\\s*)*(\+|-)?[\w\.]+?")
def _validate_settings(settings):
return re.fullmatch(_gen_settings_regex(), settings) is not None
@functools.lru_cache()
def _parse_log_settings(settings):
if settings == "":
return dict()
if not _validate_settings(settings):
raise ValueError(
f"Invalid log settings: {settings}, must be a comma separated list of registerered log or artifact names."
)
settings = re.sub(r"\s+", "", settings)
log_names = settings.split(",")
def get_name_level_pair(name):
clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
if name[0] == INCR_VERBOSITY_CHAR:
level = logging.DEBUG
elif name[0] == DECR_VERBOSITY_CHAR:
level = logging.ERROR
else:
level = logging.INFO
return clean_name, level
log_state = LogState()
for name in log_names:
name, level = get_name_level_pair(name)
if log_registry.is_log(name):
assert level is not None
log_qname = log_registry.log_alias_to_log_qname[name]
log_state.enable_log(log_qname, level)
elif log_registry.is_artifact(name):
log_state.enable_artifact(name)
elif _is_valid_module(name):
if not _has_registered_parent(name):
log_registry.register_log(name, name)
else:
log_registry.register_child_log(name)
log_state.enable_log(name, level)
else:
raise ValueError(
f"Invalid log settings: '{settings}', must be a comma separated list of log or artifact names."
)
return log_state
def _is_valid_module(qname):
try:
__import__(qname)
return True
except ImportError:
return False
def _update_log_state_from_env():
global log_state
log_setting = os.environ.get(LOG_ENV_VAR, None)
if log_setting is not None:
log_state = _parse_log_settings(log_setting)
def _has_registered_parent(log_qname):
cur_log = logging.getLogger(log_qname)
registered_log_qnames = log_registry.get_log_qnames()
while cur_log.parent:
if cur_log.name in registered_log_qnames:
return True
cur_log = cur_log.parent
return False
def _setup_handlers(create_handler_fn, log):
debug_handler = _track_handler(create_handler_fn())
debug_handler.setFormatter(DEFAULT_FORMATTER)
debug_handler.setLevel(logging.DEBUG)
log.addHandler(debug_handler)
handlers = WeakSet() # type: ignore[var-annotated]
# mark handlers that we've created
# so we don't modify user handlers
def _track_handler(handler):
handlers.add(handler)
return handler
def _is_torch_handler(handler):
return handler in handlers
# clears all torch handlers on specified loggers
def _clear_handlers(log):
to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
for handler in to_remove:
log.removeHandler(handler)
def _reset_logs():
# reset all registered logs
for log_qname in log_registry.get_log_qnames():
log = logging.getLogger(log_qname)
log.setLevel(logging.WARNING)
log.propagate = False
_clear_handlers(log)
# reset all artifact and child logs
for artifact_log_qname in itertools.chain(
log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
):
log = logging.getLogger(artifact_log_qname)
log.setLevel(logging.NOTSET)
log.propagate = True
def _get_log_state():
return log_state
def _set_log_state(state):
global log_state
log_state = state
def _init_logs(log_file_name=None):
_reset_logs()
_update_log_state_from_env()
for log_qname, level in log_state.get_log_level_pairs():
log = logging.getLogger(log_qname)
log.setLevel(level)
# setup handlers for all registered loggers
for log_qname in log_registry.get_log_qnames():
log = logging.getLogger(log_qname)
_setup_handlers(
logging.StreamHandler,
log,
)
if log_file_name is not None:
_setup_handlers(
lambda: logging.FileHandler(log_file_name),
log,
)
# configure artifact loggers, note: this must happen last
# since the levels of ancestor loggers are taken into account
for artifact_log_qname in log_registry.get_artifact_log_qnames():
log = logging.getLogger(artifact_log_qname)
configure_artifact_log(log)
@functools.lru_cache(None)
def warning_once(logger_obj, *args, **kwargs):
"""
This function is similar to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
logger_obj.warning(*args, **kwargs)

View File

@ -0,0 +1,16 @@
from ._internal import register_artifact, register_log
register_log("dynamo", "torch._dynamo")
register_log("aot", "torch._functorch.aot_autograd")
register_log("inductor", "torch._inductor")
register_log("sym_shapes", "torch.fx.experimental.symbolic_shapes")
register_artifact("guards")
register_artifact("bytecode")
register_artifact("graph")
register_artifact("graph_code")
register_artifact("aot_forward_graph")
register_artifact("aot_backward_graph")
register_artifact("aot_joint_graph")
register_artifact("output_code", off_by_default=True)
register_artifact("schedule", off_by_default=True)

View File

@ -0,0 +1,144 @@
import torch._dynamo.test_case
import unittest.mock
import os
import contextlib
import torch._logging
import torch._logging._internal
import logging
@contextlib.contextmanager
def preserve_log_state():
prev_state = torch._logging._internal._get_log_state()
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
try:
yield
finally:
torch._logging._internal._set_log_state(prev_state)
torch._logging._internal._init_logs()
def log_settings(settings):
exit_stack = contextlib.ExitStack()
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
exit_stack.enter_context(preserve_log_state())
exit_stack.enter_context(settings_patch)
torch._logging._internal._init_logs()
return exit_stack
def log_api(**kwargs):
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(preserve_log_state())
torch._logging.set_logs(**kwargs)
return exit_stack
def kwargs_to_settings(**kwargs):
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
settings = []
for name, val in kwargs.items():
if isinstance(val, bool):
settings.append(name)
elif isinstance(val, int):
if val in INT_TO_VERBOSITY:
settings.append(INT_TO_VERBOSITY[val] + name)
else:
raise ValueError("Invalid value for setting")
return ",".join(settings)
# Note on testing strategy:
# This class does two things:
# 1. Runs two versions of a test:
# 1a. patches the env var log settings to some specific value
# 1b. calls torch._logging.set_logs(..)
# 2. patches the emit method of each setup handler to gather records
# that are emitted to each console stream
# 3. passes a ref to the gathered records to each test case for checking
#
# The goal of this testing in general is to ensure that given some settings env var
# that the logs are setup correctly and capturing the correct records.
def make_logging_test(**kwargs):
def wrapper(fn):
def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
fn(self, records)
# run with API
torch._dynamo.reset()
records.clear()
with log_api(**kwargs), self._handler_watcher(records):
fn(self, records)
return test_fn
return wrapper
def make_settings_test(settings):
def wrapper(fn):
def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
with log_settings(settings), self._handler_watcher(records):
fn(self, records)
return test_fn
return wrapper
class LoggingTestCase(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
)
cls._exit_stack.enter_context(
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
)
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
torch._logging._internal.log_state.clear()
torch._logging._init_logs()
# This patches the emit method of each handler to gather records
# as they are emitted
def _handler_watcher(self, record_list):
exit_stack = contextlib.ExitStack()
def emit_post_hook(record):
nonlocal record_list
record_list.append(record)
# registered logs are the only ones with handlers, so patch those
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
logger = logging.getLogger(log_qname)
num_handlers = len(logger.handlers)
self.assertLessEqual(
num_handlers,
2,
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
)
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
for handler in logger.handlers:
old_emit = handler.emit
def new_emit(record):
old_emit(record)
emit_post_hook(record)
exit_stack.enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)
return exit_stack