mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is a reposting of PR #128519. This change is important to how we maintain PyTorch at Google. From the previous PR: " This will make the script more flexible for the directory where it is executed. ... We plan to use the deprecated_yaml from a blaze genrule that invokes pyi.py. As the input to the pyi.py, genrule requires the input file to be explicitly listed out. When we feed the value of tools/autograd/deprecated.yaml to genrule, it failed to resolve since tools/autograd is a package from blaze perspective. Any file under a blaze package will a proper blaze target to be access. " Pull Request resolved: https://github.com/pytorch/pytorch/pull/161772 Approved by: https://github.com/albanD Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
1907 lines
63 KiB
Python
1907 lines
63 KiB
Python
"""
|
|
This module implements generation of type stubs for PyTorch,
|
|
enabling use of autocomplete in IDEs like PyCharm, which otherwise
|
|
don't understand C extension modules.
|
|
|
|
At the moment, this module only handles type stubs for torch and
|
|
torch.Tensor. It should eventually be expanded to cover all functions
|
|
which come are autogenerated.
|
|
|
|
Here's our general strategy:
|
|
|
|
- We start off with a hand-written __init__.pyi.in file. This
|
|
file contains type definitions for everything we cannot automatically
|
|
generate, including pure Python definitions directly in __init__.py
|
|
(the latter case should be pretty rare).
|
|
|
|
- We go through automatically bound functions based on the
|
|
type information recorded in native_functions.yaml and
|
|
generate type hints for them (generate_type_hints)
|
|
|
|
There are a number of type hints which we've special-cased;
|
|
read gen_pyi for the gory details.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import collections
|
|
import importlib
|
|
import inspect
|
|
import sys
|
|
import textwrap
|
|
from typing import TYPE_CHECKING
|
|
from unittest.mock import Mock, patch
|
|
from warnings import warn
|
|
|
|
from tools.autograd.gen_python_functions import (
|
|
group_overloads,
|
|
load_signatures,
|
|
should_generate_py_binding,
|
|
)
|
|
|
|
from torchgen.api.python import (
|
|
format_function_signature as defs,
|
|
PythonSignatureGroup,
|
|
PythonSignatureNativeFunctionPair,
|
|
returns_structseq_pyi,
|
|
)
|
|
from torchgen.gen import parse_native_yaml, parse_tags_yaml
|
|
from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
|
|
from torchgen.utils import FileManager
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
|
|
def get_py_torch_functions(
|
|
python_funcs: Sequence[PythonSignatureNativeFunctionPair],
|
|
method: bool = False,
|
|
) -> Sequence[PythonSignatureGroup]:
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as either functions in the "torch" module or methods on Tensor.
|
|
"""
|
|
|
|
def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
|
|
return (
|
|
should_generate_py_binding(python_func.function)
|
|
and not python_func.function.python_module
|
|
and Variant.function in python_func.function.variants
|
|
)
|
|
|
|
def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
|
|
return (
|
|
should_generate_py_binding(python_func.function)
|
|
and not python_func.function.python_module
|
|
and Variant.method in python_func.function.variants
|
|
)
|
|
|
|
should_bind = should_bind_method if method else should_bind_function
|
|
return group_overloads([f for f in python_funcs if should_bind(f)])
|
|
|
|
|
|
# TODO: Consider defining some aliases for our Union[...] types, to make
|
|
# the stubs to read on the human eye.
|
|
|
|
DEVICE_PARAM = "device: DeviceLikeType | None = None"
|
|
FACTORY_PARAMS = [
|
|
"dtype: _dtype | None = None",
|
|
DEVICE_PARAM,
|
|
"requires_grad: _bool = False",
|
|
"pin_memory: _bool = False",
|
|
]
|
|
|
|
# NOTE: specifying indices for Tensor.__getitem__
|
|
# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi:
|
|
#
|
|
# key: (
|
|
# slice
|
|
# | EllipsisType
|
|
# | None
|
|
# | _ArrayLikeInt_co
|
|
# | SupportsIndex
|
|
# | tuple[slice | EllipsisType | None | _ArrayLikeInt_co | SupportsIndex, ...]
|
|
# )
|
|
#
|
|
# where:
|
|
#
|
|
# _ArrayLikeInt_co = _DualArrayLike[
|
|
# dtype[bool_ | integer[Any]],
|
|
# bool | int,
|
|
# ]
|
|
#
|
|
# and
|
|
#
|
|
# _DualArrayLike = (
|
|
# _SupportsArray[_DType]
|
|
# | _NestedSequence[_SupportsArray[_DType]]
|
|
# | _T
|
|
# | _NestedSequence[_T]
|
|
# )
|
|
#
|
|
# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple.
|
|
# We can substitute and simplify:
|
|
# _SupportsArray -> Tensor
|
|
# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]]
|
|
# which leaves us with key: T | tuple[T, ...], where T is:
|
|
# T = (
|
|
# SupportsIndex | bool | int | slice | EllipsisType | None
|
|
# | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int]
|
|
# )
|
|
_leaf_types = (
|
|
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
|
|
)
|
|
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
|
|
_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip"
|
|
INDICES = "indices: _Index | tuple[_Index, ...]"
|
|
|
|
blocklist = [
|
|
"__init_subclass__",
|
|
"__new__",
|
|
"__subclasshook__",
|
|
"cdist",
|
|
"device",
|
|
"grad",
|
|
"requires_grad",
|
|
"range",
|
|
# defined in functional
|
|
"einsum",
|
|
# Somehow, these are defined in both _C and in functional. Ick!
|
|
"broadcast_tensors",
|
|
# Manually define named tensor type stubs in __init__.pyi.in
|
|
"align_tensors",
|
|
"meshgrid",
|
|
"cartesian_prod",
|
|
"block_diag",
|
|
"norm",
|
|
"chain_matmul",
|
|
"stft",
|
|
"tensordot",
|
|
"split",
|
|
"unique_consecutive",
|
|
"atleast_1d",
|
|
"atleast_2d",
|
|
"atleast_3d",
|
|
# These are handled specially by python_arg_parser.cpp
|
|
"add",
|
|
"add_",
|
|
"add_out",
|
|
"sub",
|
|
"sub_",
|
|
"sub_out",
|
|
"mul",
|
|
"mul_",
|
|
"mul_out",
|
|
"div",
|
|
"div_",
|
|
"div_out",
|
|
"true_divide",
|
|
"true_divide_",
|
|
"true_divide_out",
|
|
"floor_divide",
|
|
"floor_divide_",
|
|
"floor_divide_out",
|
|
"to",
|
|
"_to_copy",
|
|
"copy_",
|
|
]
|
|
|
|
shift_ops = (
|
|
"lshift",
|
|
"rshift",
|
|
"ilshift",
|
|
"irshift", # inplace ops
|
|
)
|
|
arithmetic_ops = (
|
|
"add",
|
|
"sub",
|
|
"mul",
|
|
"div",
|
|
"pow",
|
|
"mod",
|
|
"truediv",
|
|
"matmul",
|
|
"floordiv",
|
|
"radd",
|
|
"rsub",
|
|
"rmul",
|
|
"rtruediv",
|
|
"rfloordiv",
|
|
"rpow", # reverse arithmetic
|
|
"iadd",
|
|
"idiv",
|
|
"imul",
|
|
"isub",
|
|
"ifloordiv",
|
|
"imod", # inplace ops
|
|
)
|
|
logic_ops = (
|
|
"and",
|
|
"or",
|
|
"xor",
|
|
"rand",
|
|
"ror",
|
|
"rxor", # reverse logic
|
|
"iand",
|
|
"ior",
|
|
"ixor", # inplace ops
|
|
)
|
|
binary_ops = shift_ops + arithmetic_ops + logic_ops
|
|
|
|
symmetric_comparison_ops = ("eq", "ne")
|
|
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
|
|
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
|
|
|
|
unary_ops = ("neg", "abs", "invert")
|
|
to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
|
|
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
|
|
|
|
|
def sig_for_ops(opname: str) -> list[str]:
|
|
"""sig_for_ops(opname : str) -> list[str]
|
|
|
|
Returns signatures for operator special functions (__add__ etc.)"""
|
|
|
|
# we have to do this by hand, because they are hand-bound in Python
|
|
|
|
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
|
|
|
|
name = opname[2:-2]
|
|
if name == "rpow":
|
|
return [ # somehow required to make mypy ci happy?
|
|
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
|
|
]
|
|
elif name in arithmetic_ops:
|
|
if name.startswith("i"):
|
|
# In-place binary-operation dunder methods, like `__iadd__`, should return `Self`
|
|
return [
|
|
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034"
|
|
]
|
|
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
|
|
elif name in logic_ops:
|
|
return [f"def {opname}(self, other: Tensor | _int) -> Tensor: ..."]
|
|
elif name in shift_ops:
|
|
return [f"def {opname}(self, other: Tensor | _int) -> Tensor: ..."]
|
|
elif name in symmetric_comparison_ops:
|
|
return [
|
|
# unsafe override https://github.com/python/mypy/issues/5704
|
|
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[overload-overlap]",
|
|
f"def {opname}(self, other: object) -> _bool: ...",
|
|
]
|
|
elif name in asymmetric_comparison_ops:
|
|
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
|
|
elif name in unary_ops:
|
|
return [f"def {opname}(self) -> Tensor: ..."]
|
|
if name in to_py_type_ops:
|
|
if name in {"bool", "float", "complex"}:
|
|
tname = name
|
|
elif name == "nonzero":
|
|
tname = "bool"
|
|
else:
|
|
tname = "int"
|
|
if tname in {"float", "int", "bool", "complex"}:
|
|
tname = "_" + tname
|
|
return [f"def {opname}(self) -> {tname}: ..."]
|
|
raise ValueError(f"unknown op {opname!r}")
|
|
|
|
|
|
def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
|
|
type_hints: list[str] = []
|
|
|
|
# Some deprecated ops that are on the blocklist are still included in pyi
|
|
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
|
return type_hints
|
|
|
|
# deprecated signatures have separate entries for their functional and out variants
|
|
# (as opposed to the native ops, which fuse the two into a single signature).
|
|
# generate the functional variant here, if an out variant exists.
|
|
if sig_group.signature.deprecated and sig_group.outplace is not None:
|
|
type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
|
|
type_hints.append(type_hint)
|
|
|
|
# PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
|
|
# Generates the out variant if one exists. Otherwise, generate the functional variant
|
|
type_hint = sig_group.signature.signature_str_pyi(
|
|
skip_outputs=sig_group.outplace is None
|
|
)
|
|
type_hints.append(type_hint)
|
|
|
|
# Some operators also additionally have a vararg variant of their signature
|
|
type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
|
|
skip_outputs=sig_group.outplace is None
|
|
)
|
|
if type_hint_vararg:
|
|
type_hints.append(type_hint_vararg)
|
|
|
|
return type_hints
|
|
|
|
|
|
def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
|
|
flag_pos = arg_list.index("{return_indices}")
|
|
# If return_indices is positional arg, everything before should have no default
|
|
arg_list_positional = (
|
|
[
|
|
", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", "))
|
|
for arg in arg_list[: flag_pos + 1]
|
|
]
|
|
+ ["/"]
|
|
+ arg_list[flag_pos + 1 :]
|
|
)
|
|
# Otherwise force return_indices to be kwarg
|
|
arg_list_keyword = arg_list.copy()
|
|
arg_list_keyword.insert(flag_pos, "*")
|
|
return {
|
|
name: [
|
|
defs(
|
|
name,
|
|
[
|
|
arg.format(return_indices="return_indices: Literal[False] = False")
|
|
for arg in arg_list
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
name,
|
|
[
|
|
arg.format(return_indices="return_indices: Literal[True]")
|
|
for arg in arg_list_positional
|
|
],
|
|
"tuple[Tensor, Tensor]",
|
|
),
|
|
defs(
|
|
name,
|
|
[
|
|
arg.format(return_indices="return_indices: Literal[True]")
|
|
for arg in arg_list_keyword
|
|
],
|
|
"tuple[Tensor, Tensor]",
|
|
),
|
|
]
|
|
}
|
|
|
|
|
|
def gen_nn_functional(fm: FileManager) -> None:
|
|
INPUT = "input: Tensor"
|
|
KERNEL_SIZE = "kernel_size: _int | _size"
|
|
STRIDE_PADDING = [
|
|
"stride: _int | _size | None = None",
|
|
"padding: _int | _size = 0",
|
|
]
|
|
|
|
# TODO the list for `torch._C._nn` is nonexhaustive
|
|
unsorted_c_nn_function_hints: dict[str, list[str]] = {}
|
|
|
|
for d in (2, 3):
|
|
unsorted_c_nn_function_hints.update(
|
|
{
|
|
f"avg_pool{d}d": [
|
|
defs(
|
|
f"avg_pool{d}d",
|
|
[
|
|
INPUT,
|
|
KERNEL_SIZE,
|
|
*STRIDE_PADDING,
|
|
"ceil_mode: bool = False",
|
|
"count_include_pad: bool = True",
|
|
"divisor_override: int | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
f"fractional_max_pool{d}d": [
|
|
defs(
|
|
f"fractional_max_pool{d}d",
|
|
[
|
|
INPUT,
|
|
KERNEL_SIZE,
|
|
"output_size: _int | _size",
|
|
"_random_samples: Tensor",
|
|
],
|
|
"tuple[Tensor, Tensor]",
|
|
)
|
|
],
|
|
f"adaptive_max_pool{d}d": [
|
|
defs(
|
|
f"adaptive_max_pool{d}d",
|
|
[
|
|
INPUT,
|
|
"output_size: _int | _size",
|
|
],
|
|
"tuple[Tensor, Tensor]",
|
|
)
|
|
],
|
|
f"adaptive_avg_pool{d}d": [
|
|
defs(
|
|
f"adaptive_avg_pool{d}d",
|
|
[
|
|
INPUT,
|
|
"output_size: _int | _size",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
f"max_pool{d}d_with_indices": [
|
|
defs(
|
|
f"max_pool{d}d_with_indices",
|
|
[
|
|
INPUT,
|
|
KERNEL_SIZE,
|
|
*STRIDE_PADDING,
|
|
"dilation: _int | _size = 1",
|
|
"ceil_mode: bool = False",
|
|
],
|
|
"tuple[Tensor, Tensor]",
|
|
)
|
|
],
|
|
}
|
|
)
|
|
|
|
unsorted_c_nn_function_hints.update(
|
|
{
|
|
"hardtanh": [
|
|
defs(
|
|
"hardtanh",
|
|
[
|
|
"input: Tensor",
|
|
"min_val: float = ...",
|
|
"max_val: float = ...",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"hardtanh_": [
|
|
defs(
|
|
"hardtanh_",
|
|
["input: Tensor", "min_val: float = ...", "max_val: float = ..."],
|
|
"Tensor",
|
|
),
|
|
],
|
|
"elu_": [defs("elu_", ["input: Tensor", "alpha: float = ..."], "Tensor")],
|
|
"leaky_relu": [
|
|
defs(
|
|
"leaky_relu",
|
|
[
|
|
"input: Tensor",
|
|
"negative_slope: float = ...",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"leaky_relu_": [
|
|
defs(
|
|
"leaky_relu_",
|
|
["input: Tensor", "negative_slope: float = ..."],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"log_sigmoid": [defs("log_sigmoid", ["input: Tensor"], "Tensor")],
|
|
"gelu": [
|
|
defs("gelu", ["input: Tensor", "approximate: str = ..."], "Tensor")
|
|
],
|
|
"softplus": [
|
|
defs(
|
|
"softplus",
|
|
["input: Tensor", "beta: float = ...", "threshold: float = ..."],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"softshrink": [
|
|
defs("softshrink", ["input: Tensor", "lambd: float = ..."], "Tensor")
|
|
],
|
|
"hardsigmoid": [
|
|
defs(
|
|
"hardsigmoid",
|
|
["input: Tensor", "*", "out: Tensor | None = None"],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"linear": [
|
|
defs(
|
|
"linear",
|
|
["input: Tensor", "weight: Tensor", "bias: Tensor | None = None"],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"pad": [
|
|
defs(
|
|
"pad",
|
|
[
|
|
"input: Tensor",
|
|
"pad: Sequence[int]",
|
|
"mode: str = ...",
|
|
"value: float | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"one_hot": [
|
|
defs("one_hot", ["tensor: Tensor", "num_classes: int = ..."], "Tensor")
|
|
],
|
|
"scaled_dot_product_attention": [
|
|
defs(
|
|
"scaled_dot_product_attention",
|
|
[
|
|
"query: Tensor",
|
|
"key: Tensor",
|
|
"value: Tensor",
|
|
"attn_mask: Tensor | None = None",
|
|
"dropout_p: float = 0.0",
|
|
"is_causal: bool = False",
|
|
"scale: float | None = None",
|
|
"enable_gqa: bool = False",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"binary_cross_entropy": [
|
|
defs(
|
|
"binary_cross_entropy",
|
|
[
|
|
INPUT,
|
|
"target: Tensor",
|
|
"weight: Tensor | None = None",
|
|
"reduction: str = ...",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"col2im": [
|
|
defs(
|
|
"col2im",
|
|
[
|
|
INPUT,
|
|
"output_size: _int | _size",
|
|
KERNEL_SIZE,
|
|
"dilation: _int | _size",
|
|
*STRIDE_PADDING,
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"elu": [
|
|
defs(
|
|
"elu",
|
|
[
|
|
INPUT,
|
|
"alpha: float = 1.0",
|
|
"scale: float = 1.0",
|
|
"input_scale: float = 1.0",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"glu": [
|
|
defs(
|
|
"glu",
|
|
[
|
|
INPUT,
|
|
"dim: int = -1",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"max_unpool2d": [
|
|
defs(
|
|
"max_unpool2d",
|
|
[
|
|
INPUT,
|
|
"indices: Tensor",
|
|
"output_size: Sequence[int] | None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"max_unpool3d": [
|
|
defs(
|
|
"max_unpool3d",
|
|
[
|
|
INPUT,
|
|
"indices: Tensor",
|
|
"output_size: Sequence[int] | None",
|
|
"stride: _int | _size",
|
|
"padding: _int | _size",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
}
|
|
)
|
|
|
|
c_nn_function_hints: list[str] = []
|
|
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
c_nn_function_hints += hints
|
|
|
|
extra_nn_functional___all__: list[str] = []
|
|
|
|
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
|
|
# through an `_add_docstr` call
|
|
torch_imports = [
|
|
"adaptive_avg_pool1d",
|
|
"avg_pool1d",
|
|
"bilinear",
|
|
"celu_",
|
|
"channel_shuffle",
|
|
"conv_tbc",
|
|
"conv_transpose1d",
|
|
"conv_transpose2d",
|
|
"conv_transpose3d",
|
|
"conv1d",
|
|
"conv2d",
|
|
"conv3d",
|
|
"cosine_similarity",
|
|
"hardshrink",
|
|
"native_channel_shuffle",
|
|
"pairwise_distance",
|
|
"pdist",
|
|
"pixel_shuffle",
|
|
"pixel_unshuffle",
|
|
"prelu",
|
|
"relu_",
|
|
"rrelu_",
|
|
"selu_",
|
|
]
|
|
imported_hints = [
|
|
"from torch import (",
|
|
*sorted(f" {name} as {name}," for name in torch_imports),
|
|
")",
|
|
]
|
|
extra_nn_functional___all__.extend(torch_imports)
|
|
|
|
# Functions imported into `torch.nn.functional` from `torch._C._nn`
|
|
c_nn_imports = [
|
|
"avg_pool2d",
|
|
"avg_pool3d",
|
|
"elu_",
|
|
"gelu",
|
|
"hardtanh_",
|
|
"leaky_relu_",
|
|
"linear",
|
|
"log_sigmoid",
|
|
"one_hot",
|
|
"pad",
|
|
"scaled_dot_product_attention",
|
|
"softplus",
|
|
"softshrink",
|
|
]
|
|
renamed = {"log_sigmoid": "logsigmoid"}
|
|
imported_hints += [
|
|
"from torch._C._nn import (",
|
|
*sorted(f" {name} as {renamed.get(name, name)}," for name in c_nn_imports),
|
|
")",
|
|
]
|
|
extra_nn_functional___all__.extend(renamed.get(name, name) for name in c_nn_imports)
|
|
|
|
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
|
|
unsorted_dispatched_hints: dict[str, list[str]] = {}
|
|
|
|
for d in (1, 2, 3):
|
|
unsorted_dispatched_hints.update(
|
|
**get_max_pool_dispatch(
|
|
f"max_pool{d}d",
|
|
[
|
|
INPUT,
|
|
KERNEL_SIZE,
|
|
*STRIDE_PADDING,
|
|
"dilation: _int | _size = 1",
|
|
"ceil_mode: bool = False",
|
|
"{return_indices}",
|
|
],
|
|
),
|
|
**get_max_pool_dispatch(
|
|
f"fractional_max_pool{d}d",
|
|
[
|
|
INPUT,
|
|
KERNEL_SIZE,
|
|
"output_size: _int | _size | None = None",
|
|
"output_ratio: _ratio_any_t | None = None",
|
|
"{return_indices}",
|
|
"_random_samples: Tensor | None = None",
|
|
],
|
|
),
|
|
**get_max_pool_dispatch(
|
|
f"adaptive_max_pool{d}d",
|
|
[
|
|
INPUT,
|
|
"output_size: _int | _size",
|
|
"{return_indices}",
|
|
],
|
|
),
|
|
)
|
|
|
|
# There's no fractional_max_pool1d
|
|
del unsorted_dispatched_hints["fractional_max_pool1d"]
|
|
extra_nn_functional___all__.extend(unsorted_dispatched_hints)
|
|
|
|
dispatched_hints: list[str] = []
|
|
for _, hints in sorted(unsorted_dispatched_hints.items()):
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
dispatched_hints += hints
|
|
|
|
extra_nn_functional___all__ = [
|
|
"__all__ += [",
|
|
*(f' "{name}",' for name in extra_nn_functional___all__),
|
|
"]",
|
|
]
|
|
|
|
fm.write_with_template(
|
|
"torch/nn/functional.pyi",
|
|
"torch/nn/functional.pyi.in",
|
|
lambda: {
|
|
"imported_hints": imported_hints,
|
|
"dispatched_hints": dispatched_hints,
|
|
"extra_nn_functional___all__": extra_nn_functional___all__,
|
|
},
|
|
)
|
|
fm.write_with_template(
|
|
"torch/_C/_nn.pyi",
|
|
"torch/_C/_nn.pyi.in",
|
|
lambda: {
|
|
"c_nn_function_hints": c_nn_function_hints,
|
|
},
|
|
)
|
|
|
|
|
|
"""
|
|
We gather the docstrings for torch with the following steps:
|
|
1. Mock torch and torch._C, which are the only dependencies of the docs files
|
|
2. Mock the _add_docstr function to save the docstrings
|
|
3. Import the docs files to trigger mocked _add_docstr and collect docstrings
|
|
"""
|
|
|
|
|
|
def gather_docstrs() -> dict[str, str]:
|
|
docstrs = {}
|
|
|
|
def mock_add_docstr(func: Mock, docstr: str) -> None:
|
|
docstrs[func._extract_mock_name()] = docstr.strip()
|
|
|
|
# sys.modules and sys.path are restored after the context manager exits
|
|
with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
|
|
# mock the torch module and torch._C._add_docstr
|
|
sys.modules["torch"] = Mock(name="torch")
|
|
sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
|
|
|
|
try:
|
|
# manually import torch._torch_docs and torch._tensor_docs to trigger
|
|
# the mocked _add_docstr and collect docstrings
|
|
sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
|
|
sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
|
|
except ModuleNotFoundError:
|
|
# Gracefully fail if these modules are not importable
|
|
warn(
|
|
"Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files."
|
|
)
|
|
|
|
return docstrs
|
|
|
|
|
|
def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
|
docstr = inspect.cleandoc(docstr).strip()
|
|
if "..." in hint: # function or method
|
|
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
|
|
hint = hint.removesuffix("...").rstrip() # remove "..."
|
|
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
|
|
# Remove trailing whitespace on each line
|
|
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
|
|
|
# attribute or property
|
|
return f'{hint}\nr"""{docstr}"""'
|
|
|
|
|
|
def gen_pyi(
|
|
native_yaml_path: str,
|
|
tags_yaml_path: str,
|
|
deprecated_yaml_path: str,
|
|
fm: FileManager,
|
|
) -> None:
|
|
"""gen_pyi()
|
|
|
|
This function generates a pyi file for torch.
|
|
"""
|
|
|
|
# Some of this logic overlaps with generate_python_signature in
|
|
# tools/autograd/gen_python_functions.py; however, this
|
|
# function is all about generating mypy type signatures, whereas
|
|
# the other function generates are custom format for argument
|
|
# checking. If you are update this, consider if your change
|
|
# also needs to update the other file.
|
|
|
|
# Dictionary for NamedTuple definitions
|
|
structseqs: dict[str, str] = {}
|
|
|
|
# Generate type signatures for top-level functions
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
|
|
|
|
for n, n1, n2 in [
|
|
("csr", "crow", "col"),
|
|
("csc", "ccol", "row"),
|
|
("bsr", "crow", "col"),
|
|
("bsc", "ccol", "row"),
|
|
]:
|
|
unsorted_function_hints.update(
|
|
{
|
|
f"sparse_{n}_tensor": [
|
|
defs(
|
|
f"sparse_{n}_tensor",
|
|
[
|
|
f"{n1}_indices: Tensor | list",
|
|
f"{n2}_indices: Tensor | list",
|
|
"values: Tensor | list",
|
|
"size: _size | None = None",
|
|
"*",
|
|
"dtype: _dtype | None = None",
|
|
"device: DeviceLikeType | None = None",
|
|
"requires_grad: _bool = False",
|
|
"check_invariants: _bool | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
}
|
|
)
|
|
|
|
unsorted_function_hints.update(
|
|
{
|
|
"set_flush_denormal": [
|
|
defs("set_flush_denormal", ["mode: _bool"], "_bool")
|
|
],
|
|
"get_default_dtype": [defs("get_default_dtype", [], "_dtype")],
|
|
"asarray": [
|
|
defs(
|
|
"asarray",
|
|
[
|
|
"obj: Any",
|
|
"*",
|
|
"dtype: _dtype | None = None",
|
|
"device: DeviceLikeType | None = None",
|
|
"copy: _bool | None = None",
|
|
"requires_grad: _bool = False",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"from_numpy": [defs("from_numpy", ["ndarray"], "Tensor")],
|
|
"frombuffer": [
|
|
defs(
|
|
"frombuffer",
|
|
[
|
|
"buffer: Any",
|
|
"*",
|
|
"dtype: _dtype",
|
|
"count: int = -1",
|
|
"offset: int = 0",
|
|
"requires_grad: _bool = False",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"numel": [defs("numel", ["self: Tensor"], "_int")],
|
|
"as_tensor": [
|
|
defs(
|
|
"as_tensor",
|
|
["data: Any", "dtype: _dtype | None = None", DEVICE_PARAM],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"get_num_threads": [defs("get_num_threads", [], "_int")],
|
|
"set_num_threads": [defs("set_num_threads", ["num: _int"], "None")],
|
|
"init_num_threads": [defs("init_num_threads", [], "None")],
|
|
"get_num_interop_threads": [defs("get_num_interop_threads", [], "_int")],
|
|
"set_num_interop_threads": [
|
|
defs("set_num_interop_threads", ["num: _int"], "None")
|
|
],
|
|
# These functions are explicitly disabled by
|
|
# SKIP_PYTHON_BINDINGS because they are hand bound.
|
|
# Correspondingly, we must hand-write their signatures.
|
|
"tensor": [defs("tensor", ["data: Any", *FACTORY_PARAMS], "Tensor")],
|
|
"sparse_coo_tensor": [
|
|
defs(
|
|
"sparse_coo_tensor",
|
|
[
|
|
"indices: Tensor",
|
|
"values: Tensor | list",
|
|
"size: _size | None = None",
|
|
"*",
|
|
"dtype: _dtype | None = None",
|
|
"device: DeviceLikeType | None = None",
|
|
"requires_grad: _bool = False",
|
|
"check_invariants: _bool | None = None",
|
|
"is_coalesced: _bool | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"sparse_compressed_tensor": [
|
|
defs(
|
|
"sparse_compressed_tensor",
|
|
[
|
|
"compressed_indices: Tensor | list",
|
|
"plain_indices: Tensor | list",
|
|
"values: Tensor | list",
|
|
"size: _size | None = None",
|
|
"*",
|
|
"dtype: _dtype | None = None",
|
|
"layout: _layout | None = None",
|
|
"device: DeviceLikeType | None = None",
|
|
"requires_grad: _bool = False",
|
|
"check_invariants: _bool | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"_sync": [defs("_sync", ["t: Tensor"], "None")],
|
|
"_is_functional_tensor": [
|
|
defs("_is_functional_tensor", ["t: Tensor"], "_bool")
|
|
],
|
|
"_is_functional_tensor_base": [
|
|
"def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
|
|
],
|
|
"_from_functional_tensor": [
|
|
defs("_from_functional_tensor", ["t: Tensor"], "Tensor")
|
|
],
|
|
"_to_functional_tensor": [
|
|
defs("_to_functional_tensor", ["t: Tensor"], "Tensor")
|
|
],
|
|
"_functionalize_replace": [
|
|
defs(
|
|
"_functionalize_replace", ["self_: Tensor", "other: Tensor"], "None"
|
|
)
|
|
],
|
|
"_functionalize_commit_update": [
|
|
defs("_functionalize_commit_update", ["t: Tensor"], "None")
|
|
],
|
|
"_functionalize_unsafe_set": [
|
|
"def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
|
|
],
|
|
"_functionalize_mark_mutation_hidden_from_autograd": [
|
|
defs(
|
|
"_functionalize_mark_mutation_hidden_from_autograd",
|
|
["t: Tensor"],
|
|
"None",
|
|
)
|
|
],
|
|
"_functionalize_mutation_counter": [
|
|
defs(
|
|
"_functionalize_mutation_counter",
|
|
["t: Tensor"],
|
|
"_int",
|
|
)
|
|
],
|
|
"_functionalize_storage_changed_counter": [
|
|
defs(
|
|
"_functionalize_storage_changed_counter",
|
|
["t: Tensor"],
|
|
"_int",
|
|
)
|
|
],
|
|
"_functionalize_inductor_storage_resized_counter": [
|
|
defs(
|
|
"_functionalize_inductor_storage_resized_counter",
|
|
["t: Tensor"],
|
|
"_int",
|
|
)
|
|
],
|
|
"_functionalize_are_all_mutations_hidden_from_autograd": [
|
|
defs(
|
|
"_functionalize_are_all_mutations_hidden_from_autograd",
|
|
["t: Tensor"],
|
|
"_bool",
|
|
)
|
|
],
|
|
"_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [
|
|
defs(
|
|
"_functionalize_are_all_mutations_under_no_grad_or_inference_mode",
|
|
["t: Tensor"],
|
|
"_bool",
|
|
)
|
|
],
|
|
"_functionalize_was_inductor_storage_resized": [
|
|
defs(
|
|
"_functionalize_was_inductor_storage_resized",
|
|
["t: Tensor"],
|
|
"_bool",
|
|
)
|
|
],
|
|
"_functionalize_sync": [defs("_functionalize_sync", ["t: Tensor"], "None")],
|
|
"_functionalize_was_storage_changed": [
|
|
defs("_functionalize_was_storage_changed", ["tensor: Tensor"], "_bool")
|
|
],
|
|
"_functionalize_mark_storage_changed": [
|
|
"def _functionalize_mark_storage_changed(tensor: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_has_metadata_mutation": [
|
|
defs(
|
|
"_functionalize_has_metadata_mutation", ["tensor: Tensor"], "_bool"
|
|
)
|
|
],
|
|
"_functionalize_apply_view_metas": [
|
|
defs(
|
|
"_functionalize_apply_view_metas",
|
|
["tensor: Tensor", "base: Tensor"],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"_functionalize_is_symbolic": [
|
|
defs("_functionalize_is_symbolic", ["tensor: Tensor"], "_bool")
|
|
],
|
|
"_enable_functionalization": [
|
|
defs(
|
|
"_enable_functionalization",
|
|
["*", "reapply_views: _bool = False"],
|
|
"None",
|
|
)
|
|
],
|
|
"_disable_functionalization": [defs("_disable_functionalization")],
|
|
"range": [
|
|
defs(
|
|
"range",
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"step: Number = 1",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"arange": [
|
|
defs(
|
|
"arange",
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"step: Number",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"arange",
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"arange",
|
|
["end: Number", "*", "out: Tensor | None = None", *FACTORY_PARAMS],
|
|
"Tensor",
|
|
),
|
|
],
|
|
"linspace": [
|
|
defs(
|
|
"linspace",
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"steps: _int | None = None",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"logspace": [
|
|
defs(
|
|
"logspace",
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"steps: _int | None = None",
|
|
"base: _float = 10.0",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"randint": [
|
|
defs(
|
|
"randint",
|
|
[
|
|
"low: _int",
|
|
"high: _int",
|
|
"size: _size",
|
|
"*",
|
|
"generator: Generator | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"randint",
|
|
[
|
|
"high: _int",
|
|
"size: _size",
|
|
"*",
|
|
"generator: Generator | None = None",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
),
|
|
],
|
|
"full": [
|
|
defs(
|
|
"full",
|
|
[
|
|
"size: _size",
|
|
"fill_value: Number | _complex",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
"layout: _layout = strided",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"full",
|
|
[
|
|
"size: _size",
|
|
"fill_value: Number | _complex",
|
|
"*",
|
|
"names: list[str | None]",
|
|
"layout: _layout = strided",
|
|
*FACTORY_PARAMS,
|
|
],
|
|
"Tensor",
|
|
),
|
|
],
|
|
"is_grad_enabled": [defs("is_grad_enabled", [], "_bool")],
|
|
"is_inference_mode_enabled": [
|
|
defs("is_inference_mode_enabled", [], "_bool")
|
|
],
|
|
"nonzero": [
|
|
defs(
|
|
"nonzero",
|
|
[
|
|
"input: Tensor",
|
|
"*",
|
|
"as_tuple: Literal[False] = False",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"nonzero",
|
|
["input: Tensor", "*", "as_tuple: Literal[True]"],
|
|
"tuple[Tensor, ...]",
|
|
),
|
|
],
|
|
"dsmm": [defs("dsmm", ["input: Tensor", "mat2: Tensor"], "Tensor")],
|
|
"hsmm": [defs("hsmm", ["input: Tensor", "mat2: Tensor"], "Tensor")],
|
|
"saddmm": [
|
|
defs(
|
|
"saddmm",
|
|
[
|
|
"input: Tensor",
|
|
"mat1: Tensor",
|
|
"mat2: Tensor",
|
|
"*",
|
|
"beta: Number = 1",
|
|
"alpha: Number = 1",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"spmm": [defs("spmm", ["input: Tensor", "mat2: Tensor"], "Tensor")],
|
|
"div": [
|
|
defs(
|
|
"div",
|
|
[
|
|
"input: Tensor | Number",
|
|
"other: Tensor | Number",
|
|
"*",
|
|
"rounding_mode: str | None = None",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
}
|
|
)
|
|
for binop in ["true_divide", "floor_divide"]:
|
|
unsorted_function_hints[binop].append(
|
|
defs(
|
|
binop,
|
|
[
|
|
"input: Tensor | Number",
|
|
"other: Tensor | Number",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
for binop in ["mul"]:
|
|
unsorted_function_hints[binop].append(
|
|
defs(
|
|
binop,
|
|
[
|
|
"input: Tensor | Number | _complex",
|
|
"other: Tensor | Number | _complex",
|
|
"*",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
for binop in ["add", "sub"]:
|
|
unsorted_function_hints[binop].append(
|
|
defs(
|
|
binop,
|
|
[
|
|
"input: Tensor | Number | _complex",
|
|
"other: Tensor | Number | _complex",
|
|
"*",
|
|
"alpha: Number | _complex | None = 1",
|
|
"out: Tensor | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
|
|
native_functions = parse_native_yaml(
|
|
native_yaml_path, tags_yaml_path
|
|
).native_functions
|
|
native_functions = list(filter(should_generate_py_binding, native_functions))
|
|
|
|
function_signatures = load_signatures(
|
|
native_functions, deprecated_yaml_path, method=False, pyi=True
|
|
)
|
|
sig_groups = get_py_torch_functions(function_signatures)
|
|
for group in sorted(sig_groups, key=lambda g: g.signature.name):
|
|
name = group.signature.name
|
|
unsorted_function_hints[name] += generate_type_hints(group)
|
|
|
|
structseq = returns_structseq_pyi(group.signature)
|
|
if structseq is not None and not group.signature.deprecated:
|
|
# deprecated structseqs are currently not included for torch functions
|
|
tuple_name, tuple_def = structseq
|
|
if tuple_name in structseqs:
|
|
assert structseqs[tuple_name] == tuple_def
|
|
else:
|
|
structseqs[tuple_name] = tuple_def
|
|
|
|
def replace_special_case(hint: str) -> str:
|
|
# NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h
|
|
hint = hint.replace("at::Reduction::Mean", "1")
|
|
hint = hint.replace(": Tensor = None", ": Tensor | None = None")
|
|
return hint
|
|
|
|
docstrs = gather_docstrs()
|
|
function_hints = []
|
|
for name, hints in sorted(unsorted_function_hints.items()):
|
|
hints = [replace_special_case(h) for h in hints]
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
docstr = docstrs.get(f"torch.{name}")
|
|
if docstr is not None:
|
|
hints = [add_docstr_to_hint(docstr, h) for h in hints]
|
|
function_hints += hints
|
|
|
|
# Generate type signatures for Tensor methods
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
index_type_def = [_index_type_def]
|
|
unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
|
|
unsorted_tensor_method_hints.update(
|
|
{
|
|
"size": [
|
|
defs("size", ["self", "dim: None = None"], "Size"),
|
|
defs("size", ["self", "dim: _int"], "_int"),
|
|
],
|
|
"stride": [
|
|
defs("stride", ["self", "dim: None = None"], "tuple[_int, ...]"),
|
|
defs("stride", ["self", "dim: _int"], "_int"),
|
|
],
|
|
"new_ones": [
|
|
defs("new_ones", ["self", "size: _size", *FACTORY_PARAMS], "Tensor")
|
|
],
|
|
"new_tensor": [
|
|
defs("new_tensor", ["self", "data: Any", *FACTORY_PARAMS], "Tensor")
|
|
],
|
|
"__new__": [defs("__new__", ["cls", "*args", "**kwargs"], "Self")],
|
|
# new and __init__ have the same signatures differ only in return type
|
|
# Adapted from legacy_tensor_ctor and legacy_tensor_new
|
|
"new": [
|
|
defs("new", ["cls", "*args: Any", DEVICE_PARAM], "Self"),
|
|
defs("new", ["cls", "storage: Storage"], "Self"),
|
|
defs("new", ["cls", "other: Tensor"], "Self"),
|
|
defs("new", ["cls", "size: _size", "*", DEVICE_PARAM], "Self"),
|
|
],
|
|
"__init__": [
|
|
defs("__init__", ["self", "*args: Any", DEVICE_PARAM], "None"),
|
|
defs("__init__", ["self", "storage: Storage"], "None"),
|
|
defs("__init__", ["self", "other: Tensor"], "None"),
|
|
defs("__init__", ["self", "size: _size", "*", DEVICE_PARAM], "None"),
|
|
],
|
|
"as_subclass": [defs("as_subclass", ["self", "cls: type[S]"], "S")],
|
|
"_make_subclass": [
|
|
"@staticmethod\n"
|
|
+ defs(
|
|
"_make_subclass",
|
|
[
|
|
"cls: type[S]",
|
|
"data: Tensor",
|
|
"require_grad: _bool = False",
|
|
"dispatch_strides: _bool = False",
|
|
"dispatch_device: _bool = False",
|
|
"device_for_backend_keys: _device | None = None",
|
|
],
|
|
"S",
|
|
)
|
|
],
|
|
"_make_wrapper_subclass": [
|
|
"@staticmethod\n"
|
|
+ defs(
|
|
"_make_wrapper_subclass",
|
|
[
|
|
"cls: type[S]",
|
|
"size: Sequence[_int | SymInt]",
|
|
"strides: Sequence[_int | SymInt] | None = None",
|
|
"storage_offset: _int | SymInt | None = None",
|
|
"memory_format: torch.memory_format | None = None",
|
|
"dtype: _dtype | None = None",
|
|
"layout: _layout = strided",
|
|
"device: _device | None = None",
|
|
"pin_memory: _bool = False",
|
|
"requires_grad: _bool = False",
|
|
"dispatch_sizes_strides_policy: str | None = None",
|
|
"dispatch_device: _bool = False",
|
|
"dispatch_layout: _bool = False",
|
|
"_extra_dispatch_keys: torch.DispatchKeySet | None = None",
|
|
"storage_size: _int | SymInt | None = None",
|
|
],
|
|
"S",
|
|
)
|
|
],
|
|
"__contains__": [defs("__contains__", ["self", "item: Any", "/"], "_bool")],
|
|
"__getitem__": [defs("__getitem__", ["self", INDICES, "/"], "Tensor")],
|
|
"__setitem__": [
|
|
defs(
|
|
"__setitem__",
|
|
["self", INDICES, "value: Tensor | Number", "/"],
|
|
"None",
|
|
)
|
|
],
|
|
"tolist": [defs("tolist", ["self"], "list")],
|
|
"requires_grad_": [
|
|
defs("requires_grad_", ["self", "mode: _bool = True"], "Tensor")
|
|
],
|
|
"element_size": [defs("element_size", ["self"], "_int")],
|
|
"data_ptr": [defs("data_ptr", ["self"], "_int")],
|
|
"dim": [defs("dim", ["self"], "_int")],
|
|
"nonzero": [
|
|
defs(
|
|
"nonzero",
|
|
["self", "*", "as_tuple: Literal[False] = False"],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"nonzero",
|
|
["self", "*", "as_tuple: Literal[True]"],
|
|
"tuple[Tensor, ...]",
|
|
),
|
|
],
|
|
"numel": [defs("numel", ["self"], "_int")],
|
|
"ndimension": [defs("ndimension", ["self"], "_int")],
|
|
"nelement": [defs("nelement", ["self"], "_int")],
|
|
"cuda": [
|
|
defs(
|
|
"cuda",
|
|
[
|
|
"self",
|
|
"device: _device | _int | str | None = None",
|
|
"non_blocking: _bool = False",
|
|
"memory_format: torch.memory_format = torch.preserve_format",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"xpu": [
|
|
defs(
|
|
"xpu",
|
|
[
|
|
"self",
|
|
"device: _device | _int | str | None = None",
|
|
"non_blocking: _bool = False",
|
|
"memory_format: torch.memory_format = torch.preserve_format",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"cpu": [
|
|
defs(
|
|
"cpu",
|
|
[
|
|
"self",
|
|
"memory_format: torch.memory_format = torch.preserve_format",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"numpy": [
|
|
defs("numpy", ["self", "*", "force: _bool = False"], "numpy.ndarray")
|
|
],
|
|
"apply_": [defs("apply_", ["self", "callable: Callable"], "Tensor")],
|
|
"map_": [
|
|
defs("map_", ["self", "other: Tensor", "callable: Callable"], "Tensor")
|
|
],
|
|
"map2_": [
|
|
defs(
|
|
"map2_",
|
|
["self", "x: Tensor", "y: Tensor", "callable: Callable"],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"storage": [defs("untyped_storage", ["self"], "UntypedStorage")],
|
|
"storage_type": [defs("storage_type", ["self"], "Storage")],
|
|
"type": [
|
|
defs(
|
|
"type",
|
|
["self", "dtype: None = None", "non_blocking: _bool = False"],
|
|
"str",
|
|
),
|
|
defs(
|
|
"type",
|
|
["self", "dtype: str | _dtype", "non_blocking: _bool = False"],
|
|
"Tensor",
|
|
),
|
|
],
|
|
"get_device": [defs("get_device", ["self"], "_int")],
|
|
"contiguous": [
|
|
defs(
|
|
"contiguous",
|
|
[
|
|
"self",
|
|
"memory_format: torch.memory_format = torch.contiguous_format",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"has_names": [defs("has_names", ["self"], "_bool")],
|
|
"is_contiguous": [
|
|
defs(
|
|
"is_contiguous",
|
|
[
|
|
"self",
|
|
"memory_format: torch.memory_format = torch.contiguous_format",
|
|
],
|
|
"_bool",
|
|
)
|
|
],
|
|
"_is_view": [defs("_is_view", ["self"], "_bool")],
|
|
"is_cpu": ["is_cpu: _bool"],
|
|
"is_cuda": ["is_cuda: _bool"],
|
|
"is_xpu": ["is_xpu: _bool"],
|
|
"is_leaf": ["is_leaf: _bool"],
|
|
"is_nested": ["is_nested: _bool"],
|
|
"is_sparse": ["is_sparse: _bool"],
|
|
"is_sparse_csr": ["is_sparse_csr: _bool"],
|
|
"is_quantized": ["is_quantized: _bool"],
|
|
"is_meta": ["is_meta: _bool"],
|
|
"is_mps": ["is_mps: _bool"],
|
|
"is_mtia": ["is_mtia: _bool"],
|
|
"is_maia": ["is_maia: _bool"],
|
|
"is_mkldnn": ["is_mkldnn: _bool"],
|
|
"is_vulkan": ["is_vulkan: _bool"],
|
|
"is_ipu": ["is_ipu: _bool"],
|
|
"storage_offset": [defs("storage_offset", ["self"], "_int | SymInt")],
|
|
"to": [
|
|
(
|
|
defs(
|
|
"to",
|
|
[
|
|
"self",
|
|
*to_args,
|
|
"non_blocking: _bool = False",
|
|
"copy: _bool = False",
|
|
"*",
|
|
"memory_format: torch.memory_format | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
for to_args in [
|
|
["dtype: _dtype"],
|
|
[
|
|
"device: DeviceLikeType | None = None",
|
|
"dtype: _dtype | None = None",
|
|
],
|
|
["other: Tensor"],
|
|
]
|
|
],
|
|
"item": [defs("item", ["self"], "Number")],
|
|
"copy_": [
|
|
defs(
|
|
"copy_",
|
|
["self", "other: Tensor", "non_blocking: _bool = False"],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"set_": [
|
|
defs(
|
|
"set_",
|
|
[
|
|
"self",
|
|
"source: Storage | TypedStorage | UntypedStorage",
|
|
"storage_offset: IntLikeType",
|
|
"size: _symsize",
|
|
"stride: _symsize",
|
|
],
|
|
"Tensor",
|
|
),
|
|
defs(
|
|
"set_",
|
|
["self", "source: Storage | TypedStorage | UntypedStorage"],
|
|
"Tensor",
|
|
),
|
|
],
|
|
"split": [
|
|
defs(
|
|
"split",
|
|
["self", "split_size: _int", "dim: _int = 0"],
|
|
"Sequence[Tensor]",
|
|
),
|
|
defs(
|
|
"split",
|
|
["self", "split_size: tuple[_int, ...]", "dim: _int = 0"],
|
|
"Sequence[Tensor]",
|
|
),
|
|
],
|
|
"div": [
|
|
defs(
|
|
"div",
|
|
[
|
|
"self",
|
|
"other: Tensor | Number",
|
|
"*",
|
|
"rounding_mode: str | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
"div_": [
|
|
defs(
|
|
"div_",
|
|
[
|
|
"self",
|
|
"other: Tensor | Number",
|
|
"*",
|
|
"rounding_mode: str | None = None",
|
|
],
|
|
"Tensor",
|
|
)
|
|
],
|
|
}
|
|
)
|
|
for binop in ["true_divide", "floor_divide"]:
|
|
for inplace in [False, True]:
|
|
out_args = ["*", "out: Tensor | None = None"]
|
|
if inplace:
|
|
binop += "_"
|
|
out_args = []
|
|
unsorted_tensor_method_hints[binop].append(
|
|
defs(
|
|
binop,
|
|
[
|
|
"self",
|
|
"other: Tensor | Number | torch.SymInt | torch.SymFloat",
|
|
*out_args,
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
for binop in ["mul"]:
|
|
for inplace in [False, True]:
|
|
out_args = ["*", "out: Tensor | None = None"]
|
|
if inplace:
|
|
binop += "_"
|
|
out_args = []
|
|
unsorted_tensor_method_hints[binop].append(
|
|
defs(
|
|
binop,
|
|
[
|
|
"self",
|
|
"other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat",
|
|
*out_args,
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
for binop in ["add", "sub"]:
|
|
for inplace in [False, True]:
|
|
out_args = ["out: Tensor | None = None"]
|
|
if inplace:
|
|
binop += "_"
|
|
out_args = []
|
|
unsorted_tensor_method_hints[binop].append(
|
|
defs(
|
|
binop,
|
|
[
|
|
"self",
|
|
"other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat",
|
|
"*",
|
|
"alpha: Number | _complex | None = 1",
|
|
*out_args,
|
|
],
|
|
"Tensor",
|
|
)
|
|
)
|
|
simple_conversions = [
|
|
"bfloat16",
|
|
"bool",
|
|
"byte",
|
|
"char",
|
|
"double",
|
|
"float",
|
|
"half",
|
|
"int",
|
|
"long",
|
|
"short",
|
|
]
|
|
for name in simple_conversions:
|
|
unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
|
|
|
|
# pyi tensor methods don't currently include deprecated signatures for some reason
|
|
# TODO: we should probably add them in
|
|
tensor_method_signatures = load_signatures(
|
|
native_functions,
|
|
deprecated_yaml_path,
|
|
method=True,
|
|
skip_deprecated=True,
|
|
pyi=True,
|
|
)
|
|
tensor_method_sig_groups = get_py_torch_functions(
|
|
tensor_method_signatures, method=True
|
|
)
|
|
|
|
for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
|
|
name = group.signature.name
|
|
unsorted_tensor_method_hints[name] += generate_type_hints(group)
|
|
|
|
structseq = returns_structseq_pyi(group.signature)
|
|
if structseq is not None and not group.signature.deprecated:
|
|
# deprecated structseqs are currently not included for torch functions
|
|
tuple_name, tuple_def = structseq
|
|
if tuple_name in structseqs:
|
|
assert structseqs[tuple_name] == tuple_def
|
|
else:
|
|
structseqs[tuple_name] = tuple_def
|
|
|
|
for op in all_ops:
|
|
name = f"__{op}__"
|
|
unsorted_tensor_method_hints[name] += sig_for_ops(name)
|
|
|
|
tensor_method_hints = []
|
|
for name, hints in sorted(unsorted_tensor_method_hints.items()):
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
docstr = docstrs.get(f"torch._C.TensorBase.{name}")
|
|
if docstr is not None:
|
|
hints = [add_docstr_to_hint(docstr, h) for h in hints]
|
|
tensor_method_hints += hints
|
|
|
|
# TODO: Missing type hints for nn
|
|
|
|
# Generate structseq definitions
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
structseqs = dict(sorted(structseqs.items()))
|
|
structseq_defs = [f"{defn}\n" for defn in structseqs.values()]
|
|
return_types___all__ = [
|
|
"__all__ = [",
|
|
' "pytree_register_structseq",',
|
|
' "all_return_types",',
|
|
*(f' "{name}",' for name in structseqs),
|
|
"]",
|
|
]
|
|
|
|
# Generate type signatures for legacy classes
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
legacy_storage_base_hints = ["class StorageBase: ..."]
|
|
|
|
legacy_class_hints = []
|
|
for c in (
|
|
"DoubleTensor",
|
|
"FloatTensor",
|
|
"BFloat16Tensor",
|
|
"LongTensor",
|
|
"IntTensor",
|
|
"ShortTensor",
|
|
"HalfTensor",
|
|
"CharTensor",
|
|
"ByteTensor",
|
|
"BoolTensor",
|
|
):
|
|
legacy_class_hints.append(f"class {c}(Tensor): ...")
|
|
|
|
# Generate type signatures for dtype classes
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
# TODO(#146647): don't explicitly list dtypes here; get it from canonical
|
|
# source
|
|
dtype_class_hints = [
|
|
f"{n}: dtype = ..."
|
|
for n in [
|
|
"float32",
|
|
"float",
|
|
"float64",
|
|
"double",
|
|
"float16",
|
|
"bfloat16",
|
|
"float8_e4m3fn",
|
|
"float8_e4m3fnuz",
|
|
"float8_e5m2",
|
|
"float8_e5m2fnuz",
|
|
"float8_e8m0fnu",
|
|
"float4_e2m1fn_x2",
|
|
"half",
|
|
"uint8",
|
|
"uint16",
|
|
"uint32",
|
|
"uint64",
|
|
"int8",
|
|
"int16",
|
|
"short",
|
|
"int32",
|
|
"int",
|
|
"int64",
|
|
"long",
|
|
"complex32",
|
|
"complex64",
|
|
"chalf",
|
|
"cfloat",
|
|
"complex128",
|
|
"cdouble",
|
|
"quint8",
|
|
"qint8",
|
|
"qint32",
|
|
"bool",
|
|
"quint4x2",
|
|
"quint2x4",
|
|
"bits1x8",
|
|
"bits2x4",
|
|
"bits4x2",
|
|
"bits8",
|
|
"bits16",
|
|
]
|
|
]
|
|
|
|
# Generate __all__ directive
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
# Include only the functions that contain hints, to prevent undefined
|
|
# symbols to be included in the `__all__` directive.
|
|
hinted_function_names = {
|
|
name for name, hint in unsorted_function_hints.items() if hint
|
|
}
|
|
all_symbols = sorted(hinted_function_names.union(structseqs))
|
|
all_directive = [
|
|
"__all__ = [",
|
|
*(f' "{name}",' for name in all_symbols),
|
|
"]",
|
|
]
|
|
|
|
# Dispatch key hints
|
|
# ~~~~~~~~~~~~~~~~~~
|
|
dispatch_key_hints = [f"{d.name} = ..." for d in DispatchKey]
|
|
torch_dispatch_mode_key_hints = [f"{k.name} = ..." for k in _TorchDispatchModeKey]
|
|
|
|
# Tags Enum type hints
|
|
# ~~~~~~~~~~~~~~~~~~~~
|
|
|
|
tag_names = sorted(parse_tags_yaml(tags_yaml_path))
|
|
tag_attributes = "\n".join(
|
|
f"{name} = {index}" for index, name in enumerate(tag_names)
|
|
)
|
|
|
|
# Write out the stub
|
|
# ~~~~~~~~~~~~~~~~~~
|
|
|
|
env = {
|
|
"structseq_defs": structseq_defs,
|
|
"return_types___all__": return_types___all__,
|
|
"function_hints": function_hints,
|
|
"index_type_def": index_type_def,
|
|
"tensor_method_hints": tensor_method_hints,
|
|
"legacy_class_hints": legacy_class_hints,
|
|
"legacy_storage_base_hints": legacy_storage_base_hints,
|
|
"dtype_class_hints": dtype_class_hints,
|
|
"dispatch_key_hints": dispatch_key_hints,
|
|
"torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints,
|
|
"all_directive": all_directive,
|
|
"tag_attributes": tag_attributes,
|
|
}
|
|
fm.write_with_template(
|
|
"torch/_C/__init__.pyi",
|
|
"torch/_C/__init__.pyi.in",
|
|
lambda: env,
|
|
)
|
|
fm.write_with_template(
|
|
"torch/_C/_VariableFunctions.pyi",
|
|
"torch/_C/_VariableFunctions.pyi.in",
|
|
lambda: env,
|
|
)
|
|
fm.write_with_template(
|
|
"torch/_VF.pyi",
|
|
"torch/_C/_VariableFunctions.pyi.in",
|
|
lambda: env,
|
|
)
|
|
fm.write_with_template(
|
|
"torch/return_types.pyi",
|
|
"torch/_C/return_types.pyi.in",
|
|
lambda: env,
|
|
)
|
|
gen_nn_functional(fm)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
|
|
parser.add_argument(
|
|
"--native-functions-path",
|
|
metavar="NATIVE",
|
|
default="aten/src/ATen/native/native_functions.yaml",
|
|
help="path to native_functions.yaml",
|
|
)
|
|
parser.add_argument(
|
|
"--tags-path",
|
|
metavar="TAGS",
|
|
default="aten/src/ATen/native/tags.yaml",
|
|
help="path to tags.yaml",
|
|
)
|
|
parser.add_argument(
|
|
"--deprecated-functions-path",
|
|
metavar="DEPRECATED",
|
|
default="tools/autograd/deprecated.yaml",
|
|
help="path to deprecated.yaml",
|
|
)
|
|
parser.add_argument(
|
|
"--out",
|
|
metavar="OUT",
|
|
default=".",
|
|
help="path to output directory",
|
|
)
|
|
parser.add_argument(
|
|
"--template-dir",
|
|
default=".",
|
|
help="path to template directory",
|
|
)
|
|
args = parser.parse_args()
|
|
fm = FileManager(
|
|
install_dir=args.out, template_dir=args.template_dir, dry_run=False
|
|
)
|
|
gen_pyi(
|
|
args.native_functions_path,
|
|
args.tags_path,
|
|
args.deprecated_functions_path,
|
|
fm,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|