Initial type hints for function_wrapper (#4947)

* Initial type hints for function_wrapper

* Don't break python 2

* Update TopEnvironment

* Add mypy check to travis

* Add .mypy_cache to .gitignore
This commit is contained in:
Richard Zou
2018-02-08 13:52:31 -05:00
committed by Edward Z. Yang
parent 696db00bcd
commit 0629785645
3 changed files with 227 additions and 20 deletions

1
.gitignore vendored
View File

@ -40,6 +40,7 @@ test/.coverage
*/**/*.dylib* */**/*.dylib*
test/data/legacy_serialized.pt test/data/legacy_serialized.pt
test/data/linear.pt test/data/linear.pt
.mypy_cache
# IPython notebook checkpoints # IPython notebook checkpoints
.ipynb_checkpoints .ipynb_checkpoints

View File

@ -12,7 +12,12 @@ sudo: false
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
env: LINT_CHECK - env: LINT_CHECK
python: "2.7" python: "2.7"
install: pip install flake8 install: pip install flake8
script: flake8 script: flake8
# mypy will complain about various files. Limiting it to only typed files
- env: MYPY_TYPE_CHECK
python: "3.6"
install: pip install mypy mypy-extensions
script: mypy --py2 aten/src/ATen/function_wrapper.py

View File

@ -5,6 +5,14 @@ import re
from collections import OrderedDict from collections import OrderedDict
from code_template import CodeTemplate from code_template import CodeTemplate
try:
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
Union, TypeVar
from mypy_extensions import TypedDict
TYPE_HINTS = True
except ImportError:
TYPE_HINTS = False
import sys import sys
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
string_type = str string_type = str
@ -274,18 +282,168 @@ class nested_dict(object):
return r return r
return self.parent[x] return self.parent[x]
if TYPE_HINTS:
Environment = TypedDict('Environment', {
'ScalarName': str,
'THTensor': str,
'THType': str,
'THTensor': str,
'Backend': str,
'AccScalarName': str,
})
TopEnvironment = TypedDict('TopEnvironment', {
'type_registrations': List[str],
'type_headers': List[str],
'type_method_declarations': List[str],
'type_method_definitions': List[str],
'type_method_inline_definitions': List[str],
'tensor_method_declarations': List[str],
'tensor_method_definitions': List[str],
'function_declarations': List[str],
'function_definitions': List[str],
'type_ids': List[str],
'native_function_declarations': List[str],
})
# A Declarations.cwrap formal argument
# type can contain THTensor* types
THFormal = TypedDict('THFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
'declared_type': str,
'ignore_check': bool,
'allocate': bool,
'mask': bool,
'if_true': bool,
'if_false': bool,
'wrap_dim': str,
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
'cpu_zero': bool,
'zero': bool,
}, total=False)
# A native_functions.yaml formal argument
# type can contain Tensor, BoolTensor, IndexTensor types
NativeFormal = TypedDict('NativeFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
# Generic ATen formal.
# type can contain Tensor& reference types.
AtFormal = TypedDict('AtFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'default_init': str,
'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
ReturnType = TypedDict('ReturnType', {
'name': str,
'type': str,
'dynamic_type': str,
}, total=False)
ReturnDecl = TypedDict('ReturnDecl', {
'kind': str,
'type': str,
'arguments': List[int],
}, total=False)
# Represents a buffer in nn.yaml
NNBuffer = TypedDict('NNBuffer', {
'name': str,
})
FunctionOption = TypedDict('FunctionOption', {
'arguments': List[THFormal],
'mode': str,
'name': str,
'return': ReturnDecl,
'variants': str,
'type_method_definition_dispatch': str,
'cname': str,
'backends': List[str],
'api_name': str,
'backend_type_pairs': List[Tuple[str, str]],
'inplace': bool,
'aten_dense_sparse': bool,
'sparse': bool,
'scalar_check': str,
'aten_custom_call': str,
'type_definition_body': List[str],
# cimpls is really a List[FunctionOption]
'cimpls': List[Any],
'when_spares_dispatch': str,
'actuals': List[str],
'buffers': List[NNBuffer],
'zero_dim_dispatch_when_scalar': str,
'zero_dim_tensor_only': bool,
'when_sparse_dispatch': str,
'formals_list': List[AtFormal],
'condition': str,
'auto_gpu': bool,
'cpu_half': bool,
# options should be List[FunctionOption]
'options': Any,
'formals': List[str],
'formals_with_defaults': List[str],
'returns': List[ReturnType],
'return_type': str,
'return_call': str,
'method_formals': List[str],
'method_formals_with_defaults': List[str],
'method_actuals': List[str],
'const_mark': str,
'method_prefix_derived': str,
'broadcast_actuals': List[str],
'broadcast_returns': List[str],
'inferred_type': str,
'broadcast_function': str,
'broadcast_modified_actuals': List[str],
'native_type_method_dispatch': str,
})
def is_real_argument_to_wrapper(argument): def is_real_argument_to_wrapper(argument):
# type: (THFormal) -> bool
return not argument.get('output', False) and\ return not argument.get('output', False) and\
argument['type'] != 'CONSTANT' and\ argument['type'] != 'CONSTANT' and\
argument['type'] != 'argument' argument['type'] != 'argument'
def is_mutable_formal_argument(argument, option): def is_mutable_formal_argument(argument, option):
# type: (THFormal, FunctionOption) -> bool
return argument.get('output') or option['inplace'] and argument['name'] == 'self' return argument.get('output') or option['inplace'] and argument['name'] == 'self'
def to_return_type(arg, option): def to_return_type(arg, option):
# type: (THFormal, FunctionOption) -> ReturnType
t = arg['type'] t = arg['type']
rt = TYPE_RETURN.get(t, t) rt = TYPE_RETURN.get(t, t)
if rt == 'Tensor' and not arg.get('allocate'): if rt == 'Tensor' and not arg.get('allocate'):
@ -300,8 +458,10 @@ def to_return_type(arg, option):
def create_generic(top_env, declarations): def create_generic(top_env, declarations):
# type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict]
# translates defaults from cwrap types to C++ values # translates defaults from cwrap types to C++ values
def translate_default(argument, type_str, default): def translate_default(argument, type_str, default):
# type: (THFormal, str, Any) -> Any
if default is None: if default is None:
# cause the default constructor for the object to run # cause the default constructor for the object to run
return '{}' return '{}'
@ -320,6 +480,7 @@ def create_generic(top_env, declarations):
# change from THTensor* to Tensor & so we get how it will appear # change from THTensor* to Tensor & so we get how it will appear
# in the aten argument list... # in the aten argument list...
def translate_formal(argument, option): def translate_formal(argument, option):
# type: (THFormal, FunctionOption) -> AtFormal
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type']) type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option): if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
type_str = 'const ' + type_str type_str = 'const ' + type_str
@ -327,7 +488,7 @@ def create_generic(top_env, declarations):
'name': argument['name'], 'name': argument['name'],
'type': type_str, 'type': type_str,
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']), 'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
} } # type: AtFormal
if 'kwarg_only' in argument: if 'kwarg_only' in argument:
translated['kwarg_only'] = argument['kwarg_only'] translated['kwarg_only'] = argument['kwarg_only']
if 'default' in argument: if 'default' in argument:
@ -347,11 +508,13 @@ def create_generic(top_env, declarations):
return translated return translated
def get_formals(option, include_constants=False): def get_formals(option, include_constants=False):
seen = set() # type: (FunctionOption, bool) -> List[AtFormal]
pos_args = [] seen = set() # type: Set[str]
kwd_args = [] pos_args = [] # type: List[THFormal]
kwd_args = [] # type: List[THFormal]
def insert(argument): def insert(argument):
# type: (THFormal) -> None
if argument['name'] not in seen: if argument['name'] not in seen:
seen.add(argument['name']) seen.add(argument['name'])
if argument.get('kwarg_only', False): if argument.get('kwarg_only', False):
@ -360,6 +523,7 @@ def create_generic(top_env, declarations):
pos_args.append(argument) pos_args.append(argument)
def has_output_mask(argument): def has_output_mask(argument):
# type: (THFormal) -> bool
return argument.get('allocate', False) and argument.get('mask', False) return argument.get('allocate', False) and argument.get('mask', False)
for argument in option['arguments']: for argument in option['arguments']:
@ -389,6 +553,7 @@ def create_generic(top_env, declarations):
return [translate_formal(argument, option) for argument in result] return [translate_formal(argument, option) for argument in result]
def get_return_types(option): def get_return_types(option):
# type: (FunctionOption) -> List[ReturnType]
ret = option['return'] ret = option['return']
if ret['kind'] == 'arguments': if ret['kind'] == 'arguments':
argument_indices = ret['arguments'] argument_indices = ret['arguments']
@ -407,11 +572,13 @@ def create_generic(top_env, declarations):
raise Exception("format_return_type") raise Exception("format_return_type")
def format_return_type(return_types): def format_return_type(return_types):
# type: (List[ReturnType]) -> str
if len(return_types) == 1: if len(return_types) == 1:
return return_types[0]['type'] return return_types[0]['type']
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types)) return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
def find_dispatch_tensor(formals): def find_dispatch_tensor(formals):
# type: (List[AtFormal]) -> Optional[str]
# dispatch to self if it's a parameter # dispatch to self if it's a parameter
for formal in formals: for formal in formals:
if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor': if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor':
@ -423,9 +590,11 @@ def create_generic(top_env, declarations):
return None return None
def format_formal(f): def format_formal(f):
# type: (AtFormal) -> str
return '{} {}'.format(f['type'], f['name']) return '{} {}'.format(f['type'], f['name'])
def formal_with_default(f): def formal_with_default(f):
# type: (AtFormal) -> str
s = format_formal(f) s = format_formal(f)
v = f.get('default') v = f.get('default')
if v is None: if v is None:
@ -435,11 +604,15 @@ def create_generic(top_env, declarations):
return '{}={}'.format(s, v) return '{}={}'.format(s, v)
def get_broadcast_argument(option): def get_broadcast_argument(option):
# type: (FunctionOption) -> Optional[THFormal]
for argument in option['arguments']: for argument in option['arguments']:
if argument.get('broadcast'): if argument.get('broadcast'):
return argument return argument
return None
def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims): def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
# type: (THFormal, bool, bool) -> List[str]
# Note: broadcast_dims can change type...
# return the actuals that will be passed to the broadcast function. # return the actuals that will be passed to the broadcast function.
# 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors # 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors
# that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2"). # that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2").
@ -451,14 +624,15 @@ def create_generic(top_env, declarations):
else: else:
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',') broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
# generate size call for each dimension # generate size call for each dimension
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
for x in broadcast_dims_spec]) for x in broadcast_dims_spec])
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list] broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
return broadcast_actuals return broadcast_actuals
def emit_nn_body(option): def emit_nn_body(option):
# type: (FunctionOption) -> Union[str, List[str]]
# Concrete definition on Type.cpp for NN functions. Delegates to the # Concrete definition on Type.cpp for NN functions. Delegates to the
# xxx_forward variant variant after creating any necessary buffers. # xxx_forward variant variant after creating any necessary buffers.
actuals = option['actuals'] actuals = option['actuals']
@ -468,7 +642,7 @@ def create_generic(top_env, declarations):
if len(option['buffers']) == 0: if len(option['buffers']) == 0:
return 'return {}({});'.format(fwd_name, ', '.join(actuals)) return 'return {}({});'.format(fwd_name, ', '.join(actuals))
body = [] body = [] # type: List[str]
if option['api_name'].endswith('_out'): if option['api_name'].endswith('_out'):
# _out variants must create buffers and insert them in the # _out variants must create buffers and insert them in the
# arguments list between output and input arguments # arguments list between output and input arguments
@ -482,6 +656,7 @@ def create_generic(top_env, declarations):
return body return body
def process_option(option, output_options): def process_option(option, output_options):
# type: (FunctionOption, List[OrderedDict]) -> None
option['inplace'] = re.search( option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None '(^__i|[^_]_$)', option['api_name']) is not None
@ -587,11 +762,13 @@ def create_generic(top_env, declarations):
])) ]))
def native_get_formals(option, include_constants=False): def native_get_formals(option, include_constants=False):
seen = set() # type: (FunctionOption, bool) -> List[AtFormal]
seen = set() # type: Set[str]
pos_args = [] pos_args = []
kwd_args = [] kwd_args = []
def insert(argument): def insert(argument):
# type: (NativeFormal) -> None
if argument['name'] not in seen: if argument['name'] not in seen:
seen.add(argument['name']) seen.add(argument['name'])
if argument.get('kwarg_only', False): if argument.get('kwarg_only', False):
@ -605,6 +782,7 @@ def create_generic(top_env, declarations):
# not clear we need dynamic_type translation as we can specify the correct type # not clear we need dynamic_type translation as we can specify the correct type
# directly in native functions # directly in native functions
def add_type_as_dynamic_type(argument, option): def add_type_as_dynamic_type(argument, option):
# type: (NativeFormal, FunctionOption) -> NativeFormal
argument['dynamic_type'] = argument['type'] argument['dynamic_type'] = argument['type']
return argument return argument
@ -613,7 +791,9 @@ def create_generic(top_env, declarations):
# ensure we get reference-type formals when appropriate # ensure we get reference-type formals when appropriate
def native_translate_formals(argument, option): def native_translate_formals(argument, option):
# type: (NativeFormal, FunctionOption) -> AtFormal
def translate_map(const): def translate_map(const):
# type: (bool) -> Dict[str, str]
return { return {
'Tensor': 'const Tensor &' if const else 'Tensor &', 'Tensor': 'const Tensor &' if const else 'Tensor &',
'BoolTensor': 'const Tensor &' if const else 'Tensor &', 'BoolTensor': 'const Tensor &' if const else 'Tensor &',
@ -632,9 +812,10 @@ def create_generic(top_env, declarations):
# this can return multiple return types in a list, e.g. ['Tensor', 'Tensor'] # this can return multiple return types in a list, e.g. ['Tensor', 'Tensor']
def native_get_return_types(option): def native_get_return_types(option):
# type: (FunctionOption) -> List[ReturnType]
ret = option['return'] ret = option['return']
return_types = [] return_types = [] # List[ReturnType]
for t_raw in ret: for t_raw in ret:
if isinstance(t_raw, string_type): if isinstance(t_raw, string_type):
t = t_raw t = t_raw
@ -653,7 +834,7 @@ def create_generic(top_env, declarations):
rtype = { rtype = {
'type': actual_return_type, 'type': actual_return_type,
'dynamic_type': t, 'dynamic_type': t,
} } # type: ReturnType
if name is not None: if name is not None:
rtype['name'] = name rtype['name'] = name
return_types.append(rtype) return_types.append(rtype)
@ -661,6 +842,7 @@ def create_generic(top_env, declarations):
return return_types return return_types
def process_native(option, output_options): def process_native(option, output_options):
# type: (FunctionOption, List[OrderedDict]) -> None
option['inplace'] = re.search( option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None '(^__i|[^_]_$)', option['api_name']) is not None
@ -721,7 +903,7 @@ def create_generic(top_env, declarations):
# generate the at::native function declarations (i.e. what the user will implement) # generate the at::native function declarations (i.e. what the user will implement)
if isinstance(dispatch, dict): if isinstance(dispatch, dict):
generated_native_functions = [] generated_native_functions = [] # type: List[str]
for key in sorted(dispatch.keys()): for key in sorted(dispatch.keys()):
value = dispatch[key] value = dispatch[key]
if value not in generated_native_functions: if value not in generated_native_functions:
@ -761,9 +943,9 @@ def create_generic(top_env, declarations):
('abstract', abstract), ('abstract', abstract),
])) ]))
output_declarations = [] output_declarations = [] # type: List[OrderedDict]
for declaration in declarations: for declaration in declarations:
output_options = [] output_options = [] # type: List[OrderedDict]
for option in declaration['options']: for option in declaration['options']:
try: try:
if option['mode'] != 'native': if option['mode'] != 'native':
@ -777,6 +959,7 @@ def create_generic(top_env, declarations):
def create_derived(backend_type_env, declarations): def create_derived(backend_type_env, declarations):
# type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
type_object_declarations = [] type_object_declarations = []
type_object_definitions = [] type_object_definitions = []
@ -785,21 +968,26 @@ def create_derived(backend_type_env, declarations):
real_is_half = backend_type_env['ScalarName'] == 'Half' real_is_half = backend_type_env['ScalarName'] == 'Half'
def replace_with_null(argument): def replace_with_null(argument):
# type: (THFormal) -> bool
return (argument['type'] == 'THGenerator*' and return (argument['type'] == 'THGenerator*' and
backend_type_env['Backend'] == 'CUDA') backend_type_env['Backend'] == 'CUDA')
def requires_checked_cast(argument): def requires_checked_cast(argument):
# type: (THFormal) -> bool
if argument['type'] == 'IntList': if argument['type'] == 'IntList':
return 'size' in argument return 'size' in argument
return argument['type'] in CHECKED_CAST return argument['type'] in CHECKED_CAST
def nullable_argument(argument): def nullable_argument(argument):
# type: (THFormal) -> bool
return argument.get('is_nullable', False) return argument.get('is_nullable', False)
def bool_option_is_string(argument): def bool_option_is_string(argument):
# type: (THFormal) -> bool
return 'if_true' in argument and isinstance(argument['if_true'], string_type) return 'if_true' in argument and isinstance(argument['if_true'], string_type)
def get_argument(argument, option): def get_argument(argument, option):
# type: (THFormal, FunctionOption) -> str
if replace_with_null(argument): if replace_with_null(argument):
return 'NULL' return 'NULL'
elif requires_checked_cast(argument): elif requires_checked_cast(argument):
@ -834,14 +1022,17 @@ def create_derived(backend_type_env, declarations):
return argument['name'] return argument['name']
def drop_argument(argument, option): def drop_argument(argument, option):
# type: (THFormal, FunctionOption) -> bool
return 'CUDA' in backend_type_env['Backend'] and ( return 'CUDA' in backend_type_env['Backend'] and (
option['mode'] == 'TH' and argument['type'] == 'THGenerator*') option['mode'] == 'TH' and argument['type'] == 'THGenerator*')
def get_arguments(arguments, option): def get_arguments(arguments, option):
# type: (List[THFormal], FunctionOption) -> List[str]
return [get_argument(argument, option) return [get_argument(argument, option)
for argument in arguments if not drop_argument(argument, option)] for argument in arguments if not drop_argument(argument, option)]
def is_actual_return_long(ret): def is_actual_return_long(ret):
# type: (ReturnDecl) -> bool
if ret['type'] == 'long': if ret['type'] == 'long':
return True return True
if ret['type'] == 'real': if ret['type'] == 'real':
@ -851,9 +1042,11 @@ def create_derived(backend_type_env, declarations):
return False return False
def get_zero_dim_dispatch_when_scalar(option): def get_zero_dim_dispatch_when_scalar(option):
return option.get('zero_dim_dispatch_when_scalar', False) # type: (FunctionOption) -> str
return option.get('zero_dim_dispatch_when_scalar', False) # type: ignore
def handle_zero_dim(env, option): def handle_zero_dim(env, option):
# type: (Environment, FunctionOption) -> List[str]
zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option)
if not zero_dim_dispatch: if not zero_dim_dispatch:
return [] return []
@ -863,6 +1056,7 @@ def create_derived(backend_type_env, declarations):
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
def handle_only_zero_dim(env, option): def handle_only_zero_dim(env, option):
# type: (Environment, FunctionOption) -> List[str]
if option.get('zero_dim_tensor_only', False): if option.get('zero_dim_tensor_only', False):
check_name = get_zero_dim_dispatch_when_scalar(option) check_name = get_zero_dim_dispatch_when_scalar(option)
return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
@ -870,6 +1064,7 @@ def create_derived(backend_type_env, declarations):
return None return None
def handle_sparse(env, option): def handle_sparse(env, option):
# type: (Environment, FunctionOption) -> List[str]
if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']: if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
return [] return []
check_name = option['when_sparse_dispatch'] check_name = option['when_sparse_dispatch']
@ -879,6 +1074,7 @@ def create_derived(backend_type_env, declarations):
return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)] return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
def allocate_arg(env, arg, output_count): def allocate_arg(env, arg, output_count):
# type: (Environment, THFormal, int) -> List[str]
name = arg['name'] name = arg['name']
allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[]) allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[])
tensor_arg = '{}_'.format(name) tensor_arg = '{}_'.format(name)
@ -892,6 +1088,7 @@ def create_derived(backend_type_env, declarations):
] ]
def resize_arg(arg): def resize_arg(arg):
# type: (THFormal) -> str
resize = arg['resize'] resize = arg['resize']
if isinstance(resize, str): if isinstance(resize, str):
return "{}.resize_({}.sizes());".format(arg['name'], resize) return "{}.resize_({}.sizes());".format(arg['name'], resize)
@ -904,6 +1101,7 @@ def create_derived(backend_type_env, declarations):
return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims)) return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))
def handle_call(env, option, cimpl): def handle_call(env, option, cimpl):
# type: (Environment, FunctionOption, FunctionOption) -> str
is_nn = option['mode'] == 'NN' is_nn = option['mode'] == 'NN'
actuals = get_arguments(cimpl['arguments'], option) actuals = get_arguments(cimpl['arguments'], option)
if is_cuda or is_nn: if is_cuda or is_nn:
@ -926,7 +1124,8 @@ def create_derived(backend_type_env, declarations):
return call return call
def emit_body(env, option): def emit_body(env, option):
body = [] # type: (Environment, FunctionOption) -> List[str]
body = [] # type: List[str]
body += handle_sparse(env, option) body += handle_sparse(env, option)
body += handle_zero_dim(env, option) body += handle_zero_dim(env, option)
only_zero_dim_check = handle_only_zero_dim(env, option) only_zero_dim_check = handle_only_zero_dim(env, option)
@ -937,8 +1136,8 @@ def create_derived(backend_type_env, declarations):
# arguments are potentially duplicated because of one argument # arguments are potentially duplicated because of one argument
# referencing another # referencing another
seen_names = set() seen_names = set() # type: Set[str]
seen_tensorlists = set() seen_tensorlists = set() # type: Set[str]
count = 0 count = 0
output_count = 0 output_count = 0
@ -1064,7 +1263,7 @@ def create_derived(backend_type_env, declarations):
scalar_check = 'maybe_scalar' scalar_check = 'maybe_scalar'
for arg in arguments: for arg in arguments:
scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict) scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict)
else scalar_check.get(arg['name'])) else scalar_check.get(arg['name'])) # type: ignore
if scalar_check_arg is not None: if scalar_check_arg is not None:
stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg) stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg)
if nullable_argument(arg): if nullable_argument(arg):
@ -1114,11 +1313,12 @@ def create_derived(backend_type_env, declarations):
return body return body
def process_option(option): def process_option(option):
# type: (FunctionOption) -> None
pair = (backend_type_env['Backend'], pair = (backend_type_env['Backend'],
backend_type_env['ScalarName']) backend_type_env['ScalarName'])
if pair in option['backend_type_pairs']: if pair in option['backend_type_pairs']:
env = nested_dict(option, backend_type_env) env = nested_dict(option, backend_type_env)
body = emit_body(env, option) body = emit_body(env, option) # type: ignore
option['type_definition_body'] = body option['type_definition_body'] = body
type_object_declarations.append( type_object_declarations.append(
TYPE_DERIVED_DECLARATION.substitute(env)) TYPE_DERIVED_DECLARATION.substitute(env))
@ -1126,6 +1326,7 @@ def create_derived(backend_type_env, declarations):
TYPE_DERIVED_DEFINITION.substitute(env)) TYPE_DERIVED_DEFINITION.substitute(env))
def process_native(option): def process_native(option):
# type: (FunctionOption) -> None
dispatch = option['type_method_definition_dispatch'] dispatch = option['type_method_definition_dispatch']
env = nested_dict(option, backend_type_env) env = nested_dict(option, backend_type_env)