Files
pytorch/torch/testing/_internal/logging_tensor.py

169 lines
6.5 KiB
Python

# mypy: ignore-errors
import torch
from torch.utils._pytree import tree_map
from typing import Optional
from collections.abc import Iterator
import logging
import contextlib
import itertools
from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.weak import WeakTensorKeyDictionary
import functools
from torch._C._profiler import gather_traceback, symbolize_tracebacks
logger = logging.getLogger("LoggingTensor")
# How the chain of calls works for LoggingTensor:
# 1. Call torch.sin
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
# 3. Enter dispatcher, wind your way through Autograd
# 4. Hit Python dispatch key, call __torch_dispatch__
# This Tensor can work with autograd in two ways:
# - The wrapped Tensor does not require gradients. In that case, the LoggingTensor
# can require gradients if the user asks for it as a constructor kwarg.
# - The wrapped Tensor can require gradients. In that case autograd will be tracked
# for the wrapped Tensor and the LoggingTensor itself cannot require gradients.
# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single
# test or you might get surprising behavior.
# TODO: TensorBase should work
class LoggingTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
context = contextlib.nullcontext
@staticmethod
def __new__(cls, elem, *args, **kwargs):
# The wrapping tensor (LoggingTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
# TODO: clone storage aliasing
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=kwargs.get("requires_grad", False)
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem.detach() if r.requires_grad else elem
return r
def __repr__(self):
return super().__repr__(tensor_contents=f"{self.elem}")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, cls) else e
def wrap(e):
return cls(e) if isinstance(e, torch.Tensor) else e
with cls.context():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004
return rs
class LoggingTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004
return rs
class LoggingTensorReentrant(LoggingTensor):
context = torch.overrides.enable_reentrant_dispatch
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
class LoggingTensorHandler(logging.Handler):
def __init__(
self, log_list: list[str], use_shortid_for_all_tensors: bool,
with_type: bool, tracebacks_list: Optional[list]) -> None:
logging.Handler.__init__(self)
self.log_list = log_list
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
self.tracebacks_list = tracebacks_list
self.memo = WeakTensorKeyDictionary()
self.next_id = 0
self.with_type = with_type
def _shortid(self, t: torch.Tensor) -> int:
if t not in self.memo:
self.memo[t] = self.next_id
self.next_id += 1
return self.memo[t]
def _fmt(self, a: object, with_type: bool = False) -> str:
cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor
if isinstance(a, cond_cls):
maybe_type = ""
if with_type and self.with_type:
maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]"
x = f"${self._shortid(a)}{maybe_type}"
return x
else:
return repr(a)
def emit(self, record):
fmt_args = ", ".join(
itertools.chain(
(str(tree_map(self._fmt, a)) for a in record.args[0]),
(f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()),
)
)
fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2])
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
if self.tracebacks_list is not None:
self.tracebacks_list.append(record.traceback)
def log_input(name: str, var: object) -> None:
logger.info("input", (name,), {}, var) # noqa: PLE1205
class GatherTraceback(logging.Filter):
def __init__(self, python=True, script=True, cpp=False):
self.python = python
self.script = script
self.cpp = cpp
def filter(self, record):
record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp)
return True
@contextlib.contextmanager
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
log_list: list[str] = []
tracebacks_list: list[str] = []
handler = LoggingTensorHandler(
log_list,
with_type=True,
use_shortid_for_all_tensors=is_mode,
tracebacks_list=tracebacks_list if collect_traceback else None
)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
if collect_traceback:
logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb))
try:
if collect_traceback:
yield log_list, tracebacks_list
else:
yield log_list
finally:
symbolized_tracebacks = symbolize_tracebacks(tracebacks_list)
tracebacks_list.clear()
tracebacks_list.extend(symbolized_tracebacks)
logger.removeHandler(handler)
@contextlib.contextmanager
def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False):
with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs:
yield logs