diff --git a/.gitignore b/.gitignore index 324477cfac2d..a29491430690 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ test/.coverage */**/*.dylib* test/data/legacy_serialized.pt test/data/linear.pt +.mypy_cache # IPython notebook checkpoints .ipynb_checkpoints diff --git a/.travis.yml b/.travis.yml index f7e610c635b6..1c4a55907c28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,12 @@ sudo: false matrix: fast_finish: true include: - env: LINT_CHECK + - env: LINT_CHECK python: "2.7" install: pip install flake8 script: flake8 + # mypy will complain about various files. Limiting it to only typed files + - env: MYPY_TYPE_CHECK + python: "3.6" + install: pip install mypy mypy-extensions + script: mypy --py2 aten/src/ATen/function_wrapper.py diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 9f213534bfad..912b87cc668e 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -5,6 +5,14 @@ import re from collections import OrderedDict from code_template import CodeTemplate +try: + from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \ + Union, TypeVar + from mypy_extensions import TypedDict + TYPE_HINTS = True +except ImportError: + TYPE_HINTS = False + import sys if sys.version_info[0] == 3: string_type = str @@ -274,18 +282,168 @@ class nested_dict(object): return r return self.parent[x] +if TYPE_HINTS: + Environment = TypedDict('Environment', { + 'ScalarName': str, + 'THTensor': str, + 'THType': str, + 'THTensor': str, + 'Backend': str, + 'AccScalarName': str, + }) + + TopEnvironment = TypedDict('TopEnvironment', { + 'type_registrations': List[str], + 'type_headers': List[str], + 'type_method_declarations': List[str], + 'type_method_definitions': List[str], + 'type_method_inline_definitions': List[str], + 'tensor_method_declarations': List[str], + 'tensor_method_definitions': List[str], + 'function_declarations': List[str], + 'function_definitions': List[str], + 'type_ids': List[str], + 'native_function_declarations': List[str], + }) + + # A Declarations.cwrap formal argument + # type can contain THTensor* types + THFormal = TypedDict('THFormal', { + 'name': str, + 'type': str, + 'dynamic_type': str, + 'kwarg_only': bool, + 'is_nullable': bool, + 'default': str, + 'default_init': str, + 'python_default_init': str, + 'output': bool, + 'size': int, + 'declared_type': str, + 'ignore_check': bool, + 'allocate': bool, + 'mask': bool, + 'if_true': bool, + 'if_false': bool, + 'wrap_dim': str, + # Broadcast is originally a str but gets unwrapped to a List or Dict in-place + 'broadcast': Any, + 'resize': str, + 'cpu_zero': bool, + 'zero': bool, + }, total=False) + + # A native_functions.yaml formal argument + # type can contain Tensor, BoolTensor, IndexTensor types + NativeFormal = TypedDict('NativeFormal', { + 'name': str, + 'type': str, + 'dynamic_type': str, + 'kwarg_only': bool, + 'is_nullable': bool, + 'default': str, + 'default_init': str, + 'python_default_init': str, + 'output': bool, + 'size': int, + }, total=False) + + # Generic ATen formal. + # type can contain Tensor& reference types. + AtFormal = TypedDict('AtFormal', { + 'name': str, + 'type': str, + 'dynamic_type': str, + 'kwarg_only': bool, + 'is_nullable': bool, + 'default': str, + 'default_init': str, + 'python_default_init': str, + 'output': bool, + 'size': int, + }, total=False) + + ReturnType = TypedDict('ReturnType', { + 'name': str, + 'type': str, + 'dynamic_type': str, + }, total=False) + + ReturnDecl = TypedDict('ReturnDecl', { + 'kind': str, + 'type': str, + 'arguments': List[int], + }, total=False) + + # Represents a buffer in nn.yaml + NNBuffer = TypedDict('NNBuffer', { + 'name': str, + }) + + FunctionOption = TypedDict('FunctionOption', { + 'arguments': List[THFormal], + 'mode': str, + 'name': str, + 'return': ReturnDecl, + 'variants': str, + 'type_method_definition_dispatch': str, + 'cname': str, + 'backends': List[str], + 'api_name': str, + 'backend_type_pairs': List[Tuple[str, str]], + 'inplace': bool, + 'aten_dense_sparse': bool, + 'sparse': bool, + 'scalar_check': str, + 'aten_custom_call': str, + 'type_definition_body': List[str], + # cimpls is really a List[FunctionOption] + 'cimpls': List[Any], + 'when_spares_dispatch': str, + 'actuals': List[str], + 'buffers': List[NNBuffer], + 'zero_dim_dispatch_when_scalar': str, + 'zero_dim_tensor_only': bool, + 'when_sparse_dispatch': str, + 'formals_list': List[AtFormal], + 'condition': str, + 'auto_gpu': bool, + 'cpu_half': bool, + # options should be List[FunctionOption] + 'options': Any, + 'formals': List[str], + 'formals_with_defaults': List[str], + 'returns': List[ReturnType], + 'return_type': str, + 'return_call': str, + 'method_formals': List[str], + 'method_formals_with_defaults': List[str], + 'method_actuals': List[str], + 'const_mark': str, + 'method_prefix_derived': str, + 'broadcast_actuals': List[str], + 'broadcast_returns': List[str], + 'inferred_type': str, + 'broadcast_function': str, + 'broadcast_modified_actuals': List[str], + 'native_type_method_dispatch': str, + }) + def is_real_argument_to_wrapper(argument): + # type: (THFormal) -> bool return not argument.get('output', False) and\ argument['type'] != 'CONSTANT' and\ argument['type'] != 'argument' def is_mutable_formal_argument(argument, option): + # type: (THFormal, FunctionOption) -> bool return argument.get('output') or option['inplace'] and argument['name'] == 'self' def to_return_type(arg, option): + # type: (THFormal, FunctionOption) -> ReturnType t = arg['type'] rt = TYPE_RETURN.get(t, t) if rt == 'Tensor' and not arg.get('allocate'): @@ -300,8 +458,10 @@ def to_return_type(arg, option): def create_generic(top_env, declarations): + # type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict] # translates defaults from cwrap types to C++ values def translate_default(argument, type_str, default): + # type: (THFormal, str, Any) -> Any if default is None: # cause the default constructor for the object to run return '{}' @@ -320,6 +480,7 @@ def create_generic(top_env, declarations): # change from THTensor* to Tensor & so we get how it will appear # in the aten argument list... def translate_formal(argument, option): + # type: (THFormal, FunctionOption) -> AtFormal type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type']) if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option): type_str = 'const ' + type_str @@ -327,7 +488,7 @@ def create_generic(top_env, declarations): 'name': argument['name'], 'type': type_str, 'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']), - } + } # type: AtFormal if 'kwarg_only' in argument: translated['kwarg_only'] = argument['kwarg_only'] if 'default' in argument: @@ -347,11 +508,13 @@ def create_generic(top_env, declarations): return translated def get_formals(option, include_constants=False): - seen = set() - pos_args = [] - kwd_args = [] + # type: (FunctionOption, bool) -> List[AtFormal] + seen = set() # type: Set[str] + pos_args = [] # type: List[THFormal] + kwd_args = [] # type: List[THFormal] def insert(argument): + # type: (THFormal) -> None if argument['name'] not in seen: seen.add(argument['name']) if argument.get('kwarg_only', False): @@ -360,6 +523,7 @@ def create_generic(top_env, declarations): pos_args.append(argument) def has_output_mask(argument): + # type: (THFormal) -> bool return argument.get('allocate', False) and argument.get('mask', False) for argument in option['arguments']: @@ -389,6 +553,7 @@ def create_generic(top_env, declarations): return [translate_formal(argument, option) for argument in result] def get_return_types(option): + # type: (FunctionOption) -> List[ReturnType] ret = option['return'] if ret['kind'] == 'arguments': argument_indices = ret['arguments'] @@ -407,11 +572,13 @@ def create_generic(top_env, declarations): raise Exception("format_return_type") def format_return_type(return_types): + # type: (List[ReturnType]) -> str if len(return_types) == 1: return return_types[0]['type'] return "std::tuple<{}>".format(','.join(r['type'] for r in return_types)) def find_dispatch_tensor(formals): + # type: (List[AtFormal]) -> Optional[str] # dispatch to self if it's a parameter for formal in formals: if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor': @@ -423,9 +590,11 @@ def create_generic(top_env, declarations): return None def format_formal(f): + # type: (AtFormal) -> str return '{} {}'.format(f['type'], f['name']) def formal_with_default(f): + # type: (AtFormal) -> str s = format_formal(f) v = f.get('default') if v is None: @@ -435,11 +604,15 @@ def create_generic(top_env, declarations): return '{}={}'.format(s, v) def get_broadcast_argument(option): + # type: (FunctionOption) -> Optional[THFormal] for argument in option['arguments']: if argument.get('broadcast'): return argument + return None def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims): + # type: (THFormal, bool, bool) -> List[str] + # Note: broadcast_dims can change type... # return the actuals that will be passed to the broadcast function. # 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors # that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2"). @@ -451,14 +624,15 @@ def create_generic(top_env, declarations): else: broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',') # generate size call for each dimension - broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' + broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore for x in broadcast_dims_spec]) - broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' + broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list] return broadcast_actuals def emit_nn_body(option): + # type: (FunctionOption) -> Union[str, List[str]] # Concrete definition on Type.cpp for NN functions. Delegates to the # xxx_forward variant variant after creating any necessary buffers. actuals = option['actuals'] @@ -468,7 +642,7 @@ def create_generic(top_env, declarations): if len(option['buffers']) == 0: return 'return {}({});'.format(fwd_name, ', '.join(actuals)) - body = [] + body = [] # type: List[str] if option['api_name'].endswith('_out'): # _out variants must create buffers and insert them in the # arguments list between output and input arguments @@ -482,6 +656,7 @@ def create_generic(top_env, declarations): return body def process_option(option, output_options): + # type: (FunctionOption, List[OrderedDict]) -> None option['inplace'] = re.search( '(^__i|[^_]_$)', option['api_name']) is not None @@ -587,11 +762,13 @@ def create_generic(top_env, declarations): ])) def native_get_formals(option, include_constants=False): - seen = set() + # type: (FunctionOption, bool) -> List[AtFormal] + seen = set() # type: Set[str] pos_args = [] kwd_args = [] def insert(argument): + # type: (NativeFormal) -> None if argument['name'] not in seen: seen.add(argument['name']) if argument.get('kwarg_only', False): @@ -605,6 +782,7 @@ def create_generic(top_env, declarations): # not clear we need dynamic_type translation as we can specify the correct type # directly in native functions def add_type_as_dynamic_type(argument, option): + # type: (NativeFormal, FunctionOption) -> NativeFormal argument['dynamic_type'] = argument['type'] return argument @@ -613,7 +791,9 @@ def create_generic(top_env, declarations): # ensure we get reference-type formals when appropriate def native_translate_formals(argument, option): + # type: (NativeFormal, FunctionOption) -> AtFormal def translate_map(const): + # type: (bool) -> Dict[str, str] return { 'Tensor': 'const Tensor &' if const else 'Tensor &', 'BoolTensor': 'const Tensor &' if const else 'Tensor &', @@ -632,9 +812,10 @@ def create_generic(top_env, declarations): # this can return multiple return types in a list, e.g. ['Tensor', 'Tensor'] def native_get_return_types(option): + # type: (FunctionOption) -> List[ReturnType] ret = option['return'] - return_types = [] + return_types = [] # List[ReturnType] for t_raw in ret: if isinstance(t_raw, string_type): t = t_raw @@ -653,7 +834,7 @@ def create_generic(top_env, declarations): rtype = { 'type': actual_return_type, 'dynamic_type': t, - } + } # type: ReturnType if name is not None: rtype['name'] = name return_types.append(rtype) @@ -661,6 +842,7 @@ def create_generic(top_env, declarations): return return_types def process_native(option, output_options): + # type: (FunctionOption, List[OrderedDict]) -> None option['inplace'] = re.search( '(^__i|[^_]_$)', option['api_name']) is not None @@ -721,7 +903,7 @@ def create_generic(top_env, declarations): # generate the at::native function declarations (i.e. what the user will implement) if isinstance(dispatch, dict): - generated_native_functions = [] + generated_native_functions = [] # type: List[str] for key in sorted(dispatch.keys()): value = dispatch[key] if value not in generated_native_functions: @@ -761,9 +943,9 @@ def create_generic(top_env, declarations): ('abstract', abstract), ])) - output_declarations = [] + output_declarations = [] # type: List[OrderedDict] for declaration in declarations: - output_options = [] + output_options = [] # type: List[OrderedDict] for option in declaration['options']: try: if option['mode'] != 'native': @@ -777,6 +959,7 @@ def create_generic(top_env, declarations): def create_derived(backend_type_env, declarations): + # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]] type_object_declarations = [] type_object_definitions = [] @@ -785,21 +968,26 @@ def create_derived(backend_type_env, declarations): real_is_half = backend_type_env['ScalarName'] == 'Half' def replace_with_null(argument): + # type: (THFormal) -> bool return (argument['type'] == 'THGenerator*' and backend_type_env['Backend'] == 'CUDA') def requires_checked_cast(argument): + # type: (THFormal) -> bool if argument['type'] == 'IntList': return 'size' in argument return argument['type'] in CHECKED_CAST def nullable_argument(argument): + # type: (THFormal) -> bool return argument.get('is_nullable', False) def bool_option_is_string(argument): + # type: (THFormal) -> bool return 'if_true' in argument and isinstance(argument['if_true'], string_type) def get_argument(argument, option): + # type: (THFormal, FunctionOption) -> str if replace_with_null(argument): return 'NULL' elif requires_checked_cast(argument): @@ -834,14 +1022,17 @@ def create_derived(backend_type_env, declarations): return argument['name'] def drop_argument(argument, option): + # type: (THFormal, FunctionOption) -> bool return 'CUDA' in backend_type_env['Backend'] and ( option['mode'] == 'TH' and argument['type'] == 'THGenerator*') def get_arguments(arguments, option): + # type: (List[THFormal], FunctionOption) -> List[str] return [get_argument(argument, option) for argument in arguments if not drop_argument(argument, option)] def is_actual_return_long(ret): + # type: (ReturnDecl) -> bool if ret['type'] == 'long': return True if ret['type'] == 'real': @@ -851,9 +1042,11 @@ def create_derived(backend_type_env, declarations): return False def get_zero_dim_dispatch_when_scalar(option): - return option.get('zero_dim_dispatch_when_scalar', False) + # type: (FunctionOption) -> str + return option.get('zero_dim_dispatch_when_scalar', False) # type: ignore def handle_zero_dim(env, option): + # type: (Environment, FunctionOption) -> List[str] zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) if not zero_dim_dispatch: return [] @@ -863,6 +1056,7 @@ def create_derived(backend_type_env, declarations): return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] def handle_only_zero_dim(env, option): + # type: (Environment, FunctionOption) -> List[str] if option.get('zero_dim_tensor_only', False): check_name = get_zero_dim_dispatch_when_scalar(option) return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] @@ -870,6 +1064,7 @@ def create_derived(backend_type_env, declarations): return None def handle_sparse(env, option): + # type: (Environment, FunctionOption) -> List[str] if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']: return [] check_name = option['when_sparse_dispatch'] @@ -879,6 +1074,7 @@ def create_derived(backend_type_env, declarations): return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)] def allocate_arg(env, arg, output_count): + # type: (Environment, THFormal, int) -> List[str] name = arg['name'] allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[]) tensor_arg = '{}_'.format(name) @@ -892,6 +1088,7 @@ def create_derived(backend_type_env, declarations): ] def resize_arg(arg): + # type: (THFormal) -> str resize = arg['resize'] if isinstance(resize, str): return "{}.resize_({}.sizes());".format(arg['name'], resize) @@ -904,6 +1101,7 @@ def create_derived(backend_type_env, declarations): return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims)) def handle_call(env, option, cimpl): + # type: (Environment, FunctionOption, FunctionOption) -> str is_nn = option['mode'] == 'NN' actuals = get_arguments(cimpl['arguments'], option) if is_cuda or is_nn: @@ -926,7 +1124,8 @@ def create_derived(backend_type_env, declarations): return call def emit_body(env, option): - body = [] + # type: (Environment, FunctionOption) -> List[str] + body = [] # type: List[str] body += handle_sparse(env, option) body += handle_zero_dim(env, option) only_zero_dim_check = handle_only_zero_dim(env, option) @@ -937,8 +1136,8 @@ def create_derived(backend_type_env, declarations): # arguments are potentially duplicated because of one argument # referencing another - seen_names = set() - seen_tensorlists = set() + seen_names = set() # type: Set[str] + seen_tensorlists = set() # type: Set[str] count = 0 output_count = 0 @@ -1064,7 +1263,7 @@ def create_derived(backend_type_env, declarations): scalar_check = 'maybe_scalar' for arg in arguments: scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict) - else scalar_check.get(arg['name'])) + else scalar_check.get(arg['name'])) # type: ignore if scalar_check_arg is not None: stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg) if nullable_argument(arg): @@ -1114,11 +1313,12 @@ def create_derived(backend_type_env, declarations): return body def process_option(option): + # type: (FunctionOption) -> None pair = (backend_type_env['Backend'], backend_type_env['ScalarName']) if pair in option['backend_type_pairs']: env = nested_dict(option, backend_type_env) - body = emit_body(env, option) + body = emit_body(env, option) # type: ignore option['type_definition_body'] = body type_object_declarations.append( TYPE_DERIVED_DECLARATION.substitute(env)) @@ -1126,6 +1326,7 @@ def create_derived(backend_type_env, declarations): TYPE_DERIVED_DEFINITION.substitute(env)) def process_native(option): + # type: (FunctionOption) -> None dispatch = option['type_method_definition_dispatch'] env = nested_dict(option, backend_type_env)