mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change `cumsum` to call its decomposition when `use_deterministic_algorithms(True)` and input is CUDA. Fixes #89492 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136224 Approved by: https://github.com/ezyang, https://github.com/justinchuby
1492 lines
54 KiB
Python
1492 lines
54 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import collections
|
|
import importlib
|
|
import sys
|
|
from pprint import pformat
|
|
from typing import Sequence
|
|
from unittest.mock import Mock, patch
|
|
from warnings import warn
|
|
|
|
from tools.autograd.gen_python_functions import (
|
|
group_overloads,
|
|
load_signatures,
|
|
should_generate_py_binding,
|
|
)
|
|
|
|
from torchgen.api.python import (
|
|
PythonSignatureGroup,
|
|
PythonSignatureNativeFunctionPair,
|
|
returns_structseq_pyi,
|
|
)
|
|
from torchgen.gen import parse_native_yaml, parse_tags_yaml
|
|
from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
|
|
from torchgen.utils import FileManager
|
|
|
|
|
|
"""
|
|
This module implements generation of type stubs for PyTorch,
|
|
enabling use of autocomplete in IDEs like PyCharm, which otherwise
|
|
don't understand C extension modules.
|
|
|
|
At the moment, this module only handles type stubs for torch and
|
|
torch.Tensor. It should eventually be expanded to cover all functions
|
|
which come are autogenerated.
|
|
|
|
Here's our general strategy:
|
|
|
|
- We start off with a hand-written __init__.pyi.in file. This
|
|
file contains type definitions for everything we cannot automatically
|
|
generate, including pure Python definitions directly in __init__.py
|
|
(the latter case should be pretty rare).
|
|
|
|
- We go through automatically bound functions based on the
|
|
type information recorded in native_functions.yaml and
|
|
generate type hints for them (generate_type_hints)
|
|
|
|
There are a number of type hints which we've special-cased;
|
|
read gen_pyi for the gory details.
|
|
"""
|
|
|
|
|
|
def get_py_torch_functions(
|
|
python_funcs: Sequence[PythonSignatureNativeFunctionPair],
|
|
method: bool = False,
|
|
) -> Sequence[PythonSignatureGroup]:
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as either functions in the "torch" module or methods on Tensor.
|
|
"""
|
|
|
|
def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
|
|
return (
|
|
should_generate_py_binding(python_func.function)
|
|
and not python_func.function.python_module
|
|
and Variant.function in python_func.function.variants
|
|
)
|
|
|
|
def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
|
|
return (
|
|
should_generate_py_binding(python_func.function)
|
|
and not python_func.function.python_module
|
|
and Variant.method in python_func.function.variants
|
|
)
|
|
|
|
should_bind = should_bind_method if method else should_bind_function
|
|
return group_overloads([f for f in python_funcs if should_bind(f)])
|
|
|
|
|
|
# TODO: Consider defining some aliases for our Union[...] types, to make
|
|
# the stubs to read on the human eye.
|
|
|
|
DEVICE_PARAM = "device: Optional[DeviceLikeType] = None"
|
|
FACTORY_PARAMS = f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False, pin_memory: _bool = False"
|
|
|
|
# NOTE: specifying indices for Tensor.__getitem__
|
|
# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi:
|
|
#
|
|
# key: (
|
|
# None
|
|
# | slice
|
|
# | ellipsis
|
|
# | SupportsIndex
|
|
# | _ArrayLikeInt_co
|
|
# | tuple[None | slice | ellipsis | _ArrayLikeInt_co | SupportsIndex, ...]
|
|
# )
|
|
#
|
|
# where:
|
|
#
|
|
# _ArrayLikeInt_co = _DualArrayLike[
|
|
# dtype[Union[bool_, integer[Any]]],
|
|
# Union[bool, int],
|
|
# ]
|
|
#
|
|
# and
|
|
#
|
|
# _DualArrayLike = Union[
|
|
# _SupportsArray[_DType],
|
|
# _NestedSequence[_SupportsArray[_DType]],
|
|
# _T,
|
|
# _NestedSequence[_T],
|
|
# ]
|
|
#
|
|
# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple.
|
|
# We can substitute and simplify:
|
|
# _SupportsArray -> Tensor
|
|
# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]]
|
|
# which leaves us with key: T | tuple[T, ...], where T is:
|
|
# T = (
|
|
# None | bool | int | slice | ellipsis | SupportsIndex
|
|
# | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int]
|
|
# )
|
|
|
|
# NOTE: ellipsis is equal to type[Ellipsis] in stub files.
|
|
_leaf_types = "Union[None, _bool, _int, slice, ellipsis, Tensor]" # not SupportsIndex!
|
|
_index = f"Union[SupportsIndex, {_leaf_types}, _NestedSequence[{_leaf_types}]]"
|
|
INDICES = f"indices: Union[{_index}, tuple[{_index}, ...]]"
|
|
|
|
blocklist = [
|
|
"__init_subclass__",
|
|
"__new__",
|
|
"__subclasshook__",
|
|
"cdist",
|
|
"device",
|
|
"grad",
|
|
"requires_grad",
|
|
"range",
|
|
# defined in functional
|
|
"cumsum",
|
|
"einsum",
|
|
# Somehow, these are defined in both _C and in functional. Ick!
|
|
"broadcast_tensors",
|
|
# Manually define named tensor type stubs in __init__.pyi.in
|
|
"align_tensors",
|
|
"meshgrid",
|
|
"cartesian_prod",
|
|
"block_diag",
|
|
"norm",
|
|
"chain_matmul",
|
|
"stft",
|
|
"tensordot",
|
|
"split",
|
|
"unique_consecutive",
|
|
"atleast_1d",
|
|
"atleast_2d",
|
|
"atleast_3d",
|
|
# These are handled specially by python_arg_parser.cpp
|
|
"add",
|
|
"add_",
|
|
"add_out",
|
|
"sub",
|
|
"sub_",
|
|
"sub_out",
|
|
"mul",
|
|
"mul_",
|
|
"mul_out",
|
|
"div",
|
|
"div_",
|
|
"div_out",
|
|
"true_divide",
|
|
"true_divide_",
|
|
"true_divide_out",
|
|
"floor_divide",
|
|
"floor_divide_",
|
|
"floor_divide_out",
|
|
"to",
|
|
"_to_copy",
|
|
"copy_",
|
|
]
|
|
|
|
binary_ops = (
|
|
"add",
|
|
"sub",
|
|
"mul",
|
|
"div",
|
|
"pow",
|
|
"lshift",
|
|
"rshift",
|
|
"mod",
|
|
"truediv",
|
|
"matmul",
|
|
"floordiv",
|
|
"radd",
|
|
"rsub",
|
|
"rmul",
|
|
"rtruediv",
|
|
"rfloordiv",
|
|
"rpow", # reverse arithmetic
|
|
"and",
|
|
"or",
|
|
"xor",
|
|
"rand",
|
|
"ror",
|
|
"rxor", # logic
|
|
"iadd",
|
|
"iand",
|
|
"idiv",
|
|
"ilshift",
|
|
"imul",
|
|
"ior",
|
|
"irshift",
|
|
"isub",
|
|
"ixor",
|
|
"ifloordiv",
|
|
"imod", # inplace ops
|
|
)
|
|
symmetric_comparison_ops = ("eq", "ne")
|
|
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
|
|
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
|
|
|
|
unary_ops = ("neg", "abs", "invert")
|
|
to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
|
|
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
|
|
|
|
|
def sig_for_ops(opname: str) -> list[str]:
|
|
"""sig_for_ops(opname : str) -> List[str]
|
|
|
|
Returns signatures for operator special functions (__add__ etc.)"""
|
|
|
|
# we have to do this by hand, because they are hand-bound in Python
|
|
|
|
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
|
|
|
|
name = opname[2:-2]
|
|
if name in binary_ops:
|
|
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
|
|
elif name in comparison_ops:
|
|
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
|
|
if name in symmetric_comparison_ops:
|
|
# unsafe override https://github.com/python/mypy/issues/5704
|
|
sig += " # type: ignore[override]"
|
|
return [sig]
|
|
elif name in unary_ops:
|
|
return [f"def {opname}(self) -> Tensor: ..."]
|
|
elif name in to_py_type_ops:
|
|
if name in {"bool", "float", "complex"}:
|
|
tname = name
|
|
elif name == "nonzero":
|
|
tname = "bool"
|
|
else:
|
|
tname = "int"
|
|
if tname in {"float", "int", "bool", "complex"}:
|
|
tname = "builtins." + tname
|
|
return [f"def {opname}(self) -> {tname}: ..."]
|
|
else:
|
|
raise Exception("unknown op", opname) # noqa: TRY002
|
|
|
|
|
|
def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
|
|
type_hints: list[str] = []
|
|
|
|
# Some deprecated ops that are on the blocklist are still included in pyi
|
|
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
|
return type_hints
|
|
|
|
# deprecated signatures have separate entries for their functional and out variants
|
|
# (as opposed to the native ops, which fuse the two into a single signature).
|
|
# generate the functional variant here, if an out variant exists.
|
|
if sig_group.signature.deprecated and sig_group.outplace is not None:
|
|
type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
|
|
type_hints.append(type_hint)
|
|
|
|
# PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
|
|
# Generates the out variant if one exists. Otherwise, generate the functional variant
|
|
type_hint = sig_group.signature.signature_str_pyi(
|
|
skip_outputs=sig_group.outplace is None
|
|
)
|
|
type_hints.append(type_hint)
|
|
|
|
# Some operators also additionally have a vararg variant of their signature
|
|
type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
|
|
skip_outputs=sig_group.outplace is None
|
|
)
|
|
if type_hint_vararg:
|
|
type_hints.append(type_hint_vararg)
|
|
|
|
return type_hints
|
|
|
|
|
|
def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
|
|
flag_pos = arg_list.index("{return_indices}")
|
|
# If return_indices is positional arg, everything before should have no default
|
|
arg_list_positional = (
|
|
[
|
|
", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", "))
|
|
for arg in arg_list[: flag_pos + 1]
|
|
]
|
|
+ ["/"]
|
|
+ arg_list[flag_pos + 1 :]
|
|
)
|
|
# Otherwise force return_indices to be kwarg
|
|
arg_list_keyword = arg_list.copy()
|
|
arg_list_keyword.insert(flag_pos, "*")
|
|
tmpl = "def {name}({args}) -> {{return_type}}: ..."
|
|
return {
|
|
name: [
|
|
tmpl.format(name=name, args=", ".join(arg_list)).format(
|
|
return_indices="return_indices: Literal[False] = False",
|
|
return_type="Tensor",
|
|
),
|
|
tmpl.format(name=name, args=", ".join(arg_list_positional)).format(
|
|
return_indices="return_indices: Literal[True]",
|
|
return_type="Tuple[Tensor, Tensor]",
|
|
),
|
|
tmpl.format(name=name, args=", ".join(arg_list_keyword)).format(
|
|
return_indices="return_indices: Literal[True]",
|
|
return_type="Tuple[Tensor, Tensor]",
|
|
),
|
|
]
|
|
}
|
|
|
|
|
|
def gen_nn_functional(fm: FileManager) -> None:
|
|
INPUT = "input: Tensor"
|
|
KERNEL_SIZE = "kernel_size: Union[_int, _size]"
|
|
STRIDE_PADDING = ", ".join(
|
|
[
|
|
"stride: Optional[Union[_int, _size]] = None",
|
|
"padding: Union[_int, _size] = 0",
|
|
]
|
|
)
|
|
|
|
# TODO the list for `torch._C._nn` is nonexhaustive
|
|
unsorted_c_nn_function_hints: dict[str, list[str]] = {}
|
|
|
|
for d in (2, 3):
|
|
unsorted_c_nn_function_hints.update(
|
|
{
|
|
f"avg_pool{d}d": [
|
|
f"def avg_pool{d}d({{}}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
f"{INPUT}",
|
|
f"{KERNEL_SIZE}",
|
|
f"{STRIDE_PADDING}",
|
|
"ceil_mode: bool = False",
|
|
"count_include_pad: bool = True",
|
|
"divisor_override: Optional[int] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
f"fractional_max_pool{d}d": [
|
|
f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format(
|
|
", ".join(
|
|
[
|
|
f"{INPUT}",
|
|
f"{KERNEL_SIZE}",
|
|
"output_size: Union[_int, _size]",
|
|
"_random_samples: Tensor",
|
|
]
|
|
),
|
|
"Tuple[Tensor, Tensor]",
|
|
)
|
|
],
|
|
f"adaptive_max_pool{d}d": [
|
|
f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format(
|
|
", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]),
|
|
"Tuple[Tensor, Tensor]",
|
|
)
|
|
],
|
|
}
|
|
)
|
|
|
|
unsorted_c_nn_function_hints.update(
|
|
{
|
|
"hardtanh": [
|
|
"def hardtanh({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Tensor",
|
|
"min_val: float = ...",
|
|
"max_val: float = ...",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"hardtanh_": [
|
|
"def hardtanh_({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Tensor",
|
|
"min_val: float = ...",
|
|
"max_val: float = ...",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."],
|
|
"leaky_relu": [
|
|
"def leaky_relu({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Tensor",
|
|
"negative_slope: float = ...",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"leaky_relu_": [
|
|
f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..."
|
|
],
|
|
"log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."],
|
|
"gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."],
|
|
"softplus": [
|
|
"def softplus({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
["input: Tensor", "beta: float = ...", "threshold: float = ..."]
|
|
)
|
|
)
|
|
],
|
|
"softshrink": [
|
|
"def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..."
|
|
],
|
|
"hardsigmoid": [
|
|
f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..."
|
|
],
|
|
"linear": [
|
|
"def linear({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Tensor",
|
|
"weight: Tensor",
|
|
"bias: Optional[Tensor] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"pad": [
|
|
"def pad({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Tensor",
|
|
"pad: Sequence[int]",
|
|
"mode: str = ...",
|
|
"value: Optional[float] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"one_hot": [
|
|
"def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..."
|
|
],
|
|
"scaled_dot_product_attention": [
|
|
"def scaled_dot_product_attention({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"query: Tensor",
|
|
"key: Tensor",
|
|
"value: Tensor",
|
|
"attn_mask: Optional[Tensor] = None",
|
|
"dropout_p: float = 0.0",
|
|
"is_causal: bool = False",
|
|
"scale: Optional[float] = None",
|
|
"enable_gqa: bool = False",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
}
|
|
)
|
|
|
|
c_nn_function_hints: list[str] = []
|
|
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
c_nn_function_hints += hints
|
|
|
|
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
|
|
# through an `_add_docstr` call
|
|
torch_imports = [
|
|
"conv1d",
|
|
"conv2d",
|
|
"conv3d",
|
|
"conv_transpose1d",
|
|
"conv_transpose2d",
|
|
"conv_transpose3d",
|
|
"conv_tbc",
|
|
"avg_pool1d",
|
|
"adaptive_avg_pool1d",
|
|
"relu_",
|
|
"selu_",
|
|
"celu_",
|
|
"prelu",
|
|
"rrelu_",
|
|
"hardshrink",
|
|
"bilinear",
|
|
"pixel_shuffle",
|
|
"pixel_unshuffle",
|
|
"channel_shuffle",
|
|
"native_channel_shuffle",
|
|
"pairwise_distance",
|
|
"pdist",
|
|
"cosine_similarity",
|
|
]
|
|
imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports]
|
|
|
|
# Functions imported into `torch.nn.functional` from `torch._C._nn`
|
|
c_nn_imports = [
|
|
"avg_pool2d",
|
|
"avg_pool3d",
|
|
"hardtanh_",
|
|
"elu_",
|
|
"leaky_relu_",
|
|
"gelu",
|
|
"softplus",
|
|
"softshrink",
|
|
"linear",
|
|
"pad",
|
|
"one_hot",
|
|
"scaled_dot_product_attention",
|
|
]
|
|
imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports]
|
|
# This is from `torch._C._nn` but renamed
|
|
imported_hints.append(
|
|
"from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid"
|
|
)
|
|
|
|
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
|
|
unsorted_dispatched_hints: dict[str, list[str]] = {}
|
|
|
|
for d in (1, 2, 3):
|
|
unsorted_dispatched_hints.update(
|
|
**get_max_pool_dispatch(
|
|
f"max_pool{d}d",
|
|
[
|
|
f"{INPUT}",
|
|
f"{KERNEL_SIZE}",
|
|
f"{STRIDE_PADDING}",
|
|
"dilation: Union[_int, _size] = 1",
|
|
"ceil_mode: bool = False",
|
|
"{return_indices}",
|
|
],
|
|
),
|
|
**get_max_pool_dispatch(
|
|
f"fractional_max_pool{d}d",
|
|
[
|
|
f"{INPUT}",
|
|
f"{KERNEL_SIZE}",
|
|
"output_size: Optional[Union[_int, _size]] = None",
|
|
"output_ratio: Optional[_ratio_any_t] = None",
|
|
"{return_indices}",
|
|
"_random_samples: Optional[Tensor] = None",
|
|
],
|
|
),
|
|
**get_max_pool_dispatch(
|
|
f"adaptive_max_pool{d}d",
|
|
[f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"],
|
|
),
|
|
)
|
|
|
|
# There's no fractional_max_pool1d
|
|
del unsorted_dispatched_hints["fractional_max_pool1d"]
|
|
|
|
dispatched_hints: list[str] = []
|
|
for _, hints in sorted(unsorted_dispatched_hints.items()):
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
dispatched_hints += hints
|
|
|
|
fm.write_with_template(
|
|
"torch/nn/functional.pyi",
|
|
"torch/nn/functional.pyi.in",
|
|
lambda: {
|
|
"imported_hints": imported_hints,
|
|
"dispatched_hints": dispatched_hints,
|
|
},
|
|
)
|
|
fm.write_with_template(
|
|
"torch/_C/_nn.pyi",
|
|
"torch/_C/_nn.pyi.in",
|
|
lambda: {
|
|
"c_nn_function_hints": c_nn_function_hints,
|
|
},
|
|
)
|
|
|
|
|
|
"""
|
|
We gather the docstrings for torch with the following steps:
|
|
1. Mock torch and torch._C, which are the only dependencies of the docs files
|
|
2. Mock the _add_docstr function to save the docstrings
|
|
3. Import the docs files to trigger mocked _add_docstr and collect docstrings
|
|
"""
|
|
|
|
|
|
def gather_docstrs() -> dict[str, str]:
|
|
docstrs = {}
|
|
|
|
def mock_add_docstr(func: Mock, docstr: str) -> None:
|
|
docstrs[func._extract_mock_name()] = docstr.strip()
|
|
|
|
# sys.modules and sys.path are restored after the context manager exits
|
|
with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
|
|
# mock the torch module and torch._C._add_docstr
|
|
sys.modules["torch"] = Mock(name="torch")
|
|
sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
|
|
|
|
try:
|
|
# manually import torch._torch_docs and torch._tensor_docs to trigger
|
|
# the mocked _add_docstr and collect docstrings
|
|
sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
|
|
sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
|
|
except ModuleNotFoundError:
|
|
# Gracefully fail if these modules are not importable
|
|
warn(
|
|
"Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files."
|
|
)
|
|
|
|
return docstrs
|
|
|
|
|
|
def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
|
if "..." in hint: # function or method
|
|
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
|
|
hint = hint[:-3] # remove "..."
|
|
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
|
|
else: # attribute or property
|
|
return f'{hint}\nr"""{docstr}"""\n'
|
|
|
|
|
|
def gen_pyi(
|
|
native_yaml_path: str,
|
|
tags_yaml_path: str,
|
|
deprecated_yaml_path: str,
|
|
fm: FileManager,
|
|
) -> None:
|
|
"""gen_pyi()
|
|
|
|
This function generates a pyi file for torch.
|
|
"""
|
|
|
|
# Some of this logic overlaps with generate_python_signature in
|
|
# tools/autograd/gen_python_functions.py; however, this
|
|
# function is all about generating mypy type signatures, whereas
|
|
# the other function generates are custom format for argument
|
|
# checking. If you are update this, consider if your change
|
|
# also needs to update the other file.
|
|
|
|
# Dictionary for NamedTuple definitions
|
|
structseqs: dict[str, str] = {}
|
|
|
|
# Generate type signatures for top-level functions
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
|
|
|
|
for n, n1, n2 in [
|
|
("csr", "crow", "col"),
|
|
("csc", "ccol", "row"),
|
|
("bsr", "crow", "col"),
|
|
("bsc", "ccol", "row"),
|
|
]:
|
|
unsorted_function_hints.update(
|
|
{
|
|
f"sparse_{n}_tensor": [
|
|
f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
f"{n1}_indices: Union[Tensor, List]",
|
|
f"{n2}_indices: Union[Tensor, List]",
|
|
"values: Union[Tensor, List]",
|
|
"size: Optional[_size] = None",
|
|
"*",
|
|
"dtype: Optional[_dtype] = None",
|
|
"device: Optional[DeviceLikeType] = None",
|
|
"requires_grad: _bool = False",
|
|
"check_invariants: Optional[_bool] = None",
|
|
]
|
|
),
|
|
)
|
|
],
|
|
}
|
|
)
|
|
|
|
unsorted_function_hints.update(
|
|
{
|
|
"set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."],
|
|
"get_default_dtype": ["def get_default_dtype() -> _dtype: ..."],
|
|
"asarray": [
|
|
"def asarray({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"obj: Any",
|
|
"*",
|
|
"dtype: Optional[_dtype] = None",
|
|
"device: Optional[DeviceLikeType] = None",
|
|
"copy: Optional[_bool] = None",
|
|
"requires_grad: _bool = False",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."],
|
|
"frombuffer": [
|
|
"def frombuffer({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"buffer: Any",
|
|
"*",
|
|
"dtype: _dtype",
|
|
"count: int = -1",
|
|
"offset: int = 0",
|
|
"requires_grad: _bool = False",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"numel": ["def numel(self: Tensor) -> _int: ..."],
|
|
"as_tensor": [
|
|
"def as_tensor({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"data: Any",
|
|
"dtype: Optional[_dtype] = None",
|
|
DEVICE_PARAM,
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"get_num_threads": ["def get_num_threads() -> _int: ..."],
|
|
"set_num_threads": ["def set_num_threads(num: _int) -> None: ..."],
|
|
"init_num_threads": ["def init_num_threads() -> None: ..."],
|
|
"get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."],
|
|
"set_num_interop_threads": [
|
|
"def set_num_interop_threads(num: _int) -> None: ..."
|
|
],
|
|
# These functions are explicitly disabled by
|
|
# SKIP_PYTHON_BINDINGS because they are hand bound.
|
|
# Correspondingly, we must hand-write their signatures.
|
|
"tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."],
|
|
"sparse_coo_tensor": [
|
|
"def sparse_coo_tensor({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"indices: Tensor",
|
|
"values: Union[Tensor, List]",
|
|
"size: Optional[_size] = None",
|
|
"*",
|
|
"dtype: Optional[_dtype] = None",
|
|
"device: Optional[DeviceLikeType] = None",
|
|
"requires_grad: _bool = False",
|
|
"check_invariants: Optional[_bool] = None",
|
|
"is_coalesced: Optional[_bool] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"sparse_compressed_tensor": [
|
|
"def sparse_compressed_tensor({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"compressed_indices: Union[Tensor, List]",
|
|
"plain_indices: Union[Tensor, List]",
|
|
"values: Union[Tensor, List]",
|
|
"size: Optional[_size] = None",
|
|
"*",
|
|
"dtype: Optional[_dtype] = None",
|
|
"layout: Optional[_layout] = None",
|
|
"device: Optional[DeviceLikeType] = None",
|
|
"requires_grad: _bool = False",
|
|
"check_invariants: Optional[_bool] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"_sync": ["def _sync(t: Tensor) -> None: ..."],
|
|
"_is_functional_tensor": [
|
|
"def _is_functional_tensor(t: Tensor) -> _bool: ..."
|
|
],
|
|
"_is_functional_tensor_base": [
|
|
"def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
|
|
],
|
|
"_from_functional_tensor": [
|
|
"def _from_functional_tensor(t: Tensor) -> Tensor: ..."
|
|
],
|
|
"_to_functional_tensor": [
|
|
"def _to_functional_tensor(t: Tensor) -> Tensor: ..."
|
|
],
|
|
"_functionalize_replace": [
|
|
"def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..."
|
|
],
|
|
"_functionalize_commit_update": [
|
|
"def _functionalize_commit_update(t: Tensor) -> None: ..."
|
|
],
|
|
"_functionalize_unsafe_set": [
|
|
"def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
|
|
],
|
|
"_functionalize_mark_mutation_hidden_from_autograd": [
|
|
"def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..."
|
|
],
|
|
"_functionalize_are_all_mutations_hidden_from_autograd": [
|
|
"def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [
|
|
"def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_was_inductor_storage_resized": [
|
|
"def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."],
|
|
"_functionalize_was_storage_changed": [
|
|
"def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_set_storage_changed": [
|
|
"def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_has_metadata_mutation": [
|
|
"def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..."
|
|
],
|
|
"_functionalize_apply_view_metas": [
|
|
"def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ..."
|
|
],
|
|
"_functionalize_is_symbolic": [
|
|
"def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..."
|
|
],
|
|
"_enable_functionalization": [
|
|
"def _enable_functionalization(*, reapply_views: _bool = False): ..."
|
|
],
|
|
"_disable_functionalization": ["def _disable_functionalization(): ..."],
|
|
"range": [
|
|
"def range({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"step: Number = 1",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"arange": [
|
|
"def arange({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"step: Number",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
"def arange({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
"def arange({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"end: Number",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
],
|
|
"linspace": [
|
|
"def linspace({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"steps: Optional[_int] = None",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"logspace": [
|
|
"def logspace({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"start: Number",
|
|
"end: Number",
|
|
"steps: Optional[_int] = None",
|
|
"base: _float = 10.0",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"randint": [
|
|
"def randint({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"low: _int",
|
|
"high: _int",
|
|
"size: _size",
|
|
"*",
|
|
"generator: Optional[Generator] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
"def randint({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"high: _int",
|
|
"size: _size",
|
|
"*",
|
|
"generator: Optional[Generator] = None",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
],
|
|
"full": [
|
|
"def full({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"size: _size",
|
|
"fill_value: Union[Number, _complex]",
|
|
"*",
|
|
"out: Optional[Tensor] = None",
|
|
"layout: _layout = strided",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
"def full({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"size: _size",
|
|
"fill_value: Union[Number, _complex]",
|
|
"*",
|
|
"names: List[Union[str, None]]",
|
|
"layout: _layout = strided",
|
|
FACTORY_PARAMS,
|
|
]
|
|
)
|
|
),
|
|
],
|
|
"is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."],
|
|
"is_inference_mode_enabled": [
|
|
"def is_inference_mode_enabled() -> _bool: ..."
|
|
],
|
|
"nonzero": [
|
|
"def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...",
|
|
"def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
|
|
],
|
|
"dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
|
"hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
|
"saddmm": [
|
|
"def saddmm({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Tensor",
|
|
"mat1: Tensor",
|
|
"mat2: Tensor",
|
|
"*",
|
|
"beta: Number = 1",
|
|
"alpha: Number = 1",
|
|
"out: Optional[Tensor] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
|
"div": [
|
|
"def div({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"input: Union[Tensor, Number]",
|
|
"other: Union[Tensor, Number]",
|
|
"*",
|
|
"rounding_mode: Optional[str] = None",
|
|
"out: Optional[Tensor] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
}
|
|
)
|
|
for binop in ["true_divide", "floor_divide"]:
|
|
unsorted_function_hints[binop].append(
|
|
f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], "
|
|
"*, out: Optional[Tensor] = None) -> Tensor: ..."
|
|
)
|
|
for binop in ["mul"]:
|
|
unsorted_function_hints[binop].append(
|
|
f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], "
|
|
"*, out: Optional[Tensor] = None) -> Tensor: ..."
|
|
)
|
|
for binop in ["add", "sub"]:
|
|
unsorted_function_hints[binop].append(
|
|
f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], "
|
|
"*, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: ..."
|
|
)
|
|
|
|
native_functions = parse_native_yaml(
|
|
native_yaml_path, tags_yaml_path
|
|
).native_functions
|
|
native_functions = list(filter(should_generate_py_binding, native_functions))
|
|
|
|
function_signatures = load_signatures(
|
|
native_functions, deprecated_yaml_path, method=False, pyi=True
|
|
)
|
|
sig_groups = get_py_torch_functions(function_signatures)
|
|
for group in sorted(sig_groups, key=lambda g: g.signature.name):
|
|
name = group.signature.name
|
|
unsorted_function_hints[name] += generate_type_hints(group)
|
|
|
|
structseq = returns_structseq_pyi(group.signature)
|
|
if structseq is not None and not group.signature.deprecated:
|
|
# deprecated structseqs are currently not included for torch functions
|
|
tuple_name, tuple_def = structseq
|
|
if tuple_name in structseqs:
|
|
assert structseqs[tuple_name] == tuple_def
|
|
else:
|
|
structseqs[tuple_name] = tuple_def
|
|
|
|
def replace_special_case(hint: str) -> str:
|
|
# NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h
|
|
hint = hint.replace("at::Reduction::Mean", "1")
|
|
hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None")
|
|
# Match both:
|
|
# ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None"
|
|
# ": Union[Tuple[Tensor, ...], List[Tensor]] = None"
|
|
hint = hint.replace(
|
|
"Tuple[Tensor, ...], List[Tensor]] = None",
|
|
"Tuple[Tensor, ...], List[Tensor], None] = None",
|
|
)
|
|
return hint
|
|
|
|
docstrs = gather_docstrs()
|
|
function_hints = []
|
|
for name, hints in sorted(unsorted_function_hints.items()):
|
|
hints = [replace_special_case(h) for h in hints]
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
docstr = docstrs.get(f"torch.{name}")
|
|
if docstr is not None:
|
|
hints = [add_docstr_to_hint(docstr, h) for h in hints]
|
|
function_hints += hints
|
|
|
|
# Generate type signatures for Tensor methods
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
|
|
unsorted_tensor_method_hints.update(
|
|
{
|
|
"size": [
|
|
"def size(self, dim: None = None) -> Size: ...",
|
|
"def size(self, dim: _int) -> _int: ...",
|
|
],
|
|
"stride": [
|
|
"def stride(self, dim: None = None) -> Tuple[_int, ...]: ...",
|
|
"def stride(self, dim: _int) -> _int: ...",
|
|
],
|
|
"new_ones": [
|
|
f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..."
|
|
],
|
|
"new_tensor": [
|
|
f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..."
|
|
],
|
|
"__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."],
|
|
# new and __init__ have the same signatures differ only in return type
|
|
# Adapted from legacy_tensor_ctor and legacy_tensor_new
|
|
"new": [
|
|
f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...",
|
|
"def new(cls, storage: Storage) -> Self: ...",
|
|
"def new(cls, other: Tensor) -> Self: ...",
|
|
f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...",
|
|
],
|
|
"__init__": [
|
|
f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...",
|
|
"def __init__(self, storage: Storage) -> None: ...",
|
|
"def __init__(self, other: Tensor) -> None: ...",
|
|
f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...",
|
|
],
|
|
"as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."],
|
|
"_make_subclass": [
|
|
"@staticmethod \ndef _make_subclass({}) -> S: ...".format(
|
|
", ".join(
|
|
[
|
|
"cls: _Type[S]",
|
|
"data: Tensor",
|
|
"require_grad: _bool = False",
|
|
"dispatch_strides: _bool = False",
|
|
"dispatch_device: _bool = False",
|
|
"device_for_backend_keys: Optional[_device] = None",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"__contains__": ["def __contains__(self, other: Any, /) -> _bool: ..."],
|
|
"__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."],
|
|
"__setitem__": [
|
|
f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..."
|
|
],
|
|
"tolist": ["def tolist(self) -> List: ..."],
|
|
"requires_grad_": [
|
|
"def requires_grad_(self, mode: _bool = True) -> Tensor: ..."
|
|
],
|
|
"element_size": ["def element_size(self) -> _int: ..."],
|
|
"data_ptr": ["def data_ptr(self) -> _int: ..."],
|
|
"dim": ["def dim(self) -> _int: ..."],
|
|
"nonzero": [
|
|
"def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...",
|
|
"def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
|
|
],
|
|
"numel": ["def numel(self) -> _int: ..."],
|
|
"ndimension": ["def ndimension(self) -> _int: ..."],
|
|
"nelement": ["def nelement(self) -> _int: ..."],
|
|
"cuda": [
|
|
"def cuda({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"self",
|
|
"device: Optional[Union[_device, _int, str]] = None",
|
|
"non_blocking: _bool = False",
|
|
"memory_format: torch.memory_format = torch.preserve_format",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"xpu": [
|
|
"def xpu({}) -> Tensor: ...".format(
|
|
", ".join(
|
|
[
|
|
"self",
|
|
"device: Optional[Union[_device, _int, str]] = None",
|
|
"non_blocking: _bool = False",
|
|
"memory_format: torch.memory_format = torch.preserve_format",
|
|
]
|
|
)
|
|
)
|
|
],
|
|
"cpu": [
|
|
"def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..."
|
|
],
|
|
"numpy": ["def numpy(self, *, force: _bool = False) -> numpy.ndarray: ..."],
|
|
"apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
|
|
"map_": [
|
|
"def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."
|
|
],
|
|
"map2_": [
|
|
"def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
|
|
],
|
|
"storage": ["def untyped_storage(self) -> UntypedStorage: ..."],
|
|
"storage_type": ["def storage_type(self) -> Storage: ..."],
|
|
"type": [
|
|
"def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...",
|
|
"def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...",
|
|
],
|
|
"get_device": ["def get_device(self) -> _int: ..."],
|
|
"contiguous": [
|
|
"def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..."
|
|
],
|
|
"has_names": ["def has_names(self) -> _bool: ..."],
|
|
"is_contiguous": [
|
|
"def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
|
|
],
|
|
"_is_view": ["def _is_view(self) -> _bool: ..."],
|
|
"is_cpu": ["is_cpu: _bool"],
|
|
"is_cuda": ["is_cuda: _bool"],
|
|
"is_leaf": ["is_leaf: _bool"],
|
|
"is_nested": ["is_nested: _bool"],
|
|
"is_sparse": ["is_sparse: _bool"],
|
|
"is_sparse_csr": ["is_sparse_csr: _bool"],
|
|
"is_quantized": ["is_quantized: _bool"],
|
|
"is_meta": ["is_meta: _bool"],
|
|
"is_mps": ["is_mps: _bool"],
|
|
"is_mtia": ["is_mtia: _bool"],
|
|
"is_maia": ["is_maia: _bool"],
|
|
"is_mkldnn": ["is_mkldnn: _bool"],
|
|
"is_vulkan": ["is_vulkan: _bool"],
|
|
"is_ipu": ["is_ipu: _bool"],
|
|
"storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."],
|
|
"to": [
|
|
(
|
|
f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, "
|
|
"memory_format: Optional[torch.memory_format] = None) -> Tensor: ..."
|
|
)
|
|
for args in [
|
|
"dtype: _dtype",
|
|
"device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None",
|
|
"other: Tensor",
|
|
]
|
|
],
|
|
"item": ["def item(self) -> Number: ..."],
|
|
"copy_": [
|
|
"def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..."
|
|
],
|
|
"set_": [
|
|
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
|
|
"offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...",
|
|
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
|
|
],
|
|
"split": [
|
|
"def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...",
|
|
"def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...",
|
|
],
|
|
"div": [
|
|
"def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
|
|
],
|
|
"div_": [
|
|
"def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
|
|
],
|
|
}
|
|
)
|
|
for binop in ["true_divide", "floor_divide"]:
|
|
for inplace in [False, True]:
|
|
out_suffix = ", *, out: Optional[Tensor] = None"
|
|
if inplace:
|
|
binop += "_"
|
|
out_suffix = ""
|
|
unsorted_tensor_method_hints[binop].append(
|
|
f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})"
|
|
" -> Tensor: ..."
|
|
)
|
|
for binop in ["mul"]:
|
|
for inplace in [False, True]:
|
|
out_suffix = ", *, out: Optional[Tensor] = None"
|
|
if inplace:
|
|
binop += "_"
|
|
out_suffix = ""
|
|
unsorted_tensor_method_hints[binop].append(
|
|
f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat]{out_suffix})"
|
|
" -> Tensor: ..."
|
|
)
|
|
for binop in ["add", "sub"]:
|
|
for inplace in [False, True]:
|
|
out_suffix = ", out: Optional[Tensor] = None"
|
|
if inplace:
|
|
binop += "_"
|
|
out_suffix = ""
|
|
unsorted_tensor_method_hints[binop].append(
|
|
f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], "
|
|
f"*, alpha: Optional[Union[Number, _complex]] = 1{out_suffix})"
|
|
" -> Tensor: ..."
|
|
)
|
|
simple_conversions = [
|
|
"byte",
|
|
"char",
|
|
"double",
|
|
"float",
|
|
"half",
|
|
"int",
|
|
"long",
|
|
"short",
|
|
"bool",
|
|
"bfloat16",
|
|
]
|
|
for name in simple_conversions:
|
|
unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
|
|
|
|
# pyi tensor methods don't currently include deprecated signatures for some reason
|
|
# TODO: we should probably add them in
|
|
tensor_method_signatures = load_signatures(
|
|
native_functions,
|
|
deprecated_yaml_path,
|
|
method=True,
|
|
skip_deprecated=True,
|
|
pyi=True,
|
|
)
|
|
tensor_method_sig_groups = get_py_torch_functions(
|
|
tensor_method_signatures, method=True
|
|
)
|
|
|
|
for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
|
|
name = group.signature.name
|
|
unsorted_tensor_method_hints[name] += generate_type_hints(group)
|
|
|
|
structseq = returns_structseq_pyi(group.signature)
|
|
if structseq is not None and not group.signature.deprecated:
|
|
# deprecated structseqs are currently not included for torch functions
|
|
tuple_name, tuple_def = structseq
|
|
if tuple_name in structseqs:
|
|
assert structseqs[tuple_name] == tuple_def
|
|
else:
|
|
structseqs[tuple_name] = tuple_def
|
|
|
|
for op in all_ops:
|
|
name = f"__{op}__"
|
|
unsorted_tensor_method_hints[name] += sig_for_ops(name)
|
|
|
|
tensor_method_hints = []
|
|
for name, hints in sorted(unsorted_tensor_method_hints.items()):
|
|
if len(hints) > 1:
|
|
hints = ["@overload\n" + h for h in hints]
|
|
docstr = docstrs.get(f"torch._C.TensorBase.{name}")
|
|
if docstr is not None:
|
|
hints = [add_docstr_to_hint(docstr, h) for h in hints]
|
|
tensor_method_hints += hints
|
|
|
|
# TODO: Missing type hints for nn
|
|
|
|
# Generate structseq definitions
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
structseq_defs = [f"{defn}\n" for defn in structseqs.values()]
|
|
|
|
# Generate type signatures for legacy classes
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
legacy_storage_base_hints = ["class StorageBase(object): ..."]
|
|
|
|
legacy_class_hints = []
|
|
for c in (
|
|
"DoubleTensor",
|
|
"FloatTensor",
|
|
"BFloat16Tensor",
|
|
"LongTensor",
|
|
"IntTensor",
|
|
"ShortTensor",
|
|
"HalfTensor",
|
|
"CharTensor",
|
|
"ByteTensor",
|
|
"BoolTensor",
|
|
):
|
|
legacy_class_hints.append(f"class {c}(Tensor): ...")
|
|
|
|
# Generate type signatures for dtype classes
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
# TODO: don't explicitly list dtypes here; get it from canonical
|
|
# source
|
|
dtype_class_hints = [
|
|
f"{n}: dtype = ..."
|
|
for n in [
|
|
"float32",
|
|
"float",
|
|
"float64",
|
|
"double",
|
|
"float16",
|
|
"bfloat16",
|
|
"float8_e4m3fn",
|
|
"float8_e4m3fnuz",
|
|
"float8_e5m2",
|
|
"float8_e5m2fnuz",
|
|
"half",
|
|
"uint8",
|
|
"uint16",
|
|
"uint32",
|
|
"uint64",
|
|
"int8",
|
|
"int16",
|
|
"short",
|
|
"int32",
|
|
"int",
|
|
"int64",
|
|
"long",
|
|
"complex32",
|
|
"complex64",
|
|
"chalf",
|
|
"cfloat",
|
|
"complex128",
|
|
"cdouble",
|
|
"quint8",
|
|
"qint8",
|
|
"qint32",
|
|
"bool",
|
|
"quint4x2",
|
|
"quint2x4",
|
|
"bits1x8",
|
|
"bits2x4",
|
|
"bits4x2",
|
|
"bits8",
|
|
"bits16",
|
|
]
|
|
]
|
|
|
|
# Generate __all__ directive
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
# Include only the functions that contain hints, to prevent undefined
|
|
# symbols to be included in the `__all__` directive.
|
|
hinted_function_names = [
|
|
name for name, hint in unsorted_function_hints.items() if hint
|
|
]
|
|
all_symbols = sorted(list(structseqs.keys()) + hinted_function_names)
|
|
all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
|
|
all_directive[0] = f"__all__ = {all_directive[0]}"
|
|
|
|
# Dispatch key hints
|
|
# ~~~~~~~~~~~~~~~~~~
|
|
dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey]
|
|
torch_dispatch_mode_key_hints = [
|
|
f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey
|
|
]
|
|
|
|
# Tags Enum type hints
|
|
# ~~~~~~~~~~~~~~~~~~~~
|
|
|
|
tag_names = sorted(parse_tags_yaml(tags_yaml_path))
|
|
tag_attributes = "\n".join(
|
|
f"{name}: _int = {index}" for index, name in enumerate(tag_names)
|
|
)
|
|
|
|
# Write out the stub
|
|
# ~~~~~~~~~~~~~~~~~~
|
|
|
|
env = {
|
|
"structseq_defs": structseq_defs,
|
|
"function_hints": function_hints,
|
|
"tensor_method_hints": tensor_method_hints,
|
|
"legacy_class_hints": legacy_class_hints,
|
|
"legacy_storage_base_hints": legacy_storage_base_hints,
|
|
"dtype_class_hints": dtype_class_hints,
|
|
"dispatch_key_hints": dispatch_key_hints,
|
|
"torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints,
|
|
"all_directive": all_directive,
|
|
"tag_attributes": tag_attributes,
|
|
}
|
|
fm.write_with_template(
|
|
"torch/_C/__init__.pyi",
|
|
"torch/_C/__init__.pyi.in",
|
|
lambda: env,
|
|
)
|
|
fm.write_with_template(
|
|
"torch/_C/_VariableFunctions.pyi",
|
|
"torch/_C/_VariableFunctions.pyi.in",
|
|
lambda: env,
|
|
)
|
|
fm.write_with_template(
|
|
"torch/_VF.pyi",
|
|
"torch/_C/_VariableFunctions.pyi.in",
|
|
lambda: env,
|
|
)
|
|
fm.write_with_template(
|
|
"torch/return_types.pyi",
|
|
"torch/_C/return_types.pyi.in",
|
|
lambda: env,
|
|
)
|
|
gen_nn_functional(fm)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
|
|
parser.add_argument(
|
|
"--native-functions-path",
|
|
metavar="NATIVE",
|
|
default="aten/src/ATen/native/native_functions.yaml",
|
|
help="path to native_functions.yaml",
|
|
)
|
|
parser.add_argument(
|
|
"--tags-path",
|
|
metavar="TAGS",
|
|
default="aten/src/ATen/native/tags.yaml",
|
|
help="path to tags.yaml",
|
|
)
|
|
parser.add_argument(
|
|
"--deprecated-functions-path",
|
|
metavar="DEPRECATED",
|
|
default="tools/autograd/deprecated.yaml",
|
|
help="path to deprecated.yaml",
|
|
)
|
|
parser.add_argument(
|
|
"--out", metavar="OUT", default=".", help="path to output directory"
|
|
)
|
|
args = parser.parse_args()
|
|
fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
|
|
gen_pyi(
|
|
args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|