[BE] Ensure generated stub files by gen_pyi are properly formatted (#150730)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730
Approved by: https://github.com/aorenste
This commit is contained in:
Xuehai Pan
2025-05-17 14:29:21 +08:00
committed by PyTorch MergeBot
parent 7ebea09986
commit 27f7b65a69
5 changed files with 84 additions and 31 deletions

View File

@ -212,13 +212,17 @@ def format_function_signature(
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
return sig
arguments = [f" {arg}," for arg in arguments]
return "\n".join(
(
f"def {name}(",
*(f" {arg}," for arg in arguments),
*(
arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on"
for arg in arguments
),
f"){return_type}: ...",
)
)
).replace(" # fmt: off\n # fmt: on\n", "")
@dataclass(frozen=True)
@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
# does not allow us to override __init__.
seq_type = f"tuple[{', '.join(python_returns)}]"
structseq_def_lines = [
f"class {structseq_name}({seq_type}):",
f"class {structseq_name}({seq_type}): # fmt: skip",
]
for name, ret_type in zip(field_names, python_returns):
structseq_def_lines.extend(
@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
)
structseq_def_lines.extend(
[
f" def __new__(cls, sequence: {seq_type}) -> Self: ...",
" def __new__(",
" cls,",
f" sequence: {seq_type},",
" ) -> Self: # fmt: skip",
" ...",
f" n_fields: Final[_int] = {len(field_names)}",
f" n_sequence_fields: Final[_int] = {len(field_names)}",
" n_unnamed_fields: Final[_int] = 0",
@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
structseq_def = "\n".join(structseq_def_lines)
# Example:
# structseq_def = (
# "class max(tuple[Tensor, Tensor]):\n"
# "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
# " @property\n"
# " def values(self) -> Tensor: ...\n"
# " @property\n"
# " def indices(self) -> Tensor: ...\n"
# " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n"
# " def __new__(\n"
# " cls,\n"
# " sequence: tuple[Tensor, Tensor],\n"
# " ) -> Self: # fmt: skip\n"
# " ...\n"
# " n_fields: Final[_int] = 2",
# " n_sequence_fields: Final[_int] = 2",
# " n_unnamed_fields: Final[_int] = 0",

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import itertools
import re
import textwrap
from typing import TYPE_CHECKING
@ -45,9 +47,12 @@ class CodeTemplate:
return kwargs[v] if v in kwargs else env[v]
def indent_lines(indent: str, v: Sequence[object]) -> str:
return "".join(
[indent + l + "\n" for e in v for l in str(e).splitlines()]
).rstrip()
content = "\n".join(
itertools.chain.from_iterable(str(e).splitlines() for e in v)
)
content = textwrap.indent(content, prefix=indent)
# Remove trailing whitespace on each line
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
def replace(match: re.Match[str]) -> str:
indent = match.group(1)