mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: We have: - This is an initial stab at creating a type stub `torch/__init__.pyi` . - This is only tested on Python 3, since that's the only Python version mypy works on. - So far, we only aim at doing this for torch functions and torch.Tensor. - Quite a few methods and functions have to be typed manually. These are done in `torch/__init__.pyi.in` For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out. An example of a generated PYI is at [this gist](https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1). Pull Request resolved: https://github.com/pytorch/pytorch/pull/12500 Differential Revision: D13695553 Pulled By: ezyang fbshipit-source-id: 4566c71913ede4e4c23ebc4a72c17151f94e8e21
74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
import re
|
|
import os
|
|
from .nested_dict import nested_dict
|
|
|
|
|
|
__all__ = [
|
|
'CodeTemplate', 'IDENT_REGEX', 'YamlLoader', 'nested_dict',
|
|
'split_name_params', 'write',
|
|
]
|
|
|
|
try:
|
|
from src.ATen.code_template import CodeTemplate
|
|
except ImportError:
|
|
from tools.shared.module_loader import import_module
|
|
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
|
|
|
# You should use these lines, rather than doing it manually.
|
|
# Especially if you see this error!
|
|
#
|
|
# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
|
|
# loader = Loader(stream)
|
|
# TypeError: 'module' object is not callable
|
|
try:
|
|
# use faster C loader if available
|
|
from yaml import CLoader as YamlLoader
|
|
except ImportError:
|
|
from yaml import Loader as YamlLoader
|
|
|
|
|
|
GENERATED_COMMENT = CodeTemplate(
|
|
"@" + "generated from tools/autograd/templates/${filename}")
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
# by signatures that contain things like std::array<bool, 2> (note the space)
|
|
def split_name_params(prototype):
|
|
name, params = re.match(r'(\w+)\((.*)\)', prototype).groups()
|
|
return name, params.split(', ')
|
|
|
|
|
|
# When tracing, we record inplace operations as out-of-place operations,
|
|
# because we don't have a story for side effects in the IR yet.
|
|
#
|
|
# Doing this un-inplacing is a little delicate however; __and__ is NOT inplace!
|
|
# TODO: Do something more robust
|
|
def uninplace_api_name(api_name):
|
|
if api_name.endswith('_') and not api_name.endswith('__'):
|
|
api_name = api_name[:-1]
|
|
if api_name.endswith('_out'):
|
|
api_name = api_name[:-4]
|
|
return api_name
|
|
|
|
|
|
def write(dirname, name, template, env):
|
|
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=name)
|
|
path = os.path.join(dirname, name)
|
|
# See Note [Unchanging results for ninja]
|
|
try:
|
|
with open(path, 'r') as f:
|
|
old_val = f.read()
|
|
except IOError:
|
|
old_val = None
|
|
new_val = template.substitute(env)
|
|
if old_val != new_val:
|
|
with open(path, 'w') as f:
|
|
print("Writing {}".format(path))
|
|
f.write(new_val)
|
|
else:
|
|
print("Skipped writing {}".format(path))
|