mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pretty-print dataclasses (#76810)
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function. when you call print on an OpInfo you get: ``` OpInfo(name = '__getitem__', ref = None, aliases = (), variant_test_name = '', op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>, method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>, inplace_variant = None, skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>, <torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>), decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>, <torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>), sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>, reference_inputs_func = None, error_inputs_func = None, sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>, sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>, dtypes = {torch.int16, torch.float64, torch.int32, torch.int64, torch.complex64, torch.float16, torch.bfloat16, torch.uint8, torch.complex128, torch.bool, torch.float32, torch.int8}, dtypesIfCUDA = {torch.int16, torch.float64, torch.int32, torch.int64, torch.complex64, torch.float16, torch.bfloat16, torch.uint8, torch.complex128, torch.bool, torch.float32, torch.int8}, dtypesIfROCM = {torch.int16, torch.float64, torch.int32, torch.int64, torch.complex64, torch.float16, torch.bfloat16, torch.uint8, torch.complex128, torch.bool, torch.float32, torch.int8}, backward_dtypes = {torch.int16, torch.float64, torch.int32, torch.int64, torch.complex64, torch.float16, torch.bfloat16, torch.uint8, torch.complex128, torch.bool, torch.float32, torch.int8}, backward_dtypesIfCUDA = {torch.int16, torch.float64, torch.int32, torch.int64, torch.complex64, torch.float16, torch.bfloat16, torch.uint8, torch.complex128, torch.bool, torch.float32, torch.int8}, backward_dtypesIfROCM = {torch.int16, torch.float64, torch.int32, torch.int64, torch.complex64, torch.float16, torch.bfloat16, torch.uint8, torch.complex128, torch.bool, torch.float32, torch.int8}, supports_out = False, supports_autograd = True, supports_gradgrad = True, supports_fwgrad_bwgrad = True, supports_inplace_autograd = False, supports_forward_ad = True, gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>, check_batched_grad = True, check_batched_gradgrad = True, check_batched_forward_grad = True, check_inplace_batched_forward_grad = True, gradcheck_nondet_tol = 0.0, gradcheck_fast_mode = None, aten_name = '__getitem__', decomp_aten_name = None, aten_backward_name = None, assert_autodiffed = False, autodiff_nonfusible_nodes = ['aten::__getitem__'], autodiff_fusible_nodes = [], supports_sparse = False, supports_scripting = False, supports_sparse_csr = False, test_conjugated_samples = True, test_neg_view = True, assert_jit_shape_analysis = False, supports_expanded_weight = False) ``` cc @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
6fa20bdfe8
commit
dca416b578
@ -4,7 +4,12 @@ import hashlib
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from dataclasses import (
|
||||
fields,
|
||||
is_dataclass,
|
||||
)
|
||||
from typing import (
|
||||
Tuple,
|
||||
List,
|
||||
@ -287,3 +292,107 @@ def make_file_manager(
|
||||
return FileManager(
|
||||
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
|
||||
)
|
||||
|
||||
|
||||
# Helper function to create a pretty representation for dataclasses
|
||||
def dataclass_repr(
|
||||
obj: Any,
|
||||
indent: int = 0,
|
||||
width: int = 80,
|
||||
) -> str:
|
||||
# built-in pprint module support dataclasses from python 3.10
|
||||
if sys.version_info >= (3, 10):
|
||||
from pprint import pformat
|
||||
|
||||
return pformat(obj, indent, width)
|
||||
|
||||
return _pformat(obj, indent=indent, width=width)
|
||||
|
||||
|
||||
def _pformat(
|
||||
obj: Any,
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int = 0,
|
||||
) -> str:
|
||||
assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
|
||||
|
||||
class_name = obj.__class__.__name__
|
||||
# update current indentation level with class name
|
||||
curr_indent += len(class_name) + 1
|
||||
|
||||
fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
|
||||
|
||||
fields_str = []
|
||||
for name, attr in fields_list:
|
||||
# update the current indent level with the field name
|
||||
# dict, list, set and tuple also add indent as done in pprint
|
||||
_curr_indent = curr_indent + len(name) + 1
|
||||
if is_dataclass(attr):
|
||||
str_repr = _pformat(attr, indent, width, _curr_indent)
|
||||
elif isinstance(attr, dict):
|
||||
str_repr = _format_dict(attr, indent, width, _curr_indent)
|
||||
elif isinstance(attr, (list, set, tuple)):
|
||||
str_repr = _format_list(attr, indent, width, _curr_indent)
|
||||
else:
|
||||
str_repr = repr(attr)
|
||||
|
||||
fields_str.append(f"{name}={str_repr}")
|
||||
|
||||
indent_str = curr_indent * " "
|
||||
body = f",\n{indent_str}".join(fields_str)
|
||||
return f"{class_name}({body})"
|
||||
|
||||
|
||||
def _format_dict(
|
||||
attr: Dict[Any, Any],
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int,
|
||||
) -> str:
|
||||
curr_indent += indent + 3
|
||||
dict_repr = []
|
||||
for k, v in attr.items():
|
||||
k_repr = repr(k)
|
||||
v_str = (
|
||||
_pformat(v, indent, width, curr_indent + len(k_repr))
|
||||
if is_dataclass(v)
|
||||
else repr(v)
|
||||
)
|
||||
dict_repr.append(f"{k_repr}: {v_str}")
|
||||
|
||||
return _format(dict_repr, indent, width, curr_indent, "{", "}")
|
||||
|
||||
|
||||
def _format_list(
|
||||
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int,
|
||||
) -> str:
|
||||
curr_indent += indent + 1
|
||||
list_repr = [
|
||||
_pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
|
||||
for l in attr
|
||||
]
|
||||
start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
|
||||
return _format(list_repr, indent, width, curr_indent, start, end)
|
||||
|
||||
|
||||
def _format(
|
||||
fields_str: List[str],
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int,
|
||||
start: str,
|
||||
end: str,
|
||||
) -> str:
|
||||
delimiter, curr_indent_str = "", ""
|
||||
# if it exceed the max width then we place one element per line
|
||||
if len(repr(fields_str)) >= width:
|
||||
delimiter = "\n"
|
||||
curr_indent_str = " " * curr_indent
|
||||
|
||||
indent_str = " " * indent
|
||||
body = f", {delimiter}{curr_indent_str}".join(fields_str)
|
||||
return f"{start}{indent_str}{body}{end}"
|
||||
|
Reference in New Issue
Block a user