mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change output_declarations in function_wrapper.py to be a NamedTuple (#5312)
* Add python typing module as build dependency * Change output_declarations to be a NamedTuple * Add mypy configuration files mypy-files.txt includes a list of all files that should be typed checked with mypy. Run mypy with `mypy @mypyfiles.txt`. mypy.ini includes mypy options. Unfortunately this can't be merged with mypy-files.txt. Update .travis.yml so that one doesn't have to specify what files to type check inside it. * Add RuntimeError on missing `typing` module Alerts users to the new build dependency.
This commit is contained in:
committed by
Soumith Chintala
parent
2130070785
commit
dcbbf346c2
@ -16,8 +16,7 @@ matrix:
|
||||
python: "2.7"
|
||||
install: pip install flake8
|
||||
script: flake8
|
||||
# mypy will complain about various files. Limiting it to only typed files
|
||||
- env: MYPY_TYPE_CHECK
|
||||
python: "3.6"
|
||||
install: pip install mypy mypy-extensions
|
||||
script: mypy --py2 aten/src/ATen/function_wrapper.py
|
||||
script: mypy @mypy-files.txt
|
||||
|
@ -171,7 +171,7 @@ On Linux
|
||||
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
|
||||
|
||||
# Install basic dependencies
|
||||
conda install numpy pyyaml mkl setuptools cmake cffi
|
||||
conda install numpy pyyaml mkl setuptools cmake cffi typing
|
||||
|
||||
# Add LAPACK support for the GPU
|
||||
conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9
|
||||
|
@ -6,12 +6,22 @@ from collections import OrderedDict
|
||||
from code_template import CodeTemplate
|
||||
|
||||
try:
|
||||
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
|
||||
Union, TypeVar
|
||||
from mypy_extensions import TypedDict
|
||||
TYPE_HINTS = True
|
||||
import typing
|
||||
except ImportError:
|
||||
TYPE_HINTS = False
|
||||
raise RuntimeError(
|
||||
'Missing build dependency: Unable to import the `typing` module. '
|
||||
'Please install it via `conda install typing` or `pip install typing`')
|
||||
|
||||
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
|
||||
Union, TypeVar, NamedTuple
|
||||
|
||||
try:
|
||||
from mypy_extensions import TypedDict
|
||||
except ImportError:
|
||||
# Avoid the dependency on the mypy_extensions package.
|
||||
# It is required, however, for type checking.
|
||||
def TypedDict(name, attrs, total=True): # type: ignore
|
||||
return Dict[Any, Any]
|
||||
|
||||
import sys
|
||||
if sys.version_info[0] == 3:
|
||||
@ -282,152 +292,163 @@ class nested_dict(object):
|
||||
return r
|
||||
return self.parent[x]
|
||||
|
||||
if TYPE_HINTS:
|
||||
Environment = TypedDict('Environment', {
|
||||
'ScalarName': str,
|
||||
'THTensor': str,
|
||||
'THType': str,
|
||||
'THTensor': str,
|
||||
'Backend': str,
|
||||
'AccScalarName': str,
|
||||
})
|
||||
Environment = TypedDict('Environment', {
|
||||
'ScalarName': str,
|
||||
'THTensor': str,
|
||||
'THType': str,
|
||||
'THTensor': str,
|
||||
'Backend': str,
|
||||
'AccScalarName': str,
|
||||
})
|
||||
|
||||
TopEnvironment = TypedDict('TopEnvironment', {
|
||||
'type_registrations': List[str],
|
||||
'type_headers': List[str],
|
||||
'type_method_declarations': List[str],
|
||||
'type_method_definitions': List[str],
|
||||
'type_method_inline_definitions': List[str],
|
||||
'tensor_method_declarations': List[str],
|
||||
'tensor_method_definitions': List[str],
|
||||
'function_declarations': List[str],
|
||||
'function_definitions': List[str],
|
||||
'type_ids': List[str],
|
||||
'native_function_declarations': List[str],
|
||||
})
|
||||
TopEnvironment = TypedDict('TopEnvironment', {
|
||||
'type_registrations': List[str],
|
||||
'type_headers': List[str],
|
||||
'type_method_declarations': List[str],
|
||||
'type_method_definitions': List[str],
|
||||
'type_method_inline_definitions': List[str],
|
||||
'tensor_method_declarations': List[str],
|
||||
'tensor_method_definitions': List[str],
|
||||
'function_declarations': List[str],
|
||||
'function_definitions': List[str],
|
||||
'type_ids': List[str],
|
||||
'native_function_declarations': List[str],
|
||||
})
|
||||
|
||||
# A Declarations.cwrap formal argument
|
||||
# type can contain THTensor* types
|
||||
THFormal = TypedDict('THFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
'declared_type': str,
|
||||
'ignore_check': bool,
|
||||
'allocate': bool,
|
||||
'mask': bool,
|
||||
'if_true': bool,
|
||||
'if_false': bool,
|
||||
'wrap_dim': str,
|
||||
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
|
||||
'broadcast': Any,
|
||||
'resize': str,
|
||||
'cpu_zero': bool,
|
||||
'zero': bool,
|
||||
}, total=False)
|
||||
# A Declarations.cwrap formal argument
|
||||
# type can contain THTensor* types
|
||||
THFormal = TypedDict('THFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
'declared_type': str,
|
||||
'ignore_check': bool,
|
||||
'allocate': bool,
|
||||
'mask': bool,
|
||||
'if_true': bool,
|
||||
'if_false': bool,
|
||||
'wrap_dim': str,
|
||||
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
|
||||
'broadcast': Any,
|
||||
'resize': str,
|
||||
'cpu_zero': bool,
|
||||
'zero': bool,
|
||||
}, total=False)
|
||||
|
||||
# A native_functions.yaml formal argument
|
||||
# type can contain Tensor, BoolTensor, IndexTensor types
|
||||
NativeFormal = TypedDict('NativeFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
# A native_functions.yaml formal argument
|
||||
# type can contain Tensor, BoolTensor, IndexTensor types
|
||||
NativeFormal = TypedDict('NativeFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
|
||||
# Generic ATen formal.
|
||||
# type can contain Tensor& reference types.
|
||||
AtFormal = TypedDict('AtFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
# Generic ATen formal.
|
||||
# type can contain Tensor& reference types.
|
||||
AtFormal = TypedDict('AtFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
|
||||
ReturnType = TypedDict('ReturnType', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
}, total=False)
|
||||
ReturnType = TypedDict('ReturnType', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
}, total=False)
|
||||
|
||||
ReturnDecl = TypedDict('ReturnDecl', {
|
||||
'kind': str,
|
||||
'type': str,
|
||||
'arguments': List[int],
|
||||
}, total=False)
|
||||
ReturnDecl = TypedDict('ReturnDecl', {
|
||||
'kind': str,
|
||||
'type': str,
|
||||
'arguments': List[int],
|
||||
}, total=False)
|
||||
|
||||
# Represents a buffer in nn.yaml
|
||||
NNBuffer = TypedDict('NNBuffer', {
|
||||
'name': str,
|
||||
})
|
||||
# Represents a buffer in nn.yaml
|
||||
NNBuffer = TypedDict('NNBuffer', {
|
||||
'name': str,
|
||||
})
|
||||
|
||||
FunctionOption = TypedDict('FunctionOption', {
|
||||
'arguments': List[THFormal],
|
||||
'mode': str,
|
||||
'name': str,
|
||||
'return': ReturnDecl,
|
||||
'variants': str,
|
||||
'type_method_definition_dispatch': str,
|
||||
'cname': str,
|
||||
'backends': List[str],
|
||||
'api_name': str,
|
||||
'backend_type_pairs': List[Tuple[str, str]],
|
||||
'inplace': bool,
|
||||
'aten_dense_sparse': bool,
|
||||
'sparse': bool,
|
||||
'scalar_check': str,
|
||||
'aten_custom_call': str,
|
||||
'type_definition_body': List[str],
|
||||
# cimpls is really a List[FunctionOption]
|
||||
'cimpls': List[Any],
|
||||
'when_spares_dispatch': str,
|
||||
'actuals': List[str],
|
||||
'buffers': List[NNBuffer],
|
||||
'zero_dim_dispatch_when_scalar': str,
|
||||
'zero_dim_tensor_only': bool,
|
||||
'when_sparse_dispatch': str,
|
||||
'formals_list': List[AtFormal],
|
||||
'condition': str,
|
||||
'auto_gpu': bool,
|
||||
'cpu_half': bool,
|
||||
# options should be List[FunctionOption]
|
||||
'options': Any,
|
||||
'formals': List[str],
|
||||
'formals_with_defaults': List[str],
|
||||
'returns': List[ReturnType],
|
||||
'return_type': str,
|
||||
'return_call': str,
|
||||
'method_formals': List[str],
|
||||
'method_formals_with_defaults': List[str],
|
||||
'method_actuals': List[str],
|
||||
'const_mark': str,
|
||||
'method_prefix_derived': str,
|
||||
'broadcast_actuals': List[str],
|
||||
'broadcast_returns': List[str],
|
||||
'inferred_type': str,
|
||||
'broadcast_function': str,
|
||||
'broadcast_modified_actuals': List[str],
|
||||
'native_type_method_dispatch': str,
|
||||
})
|
||||
FunctionOption = TypedDict('FunctionOption', {
|
||||
'arguments': List[THFormal],
|
||||
'mode': str,
|
||||
'name': str,
|
||||
'return': ReturnDecl,
|
||||
'variants': str,
|
||||
'type_method_definition_dispatch': str,
|
||||
'cname': str,
|
||||
'backends': List[str],
|
||||
'api_name': str,
|
||||
'backend_type_pairs': List[Tuple[str, str]],
|
||||
'inplace': bool,
|
||||
'aten_dense_sparse': bool,
|
||||
'sparse': bool,
|
||||
'scalar_check': str,
|
||||
'aten_custom_call': str,
|
||||
'type_definition_body': List[str],
|
||||
# cimpls is really a List[FunctionOption]
|
||||
'cimpls': List[Any],
|
||||
'when_spares_dispatch': str,
|
||||
'actuals': List[str],
|
||||
'buffers': List[NNBuffer],
|
||||
'zero_dim_dispatch_when_scalar': str,
|
||||
'zero_dim_tensor_only': bool,
|
||||
'when_sparse_dispatch': str,
|
||||
'formals_list': List[AtFormal],
|
||||
'condition': str,
|
||||
'auto_gpu': bool,
|
||||
'cpu_half': bool,
|
||||
# options should be List[FunctionOption]
|
||||
'options': Any,
|
||||
'formals': List[str],
|
||||
'formals_with_defaults': List[str],
|
||||
'returns': List[ReturnType],
|
||||
'return_type': str,
|
||||
'return_call': str,
|
||||
'method_formals': List[str],
|
||||
'method_formals_with_defaults': List[str],
|
||||
'method_actuals': List[str],
|
||||
'const_mark': str,
|
||||
'method_prefix_derived': str,
|
||||
'broadcast_actuals': List[str],
|
||||
'broadcast_returns': List[str],
|
||||
'inferred_type': str,
|
||||
'broadcast_function': str,
|
||||
'broadcast_modified_actuals': List[str],
|
||||
'native_type_method_dispatch': str,
|
||||
})
|
||||
|
||||
OutputDeclaration = NamedTuple('OutputDeclaration', [
|
||||
('name', str),
|
||||
('method_prefix_derived', str),
|
||||
('arguments', List[AtFormal]),
|
||||
('method_of', List[str]),
|
||||
('mode', str),
|
||||
('buffers', Optional[List[str]]),
|
||||
('returns', List[ReturnType]),
|
||||
('inplace', bool),
|
||||
('abstract', bool),
|
||||
])
|
||||
|
||||
|
||||
def is_real_argument_to_wrapper(argument):
|
||||
@ -458,7 +479,7 @@ def to_return_type(arg, option):
|
||||
|
||||
|
||||
def create_generic(top_env, declarations):
|
||||
# type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict]
|
||||
# type: (TopEnvironment, List[FunctionOption]) -> List[OutputDeclaration]
|
||||
# translates defaults from cwrap types to C++ values
|
||||
def translate_default(argument, type_str, default):
|
||||
# type: (THFormal, str, Any) -> Any
|
||||
@ -656,7 +677,7 @@ def create_generic(top_env, declarations):
|
||||
return body
|
||||
|
||||
def process_option(option, output_options):
|
||||
# type: (FunctionOption, List[OrderedDict]) -> None
|
||||
# type: (FunctionOption, List[OutputDeclaration]) -> None
|
||||
option['inplace'] = re.search(
|
||||
'(^__i|[^_]_$)', option['api_name']) is not None
|
||||
|
||||
@ -748,18 +769,18 @@ def create_generic(top_env, declarations):
|
||||
|
||||
buffer_names = [buffer['name'] for buffer in option.get('buffers', [])]
|
||||
|
||||
output_options.append(OrderedDict([
|
||||
('name', option['api_name']),
|
||||
('method_prefix_derived', option['method_prefix_derived']),
|
||||
('arguments', formals),
|
||||
('method_of', method_of),
|
||||
('mode', mode),
|
||||
('buffers', buffer_names),
|
||||
('returns', option['returns']),
|
||||
('inplace', option['inplace']),
|
||||
output_options.append(OutputDeclaration(
|
||||
name=option['api_name'],
|
||||
method_prefix_derived=option['method_prefix_derived'],
|
||||
arguments=formals,
|
||||
method_of=method_of,
|
||||
mode=mode,
|
||||
buffers=buffer_names,
|
||||
returns=option['returns'],
|
||||
inplace=option['inplace'],
|
||||
# See Note [Abstract ATen methods]
|
||||
('abstract', abstract),
|
||||
]))
|
||||
abstract=abstract,
|
||||
))
|
||||
|
||||
def native_get_formals(option, include_constants=False):
|
||||
# type: (FunctionOption, bool) -> List[AtFormal]
|
||||
@ -843,7 +864,7 @@ def create_generic(top_env, declarations):
|
||||
return return_types
|
||||
|
||||
def process_native(option, output_options):
|
||||
# type: (FunctionOption, List[OrderedDict]) -> None
|
||||
# type: (FunctionOption, List[OutputDeclaration]) -> None
|
||||
option['inplace'] = re.search(
|
||||
'(^__i|[^_]_$)', option['api_name']) is not None
|
||||
|
||||
@ -932,21 +953,22 @@ def create_generic(top_env, declarations):
|
||||
FUNCTION_DEFINITION.substitute(env))
|
||||
method_of.append('namespace')
|
||||
|
||||
output_options.append(OrderedDict([
|
||||
('name', option['api_name']),
|
||||
('method_prefix_derived', option['method_prefix_derived']),
|
||||
('arguments', formals),
|
||||
('method_of', method_of),
|
||||
('mode', option['mode']),
|
||||
('returns', option['returns']),
|
||||
('inplace', option['inplace']),
|
||||
output_options.append(OutputDeclaration(
|
||||
name=option['api_name'],
|
||||
method_prefix_derived=option['method_prefix_derived'],
|
||||
arguments=formals,
|
||||
method_of=method_of,
|
||||
mode=option['mode'],
|
||||
buffers=None,
|
||||
returns=option['returns'],
|
||||
inplace=option['inplace'],
|
||||
# See Note [Abstract ATen methods]
|
||||
('abstract', abstract),
|
||||
]))
|
||||
abstract=abstract,
|
||||
))
|
||||
|
||||
output_declarations = [] # type: List[OrderedDict]
|
||||
output_declarations = [] # type: List[OutputDeclaration]
|
||||
for declaration in declarations:
|
||||
output_options = [] # type: List[OrderedDict]
|
||||
output_options = [] # type: List[OutputDeclaration]
|
||||
for option in declaration['options']:
|
||||
try:
|
||||
if option['mode'] != 'native':
|
||||
|
@ -165,19 +165,25 @@ def postprocess_output_declarations(output_declarations):
|
||||
# ensure each return has a name associated with it
|
||||
for decl in output_declarations:
|
||||
has_named_ret = False
|
||||
for n, ret in enumerate(decl['returns']):
|
||||
for n, ret in enumerate(decl.returns):
|
||||
if 'name' not in ret:
|
||||
assert not has_named_ret
|
||||
if decl['inplace']:
|
||||
if decl.inplace:
|
||||
ret['name'] = 'self'
|
||||
elif len(decl['returns']) == 1:
|
||||
elif len(decl.returns) == 1:
|
||||
ret['name'] = 'result'
|
||||
else:
|
||||
ret['name'] = 'result' + str(n)
|
||||
else:
|
||||
has_named_ret = True
|
||||
|
||||
return output_declarations
|
||||
def remove_key_if_none(dictionary, key):
|
||||
if key in dictionary.keys() and dictionary[key] is None:
|
||||
del dictionary[key]
|
||||
return dictionary
|
||||
|
||||
return [remove_key_if_none(decl._asdict(), 'buffers')
|
||||
for decl in output_declarations]
|
||||
|
||||
|
||||
def format_yaml(data):
|
||||
|
11
mypy-README.md
Normal file
11
mypy-README.md
Normal file
@ -0,0 +1,11 @@
|
||||
### Optional type checking with mypy
|
||||
|
||||
mypy is an optional static typechecker that works with Python 3.
|
||||
To use it, install the following dependencies:
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install mypy mypy-extensions
|
||||
|
||||
# Run type checker in the pytorch/ directory
|
||||
mypy @mypy-files.txt
|
||||
```
|
1
mypy-files.txt
Normal file
1
mypy-files.txt
Normal file
@ -0,0 +1 @@
|
||||
aten/src/ATen/function_wrapper.py
|
@ -1 +1,2 @@
|
||||
pyyaml
|
||||
typing
|
||||
|
Reference in New Issue
Block a user