mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Initial type hints for function_wrapper (#4947)
* Initial type hints for function_wrapper * Don't break python 2 * Update TopEnvironment * Add mypy check to travis * Add .mypy_cache to .gitignore
This commit is contained in:
committed by
Edward Z. Yang
parent
696db00bcd
commit
0629785645
1
.gitignore
vendored
1
.gitignore
vendored
@ -40,6 +40,7 @@ test/.coverage
|
||||
*/**/*.dylib*
|
||||
test/data/legacy_serialized.pt
|
||||
test/data/linear.pt
|
||||
.mypy_cache
|
||||
|
||||
# IPython notebook checkpoints
|
||||
.ipynb_checkpoints
|
||||
|
@ -12,7 +12,12 @@ sudo: false
|
||||
matrix:
|
||||
fast_finish: true
|
||||
include:
|
||||
env: LINT_CHECK
|
||||
- env: LINT_CHECK
|
||||
python: "2.7"
|
||||
install: pip install flake8
|
||||
script: flake8
|
||||
# mypy will complain about various files. Limiting it to only typed files
|
||||
- env: MYPY_TYPE_CHECK
|
||||
python: "3.6"
|
||||
install: pip install mypy mypy-extensions
|
||||
script: mypy --py2 aten/src/ATen/function_wrapper.py
|
||||
|
@ -5,6 +5,14 @@ import re
|
||||
from collections import OrderedDict
|
||||
from code_template import CodeTemplate
|
||||
|
||||
try:
|
||||
from typing import Any, Dict, List, Generic, Optional, Set, Tuple, \
|
||||
Union, TypeVar
|
||||
from mypy_extensions import TypedDict
|
||||
TYPE_HINTS = True
|
||||
except ImportError:
|
||||
TYPE_HINTS = False
|
||||
|
||||
import sys
|
||||
if sys.version_info[0] == 3:
|
||||
string_type = str
|
||||
@ -274,18 +282,168 @@ class nested_dict(object):
|
||||
return r
|
||||
return self.parent[x]
|
||||
|
||||
if TYPE_HINTS:
|
||||
Environment = TypedDict('Environment', {
|
||||
'ScalarName': str,
|
||||
'THTensor': str,
|
||||
'THType': str,
|
||||
'THTensor': str,
|
||||
'Backend': str,
|
||||
'AccScalarName': str,
|
||||
})
|
||||
|
||||
TopEnvironment = TypedDict('TopEnvironment', {
|
||||
'type_registrations': List[str],
|
||||
'type_headers': List[str],
|
||||
'type_method_declarations': List[str],
|
||||
'type_method_definitions': List[str],
|
||||
'type_method_inline_definitions': List[str],
|
||||
'tensor_method_declarations': List[str],
|
||||
'tensor_method_definitions': List[str],
|
||||
'function_declarations': List[str],
|
||||
'function_definitions': List[str],
|
||||
'type_ids': List[str],
|
||||
'native_function_declarations': List[str],
|
||||
})
|
||||
|
||||
# A Declarations.cwrap formal argument
|
||||
# type can contain THTensor* types
|
||||
THFormal = TypedDict('THFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
'declared_type': str,
|
||||
'ignore_check': bool,
|
||||
'allocate': bool,
|
||||
'mask': bool,
|
||||
'if_true': bool,
|
||||
'if_false': bool,
|
||||
'wrap_dim': str,
|
||||
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
|
||||
'broadcast': Any,
|
||||
'resize': str,
|
||||
'cpu_zero': bool,
|
||||
'zero': bool,
|
||||
}, total=False)
|
||||
|
||||
# A native_functions.yaml formal argument
|
||||
# type can contain Tensor, BoolTensor, IndexTensor types
|
||||
NativeFormal = TypedDict('NativeFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
|
||||
# Generic ATen formal.
|
||||
# type can contain Tensor& reference types.
|
||||
AtFormal = TypedDict('AtFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
|
||||
ReturnType = TypedDict('ReturnType', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
}, total=False)
|
||||
|
||||
ReturnDecl = TypedDict('ReturnDecl', {
|
||||
'kind': str,
|
||||
'type': str,
|
||||
'arguments': List[int],
|
||||
}, total=False)
|
||||
|
||||
# Represents a buffer in nn.yaml
|
||||
NNBuffer = TypedDict('NNBuffer', {
|
||||
'name': str,
|
||||
})
|
||||
|
||||
FunctionOption = TypedDict('FunctionOption', {
|
||||
'arguments': List[THFormal],
|
||||
'mode': str,
|
||||
'name': str,
|
||||
'return': ReturnDecl,
|
||||
'variants': str,
|
||||
'type_method_definition_dispatch': str,
|
||||
'cname': str,
|
||||
'backends': List[str],
|
||||
'api_name': str,
|
||||
'backend_type_pairs': List[Tuple[str, str]],
|
||||
'inplace': bool,
|
||||
'aten_dense_sparse': bool,
|
||||
'sparse': bool,
|
||||
'scalar_check': str,
|
||||
'aten_custom_call': str,
|
||||
'type_definition_body': List[str],
|
||||
# cimpls is really a List[FunctionOption]
|
||||
'cimpls': List[Any],
|
||||
'when_spares_dispatch': str,
|
||||
'actuals': List[str],
|
||||
'buffers': List[NNBuffer],
|
||||
'zero_dim_dispatch_when_scalar': str,
|
||||
'zero_dim_tensor_only': bool,
|
||||
'when_sparse_dispatch': str,
|
||||
'formals_list': List[AtFormal],
|
||||
'condition': str,
|
||||
'auto_gpu': bool,
|
||||
'cpu_half': bool,
|
||||
# options should be List[FunctionOption]
|
||||
'options': Any,
|
||||
'formals': List[str],
|
||||
'formals_with_defaults': List[str],
|
||||
'returns': List[ReturnType],
|
||||
'return_type': str,
|
||||
'return_call': str,
|
||||
'method_formals': List[str],
|
||||
'method_formals_with_defaults': List[str],
|
||||
'method_actuals': List[str],
|
||||
'const_mark': str,
|
||||
'method_prefix_derived': str,
|
||||
'broadcast_actuals': List[str],
|
||||
'broadcast_returns': List[str],
|
||||
'inferred_type': str,
|
||||
'broadcast_function': str,
|
||||
'broadcast_modified_actuals': List[str],
|
||||
'native_type_method_dispatch': str,
|
||||
})
|
||||
|
||||
|
||||
def is_real_argument_to_wrapper(argument):
|
||||
# type: (THFormal) -> bool
|
||||
return not argument.get('output', False) and\
|
||||
argument['type'] != 'CONSTANT' and\
|
||||
argument['type'] != 'argument'
|
||||
|
||||
|
||||
def is_mutable_formal_argument(argument, option):
|
||||
# type: (THFormal, FunctionOption) -> bool
|
||||
return argument.get('output') or option['inplace'] and argument['name'] == 'self'
|
||||
|
||||
|
||||
def to_return_type(arg, option):
|
||||
# type: (THFormal, FunctionOption) -> ReturnType
|
||||
t = arg['type']
|
||||
rt = TYPE_RETURN.get(t, t)
|
||||
if rt == 'Tensor' and not arg.get('allocate'):
|
||||
@ -300,8 +458,10 @@ def to_return_type(arg, option):
|
||||
|
||||
|
||||
def create_generic(top_env, declarations):
|
||||
# type: (TopEnvironment, List[FunctionOption]) -> List[OrderedDict]
|
||||
# translates defaults from cwrap types to C++ values
|
||||
def translate_default(argument, type_str, default):
|
||||
# type: (THFormal, str, Any) -> Any
|
||||
if default is None:
|
||||
# cause the default constructor for the object to run
|
||||
return '{}'
|
||||
@ -320,6 +480,7 @@ def create_generic(top_env, declarations):
|
||||
# change from THTensor* to Tensor & so we get how it will appear
|
||||
# in the aten argument list...
|
||||
def translate_formal(argument, option):
|
||||
# type: (THFormal, FunctionOption) -> AtFormal
|
||||
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
|
||||
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
|
||||
type_str = 'const ' + type_str
|
||||
@ -327,7 +488,7 @@ def create_generic(top_env, declarations):
|
||||
'name': argument['name'],
|
||||
'type': type_str,
|
||||
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
|
||||
}
|
||||
} # type: AtFormal
|
||||
if 'kwarg_only' in argument:
|
||||
translated['kwarg_only'] = argument['kwarg_only']
|
||||
if 'default' in argument:
|
||||
@ -347,11 +508,13 @@ def create_generic(top_env, declarations):
|
||||
return translated
|
||||
|
||||
def get_formals(option, include_constants=False):
|
||||
seen = set()
|
||||
pos_args = []
|
||||
kwd_args = []
|
||||
# type: (FunctionOption, bool) -> List[AtFormal]
|
||||
seen = set() # type: Set[str]
|
||||
pos_args = [] # type: List[THFormal]
|
||||
kwd_args = [] # type: List[THFormal]
|
||||
|
||||
def insert(argument):
|
||||
# type: (THFormal) -> None
|
||||
if argument['name'] not in seen:
|
||||
seen.add(argument['name'])
|
||||
if argument.get('kwarg_only', False):
|
||||
@ -360,6 +523,7 @@ def create_generic(top_env, declarations):
|
||||
pos_args.append(argument)
|
||||
|
||||
def has_output_mask(argument):
|
||||
# type: (THFormal) -> bool
|
||||
return argument.get('allocate', False) and argument.get('mask', False)
|
||||
|
||||
for argument in option['arguments']:
|
||||
@ -389,6 +553,7 @@ def create_generic(top_env, declarations):
|
||||
return [translate_formal(argument, option) for argument in result]
|
||||
|
||||
def get_return_types(option):
|
||||
# type: (FunctionOption) -> List[ReturnType]
|
||||
ret = option['return']
|
||||
if ret['kind'] == 'arguments':
|
||||
argument_indices = ret['arguments']
|
||||
@ -407,11 +572,13 @@ def create_generic(top_env, declarations):
|
||||
raise Exception("format_return_type")
|
||||
|
||||
def format_return_type(return_types):
|
||||
# type: (List[ReturnType]) -> str
|
||||
if len(return_types) == 1:
|
||||
return return_types[0]['type']
|
||||
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
|
||||
|
||||
def find_dispatch_tensor(formals):
|
||||
# type: (List[AtFormal]) -> Optional[str]
|
||||
# dispatch to self if it's a parameter
|
||||
for formal in formals:
|
||||
if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor':
|
||||
@ -423,9 +590,11 @@ def create_generic(top_env, declarations):
|
||||
return None
|
||||
|
||||
def format_formal(f):
|
||||
# type: (AtFormal) -> str
|
||||
return '{} {}'.format(f['type'], f['name'])
|
||||
|
||||
def formal_with_default(f):
|
||||
# type: (AtFormal) -> str
|
||||
s = format_formal(f)
|
||||
v = f.get('default')
|
||||
if v is None:
|
||||
@ -435,11 +604,15 @@ def create_generic(top_env, declarations):
|
||||
return '{}={}'.format(s, v)
|
||||
|
||||
def get_broadcast_argument(option):
|
||||
# type: (FunctionOption) -> Optional[THFormal]
|
||||
for argument in option['arguments']:
|
||||
if argument.get('broadcast'):
|
||||
return argument
|
||||
return None
|
||||
|
||||
def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
|
||||
# type: (THFormal, bool, bool) -> List[str]
|
||||
# Note: broadcast_dims can change type...
|
||||
# return the actuals that will be passed to the broadcast function.
|
||||
# 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors
|
||||
# that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2").
|
||||
@ -451,14 +624,15 @@ def create_generic(top_env, declarations):
|
||||
else:
|
||||
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
|
||||
# generate size call for each dimension
|
||||
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')'
|
||||
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
|
||||
for x in broadcast_dims_spec])
|
||||
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}'
|
||||
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
|
||||
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
|
||||
|
||||
return broadcast_actuals
|
||||
|
||||
def emit_nn_body(option):
|
||||
# type: (FunctionOption) -> Union[str, List[str]]
|
||||
# Concrete definition on Type.cpp for NN functions. Delegates to the
|
||||
# xxx_forward variant variant after creating any necessary buffers.
|
||||
actuals = option['actuals']
|
||||
@ -468,7 +642,7 @@ def create_generic(top_env, declarations):
|
||||
if len(option['buffers']) == 0:
|
||||
return 'return {}({});'.format(fwd_name, ', '.join(actuals))
|
||||
|
||||
body = []
|
||||
body = [] # type: List[str]
|
||||
if option['api_name'].endswith('_out'):
|
||||
# _out variants must create buffers and insert them in the
|
||||
# arguments list between output and input arguments
|
||||
@ -482,6 +656,7 @@ def create_generic(top_env, declarations):
|
||||
return body
|
||||
|
||||
def process_option(option, output_options):
|
||||
# type: (FunctionOption, List[OrderedDict]) -> None
|
||||
option['inplace'] = re.search(
|
||||
'(^__i|[^_]_$)', option['api_name']) is not None
|
||||
|
||||
@ -587,11 +762,13 @@ def create_generic(top_env, declarations):
|
||||
]))
|
||||
|
||||
def native_get_formals(option, include_constants=False):
|
||||
seen = set()
|
||||
# type: (FunctionOption, bool) -> List[AtFormal]
|
||||
seen = set() # type: Set[str]
|
||||
pos_args = []
|
||||
kwd_args = []
|
||||
|
||||
def insert(argument):
|
||||
# type: (NativeFormal) -> None
|
||||
if argument['name'] not in seen:
|
||||
seen.add(argument['name'])
|
||||
if argument.get('kwarg_only', False):
|
||||
@ -605,6 +782,7 @@ def create_generic(top_env, declarations):
|
||||
# not clear we need dynamic_type translation as we can specify the correct type
|
||||
# directly in native functions
|
||||
def add_type_as_dynamic_type(argument, option):
|
||||
# type: (NativeFormal, FunctionOption) -> NativeFormal
|
||||
argument['dynamic_type'] = argument['type']
|
||||
return argument
|
||||
|
||||
@ -613,7 +791,9 @@ def create_generic(top_env, declarations):
|
||||
|
||||
# ensure we get reference-type formals when appropriate
|
||||
def native_translate_formals(argument, option):
|
||||
# type: (NativeFormal, FunctionOption) -> AtFormal
|
||||
def translate_map(const):
|
||||
# type: (bool) -> Dict[str, str]
|
||||
return {
|
||||
'Tensor': 'const Tensor &' if const else 'Tensor &',
|
||||
'BoolTensor': 'const Tensor &' if const else 'Tensor &',
|
||||
@ -632,9 +812,10 @@ def create_generic(top_env, declarations):
|
||||
|
||||
# this can return multiple return types in a list, e.g. ['Tensor', 'Tensor']
|
||||
def native_get_return_types(option):
|
||||
# type: (FunctionOption) -> List[ReturnType]
|
||||
ret = option['return']
|
||||
|
||||
return_types = []
|
||||
return_types = [] # List[ReturnType]
|
||||
for t_raw in ret:
|
||||
if isinstance(t_raw, string_type):
|
||||
t = t_raw
|
||||
@ -653,7 +834,7 @@ def create_generic(top_env, declarations):
|
||||
rtype = {
|
||||
'type': actual_return_type,
|
||||
'dynamic_type': t,
|
||||
}
|
||||
} # type: ReturnType
|
||||
if name is not None:
|
||||
rtype['name'] = name
|
||||
return_types.append(rtype)
|
||||
@ -661,6 +842,7 @@ def create_generic(top_env, declarations):
|
||||
return return_types
|
||||
|
||||
def process_native(option, output_options):
|
||||
# type: (FunctionOption, List[OrderedDict]) -> None
|
||||
option['inplace'] = re.search(
|
||||
'(^__i|[^_]_$)', option['api_name']) is not None
|
||||
|
||||
@ -721,7 +903,7 @@ def create_generic(top_env, declarations):
|
||||
|
||||
# generate the at::native function declarations (i.e. what the user will implement)
|
||||
if isinstance(dispatch, dict):
|
||||
generated_native_functions = []
|
||||
generated_native_functions = [] # type: List[str]
|
||||
for key in sorted(dispatch.keys()):
|
||||
value = dispatch[key]
|
||||
if value not in generated_native_functions:
|
||||
@ -761,9 +943,9 @@ def create_generic(top_env, declarations):
|
||||
('abstract', abstract),
|
||||
]))
|
||||
|
||||
output_declarations = []
|
||||
output_declarations = [] # type: List[OrderedDict]
|
||||
for declaration in declarations:
|
||||
output_options = []
|
||||
output_options = [] # type: List[OrderedDict]
|
||||
for option in declaration['options']:
|
||||
try:
|
||||
if option['mode'] != 'native':
|
||||
@ -777,6 +959,7 @@ def create_generic(top_env, declarations):
|
||||
|
||||
|
||||
def create_derived(backend_type_env, declarations):
|
||||
# type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
|
||||
type_object_declarations = []
|
||||
type_object_definitions = []
|
||||
|
||||
@ -785,21 +968,26 @@ def create_derived(backend_type_env, declarations):
|
||||
real_is_half = backend_type_env['ScalarName'] == 'Half'
|
||||
|
||||
def replace_with_null(argument):
|
||||
# type: (THFormal) -> bool
|
||||
return (argument['type'] == 'THGenerator*' and
|
||||
backend_type_env['Backend'] == 'CUDA')
|
||||
|
||||
def requires_checked_cast(argument):
|
||||
# type: (THFormal) -> bool
|
||||
if argument['type'] == 'IntList':
|
||||
return 'size' in argument
|
||||
return argument['type'] in CHECKED_CAST
|
||||
|
||||
def nullable_argument(argument):
|
||||
# type: (THFormal) -> bool
|
||||
return argument.get('is_nullable', False)
|
||||
|
||||
def bool_option_is_string(argument):
|
||||
# type: (THFormal) -> bool
|
||||
return 'if_true' in argument and isinstance(argument['if_true'], string_type)
|
||||
|
||||
def get_argument(argument, option):
|
||||
# type: (THFormal, FunctionOption) -> str
|
||||
if replace_with_null(argument):
|
||||
return 'NULL'
|
||||
elif requires_checked_cast(argument):
|
||||
@ -834,14 +1022,17 @@ def create_derived(backend_type_env, declarations):
|
||||
return argument['name']
|
||||
|
||||
def drop_argument(argument, option):
|
||||
# type: (THFormal, FunctionOption) -> bool
|
||||
return 'CUDA' in backend_type_env['Backend'] and (
|
||||
option['mode'] == 'TH' and argument['type'] == 'THGenerator*')
|
||||
|
||||
def get_arguments(arguments, option):
|
||||
# type: (List[THFormal], FunctionOption) -> List[str]
|
||||
return [get_argument(argument, option)
|
||||
for argument in arguments if not drop_argument(argument, option)]
|
||||
|
||||
def is_actual_return_long(ret):
|
||||
# type: (ReturnDecl) -> bool
|
||||
if ret['type'] == 'long':
|
||||
return True
|
||||
if ret['type'] == 'real':
|
||||
@ -851,9 +1042,11 @@ def create_derived(backend_type_env, declarations):
|
||||
return False
|
||||
|
||||
def get_zero_dim_dispatch_when_scalar(option):
|
||||
return option.get('zero_dim_dispatch_when_scalar', False)
|
||||
# type: (FunctionOption) -> str
|
||||
return option.get('zero_dim_dispatch_when_scalar', False) # type: ignore
|
||||
|
||||
def handle_zero_dim(env, option):
|
||||
# type: (Environment, FunctionOption) -> List[str]
|
||||
zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option)
|
||||
if not zero_dim_dispatch:
|
||||
return []
|
||||
@ -863,6 +1056,7 @@ def create_derived(backend_type_env, declarations):
|
||||
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
|
||||
|
||||
def handle_only_zero_dim(env, option):
|
||||
# type: (Environment, FunctionOption) -> List[str]
|
||||
if option.get('zero_dim_tensor_only', False):
|
||||
check_name = get_zero_dim_dispatch_when_scalar(option)
|
||||
return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
|
||||
@ -870,6 +1064,7 @@ def create_derived(backend_type_env, declarations):
|
||||
return None
|
||||
|
||||
def handle_sparse(env, option):
|
||||
# type: (Environment, FunctionOption) -> List[str]
|
||||
if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
|
||||
return []
|
||||
check_name = option['when_sparse_dispatch']
|
||||
@ -879,6 +1074,7 @@ def create_derived(backend_type_env, declarations):
|
||||
return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
|
||||
|
||||
def allocate_arg(env, arg, output_count):
|
||||
# type: (Environment, THFormal, int) -> List[str]
|
||||
name = arg['name']
|
||||
allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[])
|
||||
tensor_arg = '{}_'.format(name)
|
||||
@ -892,6 +1088,7 @@ def create_derived(backend_type_env, declarations):
|
||||
]
|
||||
|
||||
def resize_arg(arg):
|
||||
# type: (THFormal) -> str
|
||||
resize = arg['resize']
|
||||
if isinstance(resize, str):
|
||||
return "{}.resize_({}.sizes());".format(arg['name'], resize)
|
||||
@ -904,6 +1101,7 @@ def create_derived(backend_type_env, declarations):
|
||||
return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))
|
||||
|
||||
def handle_call(env, option, cimpl):
|
||||
# type: (Environment, FunctionOption, FunctionOption) -> str
|
||||
is_nn = option['mode'] == 'NN'
|
||||
actuals = get_arguments(cimpl['arguments'], option)
|
||||
if is_cuda or is_nn:
|
||||
@ -926,7 +1124,8 @@ def create_derived(backend_type_env, declarations):
|
||||
return call
|
||||
|
||||
def emit_body(env, option):
|
||||
body = []
|
||||
# type: (Environment, FunctionOption) -> List[str]
|
||||
body = [] # type: List[str]
|
||||
body += handle_sparse(env, option)
|
||||
body += handle_zero_dim(env, option)
|
||||
only_zero_dim_check = handle_only_zero_dim(env, option)
|
||||
@ -937,8 +1136,8 @@ def create_derived(backend_type_env, declarations):
|
||||
|
||||
# arguments are potentially duplicated because of one argument
|
||||
# referencing another
|
||||
seen_names = set()
|
||||
seen_tensorlists = set()
|
||||
seen_names = set() # type: Set[str]
|
||||
seen_tensorlists = set() # type: Set[str]
|
||||
count = 0
|
||||
output_count = 0
|
||||
|
||||
@ -1064,7 +1263,7 @@ def create_derived(backend_type_env, declarations):
|
||||
scalar_check = 'maybe_scalar'
|
||||
for arg in arguments:
|
||||
scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict)
|
||||
else scalar_check.get(arg['name']))
|
||||
else scalar_check.get(arg['name'])) # type: ignore
|
||||
if scalar_check_arg is not None:
|
||||
stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check_arg)
|
||||
if nullable_argument(arg):
|
||||
@ -1114,11 +1313,12 @@ def create_derived(backend_type_env, declarations):
|
||||
return body
|
||||
|
||||
def process_option(option):
|
||||
# type: (FunctionOption) -> None
|
||||
pair = (backend_type_env['Backend'],
|
||||
backend_type_env['ScalarName'])
|
||||
if pair in option['backend_type_pairs']:
|
||||
env = nested_dict(option, backend_type_env)
|
||||
body = emit_body(env, option)
|
||||
body = emit_body(env, option) # type: ignore
|
||||
option['type_definition_body'] = body
|
||||
type_object_declarations.append(
|
||||
TYPE_DERIVED_DECLARATION.substitute(env))
|
||||
@ -1126,6 +1326,7 @@ def create_derived(backend_type_env, declarations):
|
||||
TYPE_DERIVED_DEFINITION.substitute(env))
|
||||
|
||||
def process_native(option):
|
||||
# type: (FunctionOption) -> None
|
||||
dispatch = option['type_method_definition_dispatch']
|
||||
env = nested_dict(option, backend_type_env)
|
||||
|
||||
|
Reference in New Issue
Block a user