mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Ensure generated stub files by gen_pyi
are properly formatted (#150730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ebea09986
commit
27f7b65a69
@ -27,7 +27,9 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import collections
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import Mock, patch
|
||||
from warnings import warn
|
||||
@ -132,7 +134,7 @@ _leaf_types = (
|
||||
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
|
||||
)
|
||||
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
|
||||
_index_type_def = f"_Index: TypeAlias = {_index_types}"
|
||||
_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip"
|
||||
INDICES = "indices: _Index | tuple[_Index, ...]"
|
||||
|
||||
blocklist = [
|
||||
@ -252,6 +254,11 @@ def sig_for_ops(opname: str) -> list[str]:
|
||||
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
|
||||
]
|
||||
elif name in arithmetic_ops:
|
||||
if name.startswith("i"):
|
||||
# In-place binary-operation dunder methods, like `__iadd__`, should return `Self`
|
||||
return [
|
||||
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034"
|
||||
]
|
||||
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
|
||||
elif name in logic_ops:
|
||||
return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."]
|
||||
@ -327,14 +334,29 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]
|
||||
arg_list_keyword.insert(flag_pos, "*")
|
||||
return {
|
||||
name: [
|
||||
defs(name, arg_list, "Tensor").format(
|
||||
return_indices="return_indices: Literal[False] = False",
|
||||
defs(
|
||||
name,
|
||||
[
|
||||
arg.format(return_indices="return_indices: Literal[False] = False")
|
||||
for arg in arg_list
|
||||
],
|
||||
"Tensor",
|
||||
),
|
||||
defs(name, arg_list_positional, "tuple[Tensor, Tensor]").format(
|
||||
return_indices="return_indices: Literal[True]",
|
||||
defs(
|
||||
name,
|
||||
[
|
||||
arg.format(return_indices="return_indices: Literal[True]")
|
||||
for arg in arg_list_positional
|
||||
],
|
||||
"tuple[Tensor, Tensor]",
|
||||
),
|
||||
defs(name, arg_list_keyword, "tuple[Tensor, Tensor]").format(
|
||||
return_indices="return_indices: Literal[True]",
|
||||
defs(
|
||||
name,
|
||||
[
|
||||
arg.format(return_indices="return_indices: Literal[True]")
|
||||
for arg in arg_list_keyword
|
||||
],
|
||||
"tuple[Tensor, Tensor]",
|
||||
),
|
||||
]
|
||||
}
|
||||
@ -669,12 +691,16 @@ def gather_docstrs() -> dict[str, str]:
|
||||
|
||||
|
||||
def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
||||
docstr = inspect.cleandoc(docstr).strip()
|
||||
if "..." in hint: # function or method
|
||||
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
|
||||
hint = hint[:-3] # remove "..."
|
||||
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
|
||||
else: # attribute or property
|
||||
return f'{hint}\nr"""{docstr}"""\n'
|
||||
hint = hint.removesuffix("...").rstrip() # remove "..."
|
||||
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
|
||||
# Remove trailing whitespace on each line
|
||||
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||
|
||||
# attribute or property
|
||||
return f'{hint}\nr"""{docstr}"""'
|
||||
|
||||
|
||||
def gen_pyi(
|
||||
@ -1557,7 +1583,7 @@ def gen_pyi(
|
||||
# Generate type signatures for legacy classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
legacy_storage_base_hints = ["class StorageBase(object): ..."]
|
||||
legacy_storage_base_hints = ["class StorageBase: ..."]
|
||||
|
||||
legacy_class_hints = []
|
||||
for c in (
|
||||
|
@ -3163,7 +3163,10 @@ Args:
|
||||
Example:
|
||||
|
||||
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
|
||||
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
|
||||
>>> mask = torch.tensor(
|
||||
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
|
||||
... dtype=torch.bool,
|
||||
... )
|
||||
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
>>> self.masked_scatter_(mask, source)
|
||||
tensor([[0, 0, 0, 0, 1],
|
||||
@ -6561,7 +6564,10 @@ Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
|
||||
Example:
|
||||
|
||||
>>> self = torch.tensor([0, 0, 0, 0, 0])
|
||||
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
|
||||
>>> mask = torch.tensor(
|
||||
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
|
||||
... dtype=torch.bool,
|
||||
... )
|
||||
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
>>> self.masked_scatter(mask, source)
|
||||
tensor([[0, 0, 0, 0, 1],
|
||||
|
@ -10550,7 +10550,8 @@ Example:
|
||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||
... ) # fmt: skip
|
||||
>>> torch.std(a, dim=1, keepdim=True)
|
||||
tensor([[1.0311],
|
||||
[0.7477],
|
||||
@ -10608,7 +10609,8 @@ Example:
|
||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||
... ) # fmt: skip
|
||||
>>> torch.std_mean(a, dim=0, keepdim=True)
|
||||
(tensor([[1.2620, 1.0028, 1.0957, 0.6038]]),
|
||||
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
|
||||
@ -11896,7 +11898,8 @@ Example:
|
||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||
... ) # fmt: skip
|
||||
>>> torch.var(a, dim=1, keepdim=True)
|
||||
tensor([[1.0631],
|
||||
[0.5590],
|
||||
@ -11953,7 +11956,8 @@ Example:
|
||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||
... ) # fmt: skip
|
||||
>>> torch.var_mean(a, dim=0, keepdim=True)
|
||||
(tensor([[1.5926, 1.0056, 1.2005, 0.3646]]),
|
||||
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
|
||||
|
@ -212,13 +212,17 @@ def format_function_signature(
|
||||
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
|
||||
return sig
|
||||
|
||||
arguments = [f" {arg}," for arg in arguments]
|
||||
return "\n".join(
|
||||
(
|
||||
f"def {name}(",
|
||||
*(f" {arg}," for arg in arguments),
|
||||
*(
|
||||
arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on"
|
||||
for arg in arguments
|
||||
),
|
||||
f"){return_type}: ...",
|
||||
)
|
||||
)
|
||||
).replace(" # fmt: off\n # fmt: on\n", "")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
# does not allow us to override __init__.
|
||||
seq_type = f"tuple[{', '.join(python_returns)}]"
|
||||
structseq_def_lines = [
|
||||
f"class {structseq_name}({seq_type}):",
|
||||
f"class {structseq_name}({seq_type}): # fmt: skip",
|
||||
]
|
||||
for name, ret_type in zip(field_names, python_returns):
|
||||
structseq_def_lines.extend(
|
||||
@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
)
|
||||
structseq_def_lines.extend(
|
||||
[
|
||||
f" def __new__(cls, sequence: {seq_type}) -> Self: ...",
|
||||
" def __new__(",
|
||||
" cls,",
|
||||
f" sequence: {seq_type},",
|
||||
" ) -> Self: # fmt: skip",
|
||||
" ...",
|
||||
f" n_fields: Final[_int] = {len(field_names)}",
|
||||
f" n_sequence_fields: Final[_int] = {len(field_names)}",
|
||||
" n_unnamed_fields: Final[_int] = 0",
|
||||
@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
structseq_def = "\n".join(structseq_def_lines)
|
||||
# Example:
|
||||
# structseq_def = (
|
||||
# "class max(tuple[Tensor, Tensor]):\n"
|
||||
# "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
|
||||
# " @property\n"
|
||||
# " def values(self) -> Tensor: ...\n"
|
||||
# " @property\n"
|
||||
# " def indices(self) -> Tensor: ...\n"
|
||||
# " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n"
|
||||
# " def __new__(\n"
|
||||
# " cls,\n"
|
||||
# " sequence: tuple[Tensor, Tensor],\n"
|
||||
# " ) -> Self: # fmt: skip\n"
|
||||
# " ...\n"
|
||||
# " n_fields: Final[_int] = 2",
|
||||
# " n_sequence_fields: Final[_int] = 2",
|
||||
# " n_unnamed_fields: Final[_int] = 0",
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@ -45,9 +47,12 @@ class CodeTemplate:
|
||||
return kwargs[v] if v in kwargs else env[v]
|
||||
|
||||
def indent_lines(indent: str, v: Sequence[object]) -> str:
|
||||
return "".join(
|
||||
[indent + l + "\n" for e in v for l in str(e).splitlines()]
|
||||
).rstrip()
|
||||
content = "\n".join(
|
||||
itertools.chain.from_iterable(str(e).splitlines() for e in v)
|
||||
)
|
||||
content = textwrap.indent(content, prefix=indent)
|
||||
# Remove trailing whitespace on each line
|
||||
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
indent = match.group(1)
|
||||
|
Reference in New Issue
Block a user