Initial type hints for function_wrapper (#4947)

* Initial type hints for function_wrapper

* Don't break python 2

* Update TopEnvironment

* Add mypy check to travis

* Add .mypy_cache to .gitignore
This commit is contained in:
Richard Zou
2018-02-08 13:52:31 -05:00
committed by Edward Z. Yang
parent 696db00bcd
commit 0629785645
3 changed files with 227 additions and 20 deletions

1
.gitignore vendored
View File

@ -40,6 +40,7 @@ test/.coverage
*/**/*.dylib*
test/data/legacy_serialized.pt
test/data/linear.pt
.mypy_cache
# IPython notebook checkpoints
.ipynb_checkpoints

View File

@ -12,7 +12,12 @@ sudo: false
matrix:
fast_finish: true
include:
env: LINT_CHECK
- env: LINT_CHECK
python: "2.7"
install: pip install flake8
script: flake8
# mypy will complain about various files. Limiting it to only typed files
- env: MYPY_TYPE_CHECK
python: "3.6"
install: pip install mypy mypy-extensions
script: mypy --py2 aten/src/ATen/function_wrapper.py

View File

@ -5,6 +5,14 @@ import re
from collections import OrderedDict
from code_template import CodeTemplate
try:
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
Union, TypeVar
from mypy_extensions import TypedDict
TYPE_HINTS = True
except ImportError:
TYPE_HINTS = False
import sys
if sys.version_info[0] == 3:
string_type = str
@ -274,18 +282,168 @@ class nested_dict(object):
return r
return self.parent[x]
if TYPE_HINTS:
Environment = TypedDict('Environment', {
'ScalarName': str,
'THTensor': str,
'THType': str,
'THTensor': str,
'Backend': str,
'AccScalarName': str,
})
TopEnvironment = TypedDict('TopEnvironment', {
'type_registrations': List[str],
'type_headers': List[str],
'type_method_declarations': List[str],
'type_method_definitions': List[str],
'type_method_inline_definitions': List[str],
'tensor_method_declarations': List[str],
'tensor_method_definitions': List[str],
'function_declarations': List[str],
'function_definitions': List[str],
'type_ids': List[str],
'native_function_declarations': List[str],
})
# A Declarations.cwrap formal argument
# type can contain THTensor* types
THFormal = TypedDict('THFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
'declared_type': str,
'ignore_check': bool,
'allocate': bool,
'mask': bool,
'if_true': bool,
'if_false': bool,
'wrap_dim': str,
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
'cpu_zero': bool,
'zero': bool,
}, total=False)
# A native_functions.yaml formal argument
# type can contain Tensor, BoolTensor, IndexTensor types
NativeFormal = TypedDict('NativeFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
# Generic ATen formal.
# type can contain Tensor& reference types.
AtFormal = TypedDict('AtFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
ReturnType = TypedDict('ReturnType', {
'name': str,
'type': str,
'dynamic_type': str,
}, total=False)
ReturnDecl = TypedDict('ReturnDecl', {
'kind': str,
'type': str,
'arguments': List[int],
}, total=False)
# Represents a buffer in nn.yaml
NNBuffer = TypedDict('NNBuffer', {
'name': str,
})
FunctionOption = TypedDict('FunctionOption', {
'arguments': List[THFormal],
'mode': str,
'name': str,
'return': ReturnDecl,
'variants': str,
'type_method_definition_dispatch': str,
'cname': str,
'backends': List[str],
'api_name': str,
'backend_type_pairs': List[Tuple[str, str]],
'inplace': bool,
'aten_dense_sparse': bool,
'sparse': bool,
'scalar_check': str,
'aten_custom_call': str,
'type_definition_body': List[str],
# cimpls is really a List[FunctionOption]
'cimpls': List[Any],
'when_spares_dispatch': str,
'actuals': List[str],
'buffers': List[NNBuffer],
'zero_dim_dispatch_when_scalar': str,
'zero_dim_tensor_only': bool,
'when_sparse_dispatch': str,
'formals_list': List[AtFormal],
'condition': str,
'auto_gpu': bool,
'cpu_half': bool,
# options should be List[FunctionOption]
'options': Any,
'formals': List[str],
'formals_with_defaults': List[str],
'returns': List[ReturnType],
'return_type': str,
'return_call': str,
'method_formals': List[str],
'method_formals_with_defaults': List[str],
'method_actuals': List[str],
'const_mark': str,
'method_prefix_derived': str,
'broadcast_actuals': List[str],
'broadcast_returns': List[str],
'inferred_type': str,
'broadcast_function': str,
'broadcast_modified_actuals': List[str],
'native_type_method_dispatch': str,
})
def is_real_argument_to_wrapper(argument):
# type: (THFormal) -> bool
return not argument.get('output', False) and\
argument['type'] != 'CONSTANT' and\
argument['type'] != 'argument'
def is_mutable_formal_argument(argument, option):
# type: (THFormal, FunctionOption) -> bool
return argument.get('output') or option['inplace'] and argument['name'] == 'self'
def to_return_type(arg, option):
# type: (THFormal, FunctionOption) -> ReturnType
t = arg['type']
rt = TYPE_RETURN.get(t, t)
if rt == 'Tensor' and not arg.get('allocate'):
@ -300,8 +458,10 @@ def to_return_type(arg, option):
def create_generic(top_env, declarations):
# type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict]
# translates defaults from cwrap types to C++ values
def translate_default(argument, type_str, default):
# type: (THFormal, str, Any) -> Any
if default is None:
# cause the default constructor for the object to run
return '{}'
@ -320,6 +480,7 @@ def create_generic(top_env, declarations):
# change from THTensor* to Tensor & so we get how it will appear
# in the aten argument list...
def translate_formal(argument, option):
# type: (THFormal, FunctionOption) -> AtFormal
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
type_str = 'const ' + type_str
@ -327,7 +488,7 @@ def create_generic(top_env, declarations):
'name': argument['name'],
'type': type_str,
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
}
} # type: AtFormal
if 'kwarg_only' in argument:
translated['kwarg_only'] = argument['kwarg_only']
if 'default' in argument:
@ -347,11 +508,13 @@ def create_generic(top_env, declarations):
return translated
def get_formals(option, include_constants=False):
seen = set()
pos_args = []
kwd_args = []
# type: (FunctionOption, bool) -> List[AtFormal]
seen = set() # type: Set[str]
pos_args = [] # type: List[THFormal]
kwd_args = [] # type: List[THFormal]
def insert(argument):
# type: (THFormal) -> None
if argument['name'] not in seen:
seen.add(argument['name'])
if argument.get('kwarg_only', False):
@ -360,6 +523,7 @@ def create_generic(top_env, declarations):
pos_args.append(argument)
def has_output_mask(argument):
# type: (THFormal) -> bool
return argument.get('allocate', False) and argument.get('mask', False)
for argument in option['arguments']:
@ -389,6 +553,7 @@ def create_generic(top_env, declarations):
return [translate_formal(argument, option) for argument in result]
def get_return_types(option):
# type: (FunctionOption) -> List[ReturnType]
ret = option['return']
if ret['kind'] == 'arguments':
argument_indices = ret['arguments']
@ -407,11 +572,13 @@ def create_generic(top_env, declarations):
raise Exception("format_return_type")
def format_return_type(return_types):
# type: (List[ReturnType]) -> str
if len(return_types) == 1:
return return_types[0]['type']
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
def find_dispatch_tensor(formals):
# type: (List[AtFormal]) -> Optional[str]
# dispatch to self if it's a parameter
for formal in formals:
if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor':
@ -423,9 +590,11 @@ def create_generic(top_env, declarations):
return None
def format_formal(f):
# type: (AtFormal) -> str
return '{} {}'.format(f['type'], f['name'])
def formal_with_default(f):
# type: (AtFormal) -> str
s = format_formal(f)
v = f.get('default')
if v is None:
@ -435,11 +604,15 @@ def create_generic(top_env, declarations):
return '{}={}'.format(s, v)
def get_broadcast_argument(option):
# type: (FunctionOption) -> Optional[THFormal]
for argument in option['arguments']:
if argument.get('broadcast'):
return argument
return None
def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
# type: (THFormal, bool, bool) -> List[str]
# Note: broadcast_dims can change type...
# return the actuals that will be passed to the broadcast function.
# 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors
# that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2").
@ -451,14 +624,15 @@ def create_generic(top_env, declarations):
else:
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
# generate size call for each dimension
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')'
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
for x in broadcast_dims_spec])
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}'
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
return broadcast_actuals
def emit_nn_body(option):
# type: (FunctionOption) -> Union[str, List[str]]
# Concrete definition on Type.cpp for NN functions. Delegates to the
# xxx_forward variant variant after creating any necessary buffers.
actuals = option['actuals']
@ -468,7 +642,7 @@ def create_generic(top_env, declarations):
if len(option['buffers']) == 0:
return 'return {}({});'.format(fwd_name, ', '.join(actuals))
body = []
body = [] # type: List[str]
if option['api_name'].endswith('_out'):
# _out variants must create buffers and insert them in the
# arguments list between output and input arguments
@ -482,6 +656,7 @@ def create_generic(top_env, declarations):
return body
def process_option(option, output_options):
# type: (FunctionOption, List[OrderedDict]) -> None
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
@ -587,11 +762,13 @@ def create_generic(top_env, declarations):
]))
def native_get_formals(option, include_constants=False):
seen = set()
# type: (FunctionOption, bool) -> List[AtFormal]
seen = set() # type: Set[str]
pos_args = []
kwd_args = []
def insert(argument):
# type: (NativeFormal) -> None
if argument['name'] not in seen:
seen.add(argument['name'])
if argument.get('kwarg_only', False):
@ -605,6 +782,7 @@ def create_generic(top_env, declarations):
# not clear we need dynamic_type translation as we can specify the correct type
# directly in native functions
def add_type_as_dynamic_type(argument, option):
# type: (NativeFormal, FunctionOption) -> NativeFormal
argument['dynamic_type'] = argument['type']
return argument
@ -613,7 +791,9 @@ def create_generic(top_env, declarations):
# ensure we get reference-type formals when appropriate
def native_translate_formals(argument, option):
# type: (NativeFormal, FunctionOption) -> AtFormal
def translate_map(const):
# type: (bool) -> Dict[str, str]
return {
'Tensor': 'const Tensor &' if const else 'Tensor &',
'BoolTensor': 'const Tensor &' if const else 'Tensor &',
@ -632,9 +812,10 @@ def create_generic(top_env, declarations):
# this can return multiple return types in a list, e.g. ['Tensor', 'Tensor']
def native_get_return_types(option):
# type: (FunctionOption) -> List[ReturnType]
ret = option['return']
return_types = []
return_types = [] # List[ReturnType]
for t_raw in ret:
if isinstance(t_raw, string_type):
t = t_raw
@ -653,7 +834,7 @@ def create_generic(top_env, declarations):
rtype = {
'type': actual_return_type,
'dynamic_type': t,
}
} # type: ReturnType
if name is not None:
rtype['name'] = name
return_types.append(rtype)
@ -661,6 +842,7 @@ def create_generic(top_env, declarations):
return return_types
def process_native(option, output_options):
# type: (FunctionOption, List[OrderedDict]) -> None
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
@ -721,7 +903,7 @@ def create_generic(top_env, declarations):
# generate the at::native function declarations (i.e. what the user will implement)
if isinstance(dispatch, dict):
generated_native_functions = []
generated_native_functions = [] # type: List[str]
for key in sorted(dispatch.keys()):
value = dispatch[key]
if value not in generated_native_functions:
@ -761,9 +943,9 @@ def create_generic(top_env, declarations):
('abstract', abstract),
]))
output_declarations = []
output_declarations = [] # type: List[OrderedDict]
for declaration in declarations:
output_options = []
output_options = [] # type: List[OrderedDict]
for option in declaration['options']:
try:
if option['mode'] != 'native':
@ -777,6 +959,7 @@ def create_generic(top_env, declarations):
def create_derived(backend_type_env, declarations):
# type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
type_object_declarations = []
type_object_definitions = []
@ -785,21 +968,26 @@ def create_derived(backend_type_env, declarations):
real_is_half = backend_type_env['ScalarName'] == 'Half'
def replace_with_null(argument):
# type: (THFormal) -> bool
return (argument['type'] == 'THGenerator*' and
backend_type_env['Backend'] == 'CUDA')
def requires_checked_cast(argument):
# type: (THFormal) -> bool
if argument['type'] == 'IntList':
return 'size' in argument
return argument['type'] in CHECKED_CAST
def nullable_argument(argument):
# type: (THFormal) -> bool
return argument.get('is_nullable', False)
def bool_option_is_string(argument):
# type: (THFormal) -> bool
return 'if_true' in argument and isinstance(argument['if_true'], string_type)
def get_argument(argument, option):
# type: (THFormal, FunctionOption) -> str
if replace_with_null(argument):
return 'NULL'
elif requires_checked_cast(argument):
@ -834,14 +1022,17 @@ def create_derived(backend_type_env, declarations):
return argument['name']
def drop_argument(argument, option):
# type: (THFormal, FunctionOption) -> bool
return 'CUDA' in backend_type_env['Backend'] and (
option['mode'] == 'TH' and argument['type'] == 'THGenerator*')
def get_arguments(arguments, option):
# type: (List[THFormal], FunctionOption) -> List[str]
return [get_argument(argument, option)
for argument in arguments if not drop_argument(argument, option)]
def is_actual_return_long(ret):
# type: (ReturnDecl) -> bool
if ret['type'] == 'long':
return True
if ret['type'] == 'real':
@ -851,9 +1042,11 @@ def create_derived(backend_type_env, declarations):
return False
def get_zero_dim_dispatch_when_scalar(option):
return option.get('zero_dim_dispatch_when_scalar', False)
# type: (FunctionOption) -> str
return option.get('zero_dim_dispatch_when_scalar', False) # type: ignore
def handle_zero_dim(env, option):
# type: (Environment, FunctionOption) -> List[str]
zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option)
if not zero_dim_dispatch:
return []
@ -863,6 +1056,7 @@ def create_derived(backend_type_env, declarations):
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
def handle_only_zero_dim(env, option):
# type: (Environment, FunctionOption) -> List[str]
if option.get('zero_dim_tensor_only', False):
check_name = get_zero_dim_dispatch_when_scalar(option)
return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
@ -870,6 +1064,7 @@ def create_derived(backend_type_env, declarations):
return None
def handle_sparse(env, option):
# type: (Environment, FunctionOption) -> List[str]
if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
return []
check_name = option['when_sparse_dispatch']
@ -879,6 +1074,7 @@ def create_derived(backend_type_env, declarations):
return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
def allocate_arg(env, arg, output_count):
# type: (Environment, THFormal, int) -> List[str]
name = arg['name']
allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[])
tensor_arg = '{}_'.format(name)
@ -892,6 +1088,7 @@ def create_derived(backend_type_env, declarations):
]
def resize_arg(arg):
# type: (THFormal) -> str
resize = arg['resize']
if isinstance(resize, str):
return "{}.resize_({}.sizes());".format(arg['name'], resize)
@ -904,6 +1101,7 @@ def create_derived(backend_type_env, declarations):
return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))
def handle_call(env, option, cimpl):
# type: (Environment, FunctionOption, FunctionOption) -> str
is_nn = option['mode'] == 'NN'
actuals = get_arguments(cimpl['arguments'], option)
if is_cuda or is_nn:
@ -926,7 +1124,8 @@ def create_derived(backend_type_env, declarations):
return call
def emit_body(env, option):
body = []
# type: (Environment, FunctionOption) -> List[str]
body = [] # type: List[str]
body += handle_sparse(env, option)
body += handle_zero_dim(env, option)
only_zero_dim_check = handle_only_zero_dim(env, option)
@ -937,8 +1136,8 @@ def create_derived(backend_type_env, declarations):
# arguments are potentially duplicated because of one argument
# referencing another
seen_names = set()
seen_tensorlists = set()
seen_names = set() # type: Set[str]
seen_tensorlists = set() # type: Set[str]
count = 0
output_count = 0
@ -1064,7 +1263,7 @@ def create_derived(backend_type_env, declarations):
scalar_check = 'maybe_scalar'
for arg in arguments:
scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict)
else scalar_check.get(arg['name']))
else scalar_check.get(arg['name'])) # type: ignore
if scalar_check_arg is not None:
stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg)
if nullable_argument(arg):
@ -1114,11 +1313,12 @@ def create_derived(backend_type_env, declarations):
return body
def process_option(option):
# type: (FunctionOption) -> None
pair = (backend_type_env['Backend'],
backend_type_env['ScalarName'])
if pair in option['backend_type_pairs']:
env = nested_dict(option, backend_type_env)
body = emit_body(env, option)
body = emit_body(env, option) # type: ignore
option['type_definition_body'] = body
type_object_declarations.append(
TYPE_DERIVED_DECLARATION.substitute(env))
@ -1126,6 +1326,7 @@ def create_derived(backend_type_env, declarations):
TYPE_DERIVED_DEFINITION.substitute(env))
def process_native(option):
# type: (FunctionOption) -> None
dispatch = option['type_method_definition_dispatch']
env = nested_dict(option, backend_type_env)