mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
LoggingTensor fails consistently when root logger level is INFO or lower By default, root logger should be WARNING But, triton driver initialization will overwrite root logger to INFO, which causes flakiness: https://github.com/pytorch/pytorch/issues/126143 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126144 Approved by: https://github.com/jansel
183 lines
6.8 KiB
Python
183 lines
6.8 KiB
Python
# mypy: ignore-errors
|
|
|
|
import torch
|
|
from torch.utils._pytree import tree_map
|
|
from typing import Iterator, List, Optional
|
|
import logging
|
|
import contextlib
|
|
import itertools
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils.weak import WeakTensorKeyDictionary
|
|
import functools
|
|
from torch._C._profiler import gather_traceback, symbolize_tracebacks
|
|
|
|
logger = logging.getLogger("LoggingTensor")
|
|
|
|
_dtype_abbrs = {
|
|
torch.bfloat16: "bf16",
|
|
torch.float64: "f64",
|
|
torch.float32: "f32",
|
|
torch.float16: "f16",
|
|
torch.complex32: "c32",
|
|
torch.complex64: "c64",
|
|
torch.complex128: "c128",
|
|
torch.int8: "i8",
|
|
torch.int16: "i16",
|
|
torch.int32: "i32",
|
|
torch.int64: "i64",
|
|
torch.bool: "b8",
|
|
torch.uint8: "u8",
|
|
}
|
|
|
|
# How the chain of calls works for LoggingTensor:
|
|
# 1. Call torch.sin
|
|
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
|
|
# 3. Enter dispatcher, wind your way through Autograd
|
|
# 4. Hit Python dispatch key, call __torch_dispatch__
|
|
|
|
# This Tensor can work with autograd in two ways:
|
|
# - The wrapped Tensor does not require gradients. In that case, the LoggingTensor
|
|
# can require gradients if the user asks for it as a constructor kwarg.
|
|
# - The wrapped Tensor can require gradients. In that case autograd will be tracked
|
|
# for the wrapped Tensor and the LoggingTensor itself cannot require gradients.
|
|
# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single
|
|
# test or you might get surprising behavior.
|
|
|
|
# TODO: TensorBase should work
|
|
class LoggingTensor(torch.Tensor):
|
|
elem: torch.Tensor
|
|
|
|
__slots__ = ['elem']
|
|
|
|
context = contextlib.nullcontext
|
|
|
|
@staticmethod
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
# The wrapping tensor (LoggingTensor) shouldn't hold any
|
|
# memory for the class in question, but it should still
|
|
# advertise the same device as before
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
cls, elem.size(),
|
|
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
|
# TODO: clone storage aliasing
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
device=elem.device, requires_grad=kwargs.get("requires_grad", False)
|
|
)
|
|
# ...the real tensor is held as an element on the tensor.
|
|
r.elem = elem.detach() if r.requires_grad else elem
|
|
return r
|
|
|
|
def __repr__(self):
|
|
return super().__repr__(tensor_contents=f"{self.elem}")
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
def unwrap(e):
|
|
return e.elem if isinstance(e, cls) else e
|
|
|
|
def wrap(e):
|
|
return cls(e) if isinstance(e, torch.Tensor) else e
|
|
|
|
with cls.context():
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
|
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004
|
|
return rs
|
|
|
|
class LoggingTensorMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
rs = func(*args, **kwargs)
|
|
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004
|
|
return rs
|
|
|
|
class LoggingTensorReentrant(LoggingTensor):
|
|
context = torch.overrides.enable_reentrant_dispatch
|
|
|
|
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
|
|
class LoggingTensorHandler(logging.Handler):
|
|
def __init__(
|
|
self, log_list: List[str], use_shortid_for_all_tensors: bool,
|
|
with_type: bool, tracebacks_list: Optional[List]) -> None:
|
|
logging.Handler.__init__(self)
|
|
self.log_list = log_list
|
|
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
|
|
self.tracebacks_list = tracebacks_list
|
|
self.memo = WeakTensorKeyDictionary()
|
|
self.next_id = 0
|
|
self.with_type = with_type
|
|
|
|
def _shortid(self, t: torch.Tensor) -> int:
|
|
if t not in self.memo:
|
|
self.memo[t] = self.next_id
|
|
self.next_id += 1
|
|
return self.memo[t]
|
|
|
|
def _fmt(self, a: object, with_type: bool = False) -> str:
|
|
cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor
|
|
if isinstance(a, cond_cls):
|
|
maybe_type = ""
|
|
if with_type and self.with_type:
|
|
maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]"
|
|
x = f"${self._shortid(a)}{maybe_type}"
|
|
return x
|
|
else:
|
|
return repr(a)
|
|
|
|
def emit(self, record):
|
|
fmt_args = ", ".join(
|
|
itertools.chain(
|
|
(str(tree_map(self._fmt, a)) for a in record.args[0]),
|
|
(f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()),
|
|
)
|
|
)
|
|
fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2])
|
|
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
|
|
if self.tracebacks_list is not None:
|
|
self.tracebacks_list.append(record.traceback)
|
|
|
|
def log_input(name: str, var: object) -> None:
|
|
logger.info("input", (name,), {}, var) # noqa: PLE1205
|
|
|
|
class GatherTraceback(logging.Filter):
|
|
def __init__(self, python=True, script=True, cpp=False):
|
|
self.python = python
|
|
self.script = script
|
|
self.cpp = cpp
|
|
|
|
def filter(self, record):
|
|
record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp)
|
|
return True
|
|
|
|
@contextlib.contextmanager
|
|
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
|
|
collect_traceback = python_tb or script_tb or cpp_tb
|
|
log_list: List[str] = []
|
|
tracebacks_list: List[str] = []
|
|
handler = LoggingTensorHandler(
|
|
log_list,
|
|
with_type=True,
|
|
use_shortid_for_all_tensors=is_mode,
|
|
tracebacks_list=tracebacks_list if collect_traceback else None
|
|
)
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
if collect_traceback:
|
|
logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb))
|
|
try:
|
|
if collect_traceback:
|
|
yield log_list, tracebacks_list
|
|
else:
|
|
yield log_list
|
|
finally:
|
|
symbolized_tracebacks = symbolize_tracebacks(tracebacks_list)
|
|
tracebacks_list.clear()
|
|
tracebacks_list.extend(symbolized_tracebacks)
|
|
logger.removeHandler(handler)
|
|
|
|
@contextlib.contextmanager
|
|
def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False):
|
|
with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs:
|
|
yield logs
|