Files
pytorch/tools/pyi/gen_pyi.py
orangeH25 20eeb54814 Add api info for torch._C._nn.pyi (#162936)
Fix part of #148404

APis involved are as followed:

- silu
- silu_
- smooth_l1_loss
- soft_margin_loss
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162936
Approved by: https://github.com/FFFrog, https://github.com/ezyang
2025-09-24 04:55:57 +00:00

2123 lines
69 KiB
Python

"""
This module implements generation of type stubs for PyTorch,
enabling use of autocomplete in IDEs like PyCharm, which otherwise
don't understand C extension modules.
At the moment, this module only handles type stubs for torch and
torch.Tensor. It should eventually be expanded to cover all functions
which come are autogenerated.
Here's our general strategy:
- We start off with a hand-written __init__.pyi.in file. This
file contains type definitions for everything we cannot automatically
generate, including pure Python definitions directly in __init__.py
(the latter case should be pretty rare).
- We go through automatically bound functions based on the
type information recorded in native_functions.yaml and
generate type hints for them (generate_type_hints)
There are a number of type hints which we've special-cased;
read gen_pyi for the gory details.
"""
from __future__ import annotations
import argparse
import collections
import importlib
import inspect
import sys
import textwrap
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch
from warnings import warn
from tools.autograd.gen_python_functions import (
group_overloads,
load_signatures,
should_generate_py_binding,
)
from torchgen.api.python import (
format_function_signature as defs,
PythonSignatureGroup,
PythonSignatureNativeFunctionPair,
returns_structseq_pyi,
)
from torchgen.gen import parse_native_yaml, parse_tags_yaml
from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
from torchgen.utils import FileManager
if TYPE_CHECKING:
from collections.abc import Sequence
def get_py_torch_functions(
python_funcs: Sequence[PythonSignatureNativeFunctionPair],
method: bool = False,
) -> Sequence[PythonSignatureGroup]:
"""
Get declarations (grouped by name) which should be generated
as either functions in the "torch" module or methods on Tensor.
"""
def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
return (
should_generate_py_binding(python_func.function)
and not python_func.function.python_module
and Variant.function in python_func.function.variants
)
def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
return (
should_generate_py_binding(python_func.function)
and not python_func.function.python_module
and Variant.method in python_func.function.variants
)
should_bind = should_bind_method if method else should_bind_function
return group_overloads([f for f in python_funcs if should_bind(f)])
# TODO: Consider defining some aliases for our Union[...] types, to make
# the stubs to read on the human eye.
DEVICE_PARAM = "device: DeviceLikeType | None = None"
FACTORY_PARAMS = [
"dtype: _dtype | None = None",
DEVICE_PARAM,
"requires_grad: _bool = False",
"pin_memory: _bool = False",
]
# NOTE: specifying indices for Tensor.__getitem__
# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi:
#
# key: (
# slice
# | EllipsisType
# | None
# | _ArrayLikeInt_co
# | SupportsIndex
# | tuple[slice | EllipsisType | None | _ArrayLikeInt_co | SupportsIndex, ...]
# )
#
# where:
#
# _ArrayLikeInt_co = _DualArrayLike[
# dtype[bool_ | integer[Any]],
# bool | int,
# ]
#
# and
#
# _DualArrayLike = (
# _SupportsArray[_DType]
# | _NestedSequence[_SupportsArray[_DType]]
# | _T
# | _NestedSequence[_T]
# )
#
# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple.
# We can substitute and simplify:
# _SupportsArray -> Tensor
# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]]
# which leaves us with key: T | tuple[T, ...], where T is:
# T = (
# SupportsIndex | bool | int | slice | EllipsisType | None
# | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int]
# )
_leaf_types = (
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
)
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip"
INDICES = "indices: _Index | tuple[_Index, ...]"
blocklist = [
"__init_subclass__",
"__new__",
"__subclasshook__",
"cdist",
"device",
"grad",
"requires_grad",
"range",
# defined in functional
"einsum",
# Somehow, these are defined in both _C and in functional. Ick!
"broadcast_tensors",
# Manually define named tensor type stubs in __init__.pyi.in
"align_tensors",
"meshgrid",
"cartesian_prod",
"block_diag",
"norm",
"chain_matmul",
"stft",
"tensordot",
"split",
"unique_consecutive",
"atleast_1d",
"atleast_2d",
"atleast_3d",
# These are handled specially by python_arg_parser.cpp
"add",
"add_",
"add_out",
"sub",
"sub_",
"sub_out",
"mul",
"mul_",
"mul_out",
"div",
"div_",
"div_out",
"true_divide",
"true_divide_",
"true_divide_out",
"floor_divide",
"floor_divide_",
"floor_divide_out",
"to",
"_to_copy",
"copy_",
]
shift_ops = (
"lshift",
"rshift",
"ilshift",
"irshift", # inplace ops
)
arithmetic_ops = (
"add",
"sub",
"mul",
"div",
"pow",
"mod",
"truediv",
"matmul",
"floordiv",
"radd",
"rsub",
"rmul",
"rtruediv",
"rfloordiv",
"rpow", # reverse arithmetic
"iadd",
"idiv",
"imul",
"isub",
"ifloordiv",
"imod", # inplace ops
)
logic_ops = (
"and",
"or",
"xor",
"rand",
"ror",
"rxor", # reverse logic
"iand",
"ior",
"ixor", # inplace ops
)
binary_ops = shift_ops + arithmetic_ops + logic_ops
symmetric_comparison_ops = ("eq", "ne")
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
unary_ops = ("neg", "abs", "invert")
to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname: str) -> list[str]:
"""sig_for_ops(opname : str) -> list[str]
Returns signatures for operator special functions (__add__ etc.)"""
# we have to do this by hand, because they are hand-bound in Python
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
name = opname[2:-2]
if name == "rpow":
return [ # somehow required to make mypy ci happy?
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
]
elif name in arithmetic_ops:
if name.startswith("i"):
# In-place binary-operation dunder methods, like `__iadd__`, should return `Self`
return [
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034"
]
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
elif name in logic_ops:
return [f"def {opname}(self, other: Tensor | _int) -> Tensor: ..."]
elif name in shift_ops:
return [f"def {opname}(self, other: Tensor | _int) -> Tensor: ..."]
elif name in symmetric_comparison_ops:
return [
# unsafe override https://github.com/python/mypy/issues/5704
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[overload-overlap]",
f"def {opname}(self, other: object) -> _bool: ...",
]
elif name in asymmetric_comparison_ops:
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
elif name in unary_ops:
return [f"def {opname}(self) -> Tensor: ..."]
if name in to_py_type_ops:
if name in {"bool", "float", "complex"}:
tname = name
elif name == "nonzero":
tname = "bool"
else:
tname = "int"
if tname in {"float", "int", "bool", "complex"}:
tname = "_" + tname
return [f"def {opname}(self) -> {tname}: ..."]
raise ValueError(f"unknown op {opname!r}")
def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
type_hints: list[str] = []
# Some deprecated ops that are on the blocklist are still included in pyi
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
return type_hints
# deprecated signatures have separate entries for their functional and out variants
# (as opposed to the native ops, which fuse the two into a single signature).
# generate the functional variant here, if an out variant exists.
if sig_group.signature.deprecated and sig_group.outplace is not None:
type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
type_hints.append(type_hint)
# PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
# Generates the out variant if one exists. Otherwise, generate the functional variant
type_hint = sig_group.signature.signature_str_pyi(
skip_outputs=sig_group.outplace is None
)
type_hints.append(type_hint)
# Some operators also additionally have a vararg variant of their signature
type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
skip_outputs=sig_group.outplace is None
)
if type_hint_vararg:
type_hints.append(type_hint_vararg)
return type_hints
def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
flag_pos = arg_list.index("{return_indices}")
# If return_indices is positional arg, everything before should have no default
arg_list_positional = (
[
", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", "))
for arg in arg_list[: flag_pos + 1]
]
+ ["/"]
+ arg_list[flag_pos + 1 :]
)
# Otherwise force return_indices to be kwarg
arg_list_keyword = arg_list.copy()
arg_list_keyword.insert(flag_pos, "*")
return {
name: [
defs(
name,
[
arg.format(return_indices="return_indices: Literal[False] = False")
for arg in arg_list
],
"Tensor",
),
defs(
name,
[
arg.format(return_indices="return_indices: Literal[True]")
for arg in arg_list_positional
],
"tuple[Tensor, Tensor]",
),
defs(
name,
[
arg.format(return_indices="return_indices: Literal[True]")
for arg in arg_list_keyword
],
"tuple[Tensor, Tensor]",
),
]
}
def gen_nn_functional(fm: FileManager) -> None:
INPUT = "input: Tensor"
KERNEL_SIZE = "kernel_size: _int | _size"
STRIDE_PADDING = [
"stride: _int | _size | None = None",
"padding: _int | _size = 0",
]
# TODO the list for `torch._C._nn` is nonexhaustive
unsorted_c_nn_function_hints: dict[str, list[str]] = {}
for d in (2, 3):
unsorted_c_nn_function_hints.update(
{
f"avg_pool{d}d": [
defs(
f"avg_pool{d}d",
[
INPUT,
KERNEL_SIZE,
*STRIDE_PADDING,
"ceil_mode: bool = False",
"count_include_pad: bool = True",
"divisor_override: int | None = None",
],
"Tensor",
)
],
f"fractional_max_pool{d}d": [
defs(
f"fractional_max_pool{d}d",
[
INPUT,
KERNEL_SIZE,
"output_size: _int | _size",
"_random_samples: Tensor",
],
"tuple[Tensor, Tensor]",
)
],
f"adaptive_max_pool{d}d": [
defs(
f"adaptive_max_pool{d}d",
[
INPUT,
"output_size: _int | _size",
],
"tuple[Tensor, Tensor]",
)
],
f"adaptive_avg_pool{d}d": [
defs(
f"adaptive_avg_pool{d}d",
[
INPUT,
"output_size: _int | _size",
],
"Tensor",
)
],
f"max_pool{d}d_with_indices": [
defs(
f"max_pool{d}d_with_indices",
[
INPUT,
KERNEL_SIZE,
*STRIDE_PADDING,
"dilation: _int | _size = 1",
"ceil_mode: bool = False",
],
"tuple[Tensor, Tensor]",
)
],
}
)
unsorted_c_nn_function_hints.update(
{
"hardtanh": [
defs(
"hardtanh",
[
"input: Tensor",
"min_val: float = ...",
"max_val: float = ...",
"*",
"out: Tensor | None = None",
],
"Tensor",
)
],
"hardtanh_": [
defs(
"hardtanh_",
["input: Tensor", "min_val: float = ...", "max_val: float = ..."],
"Tensor",
),
],
"elu_": [defs("elu_", ["input: Tensor", "alpha: float = ..."], "Tensor")],
"leaky_relu": [
defs(
"leaky_relu",
[
"input: Tensor",
"negative_slope: float = ...",
"*",
"out: Tensor | None = None",
],
"Tensor",
)
],
"leaky_relu_": [
defs(
"leaky_relu_",
["input: Tensor", "negative_slope: float = ..."],
"Tensor",
)
],
"log_sigmoid": [defs("log_sigmoid", ["input: Tensor"], "Tensor")],
"gelu": [
defs("gelu", ["input: Tensor", "approximate: str = ..."], "Tensor")
],
"softplus": [
defs(
"softplus",
["input: Tensor", "beta: float = ...", "threshold: float = ..."],
"Tensor",
)
],
"softshrink": [
defs("softshrink", ["input: Tensor", "lambd: float = ..."], "Tensor")
],
"hardsigmoid": [
defs(
"hardsigmoid",
["input: Tensor", "*", "out: Tensor | None = None"],
"Tensor",
)
],
"linear": [
defs(
"linear",
["input: Tensor", "weight: Tensor", "bias: Tensor | None = None"],
"Tensor",
)
],
"pad": [
defs(
"pad",
[
"input: Tensor",
"pad: Sequence[int]",
"mode: str = ...",
"value: float | None = None",
],
"Tensor",
)
],
"one_hot": [
defs("one_hot", ["tensor: Tensor", "num_classes: int = ..."], "Tensor")
],
"scaled_dot_product_attention": [
defs(
"scaled_dot_product_attention",
[
"query: Tensor",
"key: Tensor",
"value: Tensor",
"attn_mask: Tensor | None = None",
"dropout_p: float = 0.0",
"is_causal: bool = False",
"scale: float | None = None",
"enable_gqa: bool = False",
],
"Tensor",
)
],
"binary_cross_entropy": [
defs(
"binary_cross_entropy",
[
INPUT,
"target: Tensor",
"weight: Tensor | None = None",
"reduction: str = ...",
],
"Tensor",
)
],
"col2im": [
defs(
"col2im",
[
INPUT,
"output_size: _int | _size",
KERNEL_SIZE,
"dilation: _int | _size",
*STRIDE_PADDING,
],
"Tensor",
)
],
"elu": [
defs(
"elu",
[
INPUT,
"alpha: float = 1.0",
"scale: float = 1.0",
"input_scale: float = 1.0",
],
"Tensor",
)
],
"glu": [
defs(
"glu",
[
INPUT,
"dim: int = -1",
],
"Tensor",
)
],
"max_unpool2d": [
defs(
"max_unpool2d",
[
INPUT,
"indices: Tensor",
"output_size: Sequence[int] | None",
],
"Tensor",
)
],
"max_unpool3d": [
defs(
"max_unpool3d",
[
INPUT,
"indices: Tensor",
"output_size: Sequence[int] | None",
"stride: _int | _size",
"padding: _int | _size",
],
"Tensor",
)
],
"cross_entropy_loss": [
defs(
"cross_entropy_loss",
[
INPUT,
"target: Tensor",
"weight: Tensor | None = None",
"reduction: str = ...",
"ignore_index: int = -100",
"label_smoothing: float = 0.0",
],
"Tensor",
)
],
"hardsigmoid_": [
defs(
"hardsigmoid_",
[
INPUT,
],
"Tensor",
)
],
"hardswish": [
defs(
"hardswish",
[
INPUT,
],
"Tensor",
)
],
"hardswish_": [
defs(
"hardswish_",
[
INPUT,
],
"Tensor",
)
],
"huber_loss": [
defs(
"huber_loss",
[
INPUT,
"target: Tensor",
"reduction: str = ...",
"delta: float = 1.0",
],
"Tensor",
)
],
"im2col": [
defs(
"im2col",
[
INPUT,
KERNEL_SIZE,
"dilation: _int | _size",
"padding: _int | _size",
"stride: _int | _size",
],
"Tensor",
)
],
"l1_loss": [
defs(
"l1_loss",
[
INPUT,
"target: Tensor",
"reduction: str = ...",
],
"Tensor",
)
],
"mish": [
defs(
"mish",
[
INPUT,
],
"Tensor",
)
],
"mish_": [
defs(
"mish_",
[
INPUT,
],
"Tensor",
)
],
"mse_loss": [
defs(
"mse_loss",
[
INPUT,
"target: Tensor",
"reduction: str = ...",
],
"Tensor",
)
],
"multilabel_margin_loss": [
defs(
"multilabel_margin_loss",
[
INPUT,
"target: Tensor",
"reduction: str = ...",
],
"Tensor",
)
],
"multi_margin_loss": [
defs(
"multi_margin_loss",
[
INPUT,
"target: Tensor",
"p: float = 1.0",
"margin: float = 1.0",
"weight: Tensor | None = None",
"reduction: str = ...",
],
"Tensor",
)
],
"nll_loss_nd": [
defs(
"nll_loss_nd",
[
INPUT,
"target: Tensor",
"weight: Tensor | None = None",
"reduction: str = ...",
"ignore_index: int = -100",
],
"Tensor",
)
],
"relu6": [
defs(
"relu6",
[
INPUT,
],
"Tensor",
)
],
"relu6_": [
defs(
"relu6_",
[
INPUT,
],
"Tensor",
)
],
"silu": [
defs(
"silu",
[
INPUT,
],
"Tensor",
)
],
"silu_": [
defs(
"silu_",
[
INPUT,
],
"Tensor",
)
],
"smooth_l1_loss": [
defs(
"smooth_l1_loss",
[
INPUT,
"target: Tensor",
"reduction: str = ...",
"beta: float = 1.0",
],
"Tensor",
)
],
"soft_margin_loss": [
defs(
"soft_margin_loss",
[
INPUT,
"target: Tensor",
"reduction: str = ...",
],
"Tensor",
)
],
}
)
c_nn_function_hints: list[str] = []
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
c_nn_function_hints += hints
extra_nn_functional___all__: list[str] = []
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
# through an `_add_docstr` call
torch_imports = [
"adaptive_avg_pool1d",
"avg_pool1d",
"bilinear",
"celu_",
"channel_shuffle",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"conv1d",
"conv2d",
"conv3d",
"cosine_similarity",
"hardshrink",
"native_channel_shuffle",
"pairwise_distance",
"pdist",
"pixel_shuffle",
"pixel_unshuffle",
"prelu",
"relu_",
"rrelu_",
"selu_",
]
imported_hints = [
"from torch import (",
*sorted(f" {name} as {name}," for name in torch_imports),
")",
]
extra_nn_functional___all__.extend(torch_imports)
# Functions imported into `torch.nn.functional` from `torch._C._nn`
c_nn_imports = [
"avg_pool2d",
"avg_pool3d",
"elu_",
"gelu",
"hardtanh_",
"leaky_relu_",
"linear",
"log_sigmoid",
"one_hot",
"pad",
"scaled_dot_product_attention",
"softplus",
"softshrink",
]
renamed = {"log_sigmoid": "logsigmoid"}
imported_hints += [
"from torch._C._nn import (",
*sorted(f" {name} as {renamed.get(name, name)}," for name in c_nn_imports),
")",
]
extra_nn_functional___all__.extend(renamed.get(name, name) for name in c_nn_imports)
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
unsorted_dispatched_hints: dict[str, list[str]] = {}
for d in (1, 2, 3):
unsorted_dispatched_hints.update(
**get_max_pool_dispatch(
f"max_pool{d}d",
[
INPUT,
KERNEL_SIZE,
*STRIDE_PADDING,
"dilation: _int | _size = 1",
"ceil_mode: bool = False",
"{return_indices}",
],
),
**get_max_pool_dispatch(
f"fractional_max_pool{d}d",
[
INPUT,
KERNEL_SIZE,
"output_size: _int | _size | None = None",
"output_ratio: _ratio_any_t | None = None",
"{return_indices}",
"_random_samples: Tensor | None = None",
],
),
**get_max_pool_dispatch(
f"adaptive_max_pool{d}d",
[
INPUT,
"output_size: _int | _size",
"{return_indices}",
],
),
)
# There's no fractional_max_pool1d
del unsorted_dispatched_hints["fractional_max_pool1d"]
extra_nn_functional___all__.extend(unsorted_dispatched_hints)
dispatched_hints: list[str] = []
for _, hints in sorted(unsorted_dispatched_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
dispatched_hints += hints
extra_nn_functional___all__ = [
"__all__ += [",
*(f' "{name}",' for name in extra_nn_functional___all__),
"]",
]
fm.write_with_template(
"torch/nn/functional.pyi",
"torch/nn/functional.pyi.in",
lambda: {
"imported_hints": imported_hints,
"dispatched_hints": dispatched_hints,
"extra_nn_functional___all__": extra_nn_functional___all__,
},
)
fm.write_with_template(
"torch/_C/_nn.pyi",
"torch/_C/_nn.pyi.in",
lambda: {
"c_nn_function_hints": c_nn_function_hints,
},
)
"""
We gather the docstrings for torch with the following steps:
1. Mock torch and torch._C, which are the only dependencies of the docs files
2. Mock the _add_docstr function to save the docstrings
3. Import the docs files to trigger mocked _add_docstr and collect docstrings
"""
def gather_docstrs() -> dict[str, str]:
docstrs = {}
def mock_add_docstr(func: Mock, docstr: str) -> None:
docstrs[func._extract_mock_name()] = docstr.strip()
# sys.modules and sys.path are restored after the context manager exits
with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
# mock the torch module and torch._C._add_docstr
sys.modules["torch"] = Mock(name="torch")
sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
try:
# manually import torch._torch_docs and torch._tensor_docs to trigger
# the mocked _add_docstr and collect docstrings
sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
except ModuleNotFoundError:
# Gracefully fail if these modules are not importable
warn(
"Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files."
)
return docstrs
def add_docstr_to_hint(docstr: str, hint: str) -> str:
docstr = inspect.cleandoc(docstr).strip()
if "..." in hint: # function or method
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
hint = hint.removesuffix("...").rstrip() # remove "..."
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
# Remove trailing whitespace on each line
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
# attribute or property
return f'{hint}\nr"""{docstr}"""'
def gen_pyi(
native_yaml_path: str,
tags_yaml_path: str,
deprecated_yaml_path: str,
fm: FileManager,
) -> None:
"""gen_pyi()
This function generates a pyi file for torch.
"""
# Some of this logic overlaps with generate_python_signature in
# tools/autograd/gen_python_functions.py; however, this
# function is all about generating mypy type signatures, whereas
# the other function generates are custom format for argument
# checking. If you are update this, consider if your change
# also needs to update the other file.
# Dictionary for NamedTuple definitions
structseqs: dict[str, str] = {}
# Generate type signatures for top-level functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
for n, n1, n2 in [
("csr", "crow", "col"),
("csc", "ccol", "row"),
("bsr", "crow", "col"),
("bsc", "ccol", "row"),
]:
unsorted_function_hints.update(
{
f"sparse_{n}_tensor": [
defs(
f"sparse_{n}_tensor",
[
f"{n1}_indices: Tensor | list",
f"{n2}_indices: Tensor | list",
"values: Tensor | list",
"size: _size | None = None",
"*",
"dtype: _dtype | None = None",
"device: DeviceLikeType | None = None",
"requires_grad: _bool = False",
"check_invariants: _bool | None = None",
],
"Tensor",
)
],
}
)
unsorted_function_hints.update(
{
"set_flush_denormal": [
defs("set_flush_denormal", ["mode: _bool"], "_bool")
],
"get_default_dtype": [defs("get_default_dtype", [], "_dtype")],
"asarray": [
defs(
"asarray",
[
"obj: Any",
"*",
"dtype: _dtype | None = None",
"device: DeviceLikeType | None = None",
"copy: _bool | None = None",
"requires_grad: _bool = False",
],
"Tensor",
)
],
"from_numpy": [defs("from_numpy", ["ndarray"], "Tensor")],
"frombuffer": [
defs(
"frombuffer",
[
"buffer: Any",
"*",
"dtype: _dtype",
"count: int = -1",
"offset: int = 0",
"requires_grad: _bool = False",
],
"Tensor",
)
],
"numel": [defs("numel", ["self: Tensor"], "_int")],
"as_tensor": [
defs(
"as_tensor",
["data: Any", "dtype: _dtype | None = None", DEVICE_PARAM],
"Tensor",
)
],
"get_num_threads": [defs("get_num_threads", [], "_int")],
"set_num_threads": [defs("set_num_threads", ["num: _int"], "None")],
"init_num_threads": [defs("init_num_threads", [], "None")],
"get_num_interop_threads": [defs("get_num_interop_threads", [], "_int")],
"set_num_interop_threads": [
defs("set_num_interop_threads", ["num: _int"], "None")
],
# These functions are explicitly disabled by
# SKIP_PYTHON_BINDINGS because they are hand bound.
# Correspondingly, we must hand-write their signatures.
"tensor": [defs("tensor", ["data: Any", *FACTORY_PARAMS], "Tensor")],
"sparse_coo_tensor": [
defs(
"sparse_coo_tensor",
[
"indices: Tensor",
"values: Tensor | list",
"size: _size | None = None",
"*",
"dtype: _dtype | None = None",
"device: DeviceLikeType | None = None",
"requires_grad: _bool = False",
"check_invariants: _bool | None = None",
"is_coalesced: _bool | None = None",
],
"Tensor",
)
],
"sparse_compressed_tensor": [
defs(
"sparse_compressed_tensor",
[
"compressed_indices: Tensor | list",
"plain_indices: Tensor | list",
"values: Tensor | list",
"size: _size | None = None",
"*",
"dtype: _dtype | None = None",
"layout: _layout | None = None",
"device: DeviceLikeType | None = None",
"requires_grad: _bool = False",
"check_invariants: _bool | None = None",
],
"Tensor",
)
],
"_sync": [defs("_sync", ["t: Tensor"], "None")],
"_is_functional_tensor": [
defs("_is_functional_tensor", ["t: Tensor"], "_bool")
],
"_is_functional_tensor_base": [
"def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
],
"_from_functional_tensor": [
defs("_from_functional_tensor", ["t: Tensor"], "Tensor")
],
"_to_functional_tensor": [
defs("_to_functional_tensor", ["t: Tensor"], "Tensor")
],
"_functionalize_replace": [
defs(
"_functionalize_replace", ["self_: Tensor", "other: Tensor"], "None"
)
],
"_functionalize_commit_update": [
defs("_functionalize_commit_update", ["t: Tensor"], "None")
],
"_functionalize_unsafe_set": [
"def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
],
"_functionalize_mark_mutation_hidden_from_autograd": [
defs(
"_functionalize_mark_mutation_hidden_from_autograd",
["t: Tensor"],
"None",
)
],
"_functionalize_mutation_counter": [
defs(
"_functionalize_mutation_counter",
["t: Tensor"],
"_int",
)
],
"_functionalize_storage_changed_counter": [
defs(
"_functionalize_storage_changed_counter",
["t: Tensor"],
"_int",
)
],
"_functionalize_inductor_storage_resized_counter": [
defs(
"_functionalize_inductor_storage_resized_counter",
["t: Tensor"],
"_int",
)
],
"_functionalize_are_all_mutations_hidden_from_autograd": [
defs(
"_functionalize_are_all_mutations_hidden_from_autograd",
["t: Tensor"],
"_bool",
)
],
"_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [
defs(
"_functionalize_are_all_mutations_under_no_grad_or_inference_mode",
["t: Tensor"],
"_bool",
)
],
"_functionalize_was_inductor_storage_resized": [
defs(
"_functionalize_was_inductor_storage_resized",
["t: Tensor"],
"_bool",
)
],
"_functionalize_sync": [defs("_functionalize_sync", ["t: Tensor"], "None")],
"_functionalize_was_storage_changed": [
defs("_functionalize_was_storage_changed", ["tensor: Tensor"], "_bool")
],
"_functionalize_mark_storage_changed": [
"def _functionalize_mark_storage_changed(tensor: Tensor) -> _bool: ..."
],
"_functionalize_has_metadata_mutation": [
defs(
"_functionalize_has_metadata_mutation", ["tensor: Tensor"], "_bool"
)
],
"_functionalize_apply_view_metas": [
defs(
"_functionalize_apply_view_metas",
["tensor: Tensor", "base: Tensor"],
"Tensor",
)
],
"_functionalize_is_symbolic": [
defs("_functionalize_is_symbolic", ["tensor: Tensor"], "_bool")
],
"_enable_functionalization": [
defs(
"_enable_functionalization",
["*", "reapply_views: _bool = False"],
"None",
)
],
"_disable_functionalization": [defs("_disable_functionalization")],
"range": [
defs(
"range",
[
"start: Number",
"end: Number",
"step: Number = 1",
"*",
"out: Tensor | None = None",
*FACTORY_PARAMS,
],
"Tensor",
)
],
"arange": [
defs(
"arange",
[
"start: Number",
"end: Number",
"step: Number",
"*",
"out: Tensor | None = None",
*FACTORY_PARAMS,
],
"Tensor",
),
defs(
"arange",
[
"start: Number",
"end: Number",
"*",
"out: Tensor | None = None",
*FACTORY_PARAMS,
],
"Tensor",
),
defs(
"arange",
["end: Number", "*", "out: Tensor | None = None", *FACTORY_PARAMS],
"Tensor",
),
],
"linspace": [
defs(
"linspace",
[
"start: Number",
"end: Number",
"steps: _int | None = None",
"*",
"out: Tensor | None = None",
*FACTORY_PARAMS,
],
"Tensor",
)
],
"logspace": [
defs(
"logspace",
[
"start: Number",
"end: Number",
"steps: _int | None = None",
"base: _float = 10.0",
"*",
"out: Tensor | None = None",
*FACTORY_PARAMS,
],
"Tensor",
)
],
"randint": [
defs(
"randint",
[
"low: _int",
"high: _int",
"size: _size",
"*",
"generator: Generator | None = None",
*FACTORY_PARAMS,
],
"Tensor",
),
defs(
"randint",
[
"high: _int",
"size: _size",
"*",
"generator: Generator | None = None",
*FACTORY_PARAMS,
],
"Tensor",
),
],
"full": [
defs(
"full",
[
"size: _size",
"fill_value: Number | _complex",
"*",
"out: Tensor | None = None",
"layout: _layout = strided",
*FACTORY_PARAMS,
],
"Tensor",
),
defs(
"full",
[
"size: _size",
"fill_value: Number | _complex",
"*",
"names: list[str | None]",
"layout: _layout = strided",
*FACTORY_PARAMS,
],
"Tensor",
),
],
"is_grad_enabled": [defs("is_grad_enabled", [], "_bool")],
"is_inference_mode_enabled": [
defs("is_inference_mode_enabled", [], "_bool")
],
"nonzero": [
defs(
"nonzero",
[
"input: Tensor",
"*",
"as_tuple: Literal[False] = False",
"out: Tensor | None = None",
],
"Tensor",
),
defs(
"nonzero",
["input: Tensor", "*", "as_tuple: Literal[True]"],
"tuple[Tensor, ...]",
),
],
"dsmm": [defs("dsmm", ["input: Tensor", "mat2: Tensor"], "Tensor")],
"hsmm": [defs("hsmm", ["input: Tensor", "mat2: Tensor"], "Tensor")],
"saddmm": [
defs(
"saddmm",
[
"input: Tensor",
"mat1: Tensor",
"mat2: Tensor",
"*",
"beta: Number = 1",
"alpha: Number = 1",
"out: Tensor | None = None",
],
"Tensor",
)
],
"spmm": [defs("spmm", ["input: Tensor", "mat2: Tensor"], "Tensor")],
"div": [
defs(
"div",
[
"input: Tensor | Number",
"other: Tensor | Number",
"*",
"rounding_mode: str | None = None",
"out: Tensor | None = None",
],
"Tensor",
)
],
}
)
for binop in ["true_divide", "floor_divide"]:
unsorted_function_hints[binop].append(
defs(
binop,
[
"input: Tensor | Number",
"other: Tensor | Number",
"*",
"out: Tensor | None = None",
],
"Tensor",
)
)
for binop in ["mul"]:
unsorted_function_hints[binop].append(
defs(
binop,
[
"input: Tensor | Number | _complex",
"other: Tensor | Number | _complex",
"*",
"out: Tensor | None = None",
],
"Tensor",
)
)
for binop in ["add", "sub"]:
unsorted_function_hints[binop].append(
defs(
binop,
[
"input: Tensor | Number | _complex",
"other: Tensor | Number | _complex",
"*",
"alpha: Number | _complex | None = 1",
"out: Tensor | None = None",
],
"Tensor",
)
)
native_functions = parse_native_yaml(
native_yaml_path, tags_yaml_path
).native_functions
native_functions = list(filter(should_generate_py_binding, native_functions))
function_signatures = load_signatures(
native_functions, deprecated_yaml_path, method=False, pyi=True
)
sig_groups = get_py_torch_functions(function_signatures)
for group in sorted(sig_groups, key=lambda g: g.signature.name):
name = group.signature.name
unsorted_function_hints[name] += generate_type_hints(group)
structseq = returns_structseq_pyi(group.signature)
if structseq is not None and not group.signature.deprecated:
# deprecated structseqs are currently not included for torch functions
tuple_name, tuple_def = structseq
if tuple_name in structseqs:
assert structseqs[tuple_name] == tuple_def
else:
structseqs[tuple_name] = tuple_def
def replace_special_case(hint: str) -> str:
# NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h
hint = hint.replace("at::Reduction::Mean", "1")
hint = hint.replace(": Tensor = None", ": Tensor | None = None")
return hint
docstrs = gather_docstrs()
function_hints = []
for name, hints in sorted(unsorted_function_hints.items()):
hints = [replace_special_case(h) for h in hints]
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
docstr = docstrs.get(f"torch.{name}")
if docstr is not None:
hints = [add_docstr_to_hint(docstr, h) for h in hints]
function_hints += hints
# Generate type signatures for Tensor methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
index_type_def = [_index_type_def]
unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
unsorted_tensor_method_hints.update(
{
"size": [
defs("size", ["self", "dim: None = None"], "Size"),
defs("size", ["self", "dim: _int"], "_int"),
],
"stride": [
defs("stride", ["self", "dim: None = None"], "tuple[_int, ...]"),
defs("stride", ["self", "dim: _int"], "_int"),
],
"new_ones": [
defs("new_ones", ["self", "size: _size", *FACTORY_PARAMS], "Tensor")
],
"new_tensor": [
defs("new_tensor", ["self", "data: Any", *FACTORY_PARAMS], "Tensor")
],
"__new__": [defs("__new__", ["cls", "*args", "**kwargs"], "Self")],
# new and __init__ have the same signatures differ only in return type
# Adapted from legacy_tensor_ctor and legacy_tensor_new
"new": [
defs("new", ["cls", "*args: Any", DEVICE_PARAM], "Self"),
defs("new", ["cls", "storage: Storage"], "Self"),
defs("new", ["cls", "other: Tensor"], "Self"),
defs("new", ["cls", "size: _size", "*", DEVICE_PARAM], "Self"),
],
"__init__": [
defs("__init__", ["self", "*args: Any", DEVICE_PARAM], "None"),
defs("__init__", ["self", "storage: Storage"], "None"),
defs("__init__", ["self", "other: Tensor"], "None"),
defs("__init__", ["self", "size: _size", "*", DEVICE_PARAM], "None"),
],
"as_subclass": [defs("as_subclass", ["self", "cls: type[S]"], "S")],
"_make_subclass": [
"@staticmethod\n"
+ defs(
"_make_subclass",
[
"cls: type[S]",
"data: Tensor",
"require_grad: _bool = False",
"dispatch_strides: _bool = False",
"dispatch_device: _bool = False",
"device_for_backend_keys: _device | None = None",
],
"S",
)
],
"_make_wrapper_subclass": [
"@staticmethod\n"
+ defs(
"_make_wrapper_subclass",
[
"cls: type[S]",
"size: Sequence[_int | SymInt]",
"strides: Sequence[_int | SymInt] | None = None",
"storage_offset: _int | SymInt | None = None",
"memory_format: torch.memory_format | None = None",
"dtype: _dtype | None = None",
"layout: _layout = strided",
"device: _device | None = None",
"pin_memory: _bool = False",
"requires_grad: _bool = False",
"dispatch_sizes_strides_policy: str | None = None",
"dispatch_device: _bool = False",
"dispatch_layout: _bool = False",
"_extra_dispatch_keys: torch.DispatchKeySet | None = None",
"storage_size: _int | SymInt | None = None",
],
"S",
)
],
"_dtensor__new__": [
"@staticmethod\n"
+ defs(
"_dtensor__new__",
[
"cls: type[S]",
"local_tensor: Tensor",
"spec: torch.distributed.tensor._dtensor_spec.DTensorSpec",
"requires_grad: _bool",
],
"S",
)
],
"__contains__": [defs("__contains__", ["self", "item: Any", "/"], "_bool")],
"__getitem__": [defs("__getitem__", ["self", INDICES, "/"], "Tensor")],
"__setitem__": [
defs(
"__setitem__",
["self", INDICES, "value: Tensor | Number", "/"],
"None",
)
],
"tolist": [defs("tolist", ["self"], "list")],
"requires_grad_": [
defs("requires_grad_", ["self", "mode: _bool = True"], "Tensor")
],
"element_size": [defs("element_size", ["self"], "_int")],
"data_ptr": [defs("data_ptr", ["self"], "_int")],
"dim": [defs("dim", ["self"], "_int")],
"nonzero": [
defs(
"nonzero",
["self", "*", "as_tuple: Literal[False] = False"],
"Tensor",
),
defs(
"nonzero",
["self", "*", "as_tuple: Literal[True]"],
"tuple[Tensor, ...]",
),
],
"numel": [defs("numel", ["self"], "_int")],
"ndimension": [defs("ndimension", ["self"], "_int")],
"nelement": [defs("nelement", ["self"], "_int")],
"cuda": [
defs(
"cuda",
[
"self",
"device: _device | _int | str | None = None",
"non_blocking: _bool = False",
"memory_format: torch.memory_format = torch.preserve_format",
],
"Tensor",
)
],
"xpu": [
defs(
"xpu",
[
"self",
"device: _device | _int | str | None = None",
"non_blocking: _bool = False",
"memory_format: torch.memory_format = torch.preserve_format",
],
"Tensor",
)
],
"cpu": [
defs(
"cpu",
[
"self",
"memory_format: torch.memory_format = torch.preserve_format",
],
"Tensor",
)
],
"numpy": [
defs("numpy", ["self", "*", "force: _bool = False"], "numpy.ndarray")
],
"apply_": [defs("apply_", ["self", "callable: Callable"], "Tensor")],
"map_": [
defs("map_", ["self", "other: Tensor", "callable: Callable"], "Tensor")
],
"map2_": [
defs(
"map2_",
["self", "x: Tensor", "y: Tensor", "callable: Callable"],
"Tensor",
)
],
"storage": [defs("untyped_storage", ["self"], "UntypedStorage")],
"storage_type": [defs("storage_type", ["self"], "Storage")],
"type": [
defs(
"type",
["self", "dtype: None = None", "non_blocking: _bool = False"],
"str",
),
defs(
"type",
["self", "dtype: str | _dtype", "non_blocking: _bool = False"],
"Tensor",
),
],
"get_device": [defs("get_device", ["self"], "_int")],
"contiguous": [
defs(
"contiguous",
[
"self",
"memory_format: torch.memory_format = torch.contiguous_format",
],
"Tensor",
)
],
"has_names": [defs("has_names", ["self"], "_bool")],
"is_contiguous": [
defs(
"is_contiguous",
[
"self",
"memory_format: torch.memory_format = torch.contiguous_format",
],
"_bool",
)
],
"_is_view": [defs("_is_view", ["self"], "_bool")],
"is_cpu": ["is_cpu: _bool"],
"is_cuda": ["is_cuda: _bool"],
"is_xpu": ["is_xpu: _bool"],
"is_leaf": ["is_leaf: _bool"],
"is_nested": ["is_nested: _bool"],
"is_sparse": ["is_sparse: _bool"],
"is_sparse_csr": ["is_sparse_csr: _bool"],
"is_quantized": ["is_quantized: _bool"],
"is_meta": ["is_meta: _bool"],
"is_mps": ["is_mps: _bool"],
"is_mtia": ["is_mtia: _bool"],
"is_maia": ["is_maia: _bool"],
"is_mkldnn": ["is_mkldnn: _bool"],
"is_vulkan": ["is_vulkan: _bool"],
"is_ipu": ["is_ipu: _bool"],
"storage_offset": [defs("storage_offset", ["self"], "_int | SymInt")],
"to": [
(
defs(
"to",
[
"self",
*to_args,
"non_blocking: _bool = False",
"copy: _bool = False",
"*",
"memory_format: torch.memory_format | None = None",
],
"Tensor",
)
)
for to_args in [
["dtype: _dtype"],
[
"device: DeviceLikeType | None = None",
"dtype: _dtype | None = None",
],
["other: Tensor"],
]
],
"item": [defs("item", ["self"], "Number")],
"copy_": [
defs(
"copy_",
["self", "other: Tensor", "non_blocking: _bool = False"],
"Tensor",
)
],
"set_": [
defs(
"set_",
[
"self",
"source: Storage | TypedStorage | UntypedStorage",
"storage_offset: IntLikeType",
"size: _symsize",
"stride: _symsize",
],
"Tensor",
),
defs(
"set_",
["self", "source: Storage | TypedStorage | UntypedStorage"],
"Tensor",
),
],
"split": [
defs(
"split",
["self", "split_size: _int", "dim: _int = 0"],
"Sequence[Tensor]",
),
defs(
"split",
["self", "split_size: tuple[_int, ...]", "dim: _int = 0"],
"Sequence[Tensor]",
),
],
"div": [
defs(
"div",
[
"self",
"other: Tensor | Number",
"*",
"rounding_mode: str | None = None",
],
"Tensor",
)
],
"div_": [
defs(
"div_",
[
"self",
"other: Tensor | Number",
"*",
"rounding_mode: str | None = None",
],
"Tensor",
)
],
}
)
for binop in ["true_divide", "floor_divide"]:
for inplace in [False, True]:
out_args = ["*", "out: Tensor | None = None"]
if inplace:
binop += "_"
out_args = []
unsorted_tensor_method_hints[binop].append(
defs(
binop,
[
"self",
"other: Tensor | Number | torch.SymInt | torch.SymFloat",
*out_args,
],
"Tensor",
)
)
for binop in ["mul"]:
for inplace in [False, True]:
out_args = ["*", "out: Tensor | None = None"]
if inplace:
binop += "_"
out_args = []
unsorted_tensor_method_hints[binop].append(
defs(
binop,
[
"self",
"other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat",
*out_args,
],
"Tensor",
)
)
for binop in ["add", "sub"]:
for inplace in [False, True]:
out_args = ["out: Tensor | None = None"]
if inplace:
binop += "_"
out_args = []
unsorted_tensor_method_hints[binop].append(
defs(
binop,
[
"self",
"other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat",
"*",
"alpha: Number | _complex | None = 1",
*out_args,
],
"Tensor",
)
)
simple_conversions = [
"bfloat16",
"bool",
"byte",
"char",
"double",
"float",
"half",
"int",
"long",
"short",
]
for name in simple_conversions:
unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
# pyi tensor methods don't currently include deprecated signatures for some reason
# TODO: we should probably add them in
tensor_method_signatures = load_signatures(
native_functions,
deprecated_yaml_path,
method=True,
skip_deprecated=True,
pyi=True,
)
tensor_method_sig_groups = get_py_torch_functions(
tensor_method_signatures, method=True
)
for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
name = group.signature.name
unsorted_tensor_method_hints[name] += generate_type_hints(group)
structseq = returns_structseq_pyi(group.signature)
if structseq is not None and not group.signature.deprecated:
# deprecated structseqs are currently not included for torch functions
tuple_name, tuple_def = structseq
if tuple_name in structseqs:
assert structseqs[tuple_name] == tuple_def
else:
structseqs[tuple_name] = tuple_def
for op in all_ops:
name = f"__{op}__"
unsorted_tensor_method_hints[name] += sig_for_ops(name)
tensor_method_hints = []
for name, hints in sorted(unsorted_tensor_method_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
docstr = docstrs.get(f"torch._C.TensorBase.{name}")
if docstr is not None:
hints = [add_docstr_to_hint(docstr, h) for h in hints]
tensor_method_hints += hints
# TODO: Missing type hints for nn
# Generate structseq definitions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
structseqs = dict(sorted(structseqs.items()))
structseq_defs = [f"{defn}\n" for defn in structseqs.values()]
return_types___all__ = [
"__all__ = [",
' "pytree_register_structseq",',
' "all_return_types",',
*(f' "{name}",' for name in structseqs),
"]",
]
# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
legacy_storage_base_hints = ["class StorageBase: ..."]
legacy_class_hints = []
for c in (
"DoubleTensor",
"FloatTensor",
"BFloat16Tensor",
"LongTensor",
"IntTensor",
"ShortTensor",
"HalfTensor",
"CharTensor",
"ByteTensor",
"BoolTensor",
):
legacy_class_hints.append(f"class {c}(Tensor): ...")
# Generate type signatures for dtype classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO(#146647): don't explicitly list dtypes here; get it from canonical
# source
dtype_class_hints = [
f"{n}: dtype = ..."
for n in [
"float32",
"float",
"float64",
"double",
"float16",
"bfloat16",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"float4_e2m1fn_x2",
"half",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"short",
"int32",
"int",
"int64",
"long",
"complex32",
"complex64",
"chalf",
"cfloat",
"complex128",
"cdouble",
"quint8",
"qint8",
"qint32",
"bool",
"quint4x2",
"quint2x4",
"bits1x8",
"bits2x4",
"bits4x2",
"bits8",
"bits16",
]
]
# Generate __all__ directive
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Include only the functions that contain hints, to prevent undefined
# symbols to be included in the `__all__` directive.
hinted_function_names = {
name for name, hint in unsorted_function_hints.items() if hint
}
all_symbols = sorted(hinted_function_names.union(structseqs))
all_directive = [
"__all__ = [",
*(f' "{name}",' for name in all_symbols),
"]",
]
# Dispatch key hints
# ~~~~~~~~~~~~~~~~~~
dispatch_key_hints = [f"{d.name} = ..." for d in DispatchKey]
torch_dispatch_mode_key_hints = [f"{k.name} = ..." for k in _TorchDispatchModeKey]
# Tags Enum type hints
# ~~~~~~~~~~~~~~~~~~~~
tag_names = sorted(parse_tags_yaml(tags_yaml_path))
tag_attributes = "\n".join(
f"{name} = {index}" for index, name in enumerate(tag_names)
)
# Write out the stub
# ~~~~~~~~~~~~~~~~~~
env = {
"structseq_defs": structseq_defs,
"return_types___all__": return_types___all__,
"function_hints": function_hints,
"index_type_def": index_type_def,
"tensor_method_hints": tensor_method_hints,
"legacy_class_hints": legacy_class_hints,
"legacy_storage_base_hints": legacy_storage_base_hints,
"dtype_class_hints": dtype_class_hints,
"dispatch_key_hints": dispatch_key_hints,
"torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints,
"all_directive": all_directive,
"tag_attributes": tag_attributes,
}
fm.write_with_template(
"torch/_C/__init__.pyi",
"torch/_C/__init__.pyi.in",
lambda: env,
)
fm.write_with_template(
"torch/_C/_VariableFunctions.pyi",
"torch/_C/_VariableFunctions.pyi.in",
lambda: env,
)
fm.write_with_template(
"torch/_VF.pyi",
"torch/_C/_VariableFunctions.pyi.in",
lambda: env,
)
fm.write_with_template(
"torch/return_types.pyi",
"torch/_C/return_types.pyi.in",
lambda: env,
)
gen_nn_functional(fm)
def main() -> None:
parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
parser.add_argument(
"--native-functions-path",
metavar="NATIVE",
default="aten/src/ATen/native/native_functions.yaml",
help="path to native_functions.yaml",
)
parser.add_argument(
"--tags-path",
metavar="TAGS",
default="aten/src/ATen/native/tags.yaml",
help="path to tags.yaml",
)
parser.add_argument(
"--deprecated-functions-path",
metavar="DEPRECATED",
default="tools/autograd/deprecated.yaml",
help="path to deprecated.yaml",
)
parser.add_argument(
"--out",
metavar="OUT",
default=".",
help="path to output directory",
)
parser.add_argument(
"--template-dir",
default=".",
help="path to template directory",
)
args = parser.parse_args()
fm = FileManager(
install_dir=args.out, template_dir=args.template_dir, dry_run=False
)
gen_pyi(
args.native_functions_path,
args.tags_path,
args.deprecated_functions_path,
fm,
)
if __name__ == "__main__":
main()