Fix torch.return_types init signature (#119284)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119284
Approved by: https://github.com/peterbell10, https://github.com/XuehaiPan
This commit is contained in:
Isuru Fernando
2024-02-23 18:06:20 +00:00
committed by PyTorch MergeBot
parent 623632a401
commit c3496d50f0
5 changed files with 93 additions and 67 deletions

View File

@ -882,7 +882,7 @@ def signature_from_schema(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
if len(returns) <= 1 or all(r.name is None for r in returns):
return []
else:
@ -894,7 +894,7 @@ def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
# PyStructSequence_UnnamedField
#
# Thus, at this point in time, we do not support unnamed
# fields in namedtuple; you must either name all fields,
# fields in structseq; you must either name all fields,
# or none of them.
raise ValueError("Unnamed field is not supported by codegen")
@ -993,29 +993,56 @@ def return_type_str_pyi(t: Type) -> str:
return argument_type_str_pyi(t)
def returns_named_tuple_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
namedtuple_name = signature.name
field_names = namedtuple_fieldnames(signature.returns.returns)
structseq_name = signature.name
field_names = structseq_fieldnames(signature.returns.returns)
if field_names:
namedtuple_def_lines = [f"class {namedtuple_name}(NamedTuple):"]
namedtuple_def_lines.extend(
f" {name}: {typ}" for name, typ in zip(field_names, python_returns)
# These types are structseq objects which act like named NamedTuples, but
# the constructor acts like the constructor of tuple. Using typing.NamedTuple
# does not allow us to override __init__.
field_names_str = ", ".join(repr(name) for name in field_names)
seq_type = f"Tuple[{', '.join(python_returns)}]"
structseq_def_lines = [
f"class {structseq_name}({seq_type}):",
]
for name, typ in zip(field_names, python_returns):
structseq_def_lines.extend(
[
" @property",
f" def {name}(self) -> {typ}: ...",
]
)
structseq_def_lines.extend(
[
f" def __new__(cls, sequence: {seq_type}): ...",
f" n_fields: _int = {len(field_names)}",
f" n_sequeunce_fields: _int = {len(field_names)}",
" n_unnamed_fields: _int = 0",
" def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
"", # add an extra newline
]
)
namedtuple_def_lines.append("") # add an extra newline
namedtuple_def = "\n".join(namedtuple_def_lines)
structseq_def = "\n".join(structseq_def_lines)
# Example:
# namedtuple_def = (
# "class max(NamedTuple):\n"
# " values: Tensor\n"
# " indices: Tensor\n"
# structseq_def = (
# "class max(Tuple[Tensor, Tensor]):\n"
# " @property\n"
# " def values(self) -> Tensor: ...\n"
# " @property\n"
# " def indices(self) -> Tensor: ...\n"
# " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
# " n_fields: _int = 2",
# " n_sequeunce_fields: _int = 2",
# " n_unnamed_fields: _int = 0",
# " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
# )
return namedtuple_name, namedtuple_def
return structseq_name, structseq_def
return None
def returns_str_pyi(signature: PythonSignature) -> str:
field_names = namedtuple_fieldnames(signature.returns.returns)
field_names = structseq_fieldnames(signature.returns.returns)
if field_names:
return f"torch.return_types.{signature.name}"