Change output_declarations in function_wrapper.py to be a NamedTuple (#5312)

* Add python typing module as build dependency

* Change output_declarations to be a NamedTuple

* Add mypy configuration files

mypy-files.txt includes a list of all files that should be typed checked
with mypy. Run mypy with `mypy @mypyfiles.txt`.

mypy.ini includes mypy options. Unfortunately this can't be merged with
mypy-files.txt.

Update .travis.yml so that one doesn't have to specify what files to
type check inside it.

* Add RuntimeError on missing `typing` module

Alerts users to the new build dependency.
This commit is contained in:
Richard Zou
2018-02-23 13:33:59 -05:00
committed by Soumith Chintala
parent 2130070785
commit dcbbf346c2
8 changed files with 218 additions and 176 deletions

View File

@ -16,8 +16,7 @@ matrix:
python: "2.7"
install: pip install flake8
script: flake8
# mypy will complain about various files. Limiting it to only typed files
- env: MYPY_TYPE_CHECK
python: "3.6"
install: pip install mypy mypy-extensions
script: mypy --py2 aten/src/ATen/function_wrapper.py
script: mypy @mypy-files.txt

View File

@ -171,7 +171,7 @@ On Linux
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
# Install basic dependencies
conda install numpy pyyaml mkl setuptools cmake cffi
conda install numpy pyyaml mkl setuptools cmake cffi typing
# Add LAPACK support for the GPU
conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9

View File

@ -6,12 +6,22 @@ from collections import OrderedDict
from code_template import CodeTemplate
try:
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
Union, TypeVar
from mypy_extensions import TypedDict
TYPE_HINTS = True
import typing
except ImportError:
TYPE_HINTS = False
raise RuntimeError(
'Missing build dependency: Unable to import the `typing` module. '
'Please install it via `conda install typing` or `pip install typing`')
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
Union, TypeVar, NamedTuple
try:
from mypy_extensions import TypedDict
except ImportError:
# Avoid the dependency on the mypy_extensions package.
# It is required, however, for type checking.
def TypedDict(name, attrs, total=True): # type: ignore
return Dict[Any, Any]
import sys
if sys.version_info[0] == 3:
@ -282,152 +292,163 @@ class nested_dict(object):
return r
return self.parent[x]
if TYPE_HINTS:
Environment = TypedDict('Environment', {
'ScalarName': str,
'THTensor': str,
'THType': str,
'THTensor': str,
'Backend': str,
'AccScalarName': str,
})
Environment = TypedDict('Environment', {
'ScalarName': str,
'THTensor': str,
'THType': str,
'THTensor': str,
'Backend': str,
'AccScalarName': str,
})
TopEnvironment = TypedDict('TopEnvironment', {
'type_registrations': List[str],
'type_headers': List[str],
'type_method_declarations': List[str],
'type_method_definitions': List[str],
'type_method_inline_definitions': List[str],
'tensor_method_declarations': List[str],
'tensor_method_definitions': List[str],
'function_declarations': List[str],
'function_definitions': List[str],
'type_ids': List[str],
'native_function_declarations': List[str],
})
TopEnvironment = TypedDict('TopEnvironment', {
'type_registrations': List[str],
'type_headers': List[str],
'type_method_declarations': List[str],
'type_method_definitions': List[str],
'type_method_inline_definitions': List[str],
'tensor_method_declarations': List[str],
'tensor_method_definitions': List[str],
'function_declarations': List[str],
'function_definitions': List[str],
'type_ids': List[str],
'native_function_declarations': List[str],
})
# A Declarations.cwrap formal argument
# type can contain THTensor* types
THFormal = TypedDict('THFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
'declared_type': str,
'ignore_check': bool,
'allocate': bool,
'mask': bool,
'if_true': bool,
'if_false': bool,
'wrap_dim': str,
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
'cpu_zero': bool,
'zero': bool,
}, total=False)
# A Declarations.cwrap formal argument
# type can contain THTensor* types
THFormal = TypedDict('THFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
'declared_type': str,
'ignore_check': bool,
'allocate': bool,
'mask': bool,
'if_true': bool,
'if_false': bool,
'wrap_dim': str,
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
'cpu_zero': bool,
'zero': bool,
}, total=False)
# A native_functions.yaml formal argument
# type can contain Tensor, BoolTensor, IndexTensor types
NativeFormal = TypedDict('NativeFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
# A native_functions.yaml formal argument
# type can contain Tensor, BoolTensor, IndexTensor types
NativeFormal = TypedDict('NativeFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
# Generic ATen formal.
# type can contain Tensor& reference types.
AtFormal = TypedDict('AtFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
# Generic ATen formal.
# type can contain Tensor& reference types.
AtFormal = TypedDict('AtFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
ReturnType = TypedDict('ReturnType', {
'name': str,
'type': str,
'dynamic_type': str,
}, total=False)
ReturnType = TypedDict('ReturnType', {
'name': str,
'type': str,
'dynamic_type': str,
}, total=False)
ReturnDecl = TypedDict('ReturnDecl', {
'kind': str,
'type': str,
'arguments': List[int],
}, total=False)
ReturnDecl = TypedDict('ReturnDecl', {
'kind': str,
'type': str,
'arguments': List[int],
}, total=False)
# Represents a buffer in nn.yaml
NNBuffer = TypedDict('NNBuffer', {
'name': str,
})
# Represents a buffer in nn.yaml
NNBuffer = TypedDict('NNBuffer', {
'name': str,
})
FunctionOption = TypedDict('FunctionOption', {
'arguments': List[THFormal],
'mode': str,
'name': str,
'return': ReturnDecl,
'variants': str,
'type_method_definition_dispatch': str,
'cname': str,
'backends': List[str],
'api_name': str,
'backend_type_pairs': List[Tuple[str, str]],
'inplace': bool,
'aten_dense_sparse': bool,
'sparse': bool,
'scalar_check': str,
'aten_custom_call': str,
'type_definition_body': List[str],
# cimpls is really a List[FunctionOption]
'cimpls': List[Any],
'when_spares_dispatch': str,
'actuals': List[str],
'buffers': List[NNBuffer],
'zero_dim_dispatch_when_scalar': str,
'zero_dim_tensor_only': bool,
'when_sparse_dispatch': str,
'formals_list': List[AtFormal],
'condition': str,
'auto_gpu': bool,
'cpu_half': bool,
# options should be List[FunctionOption]
'options': Any,
'formals': List[str],
'formals_with_defaults': List[str],
'returns': List[ReturnType],
'return_type': str,
'return_call': str,
'method_formals': List[str],
'method_formals_with_defaults': List[str],
'method_actuals': List[str],
'const_mark': str,
'method_prefix_derived': str,
'broadcast_actuals': List[str],
'broadcast_returns': List[str],
'inferred_type': str,
'broadcast_function': str,
'broadcast_modified_actuals': List[str],
'native_type_method_dispatch': str,
})
FunctionOption = TypedDict('FunctionOption', {
'arguments': List[THFormal],
'mode': str,
'name': str,
'return': ReturnDecl,
'variants': str,
'type_method_definition_dispatch': str,
'cname': str,
'backends': List[str],
'api_name': str,
'backend_type_pairs': List[Tuple[str, str]],
'inplace': bool,
'aten_dense_sparse': bool,
'sparse': bool,
'scalar_check': str,
'aten_custom_call': str,
'type_definition_body': List[str],
# cimpls is really a List[FunctionOption]
'cimpls': List[Any],
'when_spares_dispatch': str,
'actuals': List[str],
'buffers': List[NNBuffer],
'zero_dim_dispatch_when_scalar': str,
'zero_dim_tensor_only': bool,
'when_sparse_dispatch': str,
'formals_list': List[AtFormal],
'condition': str,
'auto_gpu': bool,
'cpu_half': bool,
# options should be List[FunctionOption]
'options': Any,
'formals': List[str],
'formals_with_defaults': List[str],
'returns': List[ReturnType],
'return_type': str,
'return_call': str,
'method_formals': List[str],
'method_formals_with_defaults': List[str],
'method_actuals': List[str],
'const_mark': str,
'method_prefix_derived': str,
'broadcast_actuals': List[str],
'broadcast_returns': List[str],
'inferred_type': str,
'broadcast_function': str,
'broadcast_modified_actuals': List[str],
'native_type_method_dispatch': str,
})
OutputDeclaration = NamedTuple('OutputDeclaration', [
('name', str),
('method_prefix_derived', str),
('arguments', List[AtFormal]),
('method_of', List[str]),
('mode', str),
('buffers', Optional[List[str]]),
('returns', List[ReturnType]),
('inplace', bool),
('abstract', bool),
])
def is_real_argument_to_wrapper(argument):
@ -458,7 +479,7 @@ def to_return_type(arg, option):
def create_generic(top_env, declarations):
# type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict]
# type: (TopEnvironment, List[FunctionOption]) -> List[OutputDeclaration]
# translates defaults from cwrap types to C++ values
def translate_default(argument, type_str, default):
# type: (THFormal, str, Any) -> Any
@ -656,7 +677,7 @@ def create_generic(top_env, declarations):
return body
def process_option(option, output_options):
# type: (FunctionOption, List[OrderedDict]) -> None
# type: (FunctionOption, List[OutputDeclaration]) -> None
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
@ -748,18 +769,18 @@ def create_generic(top_env, declarations):
buffer_names = [buffer['name'] for buffer in option.get('buffers', [])]
output_options.append(OrderedDict([
('name', option['api_name']),
('method_prefix_derived', option['method_prefix_derived']),
('arguments', formals),
('method_of', method_of),
('mode', mode),
('buffers', buffer_names),
('returns', option['returns']),
('inplace', option['inplace']),
output_options.append(OutputDeclaration(
name=option['api_name'],
method_prefix_derived=option['method_prefix_derived'],
arguments=formals,
method_of=method_of,
mode=mode,
buffers=buffer_names,
returns=option['returns'],
inplace=option['inplace'],
# See Note [Abstract ATen methods]
('abstract', abstract),
]))
abstract=abstract,
))
def native_get_formals(option, include_constants=False):
# type: (FunctionOption, bool) -> List[AtFormal]
@ -843,7 +864,7 @@ def create_generic(top_env, declarations):
return return_types
def process_native(option, output_options):
# type: (FunctionOption, List[OrderedDict]) -> None
# type: (FunctionOption, List[OutputDeclaration]) -> None
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
@ -932,21 +953,22 @@ def create_generic(top_env, declarations):
FUNCTION_DEFINITION.substitute(env))
method_of.append('namespace')
output_options.append(OrderedDict([
('name', option['api_name']),
('method_prefix_derived', option['method_prefix_derived']),
('arguments', formals),
('method_of', method_of),
('mode', option['mode']),
('returns', option['returns']),
('inplace', option['inplace']),
output_options.append(OutputDeclaration(
name=option['api_name'],
method_prefix_derived=option['method_prefix_derived'],
arguments=formals,
method_of=method_of,
mode=option['mode'],
buffers=None,
returns=option['returns'],
inplace=option['inplace'],
# See Note [Abstract ATen methods]
('abstract', abstract),
]))
abstract=abstract,
))
output_declarations = [] # type: List[OrderedDict]
output_declarations = [] # type: List[OutputDeclaration]
for declaration in declarations:
output_options = [] # type: List[OrderedDict]
output_options = [] # type: List[OutputDeclaration]
for option in declaration['options']:
try:
if option['mode'] != 'native':

View File

@ -165,19 +165,25 @@ def postprocess_output_declarations(output_declarations):
# ensure each return has a name associated with it
for decl in output_declarations:
has_named_ret = False
for n, ret in enumerate(decl['returns']):
for n, ret in enumerate(decl.returns):
if 'name' not in ret:
assert not has_named_ret
if decl['inplace']:
if decl.inplace:
ret['name'] = 'self'
elif len(decl['returns']) == 1:
elif len(decl.returns) == 1:
ret['name'] = 'result'
else:
ret['name'] = 'result' + str(n)
else:
has_named_ret = True
return output_declarations
def remove_key_if_none(dictionary, key):
if key in dictionary.keys() and dictionary[key] is None:
del dictionary[key]
return dictionary
return [remove_key_if_none(decl._asdict(), 'buffers')
for decl in output_declarations]
def format_yaml(data):

11
mypy-README.md Normal file
View File

@ -0,0 +1,11 @@
### Optional type checking with mypy
mypy is an optional static typechecker that works with Python 3.
To use it, install the following dependencies:
```bash
# Install dependencies
pip install mypy mypy-extensions
# Run type checker in the pytorch/ directory
mypy @mypy-files.txt
```

1
mypy-files.txt Normal file
View File

@ -0,0 +1 @@
aten/src/ATen/function_wrapper.py

2
mypy.ini Normal file
View File

@ -0,0 +1,2 @@
[mypy]
python_version = 2.7

View File

@ -1 +1,2 @@
pyyaml
typing