mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix torch.return_types init signature (#119284)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119284 Approved by: https://github.com/peterbell10, https://github.com/XuehaiPan
This commit is contained in:
committed by
PyTorch MergeBot
parent
623632a401
commit
c3496d50f0
@ -882,7 +882,7 @@ def signature_from_schema(
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
|
||||
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
if len(returns) <= 1 or all(r.name is None for r in returns):
|
||||
return []
|
||||
else:
|
||||
@ -894,7 +894,7 @@ def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
# PyStructSequence_UnnamedField
|
||||
#
|
||||
# Thus, at this point in time, we do not support unnamed
|
||||
# fields in namedtuple; you must either name all fields,
|
||||
# fields in structseq; you must either name all fields,
|
||||
# or none of them.
|
||||
raise ValueError("Unnamed field is not supported by codegen")
|
||||
|
||||
@ -993,29 +993,56 @@ def return_type_str_pyi(t: Type) -> str:
|
||||
return argument_type_str_pyi(t)
|
||||
|
||||
|
||||
def returns_named_tuple_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
|
||||
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
|
||||
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
||||
namedtuple_name = signature.name
|
||||
field_names = namedtuple_fieldnames(signature.returns.returns)
|
||||
structseq_name = signature.name
|
||||
field_names = structseq_fieldnames(signature.returns.returns)
|
||||
if field_names:
|
||||
namedtuple_def_lines = [f"class {namedtuple_name}(NamedTuple):"]
|
||||
namedtuple_def_lines.extend(
|
||||
f" {name}: {typ}" for name, typ in zip(field_names, python_returns)
|
||||
# These types are structseq objects which act like named NamedTuples, but
|
||||
# the constructor acts like the constructor of tuple. Using typing.NamedTuple
|
||||
# does not allow us to override __init__.
|
||||
field_names_str = ", ".join(repr(name) for name in field_names)
|
||||
seq_type = f"Tuple[{', '.join(python_returns)}]"
|
||||
structseq_def_lines = [
|
||||
f"class {structseq_name}({seq_type}):",
|
||||
]
|
||||
for name, typ in zip(field_names, python_returns):
|
||||
structseq_def_lines.extend(
|
||||
[
|
||||
" @property",
|
||||
f" def {name}(self) -> {typ}: ...",
|
||||
]
|
||||
)
|
||||
structseq_def_lines.extend(
|
||||
[
|
||||
f" def __new__(cls, sequence: {seq_type}): ...",
|
||||
f" n_fields: _int = {len(field_names)}",
|
||||
f" n_sequeunce_fields: _int = {len(field_names)}",
|
||||
" n_unnamed_fields: _int = 0",
|
||||
" def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
|
||||
"", # add an extra newline
|
||||
]
|
||||
)
|
||||
namedtuple_def_lines.append("") # add an extra newline
|
||||
namedtuple_def = "\n".join(namedtuple_def_lines)
|
||||
structseq_def = "\n".join(structseq_def_lines)
|
||||
# Example:
|
||||
# namedtuple_def = (
|
||||
# "class max(NamedTuple):\n"
|
||||
# " values: Tensor\n"
|
||||
# " indices: Tensor\n"
|
||||
# structseq_def = (
|
||||
# "class max(Tuple[Tensor, Tensor]):\n"
|
||||
# " @property\n"
|
||||
# " def values(self) -> Tensor: ...\n"
|
||||
# " @property\n"
|
||||
# " def indices(self) -> Tensor: ...\n"
|
||||
# " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
|
||||
# " n_fields: _int = 2",
|
||||
# " n_sequeunce_fields: _int = 2",
|
||||
# " n_unnamed_fields: _int = 0",
|
||||
# " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
|
||||
# )
|
||||
return namedtuple_name, namedtuple_def
|
||||
return structseq_name, structseq_def
|
||||
return None
|
||||
|
||||
|
||||
def returns_str_pyi(signature: PythonSignature) -> str:
|
||||
field_names = namedtuple_fieldnames(signature.returns.returns)
|
||||
field_names = structseq_fieldnames(signature.returns.returns)
|
||||
if field_names:
|
||||
return f"torch.return_types.{signature.name}"
|
||||
|
||||
|
Reference in New Issue
Block a user