create type hint stub files for module torch (#12500)

Summary:
We have:

- This is an initial stab at creating a type stub `torch/__init__.pyi` .
- This is only tested on Python 3, since that's the only Python version mypy
  works on.
- So far, we only aim at doing this for torch functions and torch.Tensor.
- Quite a few methods and functions have to be typed manually. These are
  done in `torch/__init__.pyi.in`

For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out.

An example of a generated PYI is at [this gist](https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12500

Differential Revision: D13695553

Pulled By: ezyang

fbshipit-source-id: 4566c71913ede4e4c23ebc4a72c17151f94e8e21
This commit is contained in:
Thomas Viehmann
2019-01-29 11:19:51 -08:00
committed by Facebook Github Bot
parent 3b337e7892
commit 6a6983ed7f
18 changed files with 910 additions and 21 deletions

2
.gitignore vendored
View File

@ -35,11 +35,13 @@ test/data/gpu_tensors.pt
test/data/legacy_modules.t7
test/data/legacy_serialized.pt
test/data/linear.pt
test/generated_type_hints_smoketest.py
test/htmlcov
test/cpp_extensions/install/
third_party/build/
tools/shared/_utils_internal.py
torch.egg-info/
torch/__init__.pyi
torch/csrc/autograd/generated/*
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/generated

View File

@ -34,6 +34,10 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then
# TODO: move this to Docker
pip install -q hypothesis --user
# mypy will fail to install on Python <3.4. In that case,
# we just won't run these tests.
pip install mypy --user || true
fi
# DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems

View File

@ -731,6 +731,7 @@ if __name__ == '__main__':
entry_points=entry_points,
package_data={
'torch': [
'__init__.pyi',
'lib/*.so*',
'lib/*.dylib*',
'lib/*.dll',

View File

@ -42,6 +42,7 @@ TESTS = [
'thd_distributed',
'torch',
'type_info',
'type_hints',
'utils',
]

167
test/test_type_hints.py Normal file
View File

@ -0,0 +1,167 @@
from __future__ import print_function
import unittest
from common_utils import TestCase, run_tests, download_file
import tempfile
import torch
import re
import os
import sys
import subprocess
import inspect
try:
import mypy
HAVE_MYPY = True
except ImportError:
HAVE_MYPY = False
def get_examples_from_docstring(docstr):
"""
Extracts all runnable python code from the examples
in docstrings; returns a list of lines.
"""
# TODO: Figure out if there's a way to use doctest directly to
# implement this
example_file_lines = []
# the detection is a bit hacky because there isn't a nice way of detecting
# where multiline commands end. Thus we keep track of how far we got in beginning
# and continue to add lines until we have a compileable Python statement.
exampleline_re = re.compile(r"^\s+(?:>>>|\.\.\.) (.*)$")
beginning = ""
for l in docstr.split('\n'):
if beginning:
m = exampleline_re.match(l)
if m:
beginning += m.group(1)
else:
beginning += l
else:
m = exampleline_re.match(l)
if m:
beginning += m.group(1)
if beginning:
complete = True
try:
compile(beginning, "", "exec")
except SyntaxError:
complete = False
if complete:
# found one
example_file_lines += beginning.split('\n')
beginning = ""
else:
beginning += "\n"
return [' ' + l for l in example_file_lines]
def get_all_examples():
"""get_all_examples() -> str
This function grabs (hopefully all) examples from the torch documentation
strings and puts them in one nonsensical module returned as a string.
"""
blacklist = {"_np"}
allexamples = ""
example_file_lines = [
"import torch",
"import torch.nn.functional as F",
"import math # type: ignore", # mypy complains about floats where SupportFloat is expected
"import numpy # type: ignore",
"import io # type: ignore",
"import itertools # type: ignore",
"",
# for requires_grad_ example
# NB: We are parsing this file as Python 2, so we must use
# Python 2 type annotation syntax
"def preprocess(inp):",
" # type: (torch.Tensor) -> torch.Tensor",
" return inp",
]
for fname in dir(torch):
fn = getattr(torch, fname)
docstr = inspect.getdoc(fn)
if docstr and fname not in blacklist:
e = get_examples_from_docstring(docstr)
if e:
example_file_lines.append("\n\ndef example_torch_{}():".format(fname))
example_file_lines += e
for fname in dir(torch.Tensor):
fn = getattr(torch.Tensor, fname)
docstr = inspect.getdoc(fn)
if docstr and fname not in blacklist:
e = get_examples_from_docstring(docstr)
if e:
example_file_lines.append("\n\ndef example_torch_tensor_{}():".format(fname))
example_file_lines += e
return "\n".join(example_file_lines)
class TestTypeHints(TestCase):
@unittest.skipIf(sys.version_info[0] == 2, "no type hints for Python 2")
@unittest.skipIf(not HAVE_MYPY, "need mypy")
def test_doc_examples(self):
"""
Run documentation examples through mypy.
"""
fn = os.path.join(os.path.dirname(__file__), 'generated_type_hints_smoketest.py')
with open(fn, "w") as f:
print(get_all_examples(), file=f)
# OK, so here's the deal. mypy treats installed packages
# and local modules differently: if a package is installed,
# mypy will refuse to use modules from that package for type
# checking unless the module explicitly says that it supports
# type checking. (Reference:
# https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports
# )
#
# Now, PyTorch doesn't support typechecking, and we shouldn't
# claim that it supports typechecking (it doesn't.) However, not
# claiming we support typechecking is bad for this test, which
# wants to use the partial information we get from the bits of
# PyTorch which are typed to check if it typechecks. And
# although mypy will work directly if you are working in source,
# some of our tests involve installing PyTorch and then running
# its tests.
#
# The guidance we got from Michael Sullivan and Joshua Oreman,
# and also independently developed by Thomas Viehmann,
# is that we should create a fake directory and add symlinks for
# the packages that should typecheck. So that is what we do
# here.
#
# If you want to run mypy by hand, and you run from PyTorch
# root directory, it should work fine to skip this step (since
# mypy will preferentially pick up the local files first). The
# temporary directory here is purely needed for CI. For this
# reason, we also still drop the generated file in the test
# source folder, for ease of inspection when there are failures.
with tempfile.TemporaryDirectory() as tmp_dir:
try:
os.symlink(
os.path.dirname(torch.__file__),
os.path.join(tmp_dir, 'torch'),
target_is_directory=True
)
except OSError:
raise unittest.SkipTest('cannot symlink')
try:
subprocess.run([
sys.executable,
'-mmypy',
'--follow-imports', 'silent',
'--check-untyped-defs',
os.path.abspath(fn)],
cwd=tmp_dir,
check=True)
except subprocess.CalledProcessError as e:
raise AssertionError("mypy failed. Look above this error for mypy's output.")
if __name__ == '__main__':
run_tests()

View File

@ -182,33 +182,49 @@ def should_generate_python_binding(declaration):
return True
def gen_py_variable_methods(out, declarations, template_path):
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
def get_py_variable_methods(declarations):
"""
Get declarations (grouped by name) which should be generated
as methods on Tensor.
"""
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
declaration.get('python_module') != 'nn' and
'Tensor' in declaration['method_of'])
py_variable_methods = group_declarations_by_name(declarations, should_bind)
return group_declarations_by_name(declarations, should_bind)
def gen_py_variable_methods(out, declarations, template_path):
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
py_variable_methods = get_py_variable_methods(declarations)
env = create_python_bindings(py_variable_methods, True)
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
def get_py_nn_functions(declarations):
"""
Get declarations (grouped by name) which should be generated
as functions in the "nn" module.
"""
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
(declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
return group_declarations_by_name(declarations, should_bind)
def gen_py_nn_functions(out, declarations, template_path):
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
(declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
py_nn_functions = group_declarations_by_name(declarations, should_bind)
py_nn_functions = get_py_nn_functions(declarations)
env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
@ -216,17 +232,25 @@ def gen_py_nn_functions(out, declarations, template_path):
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
def gen_py_torch_functions(out, declarations, template_path):
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
def get_py_torch_functions(declarations):
"""
Get declarations (grouped by name) which should be generated
as functions in the "torch" module.
"""
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
declaration.get('python_module') != 'nn' and
'namespace' in declaration['method_of'])
py_torch_functions = group_declarations_by_name(declarations, should_bind)
return group_declarations_by_name(declarations, should_bind)
def gen_py_torch_functions(out, declarations, template_path):
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
py_torch_functions = get_py_torch_functions(declarations)
env = create_python_bindings(py_torch_functions, has_self=False)
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
@ -800,7 +824,16 @@ def sort_declarations(grouped_decls):
def get_python_signature(declaration, include_out):
# Compute the Python function signature for argument parsing
# Compute the Python function signature for argument parsing,
# as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
# this is NOT the same type signature as specified by PEP 484
# as understood by mypy; our format was independently developed
# and has some quirks to make it more suitable specifically
# for error parsing.
#
# For a translation to mypy-valid type signatures, see
# tools/gen_pyi.py. If you change any logic here, please
# check that file too.
py_formal_args = []
output_args = []
type_args = []

View File

@ -14,6 +14,12 @@ except ImportError:
from tools.shared.module_loader import import_module
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
# You should use these lines, rather than doing it manually.
# Especially if you see this error!
#
# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
# loader = Loader(stream)
# TypeError: 'module' object is not callable
try:
# use faster C loader if available
from yaml import CLoader as YamlLoader

0
tools/pyi/__init__.py Normal file
View File

529
tools/pyi/gen_pyi.py Normal file
View File

@ -0,0 +1,529 @@
from __future__ import print_function
import multiprocessing
import sys
import os
import inspect
import collections
import yaml
import types
import re
import argparse
from ..autograd.utils import YamlLoader, CodeTemplate, write
from ..autograd.gen_python_functions import get_py_torch_functions, get_py_variable_methods
from ..autograd.gen_autograd import load_aten_declarations
"""
This module implements generation of type stubs for PyTorch,
enabling use of autocomplete in IDEs like PyCharm, which otherwise
don't understand C extension modules.
At the moment, this module only handles type stubs for torch and
torch.Tensor. It should eventually be expanded to cover all functions
which come are autogenerated.
Here's our general strategy:
- We start off with a hand-written __init__.pyi.in file. This
file contains type definitions for everything we cannot automatically
generate, including pure Python definitions directly in __init__.py
(the latter case should be pretty rare).
- We go through automatically bound functions based on the
type information recorded in Declarations.yaml and
generate type hints for them (generate_type_hints)
There are a number of type hints which we've special-cased;
read gen_pyi for the gory details.
"""
# TODO: Consider defining some aliases for our Union[...] types, to make
# the stubs to read on the human eye.
needed_modules = set()
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False"
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
blacklist = [
'__init_subclass__',
'__new__',
'__subclasshook__',
'clamp',
'clamp_',
'device',
'grad',
'requires_grad',
'range',
# defined in functional
'einsum',
# reduction argument; these bindings don't make sense
'ctc_loss',
'cosine_embedding_loss',
'hinge_embedding_loss',
'kl_div',
'margin_ranking_loss',
'triplet_margin_loss',
# Somehow, these are defined in both _C and in functional. Ick!
'broadcast_tensors',
'meshgrid',
'cartesian_prod',
'norm',
'chain_matmul',
'stft',
'tensordot',
'norm',
'split',
# These are handled specially by python_arg_parser.cpp
'add',
'add_',
'add_out',
'sub',
'sub_',
'sub_out',
'mul',
'mul_',
'mul_out',
'div',
'div_',
'div_out',
]
def type_to_python(typename, size=None):
"""type_to_python(typename: str, size: str) -> str
Transforms a Declarations.yaml type name into a Python type specification
as used for type hints.
"""
typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
# Disambiguate explicitly sized int/tensor lists from implicitly
# sized ones. These permit non-list inputs too. (IntList[] and
# TensorList[] are not real types; this is just for convenience.)
if typename in {'IntList', 'TensorList'} and size is not None:
typename += '[]'
typename = {
'Device': 'Union[_device, str, None]',
'Generator*': 'Generator',
'IntegerTensor': 'Tensor',
'Scalar': 'Number',
'ScalarType': '_dtype',
'Storage': 'Storage',
'BoolTensor': 'Tensor',
'IndexTensor': 'Tensor',
'SparseTensorRef': 'Tensor',
'Tensor': 'Tensor',
'IntList': '_size',
'IntList[]': 'Union[_int, _size]',
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
'bool': 'bool',
'double': '_float',
'int64_t': '_int',
'accreal': 'Number',
'real': 'Number',
'void*': '_int', # data_ptr
'void': 'None',
'std::string': 'str',
}[typename]
return typename
def arg_to_type_hint(arg):
"""arg_to_type_hint(arg) -> str
This takes one argument in a Declarations and returns a string
representing this argument in a type hint signature.
"""
name = arg['name']
if name == 'from': # from is a Python keyword...
name += '_'
typename = type_to_python(arg['dynamic_type'], arg.get('size'))
if arg.get('is_nullable'):
typename = 'Optional[' + typename + ']'
if 'default' in arg:
default = arg['default']
if default == 'nullptr':
default = None
elif default == 'c10::nullopt':
default = None
elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
if arg['dynamic_type'] == 'Tensor' and default == '{}':
default = None
elif arg['dynamic_type'] == 'IntList':
default = '(' + default[1:-1] + ')'
else:
raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
default = '={}'.format(default)
else:
default = ''
return name + ': ' + typename + default
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
'matmul',
'radd', 'rmul', # reverse arithmetic
'and', 'or', 'xor', # logic
'iadd', 'iand', 'idiv', 'ilshift', 'imul',
'ior', 'irshift', 'isub', 'itruediv', 'ixor', # inplace ops
)
comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
unary_ops = ('neg', 'abs', 'invert')
to_py_type_ops = ('bool', 'float', 'long', 'index', 'int', 'nonzero')
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname):
"""sig_for_ops(opname : str) -> List[str]
Returns signatures for operator special functions (__add__ etc.)"""
# we have to do this by hand, because they are hand-bound in Python
assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname)
name = opname[2:-2]
if name in binary_ops:
return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
elif name in comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
return ['def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)]
elif name in unary_ops:
return ['def {}(self) -> Tensor: ...'.format(opname)]
elif name in to_py_type_ops:
if name in {'bool', 'float'}:
tname = name
elif name == 'nonzero':
tname = 'bool'
else:
tname = 'int'
if tname in {'float', 'int'}:
tname = 'builtins.' + tname
return ['def {}(self) -> {}: ...'.format(opname, tname)]
else:
raise Exception("unknown op", opname)
def generate_type_hints(fname, decls, is_tensor=False):
"""generate_type_hints(fname, decls, is_tensor=False)
Generates type hints for the declarations pertaining to the function
:attr:`fname`. attr:`decls` are the declarations from the parsed
Declarations.yaml.
The :attr:`is_tensor` flag indicates whether we are parsing
members of the Tensor class (true) or functions in the
`torch` namespace (default, false).
This function currently encodes quite a bit about the semantics of
the translation C++ -> Python.
"""
if fname in blacklist:
return []
type_hints = []
dnames = ([d['name'] for d in decls])
has_out = fname + '_out' in dnames
if has_out:
decls = [d for d in decls if d['name'] != fname + '_out']
for decl in decls:
render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
python_args = []
has_tensor_options = 'TensorOptions' in [a['dynamic_type'] for a in decl['arguments']]
for a in decl['arguments']:
if a['dynamic_type'] != 'TensorOptions':
if a.get('kwarg_only', False) and render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args.append(arg_to_type_hint(a))
if is_tensor:
if 'self: Tensor' in python_args:
python_args.remove('self: Tensor')
python_args = ['self'] + python_args
else:
raise Exception("method without self is unexpected")
if has_out:
if render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args.append('out: Optional[Tensor]=None')
if has_tensor_options:
if render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args += ["dtype: _dtype=None",
"layout: layout=strided",
"device: Union[_device, str, None]=None",
"requires_grad:bool=False"]
python_args_s = ', '.join(python_args)
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
if len(python_returns) > 1:
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
else:
python_returns_s = python_returns[0]
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
numargs = len(decl['arguments'])
vararg_pos = int(is_tensor)
have_vararg_version = (numargs > vararg_pos and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntList', 'TensorList'} and
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
type_hints.append(type_hint)
if have_vararg_version:
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
# is an IntList or TensorList, it will be used as a vararg variant.
# The following outputs the vararg variant, the "pass a list variant" is output above.
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
typelist = decl['arguments'][vararg_pos]['dynamic_type']
if typelist == 'IntList':
vararg_type = '_int'
else:
vararg_type = 'Tensor'
# replace first argument and eliminate '*' if present
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
': ' + vararg_type] + python_args[vararg_pos + 2:])
python_args_s = ', '.join(python_args)
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
type_hints.append(type_hint)
return type_hints
def gen_pyi(declarations_path, out):
"""gen_pyi()
This function generates a pyi file for torch.
"""
# Some of this logic overlaps with generate_python_signature in
# tools/autograd/gen_python_functions.py; however, this
# function is all about generating mypy type signatures, whereas
# the other function generates are custom format for argument
# checking. If you are update this, consider if your change
# also needs to update the other file.
# Load information from YAML
declarations = load_aten_declarations(declarations_path)
# Generate type signatures for top-level functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_function_hints = collections.defaultdict(list)
unsorted_function_hints.update({
'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'],
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
'get_num_threads': ['def get_num_threads() -> _int: ...'],
'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
# These functions are explicitly disabled by
# SKIP_PYTHON_BINDINGS because they are hand bound.
# Correspondingly, we must hand-write their signatures.
'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
'range': ['def range(start: Number, end: Number,'
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
'arange': ['def arange(start: Number, end: Number, step: Number, *,'
' out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS),
'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS),
'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
'randint': ['def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...'
.format(FACTORY_PARAMS),
'def randint(high: _int, size: _size, *, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
})
for binop in ['add', 'sub', 'mul', 'div']:
unsorted_function_hints[binop].append(
'def {}(input: Union[Tensor, Number],'
' other: Union[Tensor, Number],'
' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
unsorted_function_hints[binop].append(
'def {}(input: Union[Tensor, Number],'
' value: Number,'
' other: Union[Tensor, Number],'
' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
function_declarations = get_py_torch_functions(declarations)
for name in sorted(function_declarations.keys()):
unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name])
# Generate type signatures for deprecated functions
# TODO: Maybe we shouldn't generate type hints for deprecated
# functions :) However, examples like those addcdiv rely on these.
with open('tools/autograd/deprecated.yaml', 'r') as f:
deprecated = yaml.load(f, Loader=YamlLoader)
for d in deprecated:
name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
function_hints = []
for name, hints in sorted(unsorted_function_hints.items()):
if len(hints) > 1:
hints = ['@overload\n' + h for h in hints]
function_hints += hints
# Generate type signatures for Tensor methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_tensor_method_hints = collections.defaultdict(list)
unsorted_tensor_method_hints.update({
'size': ['def size(self) -> Size: ...',
'def size(self, _int) -> _int: ...'],
'stride': ['def stride(self) -> Tuple[_int]: ...',
'def stride(self, _int) -> _int: ...'],
'new_empty': ['def new_empty(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntList'), FACTORY_PARAMS)],
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntList'), FACTORY_PARAMS)],
'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntList'), FACTORY_PARAMS)],
'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
format(type_to_python('IntList'), type_to_python('Scalar'), FACTORY_PARAMS)],
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
# clamp has no default values in the Declarations
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
'__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
'__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
" -> None: ...".format(INDICES)],
'tolist': ['def tolist(self) -> List: ...'],
'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
'element_size': ['def element_size(self) -> _int: ...'],
'dim': ['def dim(self) -> _int: ...'],
'ndimension': ['def ndimension(self) -> _int: ...'],
'nelement': ['def nelement(self) -> _int: ...'],
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
'numpy': ['def numpy(self) -> Any: ...'],
'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
'copy_': ['def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'],
'storage': ['def storage(self) -> Storage: ...'],
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
' -> Union[str, Tensor]: ...'],
'get_device': ['def get_device(self) -> _int: ...'],
'is_contiguous': ['def is_contiguous(self) -> bool: ...'],
'is_cuda': ['def is_cuda(self) -> bool: ...'],
'is_leaf': ['def is_leaf(self) -> bool: ...'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
],
'item': ["def item(self) -> Number: ..."],
})
for binop in ['add', 'sub', 'mul', 'div']:
for inplace in [True, False]:
out_suffix = ', *, out: Optional[Tensor]=None'
if inplace:
name += '_'
out_suffix = ''
unsorted_tensor_method_hints[name].append(
'def {}(self, other: Union[Tensor, Number]{})'
' -> Tensor: ...'.format(name, out_suffix))
unsorted_tensor_method_hints[name].append(
'def {}(self, value: Number,'
' other: Union[Tensor, Number]{})'
' -> Tensor: ...'.format(name, out_suffix))
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short']
for name in simple_conversions:
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
tensor_method_declarations = get_py_variable_methods(declarations)
for name in sorted(tensor_method_declarations.keys()):
unsorted_tensor_method_hints[name] += \
generate_type_hints(name, tensor_method_declarations[name], is_tensor=True)
for op in all_ops:
name = '__{}__'.format(op)
unsorted_tensor_method_hints[name] += sig_for_ops(name)
tensor_method_hints = []
for name, hints in sorted(unsorted_tensor_method_hints.items()):
if len(hints) > 1:
hints = ['@overload\n' + h for h in hints]
tensor_method_hints += hints
# TODO: Missing type hints for nn
# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_class_hints = []
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage'):
legacy_class_hints.append('class {}(Storage): ...'.format(c))
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor'):
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
# Generate type signatures for dtype classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: don't explicitly list dtypes here; get it from canonical
# source
dtype_class_hints = ['{}: dtype = ...'.format(n)
for n in
['float32', 'float', 'float64', 'double', 'float16', 'half',
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
'complex32', 'complex64', 'complex128']]
# Write out the stub
# ~~~~~~~~~~~~~~~~~~
env = {
'function_hints': function_hints,
'tensor_method_hints': tensor_method_hints,
'legacy_class_hints': legacy_class_hints,
'dtype_class_hints': dtype_class_hints,
}
TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
def main():
parser = argparse.ArgumentParser(
description='Generate type stubs for PyTorch')
parser.add_argument('--declarations-path', metavar='DECL',
default='torch/share/ATen/Declarations.yaml',
help='path to Declarations.yaml')
parser.add_argument('--out', metavar='OUT',
default='.',
help='path to output directory')
args = parser.parse_args()
gen_pyi(args.declarations_path, args.out)
if __name__ == '__main__':
main()

View File

@ -713,8 +713,26 @@ if (BUILD_PYTHON)
endif()
endif()
add_custom_target(torch_python_stubs DEPENDS "${TORCH_SRC_DIR}/__init__.pyi")
# For Declarations.yaml dependency
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
add_custom_command(
OUTPUT
"${TORCH_SRC_DIR}/__init__.pyi"
COMMAND
${PYCMD} -mtools.pyi.gen_pyi
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
DEPENDS
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
"${TORCH_SRC_DIR}/__init__.pyi.in"
WORKING_DIRECTORY
"${TORCH_ROOT}"
)
add_library(torch_python SHARED ${TORCH_PYTHON_SRCS})
add_dependencies(torch_python torch_python_stubs)
target_link_libraries(torch_python ${TORCH_PYTHON_LINK_LIBRARIES})
target_compile_definitions(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS})

View File

@ -179,6 +179,7 @@ def set_default_dtype(d):
"""
_C._set_default_dtype(d)
# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed
from .serialization import save, load
from ._tensor_str import set_printoptions

106
torch/__init__.pyi.in Normal file
View File

@ -0,0 +1,106 @@
# ${generated_comment}
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
from torch._six import inf
import builtins
# These identifiers are reexported from other modules. These modules
# are not mypy-clean yet, so in order to use this stub file usefully
# from mypy you will need to specify --follow-imports=silent.
# Not all is lost: these imports still enable IDEs like PyCharm to offer
# autocomplete.
#
# Note: Why does the syntax here look so strange? Import visibility
# rules in stubs are different from normal Python files! You must use
# 'from ... import ... as ...' syntax to cause an identifier to be
# exposed (or use a wildcard); regular syntax is not exposed.
from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
manual_seed as manual_seed, initial_seed as initial_seed
from ._tensor_str import set_printoptions as set_printoptions
from .functional import *
from .serialization import save as save, load as load
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled
class dtype: ...
class layout: ...
strided : layout = ...
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
_int = builtins.int
_float = builtins.float
class device:
def __init__(self, device: Union[_int, str, None]=None) -> None: ...
class Generator: ...
class Size(tuple): ...
class Storage: ...
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
_dtype = dtype
_device = device
_size = Union[Size, List[_int], Tuple[_int, ...]]
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float]
# TODO: One downside of doing it this way, is direct use of
# torch.tensor.Tensor doesn't get type annotations. Nobody
# should really do that, so maybe this is not so bad.
class Tensor:
dtype: _dtype = ...
shape: Size = ...
device: _device = ...
requires_grad: bool = ...
grad: Optional[Tensor] = ...
${tensor_method_hints}
# Manually defined methods from torch/tensor.py
def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ...
def register_hook(self, hook: Callable) -> Any: ...
def retain_grad(self) -> None: ...
def is_pinned(self) -> bool: ...
def is_shared(self) -> bool: ...
def share_memory_(self) -> None: ...
# TODO: fill in the types for these, or otherwise figure out some
# way to not have to write these out again...
def argmax(self, dim=None, keepdim=False): ...
def argmin(self, dim=None, keepdim=False): ...
def argsort(self, dim=None, descending=False): ...
def norm(self, p="fro", dim=None, keepdim=False): ...
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0): ...
def index_add(self, dim, index, tensor): ...
def index_copy(self, dim, index, tensor): ...
def index_fill(self, dim, index, value): ...
def scatter(self, dim, index, source): ...
def scatter_add(self, dim, index, source): ...
def masked_scatter(self, mask, tensor): ...
def masked_fill(self, mask, value): ...
def unique(self, sorted=True, return_inverse=False, dim=None): ...
${function_hints}
${legacy_class_hints}
${dtype_class_hints}
# Pure Python functions defined in torch/__init__.py
def typename(obj) -> str: ...
def is_tensor(obj) -> bool: ...
def is_storage(obj) -> bool: ...
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def set_default_dtype(d : _dtype) -> None: ...
def manager_path() -> str: ...
def compiled_with_cxx11_abi() -> bool: ...

View File

@ -20,6 +20,7 @@
import itertools
import sys
import builtins
PY2 = sys.version_info[0] == 2
@ -48,7 +49,7 @@ else:
if PY2:
FileNotFoundError = IOError
else:
FileNotFoundError = FileNotFoundError
FileNotFoundError = builtins.FileNotFoundError
if PY2:
@ -71,11 +72,10 @@ def with_metaclass(meta, *bases):
# A portable way of referring to the generator version of map
# in both Python 2 and Python 3.
# TODO: Move this into an appropriate utility library.
if hasattr(itertools, 'imap'):
imap = itertools.imap
imap = itertools.imap # type: ignore
else:
imap = map
imap = map # type: ignore
if PY3:

View File

@ -309,3 +309,13 @@ def _take_tensors(tensors, size_limit):
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf
# annotation decorator to get annotations in a way that is compatible
# with both Python 2 and 3
def annotate(ret, **kwargs):
def dec(fun):
fun.__annotations__ = dict(kwargs)
fun.__annotations__['return'] = ret
return fun
return dec

View File

@ -36,6 +36,10 @@ static std::unordered_map<std::string, ParameterType> type_map = {
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
// overloads and binding to the Tensor overload with a number of a different
// type will trigger a type error.
//
// If you modify this, you will need to adjust the blacklist in
// tools/pyi/gen_pyi.py (and add hardcoded signatures for these
// functions.)
static bool should_allow_numbers_as_tensors(const std::string& name) {
static std::unordered_set<std::string> allowed = {
"add", "add_", "add_out",

View File

@ -4,8 +4,11 @@ from torch._six import inf
from torch._C import _add_docstr
from operator import mul
from functools import reduce
from collections import Iterable
from torch._utils import annotate
from itertools import product
import math
from typing import Optional, Tuple, List, Union
import warnings
__all__ = [

View File

@ -366,7 +366,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt') as f:
>>> with open('tensor.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
"""

View File

@ -12,6 +12,10 @@ from torch._C import _add_docstr
# NB: If you subclass Tensor, and want to share the subclassed class
# across processes, you must also update torch/multiprocessing/reductions.py
# to define a ForkingPickler serialization mode for the class.
#
# NB: If you add a new method to Tensor, you must update
# torch/__init__.py.in to add a type annotation for your method;
# otherwise, it will not show up in autocomplete.
class Tensor(torch._C._TensorBase):
def __deepcopy__(self, memo):
if not self.is_leaf: