Pretty-print dataclasses (#76810)

Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.

when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
       ref = None,
       aliases = (),
       variant_test_name = '',
       op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
       method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
       inplace_variant = None,
       skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
                <torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
       decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
                     <torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
       sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
       reference_inputs_func = None,
       error_inputs_func = None,
       sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
       sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
       dtypes = {torch.int16,
                 torch.float64,
                 torch.int32,
                 torch.int64,
                 torch.complex64,
                 torch.float16,
                 torch.bfloat16,
                 torch.uint8,
                 torch.complex128,
                 torch.bool,
                 torch.float32,
                 torch.int8},
       dtypesIfCUDA = {torch.int16,
                       torch.float64,
                       torch.int32,
                       torch.int64,
                       torch.complex64,
                       torch.float16,
                       torch.bfloat16,
                       torch.uint8,
                       torch.complex128,
                       torch.bool,
                       torch.float32,
                       torch.int8},
       dtypesIfROCM = {torch.int16,
                       torch.float64,
                       torch.int32,
                       torch.int64,
                       torch.complex64,
                       torch.float16,
                       torch.bfloat16,
                       torch.uint8,
                       torch.complex128,
                       torch.bool,
                       torch.float32,
                       torch.int8},
       backward_dtypes = {torch.int16,
                          torch.float64,
                          torch.int32,
                          torch.int64,
                          torch.complex64,
                          torch.float16,
                          torch.bfloat16,
                          torch.uint8,
                          torch.complex128,
                          torch.bool,
                          torch.float32,
                          torch.int8},
       backward_dtypesIfCUDA = {torch.int16,
                                torch.float64,
                                torch.int32,
                                torch.int64,
                                torch.complex64,
                                torch.float16,
                                torch.bfloat16,
                                torch.uint8,
                                torch.complex128,
                                torch.bool,
                                torch.float32,
                                torch.int8},
       backward_dtypesIfROCM = {torch.int16,
                                torch.float64,
                                torch.int32,
                                torch.int64,
                                torch.complex64,
                                torch.float16,
                                torch.bfloat16,
                                torch.uint8,
                                torch.complex128,
                                torch.bool,
                                torch.float32,
                                torch.int8},
       supports_out = False,
       supports_autograd = True,
       supports_gradgrad = True,
       supports_fwgrad_bwgrad = True,
       supports_inplace_autograd = False,
       supports_forward_ad = True,
       gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
       check_batched_grad = True,
       check_batched_gradgrad = True,
       check_batched_forward_grad = True,
       check_inplace_batched_forward_grad = True,
       gradcheck_nondet_tol = 0.0,
       gradcheck_fast_mode = None,
       aten_name = '__getitem__',
       decomp_aten_name = None,
       aten_backward_name = None,
       assert_autodiffed = False,
       autodiff_nonfusible_nodes = ['aten::__getitem__'],
       autodiff_fusible_nodes = [],
       supports_sparse = False,
       supports_scripting = False,
       supports_sparse_csr = False,
       test_conjugated_samples = True,
       test_neg_view = True,
       assert_jit_shape_analysis = False,
       supports_expanded_weight = False)
```

cc @ezyang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
This commit is contained in:
francescocastelli
2022-05-16 14:20:41 +00:00
committed by PyTorch MergeBot
parent 6fa20bdfe8
commit dca416b578
3 changed files with 114 additions and 1 deletions

View File

@ -4,7 +4,12 @@ import hashlib
import os
import re
import textwrap
import sys
from argparse import Namespace
from dataclasses import (
fields,
is_dataclass,
)
from typing import (
Tuple,
List,
@ -287,3 +292,107 @@ def make_file_manager(
return FileManager(
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
)
# Helper function to create a pretty representation for dataclasses
def dataclass_repr(
obj: Any,
indent: int = 0,
width: int = 80,
) -> str:
# built-in pprint module support dataclasses from python 3.10
if sys.version_info >= (3, 10):
from pprint import pformat
return pformat(obj, indent, width)
return _pformat(obj, indent=indent, width=width)
def _pformat(
obj: Any,
indent: int,
width: int,
curr_indent: int = 0,
) -> str:
assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
class_name = obj.__class__.__name__
# update current indentation level with class name
curr_indent += len(class_name) + 1
fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
fields_str = []
for name, attr in fields_list:
# update the current indent level with the field name
# dict, list, set and tuple also add indent as done in pprint
_curr_indent = curr_indent + len(name) + 1
if is_dataclass(attr):
str_repr = _pformat(attr, indent, width, _curr_indent)
elif isinstance(attr, dict):
str_repr = _format_dict(attr, indent, width, _curr_indent)
elif isinstance(attr, (list, set, tuple)):
str_repr = _format_list(attr, indent, width, _curr_indent)
else:
str_repr = repr(attr)
fields_str.append(f"{name}={str_repr}")
indent_str = curr_indent * " "
body = f",\n{indent_str}".join(fields_str)
return f"{class_name}({body})"
def _format_dict(
attr: Dict[Any, Any],
indent: int,
width: int,
curr_indent: int,
) -> str:
curr_indent += indent + 3
dict_repr = []
for k, v in attr.items():
k_repr = repr(k)
v_str = (
_pformat(v, indent, width, curr_indent + len(k_repr))
if is_dataclass(v)
else repr(v)
)
dict_repr.append(f"{k_repr}: {v_str}")
return _format(dict_repr, indent, width, curr_indent, "{", "}")
def _format_list(
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
indent: int,
width: int,
curr_indent: int,
) -> str:
curr_indent += indent + 1
list_repr = [
_pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
for l in attr
]
start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
return _format(list_repr, indent, width, curr_indent, start, end)
def _format(
fields_str: List[str],
indent: int,
width: int,
curr_indent: int,
start: str,
end: str,
) -> str:
delimiter, curr_indent_str = "", ""
# if it exceed the max width then we place one element per line
if len(repr(fields_str)) >= width:
delimiter = "\n"
curr_indent_str = " " * curr_indent
indent_str = " " * indent
body = f", {delimiter}{curr_indent_str}".join(fields_str)
return f"{start}{indent_str}{body}{end}"