From 27f7b65a6903e960481ce6e44bcd19620165759b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 17 May 2025 14:29:21 +0800 Subject: [PATCH] [BE] Ensure generated stub files by `gen_pyi` are properly formatted (#150730) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730 Approved by: https://github.com/aorenste --- tools/pyi/gen_pyi.py | 50 +++++++++++++++++++++++++++++---------- torch/_tensor_docs.py | 18 +++++++++----- torch/_torch_docs.py | 12 ++++++---- torchgen/api/python.py | 24 ++++++++++++++----- torchgen/code_template.py | 11 ++++++--- 5 files changed, 84 insertions(+), 31 deletions(-) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index ed224be7eecd..e292012aab4f 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -27,7 +27,9 @@ from __future__ import annotations import argparse import collections import importlib +import inspect import sys +import textwrap from typing import TYPE_CHECKING from unittest.mock import Mock, patch from warnings import warn @@ -132,7 +134,7 @@ _leaf_types = ( "_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex! ) _index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]" -_index_type_def = f"_Index: TypeAlias = {_index_types}" +_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip" INDICES = "indices: _Index | tuple[_Index, ...]" blocklist = [ @@ -252,6 +254,11 @@ def sig_for_ops(opname: str) -> list[str]: f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]" ] elif name in arithmetic_ops: + if name.startswith("i"): + # In-place binary-operation dunder methods, like `__iadd__`, should return `Self` + return [ + f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034" + ] return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."] elif name in logic_ops: return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."] @@ -327,14 +334,29 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str] arg_list_keyword.insert(flag_pos, "*") return { name: [ - defs(name, arg_list, "Tensor").format( - return_indices="return_indices: Literal[False] = False", + defs( + name, + [ + arg.format(return_indices="return_indices: Literal[False] = False") + for arg in arg_list + ], + "Tensor", ), - defs(name, arg_list_positional, "tuple[Tensor, Tensor]").format( - return_indices="return_indices: Literal[True]", + defs( + name, + [ + arg.format(return_indices="return_indices: Literal[True]") + for arg in arg_list_positional + ], + "tuple[Tensor, Tensor]", ), - defs(name, arg_list_keyword, "tuple[Tensor, Tensor]").format( - return_indices="return_indices: Literal[True]", + defs( + name, + [ + arg.format(return_indices="return_indices: Literal[True]") + for arg in arg_list_keyword + ], + "tuple[Tensor, Tensor]", ), ] } @@ -669,12 +691,16 @@ def gather_docstrs() -> dict[str, str]: def add_docstr_to_hint(docstr: str, hint: str) -> str: + docstr = inspect.cleandoc(docstr).strip() if "..." in hint: # function or method assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'" - hint = hint[:-3] # remove "..." - return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."]) - else: # attribute or property - return f'{hint}\nr"""{docstr}"""\n' + hint = hint.removesuffix("...").rstrip() # remove "..." + content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ") + # Remove trailing whitespace on each line + return "\n".join(map(str.rstrip, content.splitlines())).rstrip() + + # attribute or property + return f'{hint}\nr"""{docstr}"""' def gen_pyi( @@ -1557,7 +1583,7 @@ def gen_pyi( # Generate type signatures for legacy classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - legacy_storage_base_hints = ["class StorageBase(object): ..."] + legacy_storage_base_hints = ["class StorageBase: ..."] legacy_class_hints = [] for c in ( diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 2a4d684ba858..bee7f7385fb0 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3163,7 +3163,10 @@ Args: Example: >>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) - >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) + >>> mask = torch.tensor( + ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], + ... dtype=torch.bool, + ... ) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], @@ -3645,7 +3648,7 @@ Example: # Example 1: Padding >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) >>> static_size = 4 - >>> t = torch.nonzero_static(input_tensor, size = static_size) + >>> t = torch.nonzero_static(input_tensor, size=static_size) tensor([[ 0, 0], [ 1, 0], [ 1, 1], @@ -3654,20 +3657,20 @@ Example: # Example 2: Truncating >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) >>> static_size = 2 - >>> t = torch.nonzero_static(input_tensor, size = static_size) + >>> t = torch.nonzero_static(input_tensor, size=static_size) tensor([[ 0, 0], [ 1, 0]], dtype=torch.int64) # Example 3: 0 size >>> input_tensor = torch.tensor([10]) >>> static_size = 0 - >>> t = torch.nonzero_static(input_tensor, size = static_size) + >>> t = torch.nonzero_static(input_tensor, size=static_size) tensor([], size=(0, 1), dtype=torch.int64) # Example 4: 0 rank input >>> input_tensor = torch.tensor(10) >>> static_size = 2 - >>> t = torch.nonzero_static(input_tensor, size = static_size) + >>> t = torch.nonzero_static(input_tensor, size=static_size) tensor([], size=(2, 0), dtype=torch.int64) """, ) @@ -6561,7 +6564,10 @@ Out-of-place version of :meth:`torch.Tensor.masked_scatter_` Example: >>> self = torch.tensor([0, 0, 0, 0, 0]) - >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) + >>> mask = torch.tensor( + ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], + ... dtype=torch.bool, + ... ) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter(mask, source) tensor([[0, 0, 0, 0, 1], diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 50ca81a782a9..417f2ea0b16e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -10550,7 +10550,8 @@ Example: ... [[ 0.2035, 1.2959, 1.8101, -0.4644], ... [ 1.5027, -0.3270, 0.5905, 0.6538], ... [-1.5745, 1.3330, -0.5596, -0.6548], - ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip >>> torch.std(a, dim=1, keepdim=True) tensor([[1.0311], [0.7477], @@ -10608,7 +10609,8 @@ Example: ... [[ 0.2035, 1.2959, 1.8101, -0.4644], ... [ 1.5027, -0.3270, 0.5905, 0.6538], ... [-1.5745, 1.3330, -0.5596, -0.6548], - ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip >>> torch.std_mean(a, dim=0, keepdim=True) (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) @@ -11896,7 +11898,8 @@ Example: ... [[ 0.2035, 1.2959, 1.8101, -0.4644], ... [ 1.5027, -0.3270, 0.5905, 0.6538], ... [-1.5745, 1.3330, -0.5596, -0.6548], - ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip >>> torch.var(a, dim=1, keepdim=True) tensor([[1.0631], [0.5590], @@ -11953,7 +11956,8 @@ Example: ... [[ 0.2035, 1.2959, 1.8101, -0.4644], ... [ 1.5027, -0.3270, 0.5905, 0.6538], ... [-1.5745, 1.3330, -0.5596, -0.6548], - ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip >>> torch.var_mean(a, dim=0, keepdim=True) (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 29cc52d3a3c9..de6134bee6b8 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -212,13 +212,17 @@ def format_function_signature( if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",): return sig + arguments = [f" {arg}," for arg in arguments] return "\n".join( ( f"def {name}(", - *(f" {arg}," for arg in arguments), + *( + arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on" + for arg in arguments + ), f"){return_type}: ...", ) - ) + ).replace(" # fmt: off\n # fmt: on\n", "") @dataclass(frozen=True) @@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: # does not allow us to override __init__. seq_type = f"tuple[{', '.join(python_returns)}]" structseq_def_lines = [ - f"class {structseq_name}({seq_type}):", + f"class {structseq_name}({seq_type}): # fmt: skip", ] for name, ret_type in zip(field_names, python_returns): structseq_def_lines.extend( @@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: ) structseq_def_lines.extend( [ - f" def __new__(cls, sequence: {seq_type}) -> Self: ...", + " def __new__(", + " cls,", + f" sequence: {seq_type},", + " ) -> Self: # fmt: skip", + " ...", f" n_fields: Final[_int] = {len(field_names)}", f" n_sequence_fields: Final[_int] = {len(field_names)}", " n_unnamed_fields: Final[_int] = 0", @@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: structseq_def = "\n".join(structseq_def_lines) # Example: # structseq_def = ( - # "class max(tuple[Tensor, Tensor]):\n" + # "class max(tuple[Tensor, Tensor]): # fmt: skip\n" # " @property\n" # " def values(self) -> Tensor: ...\n" # " @property\n" # " def indices(self) -> Tensor: ...\n" - # " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n" + # " def __new__(\n" + # " cls,\n" + # " sequence: tuple[Tensor, Tensor],\n" + # " ) -> Self: # fmt: skip\n" + # " ...\n" # " n_fields: Final[_int] = 2", # " n_sequence_fields: Final[_int] = 2", # " n_unnamed_fields: Final[_int] = 0", diff --git a/torchgen/code_template.py b/torchgen/code_template.py index 8c33aad126f8..bafe1fa7568e 100644 --- a/torchgen/code_template.py +++ b/torchgen/code_template.py @@ -1,6 +1,8 @@ from __future__ import annotations +import itertools import re +import textwrap from typing import TYPE_CHECKING @@ -45,9 +47,12 @@ class CodeTemplate: return kwargs[v] if v in kwargs else env[v] def indent_lines(indent: str, v: Sequence[object]) -> str: - return "".join( - [indent + l + "\n" for e in v for l in str(e).splitlines()] - ).rstrip() + content = "\n".join( + itertools.chain.from_iterable(str(e).splitlines() for e in v) + ) + content = textwrap.indent(content, prefix=indent) + # Remove trailing whitespace on each line + return "\n".join(map(str.rstrip, content.splitlines())).rstrip() def replace(match: re.Match[str]) -> str: indent = match.group(1)