mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Ensure generated stub files by gen_pyi
are properly formatted (#150730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ebea09986
commit
27f7b65a69
@ -212,13 +212,17 @@ def format_function_signature(
|
||||
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
|
||||
return sig
|
||||
|
||||
arguments = [f" {arg}," for arg in arguments]
|
||||
return "\n".join(
|
||||
(
|
||||
f"def {name}(",
|
||||
*(f" {arg}," for arg in arguments),
|
||||
*(
|
||||
arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on"
|
||||
for arg in arguments
|
||||
),
|
||||
f"){return_type}: ...",
|
||||
)
|
||||
)
|
||||
).replace(" # fmt: off\n # fmt: on\n", "")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
# does not allow us to override __init__.
|
||||
seq_type = f"tuple[{', '.join(python_returns)}]"
|
||||
structseq_def_lines = [
|
||||
f"class {structseq_name}({seq_type}):",
|
||||
f"class {structseq_name}({seq_type}): # fmt: skip",
|
||||
]
|
||||
for name, ret_type in zip(field_names, python_returns):
|
||||
structseq_def_lines.extend(
|
||||
@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
)
|
||||
structseq_def_lines.extend(
|
||||
[
|
||||
f" def __new__(cls, sequence: {seq_type}) -> Self: ...",
|
||||
" def __new__(",
|
||||
" cls,",
|
||||
f" sequence: {seq_type},",
|
||||
" ) -> Self: # fmt: skip",
|
||||
" ...",
|
||||
f" n_fields: Final[_int] = {len(field_names)}",
|
||||
f" n_sequence_fields: Final[_int] = {len(field_names)}",
|
||||
" n_unnamed_fields: Final[_int] = 0",
|
||||
@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
structseq_def = "\n".join(structseq_def_lines)
|
||||
# Example:
|
||||
# structseq_def = (
|
||||
# "class max(tuple[Tensor, Tensor]):\n"
|
||||
# "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
|
||||
# " @property\n"
|
||||
# " def values(self) -> Tensor: ...\n"
|
||||
# " @property\n"
|
||||
# " def indices(self) -> Tensor: ...\n"
|
||||
# " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n"
|
||||
# " def __new__(\n"
|
||||
# " cls,\n"
|
||||
# " sequence: tuple[Tensor, Tensor],\n"
|
||||
# " ) -> Self: # fmt: skip\n"
|
||||
# " ...\n"
|
||||
# " n_fields: Final[_int] = 2",
|
||||
# " n_sequence_fields: Final[_int] = 2",
|
||||
# " n_unnamed_fields: Final[_int] = 0",
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@ -45,9 +47,12 @@ class CodeTemplate:
|
||||
return kwargs[v] if v in kwargs else env[v]
|
||||
|
||||
def indent_lines(indent: str, v: Sequence[object]) -> str:
|
||||
return "".join(
|
||||
[indent + l + "\n" for e in v for l in str(e).splitlines()]
|
||||
).rstrip()
|
||||
content = "\n".join(
|
||||
itertools.chain.from_iterable(str(e).splitlines() for e in v)
|
||||
)
|
||||
content = textwrap.indent(content, prefix=indent)
|
||||
# Remove trailing whitespace on each line
|
||||
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
indent = match.group(1)
|
||||
|
Reference in New Issue
Block a user