[BE] Ensure generated stub files by gen_pyi are properly formatted (#150730)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730
Approved by: https://github.com/aorenste
This commit is contained in:
Xuehai Pan
2025-05-17 14:29:21 +08:00
committed by PyTorch MergeBot
parent 7ebea09986
commit 27f7b65a69
5 changed files with 84 additions and 31 deletions

View File

@ -27,7 +27,9 @@ from __future__ import annotations
import argparse
import collections
import importlib
import inspect
import sys
import textwrap
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch
from warnings import warn
@ -132,7 +134,7 @@ _leaf_types = (
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
)
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
_index_type_def = f"_Index: TypeAlias = {_index_types}"
_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip"
INDICES = "indices: _Index | tuple[_Index, ...]"
blocklist = [
@ -252,6 +254,11 @@ def sig_for_ops(opname: str) -> list[str]:
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
]
elif name in arithmetic_ops:
if name.startswith("i"):
# In-place binary-operation dunder methods, like `__iadd__`, should return `Self`
return [
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034"
]
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
elif name in logic_ops:
return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."]
@ -327,14 +334,29 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]
arg_list_keyword.insert(flag_pos, "*")
return {
name: [
defs(name, arg_list, "Tensor").format(
return_indices="return_indices: Literal[False] = False",
defs(
name,
[
arg.format(return_indices="return_indices: Literal[False] = False")
for arg in arg_list
],
"Tensor",
),
defs(name, arg_list_positional, "tuple[Tensor, Tensor]").format(
return_indices="return_indices: Literal[True]",
defs(
name,
[
arg.format(return_indices="return_indices: Literal[True]")
for arg in arg_list_positional
],
"tuple[Tensor, Tensor]",
),
defs(name, arg_list_keyword, "tuple[Tensor, Tensor]").format(
return_indices="return_indices: Literal[True]",
defs(
name,
[
arg.format(return_indices="return_indices: Literal[True]")
for arg in arg_list_keyword
],
"tuple[Tensor, Tensor]",
),
]
}
@ -669,12 +691,16 @@ def gather_docstrs() -> dict[str, str]:
def add_docstr_to_hint(docstr: str, hint: str) -> str:
docstr = inspect.cleandoc(docstr).strip()
if "..." in hint: # function or method
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
hint = hint[:-3] # remove "..."
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
else: # attribute or property
return f'{hint}\nr"""{docstr}"""\n'
hint = hint.removesuffix("...").rstrip() # remove "..."
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
# Remove trailing whitespace on each line
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
# attribute or property
return f'{hint}\nr"""{docstr}"""'
def gen_pyi(
@ -1557,7 +1583,7 @@ def gen_pyi(
# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
legacy_storage_base_hints = ["class StorageBase(object): ..."]
legacy_storage_base_hints = ["class StorageBase: ..."]
legacy_class_hints = []
for c in (

View File

@ -3163,7 +3163,10 @@ Args:
Example:
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> mask = torch.tensor(
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
... dtype=torch.bool,
... )
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
@ -3645,7 +3648,7 @@ Example:
# Example 1: Padding
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
>>> static_size = 4
>>> t = torch.nonzero_static(input_tensor, size = static_size)
>>> t = torch.nonzero_static(input_tensor, size=static_size)
tensor([[ 0, 0],
[ 1, 0],
[ 1, 1],
@ -3654,20 +3657,20 @@ Example:
# Example 2: Truncating
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
>>> static_size = 2
>>> t = torch.nonzero_static(input_tensor, size = static_size)
>>> t = torch.nonzero_static(input_tensor, size=static_size)
tensor([[ 0, 0],
[ 1, 0]], dtype=torch.int64)
# Example 3: 0 size
>>> input_tensor = torch.tensor([10])
>>> static_size = 0
>>> t = torch.nonzero_static(input_tensor, size = static_size)
>>> t = torch.nonzero_static(input_tensor, size=static_size)
tensor([], size=(0, 1), dtype=torch.int64)
# Example 4: 0 rank input
>>> input_tensor = torch.tensor(10)
>>> static_size = 2
>>> t = torch.nonzero_static(input_tensor, size = static_size)
>>> t = torch.nonzero_static(input_tensor, size=static_size)
tensor([], size=(2, 0), dtype=torch.int64)
""",
)
@ -6561,7 +6564,10 @@ Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
Example:
>>> self = torch.tensor([0, 0, 0, 0, 0])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> mask = torch.tensor(
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
... dtype=torch.bool,
... )
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter(mask, source)
tensor([[0, 0, 0, 0, 1],

View File

@ -10550,7 +10550,8 @@ Example:
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
... [ 1.5027, -0.3270, 0.5905, 0.6538],
... [-1.5745, 1.3330, -0.5596, -0.6548],
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
... ) # fmt: skip
>>> torch.std(a, dim=1, keepdim=True)
tensor([[1.0311],
[0.7477],
@ -10608,7 +10609,8 @@ Example:
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
... [ 1.5027, -0.3270, 0.5905, 0.6538],
... [-1.5745, 1.3330, -0.5596, -0.6548],
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
... ) # fmt: skip
>>> torch.std_mean(a, dim=0, keepdim=True)
(tensor([[1.2620, 1.0028, 1.0957, 0.6038]]),
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
@ -11896,7 +11898,8 @@ Example:
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
... [ 1.5027, -0.3270, 0.5905, 0.6538],
... [-1.5745, 1.3330, -0.5596, -0.6548],
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
... ) # fmt: skip
>>> torch.var(a, dim=1, keepdim=True)
tensor([[1.0631],
[0.5590],
@ -11953,7 +11956,8 @@ Example:
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
... [ 1.5027, -0.3270, 0.5905, 0.6538],
... [-1.5745, 1.3330, -0.5596, -0.6548],
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
... ) # fmt: skip
>>> torch.var_mean(a, dim=0, keepdim=True)
(tensor([[1.5926, 1.0056, 1.2005, 0.3646]]),
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))

View File

@ -212,13 +212,17 @@ def format_function_signature(
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
return sig
arguments = [f" {arg}," for arg in arguments]
return "\n".join(
(
f"def {name}(",
*(f" {arg}," for arg in arguments),
*(
arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on"
for arg in arguments
),
f"){return_type}: ...",
)
)
).replace(" # fmt: off\n # fmt: on\n", "")
@dataclass(frozen=True)
@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
# does not allow us to override __init__.
seq_type = f"tuple[{', '.join(python_returns)}]"
structseq_def_lines = [
f"class {structseq_name}({seq_type}):",
f"class {structseq_name}({seq_type}): # fmt: skip",
]
for name, ret_type in zip(field_names, python_returns):
structseq_def_lines.extend(
@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
)
structseq_def_lines.extend(
[
f" def __new__(cls, sequence: {seq_type}) -> Self: ...",
" def __new__(",
" cls,",
f" sequence: {seq_type},",
" ) -> Self: # fmt: skip",
" ...",
f" n_fields: Final[_int] = {len(field_names)}",
f" n_sequence_fields: Final[_int] = {len(field_names)}",
" n_unnamed_fields: Final[_int] = 0",
@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
structseq_def = "\n".join(structseq_def_lines)
# Example:
# structseq_def = (
# "class max(tuple[Tensor, Tensor]):\n"
# "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
# " @property\n"
# " def values(self) -> Tensor: ...\n"
# " @property\n"
# " def indices(self) -> Tensor: ...\n"
# " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n"
# " def __new__(\n"
# " cls,\n"
# " sequence: tuple[Tensor, Tensor],\n"
# " ) -> Self: # fmt: skip\n"
# " ...\n"
# " n_fields: Final[_int] = 2",
# " n_sequence_fields: Final[_int] = 2",
# " n_unnamed_fields: Final[_int] = 0",

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import itertools
import re
import textwrap
from typing import TYPE_CHECKING
@ -45,9 +47,12 @@ class CodeTemplate:
return kwargs[v] if v in kwargs else env[v]
def indent_lines(indent: str, v: Sequence[object]) -> str:
return "".join(
[indent + l + "\n" for e in v for l in str(e).splitlines()]
).rstrip()
content = "\n".join(
itertools.chain.from_iterable(str(e).splitlines() for e in v)
)
content = textwrap.indent(content, prefix=indent)
# Remove trailing whitespace on each line
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
def replace(match: re.Match[str]) -> str:
indent = match.group(1)