mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
create type hint stub files for module torch (#12500)
Summary: We have: - This is an initial stab at creating a type stub `torch/__init__.pyi` . - This is only tested on Python 3, since that's the only Python version mypy works on. - So far, we only aim at doing this for torch functions and torch.Tensor. - Quite a few methods and functions have to be typed manually. These are done in `torch/__init__.pyi.in` For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out. An example of a generated PYI is at [this gist](https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1). Pull Request resolved: https://github.com/pytorch/pytorch/pull/12500 Differential Revision: D13695553 Pulled By: ezyang fbshipit-source-id: 4566c71913ede4e4c23ebc4a72c17151f94e8e21
This commit is contained in:
committed by
Facebook Github Bot
parent
3b337e7892
commit
6a6983ed7f
2
.gitignore
vendored
2
.gitignore
vendored
@ -35,11 +35,13 @@ test/data/gpu_tensors.pt
|
||||
test/data/legacy_modules.t7
|
||||
test/data/legacy_serialized.pt
|
||||
test/data/linear.pt
|
||||
test/generated_type_hints_smoketest.py
|
||||
test/htmlcov
|
||||
test/cpp_extensions/install/
|
||||
third_party/build/
|
||||
tools/shared/_utils_internal.py
|
||||
torch.egg-info/
|
||||
torch/__init__.pyi
|
||||
torch/csrc/autograd/generated/*
|
||||
torch/csrc/cudnn/cuDNN.cpp
|
||||
torch/csrc/generated
|
||||
|
@ -34,6 +34,10 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then
|
||||
|
||||
# TODO: move this to Docker
|
||||
pip install -q hypothesis --user
|
||||
|
||||
# mypy will fail to install on Python <3.4. In that case,
|
||||
# we just won't run these tests.
|
||||
pip install mypy --user || true
|
||||
fi
|
||||
|
||||
# DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems
|
||||
|
1
setup.py
1
setup.py
@ -731,6 +731,7 @@ if __name__ == '__main__':
|
||||
entry_points=entry_points,
|
||||
package_data={
|
||||
'torch': [
|
||||
'__init__.pyi',
|
||||
'lib/*.so*',
|
||||
'lib/*.dylib*',
|
||||
'lib/*.dll',
|
||||
|
@ -42,6 +42,7 @@ TESTS = [
|
||||
'thd_distributed',
|
||||
'torch',
|
||||
'type_info',
|
||||
'type_hints',
|
||||
'utils',
|
||||
]
|
||||
|
||||
|
167
test/test_type_hints.py
Normal file
167
test/test_type_hints.py
Normal file
@ -0,0 +1,167 @@
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
from common_utils import TestCase, run_tests, download_file
|
||||
import tempfile
|
||||
import torch
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import inspect
|
||||
|
||||
try:
|
||||
import mypy
|
||||
HAVE_MYPY = True
|
||||
except ImportError:
|
||||
HAVE_MYPY = False
|
||||
|
||||
|
||||
def get_examples_from_docstring(docstr):
|
||||
"""
|
||||
Extracts all runnable python code from the examples
|
||||
in docstrings; returns a list of lines.
|
||||
"""
|
||||
# TODO: Figure out if there's a way to use doctest directly to
|
||||
# implement this
|
||||
example_file_lines = []
|
||||
# the detection is a bit hacky because there isn't a nice way of detecting
|
||||
# where multiline commands end. Thus we keep track of how far we got in beginning
|
||||
# and continue to add lines until we have a compileable Python statement.
|
||||
exampleline_re = re.compile(r"^\s+(?:>>>|\.\.\.) (.*)$")
|
||||
beginning = ""
|
||||
for l in docstr.split('\n'):
|
||||
if beginning:
|
||||
m = exampleline_re.match(l)
|
||||
if m:
|
||||
beginning += m.group(1)
|
||||
else:
|
||||
beginning += l
|
||||
else:
|
||||
m = exampleline_re.match(l)
|
||||
if m:
|
||||
beginning += m.group(1)
|
||||
if beginning:
|
||||
complete = True
|
||||
try:
|
||||
compile(beginning, "", "exec")
|
||||
except SyntaxError:
|
||||
complete = False
|
||||
if complete:
|
||||
# found one
|
||||
example_file_lines += beginning.split('\n')
|
||||
beginning = ""
|
||||
else:
|
||||
beginning += "\n"
|
||||
return [' ' + l for l in example_file_lines]
|
||||
|
||||
|
||||
def get_all_examples():
|
||||
"""get_all_examples() -> str
|
||||
|
||||
This function grabs (hopefully all) examples from the torch documentation
|
||||
strings and puts them in one nonsensical module returned as a string.
|
||||
"""
|
||||
blacklist = {"_np"}
|
||||
allexamples = ""
|
||||
|
||||
example_file_lines = [
|
||||
"import torch",
|
||||
"import torch.nn.functional as F",
|
||||
"import math # type: ignore", # mypy complains about floats where SupportFloat is expected
|
||||
"import numpy # type: ignore",
|
||||
"import io # type: ignore",
|
||||
"import itertools # type: ignore",
|
||||
"",
|
||||
# for requires_grad_ example
|
||||
# NB: We are parsing this file as Python 2, so we must use
|
||||
# Python 2 type annotation syntax
|
||||
"def preprocess(inp):",
|
||||
" # type: (torch.Tensor) -> torch.Tensor",
|
||||
" return inp",
|
||||
]
|
||||
|
||||
for fname in dir(torch):
|
||||
fn = getattr(torch, fname)
|
||||
docstr = inspect.getdoc(fn)
|
||||
if docstr and fname not in blacklist:
|
||||
e = get_examples_from_docstring(docstr)
|
||||
if e:
|
||||
example_file_lines.append("\n\ndef example_torch_{}():".format(fname))
|
||||
example_file_lines += e
|
||||
|
||||
for fname in dir(torch.Tensor):
|
||||
fn = getattr(torch.Tensor, fname)
|
||||
docstr = inspect.getdoc(fn)
|
||||
if docstr and fname not in blacklist:
|
||||
e = get_examples_from_docstring(docstr)
|
||||
if e:
|
||||
example_file_lines.append("\n\ndef example_torch_tensor_{}():".format(fname))
|
||||
example_file_lines += e
|
||||
|
||||
return "\n".join(example_file_lines)
|
||||
|
||||
|
||||
class TestTypeHints(TestCase):
|
||||
@unittest.skipIf(sys.version_info[0] == 2, "no type hints for Python 2")
|
||||
@unittest.skipIf(not HAVE_MYPY, "need mypy")
|
||||
def test_doc_examples(self):
|
||||
"""
|
||||
Run documentation examples through mypy.
|
||||
"""
|
||||
fn = os.path.join(os.path.dirname(__file__), 'generated_type_hints_smoketest.py')
|
||||
with open(fn, "w") as f:
|
||||
print(get_all_examples(), file=f)
|
||||
|
||||
# OK, so here's the deal. mypy treats installed packages
|
||||
# and local modules differently: if a package is installed,
|
||||
# mypy will refuse to use modules from that package for type
|
||||
# checking unless the module explicitly says that it supports
|
||||
# type checking. (Reference:
|
||||
# https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports
|
||||
# )
|
||||
#
|
||||
# Now, PyTorch doesn't support typechecking, and we shouldn't
|
||||
# claim that it supports typechecking (it doesn't.) However, not
|
||||
# claiming we support typechecking is bad for this test, which
|
||||
# wants to use the partial information we get from the bits of
|
||||
# PyTorch which are typed to check if it typechecks. And
|
||||
# although mypy will work directly if you are working in source,
|
||||
# some of our tests involve installing PyTorch and then running
|
||||
# its tests.
|
||||
#
|
||||
# The guidance we got from Michael Sullivan and Joshua Oreman,
|
||||
# and also independently developed by Thomas Viehmann,
|
||||
# is that we should create a fake directory and add symlinks for
|
||||
# the packages that should typecheck. So that is what we do
|
||||
# here.
|
||||
#
|
||||
# If you want to run mypy by hand, and you run from PyTorch
|
||||
# root directory, it should work fine to skip this step (since
|
||||
# mypy will preferentially pick up the local files first). The
|
||||
# temporary directory here is purely needed for CI. For this
|
||||
# reason, we also still drop the generated file in the test
|
||||
# source folder, for ease of inspection when there are failures.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
try:
|
||||
os.symlink(
|
||||
os.path.dirname(torch.__file__),
|
||||
os.path.join(tmp_dir, 'torch'),
|
||||
target_is_directory=True
|
||||
)
|
||||
except OSError:
|
||||
raise unittest.SkipTest('cannot symlink')
|
||||
try:
|
||||
subprocess.run([
|
||||
sys.executable,
|
||||
'-mmypy',
|
||||
'--follow-imports', 'silent',
|
||||
'--check-untyped-defs',
|
||||
os.path.abspath(fn)],
|
||||
cwd=tmp_dir,
|
||||
check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise AssertionError("mypy failed. Look above this error for mypy's output.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -182,33 +182,49 @@ def should_generate_python_binding(declaration):
|
||||
return True
|
||||
|
||||
|
||||
def gen_py_variable_methods(out, declarations, template_path):
|
||||
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
||||
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
|
||||
|
||||
def get_py_variable_methods(declarations):
|
||||
"""
|
||||
Get declarations (grouped by name) which should be generated
|
||||
as methods on Tensor.
|
||||
"""
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
declaration['mode'] != 'NN' and
|
||||
declaration.get('python_module') != 'nn' and
|
||||
'Tensor' in declaration['method_of'])
|
||||
|
||||
py_variable_methods = group_declarations_by_name(declarations, should_bind)
|
||||
return group_declarations_by_name(declarations, should_bind)
|
||||
|
||||
|
||||
def gen_py_variable_methods(out, declarations, template_path):
|
||||
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
||||
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
|
||||
|
||||
py_variable_methods = get_py_variable_methods(declarations)
|
||||
|
||||
env = create_python_bindings(py_variable_methods, True)
|
||||
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
|
||||
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
|
||||
|
||||
|
||||
def get_py_nn_functions(declarations):
|
||||
"""
|
||||
Get declarations (grouped by name) which should be generated
|
||||
as functions in the "nn" module.
|
||||
"""
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
(declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
|
||||
|
||||
return group_declarations_by_name(declarations, should_bind)
|
||||
|
||||
|
||||
def gen_py_nn_functions(out, declarations, template_path):
|
||||
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
|
||||
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
|
||||
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
|
||||
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
(declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
|
||||
|
||||
py_nn_functions = group_declarations_by_name(declarations, should_bind)
|
||||
py_nn_functions = get_py_nn_functions(declarations)
|
||||
|
||||
env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
|
||||
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
|
||||
@ -216,17 +232,25 @@ def gen_py_nn_functions(out, declarations, template_path):
|
||||
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
|
||||
|
||||
|
||||
def gen_py_torch_functions(out, declarations, template_path):
|
||||
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
||||
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
|
||||
|
||||
def get_py_torch_functions(declarations):
|
||||
"""
|
||||
Get declarations (grouped by name) which should be generated
|
||||
as functions in the "torch" module.
|
||||
"""
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
declaration['mode'] != 'NN' and
|
||||
declaration.get('python_module') != 'nn' and
|
||||
'namespace' in declaration['method_of'])
|
||||
|
||||
py_torch_functions = group_declarations_by_name(declarations, should_bind)
|
||||
return group_declarations_by_name(declarations, should_bind)
|
||||
|
||||
|
||||
def gen_py_torch_functions(out, declarations, template_path):
|
||||
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
||||
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
|
||||
|
||||
py_torch_functions = get_py_torch_functions(declarations)
|
||||
|
||||
env = create_python_bindings(py_torch_functions, has_self=False)
|
||||
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
|
||||
@ -800,7 +824,16 @@ def sort_declarations(grouped_decls):
|
||||
|
||||
|
||||
def get_python_signature(declaration, include_out):
|
||||
# Compute the Python function signature for argument parsing
|
||||
# Compute the Python function signature for argument parsing,
|
||||
# as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
|
||||
# this is NOT the same type signature as specified by PEP 484
|
||||
# as understood by mypy; our format was independently developed
|
||||
# and has some quirks to make it more suitable specifically
|
||||
# for error parsing.
|
||||
#
|
||||
# For a translation to mypy-valid type signatures, see
|
||||
# tools/gen_pyi.py. If you change any logic here, please
|
||||
# check that file too.
|
||||
py_formal_args = []
|
||||
output_args = []
|
||||
type_args = []
|
||||
|
@ -14,6 +14,12 @@ except ImportError:
|
||||
from tools.shared.module_loader import import_module
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
|
||||
# You should use these lines, rather than doing it manually.
|
||||
# Especially if you see this error!
|
||||
#
|
||||
# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
|
||||
# loader = Loader(stream)
|
||||
# TypeError: 'module' object is not callable
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CLoader as YamlLoader
|
||||
|
0
tools/pyi/__init__.py
Normal file
0
tools/pyi/__init__.py
Normal file
529
tools/pyi/gen_pyi.py
Normal file
529
tools/pyi/gen_pyi.py
Normal file
@ -0,0 +1,529 @@
|
||||
from __future__ import print_function
|
||||
import multiprocessing
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
import collections
|
||||
import yaml
|
||||
import types
|
||||
import re
|
||||
import argparse
|
||||
|
||||
from ..autograd.utils import YamlLoader, CodeTemplate, write
|
||||
from ..autograd.gen_python_functions import get_py_torch_functions, get_py_variable_methods
|
||||
from ..autograd.gen_autograd import load_aten_declarations
|
||||
|
||||
"""
|
||||
This module implements generation of type stubs for PyTorch,
|
||||
enabling use of autocomplete in IDEs like PyCharm, which otherwise
|
||||
don't understand C extension modules.
|
||||
|
||||
At the moment, this module only handles type stubs for torch and
|
||||
torch.Tensor. It should eventually be expanded to cover all functions
|
||||
which come are autogenerated.
|
||||
|
||||
Here's our general strategy:
|
||||
|
||||
- We start off with a hand-written __init__.pyi.in file. This
|
||||
file contains type definitions for everything we cannot automatically
|
||||
generate, including pure Python definitions directly in __init__.py
|
||||
(the latter case should be pretty rare).
|
||||
|
||||
- We go through automatically bound functions based on the
|
||||
type information recorded in Declarations.yaml and
|
||||
generate type hints for them (generate_type_hints)
|
||||
|
||||
There are a number of type hints which we've special-cased;
|
||||
read gen_pyi for the gory details.
|
||||
"""
|
||||
|
||||
# TODO: Consider defining some aliases for our Union[...] types, to make
|
||||
# the stubs to read on the human eye.
|
||||
|
||||
needed_modules = set()
|
||||
|
||||
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False"
|
||||
|
||||
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
|
||||
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
|
||||
|
||||
blacklist = [
|
||||
'__init_subclass__',
|
||||
'__new__',
|
||||
'__subclasshook__',
|
||||
'clamp',
|
||||
'clamp_',
|
||||
'device',
|
||||
'grad',
|
||||
'requires_grad',
|
||||
'range',
|
||||
# defined in functional
|
||||
'einsum',
|
||||
# reduction argument; these bindings don't make sense
|
||||
'ctc_loss',
|
||||
'cosine_embedding_loss',
|
||||
'hinge_embedding_loss',
|
||||
'kl_div',
|
||||
'margin_ranking_loss',
|
||||
'triplet_margin_loss',
|
||||
# Somehow, these are defined in both _C and in functional. Ick!
|
||||
'broadcast_tensors',
|
||||
'meshgrid',
|
||||
'cartesian_prod',
|
||||
'norm',
|
||||
'chain_matmul',
|
||||
'stft',
|
||||
'tensordot',
|
||||
'norm',
|
||||
'split',
|
||||
# These are handled specially by python_arg_parser.cpp
|
||||
'add',
|
||||
'add_',
|
||||
'add_out',
|
||||
'sub',
|
||||
'sub_',
|
||||
'sub_out',
|
||||
'mul',
|
||||
'mul_',
|
||||
'mul_out',
|
||||
'div',
|
||||
'div_',
|
||||
'div_out',
|
||||
]
|
||||
|
||||
|
||||
def type_to_python(typename, size=None):
|
||||
"""type_to_python(typename: str, size: str) -> str
|
||||
|
||||
Transforms a Declarations.yaml type name into a Python type specification
|
||||
as used for type hints.
|
||||
"""
|
||||
typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
|
||||
|
||||
# Disambiguate explicitly sized int/tensor lists from implicitly
|
||||
# sized ones. These permit non-list inputs too. (IntList[] and
|
||||
# TensorList[] are not real types; this is just for convenience.)
|
||||
if typename in {'IntList', 'TensorList'} and size is not None:
|
||||
typename += '[]'
|
||||
|
||||
typename = {
|
||||
'Device': 'Union[_device, str, None]',
|
||||
'Generator*': 'Generator',
|
||||
'IntegerTensor': 'Tensor',
|
||||
'Scalar': 'Number',
|
||||
'ScalarType': '_dtype',
|
||||
'Storage': 'Storage',
|
||||
'BoolTensor': 'Tensor',
|
||||
'IndexTensor': 'Tensor',
|
||||
'SparseTensorRef': 'Tensor',
|
||||
'Tensor': 'Tensor',
|
||||
'IntList': '_size',
|
||||
'IntList[]': 'Union[_int, _size]',
|
||||
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
|
||||
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
|
||||
'bool': 'bool',
|
||||
'double': '_float',
|
||||
'int64_t': '_int',
|
||||
'accreal': 'Number',
|
||||
'real': 'Number',
|
||||
'void*': '_int', # data_ptr
|
||||
'void': 'None',
|
||||
'std::string': 'str',
|
||||
}[typename]
|
||||
|
||||
return typename
|
||||
|
||||
|
||||
def arg_to_type_hint(arg):
|
||||
"""arg_to_type_hint(arg) -> str
|
||||
|
||||
This takes one argument in a Declarations and returns a string
|
||||
representing this argument in a type hint signature.
|
||||
"""
|
||||
name = arg['name']
|
||||
if name == 'from': # from is a Python keyword...
|
||||
name += '_'
|
||||
typename = type_to_python(arg['dynamic_type'], arg.get('size'))
|
||||
if arg.get('is_nullable'):
|
||||
typename = 'Optional[' + typename + ']'
|
||||
if 'default' in arg:
|
||||
default = arg['default']
|
||||
if default == 'nullptr':
|
||||
default = None
|
||||
elif default == 'c10::nullopt':
|
||||
default = None
|
||||
elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
|
||||
if arg['dynamic_type'] == 'Tensor' and default == '{}':
|
||||
default = None
|
||||
elif arg['dynamic_type'] == 'IntList':
|
||||
default = '(' + default[1:-1] + ')'
|
||||
else:
|
||||
raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
|
||||
default = '={}'.format(default)
|
||||
else:
|
||||
default = ''
|
||||
return name + ': ' + typename + default
|
||||
|
||||
|
||||
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
|
||||
'matmul',
|
||||
'radd', 'rmul', # reverse arithmetic
|
||||
'and', 'or', 'xor', # logic
|
||||
'iadd', 'iand', 'idiv', 'ilshift', 'imul',
|
||||
'ior', 'irshift', 'isub', 'itruediv', 'ixor', # inplace ops
|
||||
)
|
||||
comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
|
||||
unary_ops = ('neg', 'abs', 'invert')
|
||||
to_py_type_ops = ('bool', 'float', 'long', 'index', 'int', 'nonzero')
|
||||
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
||||
|
||||
|
||||
def sig_for_ops(opname):
|
||||
"""sig_for_ops(opname : str) -> List[str]
|
||||
|
||||
Returns signatures for operator special functions (__add__ etc.)"""
|
||||
|
||||
# we have to do this by hand, because they are hand-bound in Python
|
||||
|
||||
assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname)
|
||||
|
||||
name = opname[2:-2]
|
||||
if name in binary_ops:
|
||||
return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
|
||||
elif name in comparison_ops:
|
||||
# unsafe override https://github.com/python/mypy/issues/5704
|
||||
return ['def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)]
|
||||
elif name in unary_ops:
|
||||
return ['def {}(self) -> Tensor: ...'.format(opname)]
|
||||
elif name in to_py_type_ops:
|
||||
if name in {'bool', 'float'}:
|
||||
tname = name
|
||||
elif name == 'nonzero':
|
||||
tname = 'bool'
|
||||
else:
|
||||
tname = 'int'
|
||||
if tname in {'float', 'int'}:
|
||||
tname = 'builtins.' + tname
|
||||
return ['def {}(self) -> {}: ...'.format(opname, tname)]
|
||||
else:
|
||||
raise Exception("unknown op", opname)
|
||||
|
||||
|
||||
def generate_type_hints(fname, decls, is_tensor=False):
|
||||
"""generate_type_hints(fname, decls, is_tensor=False)
|
||||
|
||||
Generates type hints for the declarations pertaining to the function
|
||||
:attr:`fname`. attr:`decls` are the declarations from the parsed
|
||||
Declarations.yaml.
|
||||
The :attr:`is_tensor` flag indicates whether we are parsing
|
||||
members of the Tensor class (true) or functions in the
|
||||
`torch` namespace (default, false).
|
||||
|
||||
This function currently encodes quite a bit about the semantics of
|
||||
the translation C++ -> Python.
|
||||
"""
|
||||
if fname in blacklist:
|
||||
return []
|
||||
|
||||
type_hints = []
|
||||
dnames = ([d['name'] for d in decls])
|
||||
has_out = fname + '_out' in dnames
|
||||
|
||||
if has_out:
|
||||
decls = [d for d in decls if d['name'] != fname + '_out']
|
||||
|
||||
for decl in decls:
|
||||
render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
|
||||
python_args = []
|
||||
|
||||
has_tensor_options = 'TensorOptions' in [a['dynamic_type'] for a in decl['arguments']]
|
||||
|
||||
for a in decl['arguments']:
|
||||
if a['dynamic_type'] != 'TensorOptions':
|
||||
if a.get('kwarg_only', False) and render_kw_only_separator:
|
||||
python_args.append('*')
|
||||
render_kw_only_separator = False
|
||||
python_args.append(arg_to_type_hint(a))
|
||||
|
||||
if is_tensor:
|
||||
if 'self: Tensor' in python_args:
|
||||
python_args.remove('self: Tensor')
|
||||
python_args = ['self'] + python_args
|
||||
else:
|
||||
raise Exception("method without self is unexpected")
|
||||
|
||||
if has_out:
|
||||
if render_kw_only_separator:
|
||||
python_args.append('*')
|
||||
render_kw_only_separator = False
|
||||
python_args.append('out: Optional[Tensor]=None')
|
||||
|
||||
if has_tensor_options:
|
||||
if render_kw_only_separator:
|
||||
python_args.append('*')
|
||||
render_kw_only_separator = False
|
||||
python_args += ["dtype: _dtype=None",
|
||||
"layout: layout=strided",
|
||||
"device: Union[_device, str, None]=None",
|
||||
"requires_grad:bool=False"]
|
||||
|
||||
python_args_s = ', '.join(python_args)
|
||||
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
|
||||
|
||||
if len(python_returns) > 1:
|
||||
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
|
||||
else:
|
||||
python_returns_s = python_returns[0]
|
||||
|
||||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
||||
numargs = len(decl['arguments'])
|
||||
vararg_pos = int(is_tensor)
|
||||
have_vararg_version = (numargs > vararg_pos and
|
||||
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntList', 'TensorList'} and
|
||||
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
|
||||
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
|
||||
|
||||
type_hints.append(type_hint)
|
||||
|
||||
if have_vararg_version:
|
||||
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
|
||||
# is an IntList or TensorList, it will be used as a vararg variant.
|
||||
# The following outputs the vararg variant, the "pass a list variant" is output above.
|
||||
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
|
||||
typelist = decl['arguments'][vararg_pos]['dynamic_type']
|
||||
if typelist == 'IntList':
|
||||
vararg_type = '_int'
|
||||
else:
|
||||
vararg_type = 'Tensor'
|
||||
# replace first argument and eliminate '*' if present
|
||||
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
|
||||
': ' + vararg_type] + python_args[vararg_pos + 2:])
|
||||
python_args_s = ', '.join(python_args)
|
||||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
||||
type_hints.append(type_hint)
|
||||
|
||||
return type_hints
|
||||
|
||||
|
||||
def gen_pyi(declarations_path, out):
|
||||
"""gen_pyi()
|
||||
|
||||
This function generates a pyi file for torch.
|
||||
"""
|
||||
|
||||
# Some of this logic overlaps with generate_python_signature in
|
||||
# tools/autograd/gen_python_functions.py; however, this
|
||||
# function is all about generating mypy type signatures, whereas
|
||||
# the other function generates are custom format for argument
|
||||
# checking. If you are update this, consider if your change
|
||||
# also needs to update the other file.
|
||||
|
||||
# Load information from YAML
|
||||
declarations = load_aten_declarations(declarations_path)
|
||||
|
||||
# Generate type signatures for top-level functions
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
unsorted_function_hints = collections.defaultdict(list)
|
||||
unsorted_function_hints.update({
|
||||
'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'],
|
||||
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
|
||||
'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
|
||||
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
||||
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
|
||||
'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
|
||||
'get_num_threads': ['def get_num_threads() -> _int: ...'],
|
||||
'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
|
||||
# These functions are explicitly disabled by
|
||||
# SKIP_PYTHON_BINDINGS because they are hand bound.
|
||||
# Correspondingly, we must hand-write their signatures.
|
||||
'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
||||
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
|
||||
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
|
||||
'range': ['def range(start: Number, end: Number,'
|
||||
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS)],
|
||||
'arange': ['def arange(start: Number, end: Number, step: Number, *,'
|
||||
' out: Optional[Tensor]=None, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS),
|
||||
'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS),
|
||||
'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS)],
|
||||
'randint': ['def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS),
|
||||
'def randint(high: _int, size: _size, *, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS)],
|
||||
})
|
||||
for binop in ['add', 'sub', 'mul', 'div']:
|
||||
unsorted_function_hints[binop].append(
|
||||
'def {}(input: Union[Tensor, Number],'
|
||||
' other: Union[Tensor, Number],'
|
||||
' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
|
||||
unsorted_function_hints[binop].append(
|
||||
'def {}(input: Union[Tensor, Number],'
|
||||
' value: Number,'
|
||||
' other: Union[Tensor, Number],'
|
||||
' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
|
||||
|
||||
function_declarations = get_py_torch_functions(declarations)
|
||||
for name in sorted(function_declarations.keys()):
|
||||
unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name])
|
||||
|
||||
# Generate type signatures for deprecated functions
|
||||
|
||||
# TODO: Maybe we shouldn't generate type hints for deprecated
|
||||
# functions :) However, examples like those addcdiv rely on these.
|
||||
with open('tools/autograd/deprecated.yaml', 'r') as f:
|
||||
deprecated = yaml.load(f, Loader=YamlLoader)
|
||||
for d in deprecated:
|
||||
name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
|
||||
sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
|
||||
sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
|
||||
unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
|
||||
|
||||
function_hints = []
|
||||
for name, hints in sorted(unsorted_function_hints.items()):
|
||||
if len(hints) > 1:
|
||||
hints = ['@overload\n' + h for h in hints]
|
||||
function_hints += hints
|
||||
|
||||
# Generate type signatures for Tensor methods
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
unsorted_tensor_method_hints = collections.defaultdict(list)
|
||||
unsorted_tensor_method_hints.update({
|
||||
'size': ['def size(self) -> Size: ...',
|
||||
'def size(self, _int) -> _int: ...'],
|
||||
'stride': ['def stride(self) -> Tuple[_int]: ...',
|
||||
'def stride(self, _int) -> _int: ...'],
|
||||
'new_empty': ['def new_empty(self, size: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntList'), FACTORY_PARAMS)],
|
||||
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntList'), FACTORY_PARAMS)],
|
||||
'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntList'), FACTORY_PARAMS)],
|
||||
'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntList'), type_to_python('Scalar'), FACTORY_PARAMS)],
|
||||
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
||||
# clamp has no default values in the Declarations
|
||||
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
||||
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
|
||||
'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
|
||||
'__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
|
||||
'__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
|
||||
" -> None: ...".format(INDICES)],
|
||||
'tolist': ['def tolist(self) -> List: ...'],
|
||||
'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
|
||||
'element_size': ['def element_size(self) -> _int: ...'],
|
||||
'dim': ['def dim(self) -> _int: ...'],
|
||||
'ndimension': ['def ndimension(self) -> _int: ...'],
|
||||
'nelement': ['def nelement(self) -> _int: ...'],
|
||||
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
|
||||
'numpy': ['def numpy(self) -> Any: ...'],
|
||||
'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
|
||||
'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
|
||||
'copy_': ['def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'],
|
||||
'storage': ['def storage(self) -> Storage: ...'],
|
||||
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
|
||||
' -> Union[str, Tensor]: ...'],
|
||||
'get_device': ['def get_device(self) -> _int: ...'],
|
||||
'is_contiguous': ['def is_contiguous(self) -> bool: ...'],
|
||||
'is_cuda': ['def is_cuda(self) -> bool: ...'],
|
||||
'is_leaf': ['def is_leaf(self) -> bool: ...'],
|
||||
'storage_offset': ['def storage_offset(self) -> _int: ...'],
|
||||
'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
|
||||
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
|
||||
'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
|
||||
'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
|
||||
],
|
||||
'item': ["def item(self) -> Number: ..."],
|
||||
})
|
||||
for binop in ['add', 'sub', 'mul', 'div']:
|
||||
for inplace in [True, False]:
|
||||
out_suffix = ', *, out: Optional[Tensor]=None'
|
||||
if inplace:
|
||||
name += '_'
|
||||
out_suffix = ''
|
||||
unsorted_tensor_method_hints[name].append(
|
||||
'def {}(self, other: Union[Tensor, Number]{})'
|
||||
' -> Tensor: ...'.format(name, out_suffix))
|
||||
unsorted_tensor_method_hints[name].append(
|
||||
'def {}(self, value: Number,'
|
||||
' other: Union[Tensor, Number]{})'
|
||||
' -> Tensor: ...'.format(name, out_suffix))
|
||||
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short']
|
||||
for name in simple_conversions:
|
||||
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
|
||||
|
||||
tensor_method_declarations = get_py_variable_methods(declarations)
|
||||
for name in sorted(tensor_method_declarations.keys()):
|
||||
unsorted_tensor_method_hints[name] += \
|
||||
generate_type_hints(name, tensor_method_declarations[name], is_tensor=True)
|
||||
|
||||
for op in all_ops:
|
||||
name = '__{}__'.format(op)
|
||||
unsorted_tensor_method_hints[name] += sig_for_ops(name)
|
||||
|
||||
tensor_method_hints = []
|
||||
for name, hints in sorted(unsorted_tensor_method_hints.items()):
|
||||
if len(hints) > 1:
|
||||
hints = ['@overload\n' + h for h in hints]
|
||||
tensor_method_hints += hints
|
||||
|
||||
# TODO: Missing type hints for nn
|
||||
|
||||
# Generate type signatures for legacy classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# TODO: These are deprecated, maybe we shouldn't type hint them
|
||||
legacy_class_hints = []
|
||||
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage'):
|
||||
legacy_class_hints.append('class {}(Storage): ...'.format(c))
|
||||
|
||||
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor'):
|
||||
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
|
||||
|
||||
# Generate type signatures for dtype classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# TODO: don't explicitly list dtypes here; get it from canonical
|
||||
# source
|
||||
dtype_class_hints = ['{}: dtype = ...'.format(n)
|
||||
for n in
|
||||
['float32', 'float', 'float64', 'double', 'float16', 'half',
|
||||
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
|
||||
'complex32', 'complex64', 'complex128']]
|
||||
|
||||
# Write out the stub
|
||||
# ~~~~~~~~~~~~~~~~~~
|
||||
|
||||
env = {
|
||||
'function_hints': function_hints,
|
||||
'tensor_method_hints': tensor_method_hints,
|
||||
'legacy_class_hints': legacy_class_hints,
|
||||
'dtype_class_hints': dtype_class_hints,
|
||||
}
|
||||
TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
|
||||
|
||||
write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Generate type stubs for PyTorch')
|
||||
parser.add_argument('--declarations-path', metavar='DECL',
|
||||
default='torch/share/ATen/Declarations.yaml',
|
||||
help='path to Declarations.yaml')
|
||||
parser.add_argument('--out', metavar='OUT',
|
||||
default='.',
|
||||
help='path to output directory')
|
||||
args = parser.parse_args()
|
||||
gen_pyi(args.declarations_path, args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -713,8 +713,26 @@ if (BUILD_PYTHON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_custom_target(torch_python_stubs DEPENDS "${TORCH_SRC_DIR}/__init__.pyi")
|
||||
# For Declarations.yaml dependency
|
||||
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
|
||||
add_custom_command(
|
||||
OUTPUT
|
||||
"${TORCH_SRC_DIR}/__init__.pyi"
|
||||
COMMAND
|
||||
${PYCMD} -mtools.pyi.gen_pyi
|
||||
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
DEPENDS
|
||||
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
"${TORCH_SRC_DIR}/__init__.pyi.in"
|
||||
WORKING_DIRECTORY
|
||||
"${TORCH_ROOT}"
|
||||
)
|
||||
|
||||
add_library(torch_python SHARED ${TORCH_PYTHON_SRCS})
|
||||
|
||||
add_dependencies(torch_python torch_python_stubs)
|
||||
|
||||
target_link_libraries(torch_python ${TORCH_PYTHON_LINK_LIBRARIES})
|
||||
|
||||
target_compile_definitions(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS})
|
||||
|
@ -179,6 +179,7 @@ def set_default_dtype(d):
|
||||
"""
|
||||
_C._set_default_dtype(d)
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed
|
||||
from .serialization import save, load
|
||||
from ._tensor_str import set_printoptions
|
||||
|
106
torch/__init__.pyi.in
Normal file
106
torch/__init__.pyi.in
Normal file
@ -0,0 +1,106 @@
|
||||
# ${generated_comment}
|
||||
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
|
||||
from torch._six import inf
|
||||
|
||||
import builtins
|
||||
|
||||
# These identifiers are reexported from other modules. These modules
|
||||
# are not mypy-clean yet, so in order to use this stub file usefully
|
||||
# from mypy you will need to specify --follow-imports=silent.
|
||||
# Not all is lost: these imports still enable IDEs like PyCharm to offer
|
||||
# autocomplete.
|
||||
#
|
||||
# Note: Why does the syntax here look so strange? Import visibility
|
||||
# rules in stubs are different from normal Python files! You must use
|
||||
# 'from ... import ... as ...' syntax to cause an identifier to be
|
||||
# exposed (or use a wildcard); regular syntax is not exposed.
|
||||
from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
|
||||
manual_seed as manual_seed, initial_seed as initial_seed
|
||||
from ._tensor_str import set_printoptions as set_printoptions
|
||||
from .functional import *
|
||||
from .serialization import save as save, load as load
|
||||
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
|
||||
set_grad_enabled as set_grad_enabled
|
||||
|
||||
class dtype: ...
|
||||
|
||||
class layout: ...
|
||||
|
||||
strided : layout = ...
|
||||
|
||||
# See https://github.com/python/mypy/issues/4146 for why these workarounds
|
||||
# is necessary
|
||||
_int = builtins.int
|
||||
_float = builtins.float
|
||||
|
||||
class device:
|
||||
def __init__(self, device: Union[_int, str, None]=None) -> None: ...
|
||||
|
||||
class Generator: ...
|
||||
|
||||
class Size(tuple): ...
|
||||
|
||||
class Storage: ...
|
||||
|
||||
# See https://github.com/python/mypy/issues/4146 for why these workarounds
|
||||
# is necessary
|
||||
_dtype = dtype
|
||||
_device = device
|
||||
_size = Union[Size, List[_int], Tuple[_int, ...]]
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number = Union[builtins.int, builtins.float]
|
||||
|
||||
# TODO: One downside of doing it this way, is direct use of
|
||||
# torch.tensor.Tensor doesn't get type annotations. Nobody
|
||||
# should really do that, so maybe this is not so bad.
|
||||
class Tensor:
|
||||
dtype: _dtype = ...
|
||||
shape: Size = ...
|
||||
device: _device = ...
|
||||
requires_grad: bool = ...
|
||||
grad: Optional[Tensor] = ...
|
||||
|
||||
${tensor_method_hints}
|
||||
|
||||
# Manually defined methods from torch/tensor.py
|
||||
def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ...
|
||||
def register_hook(self, hook: Callable) -> Any: ...
|
||||
def retain_grad(self) -> None: ...
|
||||
def is_pinned(self) -> bool: ...
|
||||
def is_shared(self) -> bool: ...
|
||||
def share_memory_(self) -> None: ...
|
||||
# TODO: fill in the types for these, or otherwise figure out some
|
||||
# way to not have to write these out again...
|
||||
def argmax(self, dim=None, keepdim=False): ...
|
||||
def argmin(self, dim=None, keepdim=False): ...
|
||||
def argsort(self, dim=None, descending=False): ...
|
||||
def norm(self, p="fro", dim=None, keepdim=False): ...
|
||||
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
|
||||
def split(self, split_size, dim=0): ...
|
||||
def index_add(self, dim, index, tensor): ...
|
||||
def index_copy(self, dim, index, tensor): ...
|
||||
def index_fill(self, dim, index, value): ...
|
||||
def scatter(self, dim, index, source): ...
|
||||
def scatter_add(self, dim, index, source): ...
|
||||
def masked_scatter(self, mask, tensor): ...
|
||||
def masked_fill(self, mask, value): ...
|
||||
def unique(self, sorted=True, return_inverse=False, dim=None): ...
|
||||
|
||||
${function_hints}
|
||||
|
||||
${legacy_class_hints}
|
||||
|
||||
${dtype_class_hints}
|
||||
|
||||
# Pure Python functions defined in torch/__init__.py
|
||||
|
||||
def typename(obj) -> str: ...
|
||||
def is_tensor(obj) -> bool: ...
|
||||
def is_storage(obj) -> bool: ...
|
||||
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
|
||||
def set_default_dtype(d : _dtype) -> None: ...
|
||||
def manager_path() -> str: ...
|
||||
def compiled_with_cxx11_abi() -> bool: ...
|
@ -20,6 +20,7 @@
|
||||
|
||||
import itertools
|
||||
import sys
|
||||
import builtins
|
||||
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
@ -48,7 +49,7 @@ else:
|
||||
if PY2:
|
||||
FileNotFoundError = IOError
|
||||
else:
|
||||
FileNotFoundError = FileNotFoundError
|
||||
FileNotFoundError = builtins.FileNotFoundError
|
||||
|
||||
|
||||
if PY2:
|
||||
@ -71,11 +72,10 @@ def with_metaclass(meta, *bases):
|
||||
|
||||
# A portable way of referring to the generator version of map
|
||||
# in both Python 2 and Python 3.
|
||||
# TODO: Move this into an appropriate utility library.
|
||||
if hasattr(itertools, 'imap'):
|
||||
imap = itertools.imap
|
||||
imap = itertools.imap # type: ignore
|
||||
else:
|
||||
imap = map
|
||||
imap = map # type: ignore
|
||||
|
||||
|
||||
if PY3:
|
||||
|
@ -309,3 +309,13 @@ def _take_tensors(tensors, size_limit):
|
||||
for buf, _ in buf_dict.values():
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
# annotation decorator to get annotations in a way that is compatible
|
||||
# with both Python 2 and 3
|
||||
def annotate(ret, **kwargs):
|
||||
def dec(fun):
|
||||
fun.__annotations__ = dict(kwargs)
|
||||
fun.__annotations__['return'] = ret
|
||||
return fun
|
||||
return dec
|
||||
|
@ -36,6 +36,10 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
|
||||
// overloads and binding to the Tensor overload with a number of a different
|
||||
// type will trigger a type error.
|
||||
//
|
||||
// If you modify this, you will need to adjust the blacklist in
|
||||
// tools/pyi/gen_pyi.py (and add hardcoded signatures for these
|
||||
// functions.)
|
||||
static bool should_allow_numbers_as_tensors(const std::string& name) {
|
||||
static std::unordered_set<std::string> allowed = {
|
||||
"add", "add_", "add_out",
|
||||
|
@ -4,8 +4,11 @@ from torch._six import inf
|
||||
from torch._C import _add_docstr
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
from collections import Iterable
|
||||
from torch._utils import annotate
|
||||
from itertools import product
|
||||
import math
|
||||
from typing import Optional, Tuple, List, Union
|
||||
import warnings
|
||||
|
||||
__all__ = [
|
||||
|
@ -366,7 +366,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
# Map tensors from GPU 1 to GPU 0
|
||||
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
|
||||
# Load tensor from io.BytesIO object
|
||||
>>> with open('tensor.pt') as f:
|
||||
>>> with open('tensor.pt', 'rb') as f:
|
||||
buffer = io.BytesIO(f.read())
|
||||
>>> torch.load(buffer)
|
||||
"""
|
||||
|
@ -12,6 +12,10 @@ from torch._C import _add_docstr
|
||||
# NB: If you subclass Tensor, and want to share the subclassed class
|
||||
# across processes, you must also update torch/multiprocessing/reductions.py
|
||||
# to define a ForkingPickler serialization mode for the class.
|
||||
#
|
||||
# NB: If you add a new method to Tensor, you must update
|
||||
# torch/__init__.py.in to add a type annotation for your method;
|
||||
# otherwise, it will not show up in autocomplete.
|
||||
class Tensor(torch._C._TensorBase):
|
||||
def __deepcopy__(self, memo):
|
||||
if not self.is_leaf:
|
||||
|
Reference in New Issue
Block a user