mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Ensure generated stub files by gen_pyi
are properly formatted (#150730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ebea09986
commit
27f7b65a69
@ -27,7 +27,9 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
|
import textwrap
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
@ -132,7 +134,7 @@ _leaf_types = (
|
|||||||
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
|
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
|
||||||
)
|
)
|
||||||
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
|
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
|
||||||
_index_type_def = f"_Index: TypeAlias = {_index_types}"
|
_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip"
|
||||||
INDICES = "indices: _Index | tuple[_Index, ...]"
|
INDICES = "indices: _Index | tuple[_Index, ...]"
|
||||||
|
|
||||||
blocklist = [
|
blocklist = [
|
||||||
@ -252,6 +254,11 @@ def sig_for_ops(opname: str) -> list[str]:
|
|||||||
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
|
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
|
||||||
]
|
]
|
||||||
elif name in arithmetic_ops:
|
elif name in arithmetic_ops:
|
||||||
|
if name.startswith("i"):
|
||||||
|
# In-place binary-operation dunder methods, like `__iadd__`, should return `Self`
|
||||||
|
return [
|
||||||
|
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034"
|
||||||
|
]
|
||||||
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
|
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
|
||||||
elif name in logic_ops:
|
elif name in logic_ops:
|
||||||
return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."]
|
return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."]
|
||||||
@ -327,14 +334,29 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]
|
|||||||
arg_list_keyword.insert(flag_pos, "*")
|
arg_list_keyword.insert(flag_pos, "*")
|
||||||
return {
|
return {
|
||||||
name: [
|
name: [
|
||||||
defs(name, arg_list, "Tensor").format(
|
defs(
|
||||||
return_indices="return_indices: Literal[False] = False",
|
name,
|
||||||
|
[
|
||||||
|
arg.format(return_indices="return_indices: Literal[False] = False")
|
||||||
|
for arg in arg_list
|
||||||
|
],
|
||||||
|
"Tensor",
|
||||||
),
|
),
|
||||||
defs(name, arg_list_positional, "tuple[Tensor, Tensor]").format(
|
defs(
|
||||||
return_indices="return_indices: Literal[True]",
|
name,
|
||||||
|
[
|
||||||
|
arg.format(return_indices="return_indices: Literal[True]")
|
||||||
|
for arg in arg_list_positional
|
||||||
|
],
|
||||||
|
"tuple[Tensor, Tensor]",
|
||||||
),
|
),
|
||||||
defs(name, arg_list_keyword, "tuple[Tensor, Tensor]").format(
|
defs(
|
||||||
return_indices="return_indices: Literal[True]",
|
name,
|
||||||
|
[
|
||||||
|
arg.format(return_indices="return_indices: Literal[True]")
|
||||||
|
for arg in arg_list_keyword
|
||||||
|
],
|
||||||
|
"tuple[Tensor, Tensor]",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -669,12 +691,16 @@ def gather_docstrs() -> dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
||||||
|
docstr = inspect.cleandoc(docstr).strip()
|
||||||
if "..." in hint: # function or method
|
if "..." in hint: # function or method
|
||||||
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
|
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
|
||||||
hint = hint[:-3] # remove "..."
|
hint = hint.removesuffix("...").rstrip() # remove "..."
|
||||||
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
|
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
|
||||||
else: # attribute or property
|
# Remove trailing whitespace on each line
|
||||||
return f'{hint}\nr"""{docstr}"""\n'
|
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||||
|
|
||||||
|
# attribute or property
|
||||||
|
return f'{hint}\nr"""{docstr}"""'
|
||||||
|
|
||||||
|
|
||||||
def gen_pyi(
|
def gen_pyi(
|
||||||
@ -1557,7 +1583,7 @@ def gen_pyi(
|
|||||||
# Generate type signatures for legacy classes
|
# Generate type signatures for legacy classes
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
legacy_storage_base_hints = ["class StorageBase(object): ..."]
|
legacy_storage_base_hints = ["class StorageBase: ..."]
|
||||||
|
|
||||||
legacy_class_hints = []
|
legacy_class_hints = []
|
||||||
for c in (
|
for c in (
|
||||||
|
@ -3163,7 +3163,10 @@ Args:
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
|
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
|
||||||
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
|
>>> mask = torch.tensor(
|
||||||
|
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
|
||||||
|
... dtype=torch.bool,
|
||||||
|
... )
|
||||||
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||||
>>> self.masked_scatter_(mask, source)
|
>>> self.masked_scatter_(mask, source)
|
||||||
tensor([[0, 0, 0, 0, 1],
|
tensor([[0, 0, 0, 0, 1],
|
||||||
@ -3645,7 +3648,7 @@ Example:
|
|||||||
# Example 1: Padding
|
# Example 1: Padding
|
||||||
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
|
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
|
||||||
>>> static_size = 4
|
>>> static_size = 4
|
||||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
>>> t = torch.nonzero_static(input_tensor, size=static_size)
|
||||||
tensor([[ 0, 0],
|
tensor([[ 0, 0],
|
||||||
[ 1, 0],
|
[ 1, 0],
|
||||||
[ 1, 1],
|
[ 1, 1],
|
||||||
@ -3654,20 +3657,20 @@ Example:
|
|||||||
# Example 2: Truncating
|
# Example 2: Truncating
|
||||||
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
|
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
|
||||||
>>> static_size = 2
|
>>> static_size = 2
|
||||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
>>> t = torch.nonzero_static(input_tensor, size=static_size)
|
||||||
tensor([[ 0, 0],
|
tensor([[ 0, 0],
|
||||||
[ 1, 0]], dtype=torch.int64)
|
[ 1, 0]], dtype=torch.int64)
|
||||||
|
|
||||||
# Example 3: 0 size
|
# Example 3: 0 size
|
||||||
>>> input_tensor = torch.tensor([10])
|
>>> input_tensor = torch.tensor([10])
|
||||||
>>> static_size = 0
|
>>> static_size = 0
|
||||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
>>> t = torch.nonzero_static(input_tensor, size=static_size)
|
||||||
tensor([], size=(0, 1), dtype=torch.int64)
|
tensor([], size=(0, 1), dtype=torch.int64)
|
||||||
|
|
||||||
# Example 4: 0 rank input
|
# Example 4: 0 rank input
|
||||||
>>> input_tensor = torch.tensor(10)
|
>>> input_tensor = torch.tensor(10)
|
||||||
>>> static_size = 2
|
>>> static_size = 2
|
||||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
>>> t = torch.nonzero_static(input_tensor, size=static_size)
|
||||||
tensor([], size=(2, 0), dtype=torch.int64)
|
tensor([], size=(2, 0), dtype=torch.int64)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -6561,7 +6564,10 @@ Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> self = torch.tensor([0, 0, 0, 0, 0])
|
>>> self = torch.tensor([0, 0, 0, 0, 0])
|
||||||
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
|
>>> mask = torch.tensor(
|
||||||
|
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
|
||||||
|
... dtype=torch.bool,
|
||||||
|
... )
|
||||||
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||||
>>> self.masked_scatter(mask, source)
|
>>> self.masked_scatter(mask, source)
|
||||||
tensor([[0, 0, 0, 0, 1],
|
tensor([[0, 0, 0, 0, 1],
|
||||||
|
@ -10550,7 +10550,8 @@ Example:
|
|||||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||||
|
... ) # fmt: skip
|
||||||
>>> torch.std(a, dim=1, keepdim=True)
|
>>> torch.std(a, dim=1, keepdim=True)
|
||||||
tensor([[1.0311],
|
tensor([[1.0311],
|
||||||
[0.7477],
|
[0.7477],
|
||||||
@ -10608,7 +10609,8 @@ Example:
|
|||||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||||
|
... ) # fmt: skip
|
||||||
>>> torch.std_mean(a, dim=0, keepdim=True)
|
>>> torch.std_mean(a, dim=0, keepdim=True)
|
||||||
(tensor([[1.2620, 1.0028, 1.0957, 0.6038]]),
|
(tensor([[1.2620, 1.0028, 1.0957, 0.6038]]),
|
||||||
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
|
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
|
||||||
@ -11896,7 +11898,8 @@ Example:
|
|||||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||||
|
... ) # fmt: skip
|
||||||
>>> torch.var(a, dim=1, keepdim=True)
|
>>> torch.var(a, dim=1, keepdim=True)
|
||||||
tensor([[1.0631],
|
tensor([[1.0631],
|
||||||
[0.5590],
|
[0.5590],
|
||||||
@ -11953,7 +11956,8 @@ Example:
|
|||||||
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
|
||||||
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
... [ 1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
... [-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
|
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
|
||||||
|
... ) # fmt: skip
|
||||||
>>> torch.var_mean(a, dim=0, keepdim=True)
|
>>> torch.var_mean(a, dim=0, keepdim=True)
|
||||||
(tensor([[1.5926, 1.0056, 1.2005, 0.3646]]),
|
(tensor([[1.5926, 1.0056, 1.2005, 0.3646]]),
|
||||||
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
|
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
|
||||||
|
@ -212,13 +212,17 @@ def format_function_signature(
|
|||||||
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
|
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
|
||||||
return sig
|
return sig
|
||||||
|
|
||||||
|
arguments = [f" {arg}," for arg in arguments]
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
(
|
(
|
||||||
f"def {name}(",
|
f"def {name}(",
|
||||||
*(f" {arg}," for arg in arguments),
|
*(
|
||||||
|
arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on"
|
||||||
|
for arg in arguments
|
||||||
|
),
|
||||||
f"){return_type}: ...",
|
f"){return_type}: ...",
|
||||||
)
|
)
|
||||||
)
|
).replace(" # fmt: off\n # fmt: on\n", "")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
|||||||
# does not allow us to override __init__.
|
# does not allow us to override __init__.
|
||||||
seq_type = f"tuple[{', '.join(python_returns)}]"
|
seq_type = f"tuple[{', '.join(python_returns)}]"
|
||||||
structseq_def_lines = [
|
structseq_def_lines = [
|
||||||
f"class {structseq_name}({seq_type}):",
|
f"class {structseq_name}({seq_type}): # fmt: skip",
|
||||||
]
|
]
|
||||||
for name, ret_type in zip(field_names, python_returns):
|
for name, ret_type in zip(field_names, python_returns):
|
||||||
structseq_def_lines.extend(
|
structseq_def_lines.extend(
|
||||||
@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
|||||||
)
|
)
|
||||||
structseq_def_lines.extend(
|
structseq_def_lines.extend(
|
||||||
[
|
[
|
||||||
f" def __new__(cls, sequence: {seq_type}) -> Self: ...",
|
" def __new__(",
|
||||||
|
" cls,",
|
||||||
|
f" sequence: {seq_type},",
|
||||||
|
" ) -> Self: # fmt: skip",
|
||||||
|
" ...",
|
||||||
f" n_fields: Final[_int] = {len(field_names)}",
|
f" n_fields: Final[_int] = {len(field_names)}",
|
||||||
f" n_sequence_fields: Final[_int] = {len(field_names)}",
|
f" n_sequence_fields: Final[_int] = {len(field_names)}",
|
||||||
" n_unnamed_fields: Final[_int] = 0",
|
" n_unnamed_fields: Final[_int] = 0",
|
||||||
@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
|||||||
structseq_def = "\n".join(structseq_def_lines)
|
structseq_def = "\n".join(structseq_def_lines)
|
||||||
# Example:
|
# Example:
|
||||||
# structseq_def = (
|
# structseq_def = (
|
||||||
# "class max(tuple[Tensor, Tensor]):\n"
|
# "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
|
||||||
# " @property\n"
|
# " @property\n"
|
||||||
# " def values(self) -> Tensor: ...\n"
|
# " def values(self) -> Tensor: ...\n"
|
||||||
# " @property\n"
|
# " @property\n"
|
||||||
# " def indices(self) -> Tensor: ...\n"
|
# " def indices(self) -> Tensor: ...\n"
|
||||||
# " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n"
|
# " def __new__(\n"
|
||||||
|
# " cls,\n"
|
||||||
|
# " sequence: tuple[Tensor, Tensor],\n"
|
||||||
|
# " ) -> Self: # fmt: skip\n"
|
||||||
|
# " ...\n"
|
||||||
# " n_fields: Final[_int] = 2",
|
# " n_fields: Final[_int] = 2",
|
||||||
# " n_sequence_fields: Final[_int] = 2",
|
# " n_sequence_fields: Final[_int] = 2",
|
||||||
# " n_unnamed_fields: Final[_int] = 0",
|
# " n_unnamed_fields: Final[_int] = 0",
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
import re
|
import re
|
||||||
|
import textwrap
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
@ -45,9 +47,12 @@ class CodeTemplate:
|
|||||||
return kwargs[v] if v in kwargs else env[v]
|
return kwargs[v] if v in kwargs else env[v]
|
||||||
|
|
||||||
def indent_lines(indent: str, v: Sequence[object]) -> str:
|
def indent_lines(indent: str, v: Sequence[object]) -> str:
|
||||||
return "".join(
|
content = "\n".join(
|
||||||
[indent + l + "\n" for e in v for l in str(e).splitlines()]
|
itertools.chain.from_iterable(str(e).splitlines() for e in v)
|
||||||
).rstrip()
|
)
|
||||||
|
content = textwrap.indent(content, prefix=indent)
|
||||||
|
# Remove trailing whitespace on each line
|
||||||
|
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||||
|
|
||||||
def replace(match: re.Match[str]) -> str:
|
def replace(match: re.Match[str]) -> str:
|
||||||
indent = match.group(1)
|
indent = match.group(1)
|
||||||
|
Reference in New Issue
Block a user