mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Initial type hints for function_wrapper (#4947)
* Initial type hints for function_wrapper * Don't break python 2 * Update TopEnvironment * Add mypy check to travis * Add .mypy_cache to .gitignore
This commit is contained in:
		
				
					committed by
					
						 Edward Z. Yang
						Edward Z. Yang
					
				
			
			
				
	
			
			
			
						parent
						
							696db00bcd
						
					
				
				
					commit
					0629785645
				
			
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -40,6 +40,7 @@ test/.coverage | |||||||
| */**/*.dylib* | */**/*.dylib* | ||||||
| test/data/legacy_serialized.pt | test/data/legacy_serialized.pt | ||||||
| test/data/linear.pt | test/data/linear.pt | ||||||
|  | .mypy_cache | ||||||
|  |  | ||||||
| # IPython notebook checkpoints | # IPython notebook checkpoints | ||||||
| .ipynb_checkpoints | .ipynb_checkpoints | ||||||
|  | |||||||
| @ -12,7 +12,12 @@ sudo: false | |||||||
| matrix: | matrix: | ||||||
|     fast_finish: true |     fast_finish: true | ||||||
|     include: |     include: | ||||||
|         env: LINT_CHECK |       - env: LINT_CHECK | ||||||
|         python: "2.7" |         python: "2.7" | ||||||
|         install: pip install flake8 |         install: pip install flake8 | ||||||
|         script: flake8 |         script: flake8 | ||||||
|  |       # mypy will complain about various files. Limiting it to only typed files | ||||||
|  |       - env: MYPY_TYPE_CHECK | ||||||
|  |         python: "3.6" | ||||||
|  |         install: pip install mypy mypy-extensions | ||||||
|  |         script: mypy --py2 aten/src/ATen/function_wrapper.py | ||||||
|  | |||||||
| @ -5,6 +5,14 @@ import re | |||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from code_template import CodeTemplate | from code_template import CodeTemplate | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \ | ||||||
|  |         Union, TypeVar | ||||||
|  |     from mypy_extensions import TypedDict | ||||||
|  |     TYPE_HINTS = True | ||||||
|  | except ImportError: | ||||||
|  |     TYPE_HINTS = False | ||||||
|  |  | ||||||
| import sys | import sys | ||||||
| if sys.version_info[0] == 3: | if sys.version_info[0] == 3: | ||||||
|     string_type = str |     string_type = str | ||||||
| @ -274,18 +282,168 @@ class nested_dict(object): | |||||||
|             return r |             return r | ||||||
|         return self.parent[x] |         return self.parent[x] | ||||||
|  |  | ||||||
|  | if TYPE_HINTS: | ||||||
|  |     Environment = TypedDict('Environment', { | ||||||
|  |         'ScalarName': str, | ||||||
|  |         'THTensor': str, | ||||||
|  |         'THType': str, | ||||||
|  |         'THTensor': str, | ||||||
|  |         'Backend': str, | ||||||
|  |         'AccScalarName': str, | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     TopEnvironment = TypedDict('TopEnvironment', { | ||||||
|  |         'type_registrations': List[str], | ||||||
|  |         'type_headers': List[str], | ||||||
|  |         'type_method_declarations': List[str], | ||||||
|  |         'type_method_definitions': List[str], | ||||||
|  |         'type_method_inline_definitions': List[str], | ||||||
|  |         'tensor_method_declarations': List[str], | ||||||
|  |         'tensor_method_definitions': List[str], | ||||||
|  |         'function_declarations': List[str], | ||||||
|  |         'function_definitions': List[str], | ||||||
|  |         'type_ids': List[str], | ||||||
|  |         'native_function_declarations': List[str], | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     # A Declarations.cwrap formal argument | ||||||
|  |     # type can contain THTensor* types | ||||||
|  |     THFormal = TypedDict('THFormal', { | ||||||
|  |         'name': str, | ||||||
|  |         'type': str, | ||||||
|  |         'dynamic_type': str, | ||||||
|  |         'kwarg_only': bool, | ||||||
|  |         'is_nullable': bool, | ||||||
|  |         'default': str, | ||||||
|  |         'default_init': str, | ||||||
|  |         'python_default_init': str, | ||||||
|  |         'output': bool, | ||||||
|  |         'size': int, | ||||||
|  |         'declared_type': str, | ||||||
|  |         'ignore_check': bool, | ||||||
|  |         'allocate': bool, | ||||||
|  |         'mask': bool, | ||||||
|  |         'if_true': bool, | ||||||
|  |         'if_false': bool, | ||||||
|  |         'wrap_dim': str, | ||||||
|  |         # Broadcast is originally a str but gets unwrapped to a List or Dict in-place | ||||||
|  |         'broadcast': Any, | ||||||
|  |         'resize': str, | ||||||
|  |         'cpu_zero': bool, | ||||||
|  |         'zero': bool, | ||||||
|  |     }, total=False) | ||||||
|  |  | ||||||
|  |     # A native_functions.yaml formal argument | ||||||
|  |     # type can contain Tensor, BoolTensor, IndexTensor types | ||||||
|  |     NativeFormal = TypedDict('NativeFormal', { | ||||||
|  |         'name': str, | ||||||
|  |         'type': str, | ||||||
|  |         'dynamic_type': str, | ||||||
|  |         'kwarg_only': bool, | ||||||
|  |         'is_nullable': bool, | ||||||
|  |         'default': str, | ||||||
|  |         'default_init': str, | ||||||
|  |         'python_default_init': str, | ||||||
|  |         'output': bool, | ||||||
|  |         'size': int, | ||||||
|  |     }, total=False) | ||||||
|  |  | ||||||
|  |     # Generic ATen formal. | ||||||
|  |     # type can contain Tensor& reference types. | ||||||
|  |     AtFormal = TypedDict('AtFormal', { | ||||||
|  |         'name': str, | ||||||
|  |         'type': str, | ||||||
|  |         'dynamic_type': str, | ||||||
|  |         'kwarg_only': bool, | ||||||
|  |         'is_nullable': bool, | ||||||
|  |         'default': str, | ||||||
|  |         'default_init': str, | ||||||
|  |         'python_default_init': str, | ||||||
|  |         'output': bool, | ||||||
|  |         'size': int, | ||||||
|  |     }, total=False) | ||||||
|  |  | ||||||
|  |     ReturnType = TypedDict('ReturnType', { | ||||||
|  |         'name': str, | ||||||
|  |         'type': str, | ||||||
|  |         'dynamic_type': str, | ||||||
|  |     }, total=False) | ||||||
|  |  | ||||||
|  |     ReturnDecl = TypedDict('ReturnDecl', { | ||||||
|  |         'kind': str, | ||||||
|  |         'type': str, | ||||||
|  |         'arguments': List[int], | ||||||
|  |     }, total=False) | ||||||
|  |  | ||||||
|  |     # Represents a buffer in nn.yaml | ||||||
|  |     NNBuffer = TypedDict('NNBuffer', { | ||||||
|  |         'name': str, | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     FunctionOption = TypedDict('FunctionOption', { | ||||||
|  |         'arguments': List[THFormal], | ||||||
|  |         'mode': str, | ||||||
|  |         'name': str, | ||||||
|  |         'return': ReturnDecl, | ||||||
|  |         'variants': str, | ||||||
|  |         'type_method_definition_dispatch': str, | ||||||
|  |         'cname': str, | ||||||
|  |         'backends': List[str], | ||||||
|  |         'api_name': str, | ||||||
|  |         'backend_type_pairs': List[Tuple[str, str]], | ||||||
|  |         'inplace': bool, | ||||||
|  |         'aten_dense_sparse': bool, | ||||||
|  |         'sparse': bool, | ||||||
|  |         'scalar_check': str, | ||||||
|  |         'aten_custom_call': str, | ||||||
|  |         'type_definition_body': List[str], | ||||||
|  |         # cimpls is really a List[FunctionOption] | ||||||
|  |         'cimpls': List[Any], | ||||||
|  |         'when_spares_dispatch': str, | ||||||
|  |         'actuals': List[str], | ||||||
|  |         'buffers': List[NNBuffer], | ||||||
|  |         'zero_dim_dispatch_when_scalar': str, | ||||||
|  |         'zero_dim_tensor_only': bool, | ||||||
|  |         'when_sparse_dispatch': str, | ||||||
|  |         'formals_list': List[AtFormal], | ||||||
|  |         'condition': str, | ||||||
|  |         'auto_gpu': bool, | ||||||
|  |         'cpu_half': bool, | ||||||
|  |         # options should be List[FunctionOption] | ||||||
|  |         'options': Any, | ||||||
|  |         'formals': List[str], | ||||||
|  |         'formals_with_defaults': List[str], | ||||||
|  |         'returns': List[ReturnType], | ||||||
|  |         'return_type': str, | ||||||
|  |         'return_call': str, | ||||||
|  |         'method_formals': List[str], | ||||||
|  |         'method_formals_with_defaults': List[str], | ||||||
|  |         'method_actuals': List[str], | ||||||
|  |         'const_mark': str, | ||||||
|  |         'method_prefix_derived': str, | ||||||
|  |         'broadcast_actuals': List[str], | ||||||
|  |         'broadcast_returns': List[str], | ||||||
|  |         'inferred_type': str, | ||||||
|  |         'broadcast_function': str, | ||||||
|  |         'broadcast_modified_actuals': List[str], | ||||||
|  |         'native_type_method_dispatch': str, | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_real_argument_to_wrapper(argument): | def is_real_argument_to_wrapper(argument): | ||||||
|  |     # type: (THFormal) -> bool | ||||||
|     return not argument.get('output', False) and\ |     return not argument.get('output', False) and\ | ||||||
|         argument['type'] != 'CONSTANT' and\ |         argument['type'] != 'CONSTANT' and\ | ||||||
|         argument['type'] != 'argument' |         argument['type'] != 'argument' | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_mutable_formal_argument(argument, option): | def is_mutable_formal_argument(argument, option): | ||||||
|  |     # type: (THFormal, FunctionOption) -> bool | ||||||
|     return argument.get('output') or option['inplace'] and argument['name'] == 'self' |     return argument.get('output') or option['inplace'] and argument['name'] == 'self' | ||||||
|  |  | ||||||
|  |  | ||||||
| def to_return_type(arg, option): | def to_return_type(arg, option): | ||||||
|  |     # type: (THFormal, FunctionOption) -> ReturnType | ||||||
|     t = arg['type'] |     t = arg['type'] | ||||||
|     rt = TYPE_RETURN.get(t, t) |     rt = TYPE_RETURN.get(t, t) | ||||||
|     if rt == 'Tensor' and not arg.get('allocate'): |     if rt == 'Tensor' and not arg.get('allocate'): | ||||||
| @ -300,8 +458,10 @@ def to_return_type(arg, option): | |||||||
|  |  | ||||||
|  |  | ||||||
| def create_generic(top_env, declarations): | def create_generic(top_env, declarations): | ||||||
|  |     # type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict] | ||||||
|     # translates defaults from cwrap types to C++ values |     # translates defaults from cwrap types to C++ values | ||||||
|     def translate_default(argument, type_str, default): |     def translate_default(argument, type_str, default): | ||||||
|  |         # type: (THFormal, str, Any) -> Any | ||||||
|         if default is None: |         if default is None: | ||||||
|             # cause the default constructor for the object to run |             # cause the default constructor for the object to run | ||||||
|             return '{}' |             return '{}' | ||||||
| @ -320,6 +480,7 @@ def create_generic(top_env, declarations): | |||||||
|     # change from THTensor* to Tensor & so we get how it will appear |     # change from THTensor* to Tensor & so we get how it will appear | ||||||
|     # in the aten argument list... |     # in the aten argument list... | ||||||
|     def translate_formal(argument, option): |     def translate_formal(argument, option): | ||||||
|  |         # type: (THFormal, FunctionOption) -> AtFormal | ||||||
|         type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type']) |         type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type']) | ||||||
|         if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option): |         if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option): | ||||||
|             type_str = 'const ' + type_str |             type_str = 'const ' + type_str | ||||||
| @ -327,7 +488,7 @@ def create_generic(top_env, declarations): | |||||||
|             'name': argument['name'], |             'name': argument['name'], | ||||||
|             'type': type_str, |             'type': type_str, | ||||||
|             'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']), |             'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']), | ||||||
|         } |         }  # type: AtFormal | ||||||
|         if 'kwarg_only' in argument: |         if 'kwarg_only' in argument: | ||||||
|             translated['kwarg_only'] = argument['kwarg_only'] |             translated['kwarg_only'] = argument['kwarg_only'] | ||||||
|         if 'default' in argument: |         if 'default' in argument: | ||||||
| @ -347,11 +508,13 @@ def create_generic(top_env, declarations): | |||||||
|         return translated |         return translated | ||||||
|  |  | ||||||
|     def get_formals(option, include_constants=False): |     def get_formals(option, include_constants=False): | ||||||
|         seen = set() |         # type: (FunctionOption, bool) -> List[AtFormal] | ||||||
|         pos_args = [] |         seen = set()  # type: Set[str] | ||||||
|         kwd_args = [] |         pos_args = []  # type: List[THFormal] | ||||||
|  |         kwd_args = []  # type: List[THFormal] | ||||||
|  |  | ||||||
|         def insert(argument): |         def insert(argument): | ||||||
|  |             # type: (THFormal) -> None | ||||||
|             if argument['name'] not in seen: |             if argument['name'] not in seen: | ||||||
|                 seen.add(argument['name']) |                 seen.add(argument['name']) | ||||||
|                 if argument.get('kwarg_only', False): |                 if argument.get('kwarg_only', False): | ||||||
| @ -360,6 +523,7 @@ def create_generic(top_env, declarations): | |||||||
|                     pos_args.append(argument) |                     pos_args.append(argument) | ||||||
|  |  | ||||||
|         def has_output_mask(argument): |         def has_output_mask(argument): | ||||||
|  |             # type: (THFormal) -> bool | ||||||
|             return argument.get('allocate', False) and argument.get('mask', False) |             return argument.get('allocate', False) and argument.get('mask', False) | ||||||
|  |  | ||||||
|         for argument in option['arguments']: |         for argument in option['arguments']: | ||||||
| @ -389,6 +553,7 @@ def create_generic(top_env, declarations): | |||||||
|         return [translate_formal(argument, option) for argument in result] |         return [translate_formal(argument, option) for argument in result] | ||||||
|  |  | ||||||
|     def get_return_types(option): |     def get_return_types(option): | ||||||
|  |         # type: (FunctionOption) -> List[ReturnType] | ||||||
|         ret = option['return'] |         ret = option['return'] | ||||||
|         if ret['kind'] == 'arguments': |         if ret['kind'] == 'arguments': | ||||||
|             argument_indices = ret['arguments'] |             argument_indices = ret['arguments'] | ||||||
| @ -407,11 +572,13 @@ def create_generic(top_env, declarations): | |||||||
|             raise Exception("format_return_type") |             raise Exception("format_return_type") | ||||||
|  |  | ||||||
|     def format_return_type(return_types): |     def format_return_type(return_types): | ||||||
|  |         # type: (List[ReturnType]) -> str | ||||||
|         if len(return_types) == 1: |         if len(return_types) == 1: | ||||||
|             return return_types[0]['type'] |             return return_types[0]['type'] | ||||||
|         return "std::tuple<{}>".format(','.join(r['type'] for r in return_types)) |         return "std::tuple<{}>".format(','.join(r['type'] for r in return_types)) | ||||||
|  |  | ||||||
|     def find_dispatch_tensor(formals): |     def find_dispatch_tensor(formals): | ||||||
|  |         # type: (List[AtFormal]) -> Optional[str] | ||||||
|         # dispatch to self if it's a parameter |         # dispatch to self if it's a parameter | ||||||
|         for formal in formals: |         for formal in formals: | ||||||
|             if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor': |             if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor': | ||||||
| @ -423,9 +590,11 @@ def create_generic(top_env, declarations): | |||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def format_formal(f): |     def format_formal(f): | ||||||
|  |         # type: (AtFormal) -> str | ||||||
|         return '{} {}'.format(f['type'], f['name']) |         return '{} {}'.format(f['type'], f['name']) | ||||||
|  |  | ||||||
|     def formal_with_default(f): |     def formal_with_default(f): | ||||||
|  |         # type: (AtFormal) -> str | ||||||
|         s = format_formal(f) |         s = format_formal(f) | ||||||
|         v = f.get('default') |         v = f.get('default') | ||||||
|         if v is None: |         if v is None: | ||||||
| @ -435,11 +604,15 @@ def create_generic(top_env, declarations): | |||||||
|         return '{}={}'.format(s, v) |         return '{}={}'.format(s, v) | ||||||
|  |  | ||||||
|     def get_broadcast_argument(option): |     def get_broadcast_argument(option): | ||||||
|  |         # type: (FunctionOption) -> Optional[THFormal] | ||||||
|         for argument in option['arguments']: |         for argument in option['arguments']: | ||||||
|             if argument.get('broadcast'): |             if argument.get('broadcast'): | ||||||
|                 return argument |                 return argument | ||||||
|  |         return None | ||||||
|  |  | ||||||
|     def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims): |     def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims): | ||||||
|  |         # type: (THFormal, bool, bool) -> List[str] | ||||||
|  |         # Note: broadcast_dims can change type... | ||||||
|         # return the actuals that will be passed to the broadcast function. |         # return the actuals that will be passed to the broadcast function. | ||||||
|         # 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors |         # 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors | ||||||
|         #    that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2"). |         #    that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2"). | ||||||
| @ -451,14 +624,15 @@ def create_generic(top_env, declarations): | |||||||
|         else: |         else: | ||||||
|             broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',') |             broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',') | ||||||
|             # generate size call for each dimension |             # generate size call for each dimension | ||||||
|             broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' |             broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')'  # type: ignore | ||||||
|                               for x in broadcast_dims_spec]) |                               for x in broadcast_dims_spec]) | ||||||
|             broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' |             broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}'  # type: ignore | ||||||
|             broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list] |             broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list] | ||||||
|  |  | ||||||
|         return broadcast_actuals |         return broadcast_actuals | ||||||
|  |  | ||||||
|     def emit_nn_body(option): |     def emit_nn_body(option): | ||||||
|  |         # type: (FunctionOption) -> Union[str, List[str]] | ||||||
|         # Concrete definition on Type.cpp for NN functions. Delegates to the |         # Concrete definition on Type.cpp for NN functions. Delegates to the | ||||||
|         # xxx_forward variant variant after creating any necessary buffers. |         # xxx_forward variant variant after creating any necessary buffers. | ||||||
|         actuals = option['actuals'] |         actuals = option['actuals'] | ||||||
| @ -468,7 +642,7 @@ def create_generic(top_env, declarations): | |||||||
|         if len(option['buffers']) == 0: |         if len(option['buffers']) == 0: | ||||||
|             return 'return {}({});'.format(fwd_name, ', '.join(actuals)) |             return 'return {}({});'.format(fwd_name, ', '.join(actuals)) | ||||||
|  |  | ||||||
|         body = [] |         body = []  # type: List[str] | ||||||
|         if option['api_name'].endswith('_out'): |         if option['api_name'].endswith('_out'): | ||||||
|             # _out variants must create buffers and insert them in the |             # _out variants must create buffers and insert them in the | ||||||
|             # arguments list between output and input arguments |             # arguments list between output and input arguments | ||||||
| @ -482,6 +656,7 @@ def create_generic(top_env, declarations): | |||||||
|         return body |         return body | ||||||
|  |  | ||||||
|     def process_option(option, output_options): |     def process_option(option, output_options): | ||||||
|  |         # type: (FunctionOption, List[OrderedDict]) -> None | ||||||
|         option['inplace'] = re.search( |         option['inplace'] = re.search( | ||||||
|             '(^__i|[^_]_$)', option['api_name']) is not None |             '(^__i|[^_]_$)', option['api_name']) is not None | ||||||
|  |  | ||||||
| @ -587,11 +762,13 @@ def create_generic(top_env, declarations): | |||||||
|         ])) |         ])) | ||||||
|  |  | ||||||
|     def native_get_formals(option, include_constants=False): |     def native_get_formals(option, include_constants=False): | ||||||
|         seen = set() |         # type: (FunctionOption, bool) -> List[AtFormal] | ||||||
|  |         seen = set()  # type: Set[str] | ||||||
|         pos_args = [] |         pos_args = [] | ||||||
|         kwd_args = [] |         kwd_args = [] | ||||||
|  |  | ||||||
|         def insert(argument): |         def insert(argument): | ||||||
|  |             # type: (NativeFormal) -> None | ||||||
|             if argument['name'] not in seen: |             if argument['name'] not in seen: | ||||||
|                 seen.add(argument['name']) |                 seen.add(argument['name']) | ||||||
|                 if argument.get('kwarg_only', False): |                 if argument.get('kwarg_only', False): | ||||||
| @ -605,6 +782,7 @@ def create_generic(top_env, declarations): | |||||||
|         # not clear we need dynamic_type translation as we can specify the correct type |         # not clear we need dynamic_type translation as we can specify the correct type | ||||||
|         # directly in native functions |         # directly in native functions | ||||||
|         def add_type_as_dynamic_type(argument, option): |         def add_type_as_dynamic_type(argument, option): | ||||||
|  |             # type: (NativeFormal, FunctionOption) -> NativeFormal | ||||||
|             argument['dynamic_type'] = argument['type'] |             argument['dynamic_type'] = argument['type'] | ||||||
|             return argument |             return argument | ||||||
|  |  | ||||||
| @ -613,7 +791,9 @@ def create_generic(top_env, declarations): | |||||||
|  |  | ||||||
|         # ensure we get reference-type formals when appropriate |         # ensure we get reference-type formals when appropriate | ||||||
|         def native_translate_formals(argument, option): |         def native_translate_formals(argument, option): | ||||||
|  |             # type: (NativeFormal, FunctionOption) -> AtFormal | ||||||
|             def translate_map(const): |             def translate_map(const): | ||||||
|  |                 # type: (bool) -> Dict[str, str] | ||||||
|                 return { |                 return { | ||||||
|                     'Tensor': 'const Tensor &' if const else 'Tensor &', |                     'Tensor': 'const Tensor &' if const else 'Tensor &', | ||||||
|                     'BoolTensor': 'const Tensor &' if const else 'Tensor &', |                     'BoolTensor': 'const Tensor &' if const else 'Tensor &', | ||||||
| @ -632,9 +812,10 @@ def create_generic(top_env, declarations): | |||||||
|  |  | ||||||
|     # this can return multiple return types in a list, e.g. ['Tensor', 'Tensor'] |     # this can return multiple return types in a list, e.g. ['Tensor', 'Tensor'] | ||||||
|     def native_get_return_types(option): |     def native_get_return_types(option): | ||||||
|  |         # type: (FunctionOption) -> List[ReturnType] | ||||||
|         ret = option['return'] |         ret = option['return'] | ||||||
|  |  | ||||||
|         return_types = [] |         return_types = []  # List[ReturnType] | ||||||
|         for t_raw in ret: |         for t_raw in ret: | ||||||
|             if isinstance(t_raw, string_type): |             if isinstance(t_raw, string_type): | ||||||
|                 t = t_raw |                 t = t_raw | ||||||
| @ -653,7 +834,7 @@ def create_generic(top_env, declarations): | |||||||
|             rtype = { |             rtype = { | ||||||
|                 'type': actual_return_type, |                 'type': actual_return_type, | ||||||
|                 'dynamic_type': t, |                 'dynamic_type': t, | ||||||
|             } |             }  # type: ReturnType | ||||||
|             if name is not None: |             if name is not None: | ||||||
|                 rtype['name'] = name |                 rtype['name'] = name | ||||||
|             return_types.append(rtype) |             return_types.append(rtype) | ||||||
| @ -661,6 +842,7 @@ def create_generic(top_env, declarations): | |||||||
|         return return_types |         return return_types | ||||||
|  |  | ||||||
|     def process_native(option, output_options): |     def process_native(option, output_options): | ||||||
|  |         # type: (FunctionOption, List[OrderedDict]) -> None | ||||||
|         option['inplace'] = re.search( |         option['inplace'] = re.search( | ||||||
|             '(^__i|[^_]_$)', option['api_name']) is not None |             '(^__i|[^_]_$)', option['api_name']) is not None | ||||||
|  |  | ||||||
| @ -721,7 +903,7 @@ def create_generic(top_env, declarations): | |||||||
|  |  | ||||||
|         # generate the at::native function declarations (i.e. what the user will implement) |         # generate the at::native function declarations (i.e. what the user will implement) | ||||||
|         if isinstance(dispatch, dict): |         if isinstance(dispatch, dict): | ||||||
|             generated_native_functions = [] |             generated_native_functions = []  # type: List[str] | ||||||
|             for key in sorted(dispatch.keys()): |             for key in sorted(dispatch.keys()): | ||||||
|                 value = dispatch[key] |                 value = dispatch[key] | ||||||
|                 if value not in generated_native_functions: |                 if value not in generated_native_functions: | ||||||
| @ -761,9 +943,9 @@ def create_generic(top_env, declarations): | |||||||
|             ('abstract', abstract), |             ('abstract', abstract), | ||||||
|         ])) |         ])) | ||||||
|  |  | ||||||
|     output_declarations = [] |     output_declarations = []  # type: List[OrderedDict] | ||||||
|     for declaration in declarations: |     for declaration in declarations: | ||||||
|         output_options = [] |         output_options = []  # type: List[OrderedDict] | ||||||
|         for option in declaration['options']: |         for option in declaration['options']: | ||||||
|             try: |             try: | ||||||
|                 if option['mode'] != 'native': |                 if option['mode'] != 'native': | ||||||
| @ -777,6 +959,7 @@ def create_generic(top_env, declarations): | |||||||
|  |  | ||||||
|  |  | ||||||
| def create_derived(backend_type_env, declarations): | def create_derived(backend_type_env, declarations): | ||||||
|  |     # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]] | ||||||
|     type_object_declarations = [] |     type_object_declarations = [] | ||||||
|     type_object_definitions = [] |     type_object_definitions = [] | ||||||
|  |  | ||||||
| @ -785,21 +968,26 @@ def create_derived(backend_type_env, declarations): | |||||||
|     real_is_half = backend_type_env['ScalarName'] == 'Half' |     real_is_half = backend_type_env['ScalarName'] == 'Half' | ||||||
|  |  | ||||||
|     def replace_with_null(argument): |     def replace_with_null(argument): | ||||||
|  |         # type: (THFormal) -> bool | ||||||
|         return (argument['type'] == 'THGenerator*' and |         return (argument['type'] == 'THGenerator*' and | ||||||
|                 backend_type_env['Backend'] == 'CUDA') |                 backend_type_env['Backend'] == 'CUDA') | ||||||
|  |  | ||||||
|     def requires_checked_cast(argument): |     def requires_checked_cast(argument): | ||||||
|  |         # type: (THFormal) -> bool | ||||||
|         if argument['type'] == 'IntList': |         if argument['type'] == 'IntList': | ||||||
|             return 'size' in argument |             return 'size' in argument | ||||||
|         return argument['type'] in CHECKED_CAST |         return argument['type'] in CHECKED_CAST | ||||||
|  |  | ||||||
|     def nullable_argument(argument): |     def nullable_argument(argument): | ||||||
|  |         # type: (THFormal) -> bool | ||||||
|         return argument.get('is_nullable', False) |         return argument.get('is_nullable', False) | ||||||
|  |  | ||||||
|     def bool_option_is_string(argument): |     def bool_option_is_string(argument): | ||||||
|  |         # type: (THFormal) -> bool | ||||||
|         return 'if_true' in argument and isinstance(argument['if_true'], string_type) |         return 'if_true' in argument and isinstance(argument['if_true'], string_type) | ||||||
|  |  | ||||||
|     def get_argument(argument, option): |     def get_argument(argument, option): | ||||||
|  |         # type: (THFormal, FunctionOption) -> str | ||||||
|         if replace_with_null(argument): |         if replace_with_null(argument): | ||||||
|             return 'NULL' |             return 'NULL' | ||||||
|         elif requires_checked_cast(argument): |         elif requires_checked_cast(argument): | ||||||
| @ -834,14 +1022,17 @@ def create_derived(backend_type_env, declarations): | |||||||
|             return argument['name'] |             return argument['name'] | ||||||
|  |  | ||||||
|     def drop_argument(argument, option): |     def drop_argument(argument, option): | ||||||
|  |         # type: (THFormal, FunctionOption) -> bool | ||||||
|         return 'CUDA' in backend_type_env['Backend'] and ( |         return 'CUDA' in backend_type_env['Backend'] and ( | ||||||
|             option['mode'] == 'TH' and argument['type'] == 'THGenerator*') |             option['mode'] == 'TH' and argument['type'] == 'THGenerator*') | ||||||
|  |  | ||||||
|     def get_arguments(arguments, option): |     def get_arguments(arguments, option): | ||||||
|  |         # type: (List[THFormal], FunctionOption) -> List[str] | ||||||
|         return [get_argument(argument, option) |         return [get_argument(argument, option) | ||||||
|                 for argument in arguments if not drop_argument(argument, option)] |                 for argument in arguments if not drop_argument(argument, option)] | ||||||
|  |  | ||||||
|     def is_actual_return_long(ret): |     def is_actual_return_long(ret): | ||||||
|  |         # type: (ReturnDecl) -> bool | ||||||
|         if ret['type'] == 'long': |         if ret['type'] == 'long': | ||||||
|             return True |             return True | ||||||
|         if ret['type'] == 'real': |         if ret['type'] == 'real': | ||||||
| @ -851,9 +1042,11 @@ def create_derived(backend_type_env, declarations): | |||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def get_zero_dim_dispatch_when_scalar(option): |     def get_zero_dim_dispatch_when_scalar(option): | ||||||
|         return option.get('zero_dim_dispatch_when_scalar', False) |         # type: (FunctionOption) -> str | ||||||
|  |         return option.get('zero_dim_dispatch_when_scalar', False)  # type: ignore | ||||||
|  |  | ||||||
|     def handle_zero_dim(env, option): |     def handle_zero_dim(env, option): | ||||||
|  |         # type: (Environment, FunctionOption) -> List[str] | ||||||
|         zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) |         zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) | ||||||
|         if not zero_dim_dispatch: |         if not zero_dim_dispatch: | ||||||
|             return [] |             return [] | ||||||
| @ -863,6 +1056,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|         return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] |         return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] | ||||||
|  |  | ||||||
|     def handle_only_zero_dim(env, option): |     def handle_only_zero_dim(env, option): | ||||||
|  |         # type: (Environment, FunctionOption) -> List[str] | ||||||
|         if option.get('zero_dim_tensor_only', False): |         if option.get('zero_dim_tensor_only', False): | ||||||
|             check_name = get_zero_dim_dispatch_when_scalar(option) |             check_name = get_zero_dim_dispatch_when_scalar(option) | ||||||
|             return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] |             return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] | ||||||
| @ -870,6 +1064,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def handle_sparse(env, option): |     def handle_sparse(env, option): | ||||||
|  |         # type: (Environment, FunctionOption) -> List[str] | ||||||
|         if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']: |         if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']: | ||||||
|             return [] |             return [] | ||||||
|         check_name = option['when_sparse_dispatch'] |         check_name = option['when_sparse_dispatch'] | ||||||
| @ -879,6 +1074,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|         return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)] |         return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)] | ||||||
|  |  | ||||||
|     def allocate_arg(env, arg, output_count): |     def allocate_arg(env, arg, output_count): | ||||||
|  |         # type: (Environment, THFormal, int) -> List[str] | ||||||
|         name = arg['name'] |         name = arg['name'] | ||||||
|         allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[]) |         allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[]) | ||||||
|         tensor_arg = '{}_'.format(name) |         tensor_arg = '{}_'.format(name) | ||||||
| @ -892,6 +1088,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|         ] |         ] | ||||||
|  |  | ||||||
|     def resize_arg(arg): |     def resize_arg(arg): | ||||||
|  |         # type: (THFormal) -> str | ||||||
|         resize = arg['resize'] |         resize = arg['resize'] | ||||||
|         if isinstance(resize, str): |         if isinstance(resize, str): | ||||||
|             return "{}.resize_({}.sizes());".format(arg['name'], resize) |             return "{}.resize_({}.sizes());".format(arg['name'], resize) | ||||||
| @ -904,6 +1101,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|             return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims)) |             return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims)) | ||||||
|  |  | ||||||
|     def handle_call(env, option, cimpl): |     def handle_call(env, option, cimpl): | ||||||
|  |         # type: (Environment, FunctionOption, FunctionOption) -> str | ||||||
|         is_nn = option['mode'] == 'NN' |         is_nn = option['mode'] == 'NN' | ||||||
|         actuals = get_arguments(cimpl['arguments'], option) |         actuals = get_arguments(cimpl['arguments'], option) | ||||||
|         if is_cuda or is_nn: |         if is_cuda or is_nn: | ||||||
| @ -926,7 +1124,8 @@ def create_derived(backend_type_env, declarations): | |||||||
|         return call |         return call | ||||||
|  |  | ||||||
|     def emit_body(env, option): |     def emit_body(env, option): | ||||||
|         body = [] |         # type: (Environment, FunctionOption) -> List[str] | ||||||
|  |         body = []  # type: List[str] | ||||||
|         body += handle_sparse(env, option) |         body += handle_sparse(env, option) | ||||||
|         body += handle_zero_dim(env, option) |         body += handle_zero_dim(env, option) | ||||||
|         only_zero_dim_check = handle_only_zero_dim(env, option) |         only_zero_dim_check = handle_only_zero_dim(env, option) | ||||||
| @ -937,8 +1136,8 @@ def create_derived(backend_type_env, declarations): | |||||||
|  |  | ||||||
|         # arguments are potentially duplicated because of one argument |         # arguments are potentially duplicated because of one argument | ||||||
|         # referencing another |         # referencing another | ||||||
|         seen_names = set() |         seen_names = set()  # type: Set[str] | ||||||
|         seen_tensorlists = set() |         seen_tensorlists = set()  # type: Set[str] | ||||||
|         count = 0 |         count = 0 | ||||||
|         output_count = 0 |         output_count = 0 | ||||||
|  |  | ||||||
| @ -1064,7 +1263,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|                         scalar_check = 'maybe_scalar' |                         scalar_check = 'maybe_scalar' | ||||||
|                 for arg in arguments: |                 for arg in arguments: | ||||||
|                     scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict) |                     scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict) | ||||||
|                                         else scalar_check.get(arg['name'])) |                                         else scalar_check.get(arg['name']))  # type: ignore | ||||||
|                     if scalar_check_arg is not None: |                     if scalar_check_arg is not None: | ||||||
|                         stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg) |                         stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg) | ||||||
|                         if nullable_argument(arg): |                         if nullable_argument(arg): | ||||||
| @ -1114,11 +1313,12 @@ def create_derived(backend_type_env, declarations): | |||||||
|         return body |         return body | ||||||
|  |  | ||||||
|     def process_option(option): |     def process_option(option): | ||||||
|  |         # type: (FunctionOption) -> None | ||||||
|         pair = (backend_type_env['Backend'], |         pair = (backend_type_env['Backend'], | ||||||
|                 backend_type_env['ScalarName']) |                 backend_type_env['ScalarName']) | ||||||
|         if pair in option['backend_type_pairs']: |         if pair in option['backend_type_pairs']: | ||||||
|             env = nested_dict(option, backend_type_env) |             env = nested_dict(option, backend_type_env) | ||||||
|             body = emit_body(env, option) |             body = emit_body(env, option)  # type: ignore | ||||||
|             option['type_definition_body'] = body |             option['type_definition_body'] = body | ||||||
|             type_object_declarations.append( |             type_object_declarations.append( | ||||||
|                 TYPE_DERIVED_DECLARATION.substitute(env)) |                 TYPE_DERIVED_DECLARATION.substitute(env)) | ||||||
| @ -1126,6 +1326,7 @@ def create_derived(backend_type_env, declarations): | |||||||
|                 TYPE_DERIVED_DEFINITION.substitute(env)) |                 TYPE_DERIVED_DEFINITION.substitute(env)) | ||||||
|  |  | ||||||
|     def process_native(option): |     def process_native(option): | ||||||
|  |         # type: (FunctionOption) -> None | ||||||
|         dispatch = option['type_method_definition_dispatch'] |         dispatch = option['type_method_definition_dispatch'] | ||||||
|         env = nested_dict(option, backend_type_env) |         env = nested_dict(option, backend_type_env) | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user