mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Summary: This adds 88 matches. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17386 Differential Revision: D14179139 Pulled By: cpuhrsch fbshipit-source-id: 2c3263b8e4d084db84791e53290e8c8b1b7aecd5
		
			
				
	
	
		
			926 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			926 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Generates Python bindings for ATen functions
 | 
						|
#
 | 
						|
# The bindings are generated as methods on python_variable or functions on the
 | 
						|
# torch._C._nn object.
 | 
						|
#
 | 
						|
from collections import defaultdict
 | 
						|
import re
 | 
						|
from .nested_dict import nested_dict
 | 
						|
from .gen_variable_type import should_trace
 | 
						|
from .utils import write
 | 
						|
 | 
						|
try:
 | 
						|
    from src.ATen.code_template import CodeTemplate
 | 
						|
except ImportError:
 | 
						|
    from tools.shared.module_loader import import_module
 | 
						|
    CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
 | 
						|
 | 
						|
# These functions require manual Python bindings or are not exposed to Python
 | 
						|
SKIP_PYTHON_BINDINGS = [
 | 
						|
    'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
 | 
						|
    '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
 | 
						|
    '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
 | 
						|
    '_arange.*', '_range.*', '_linspace.*', '_logspace.*',
 | 
						|
    '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
 | 
						|
    'index',
 | 
						|
    '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
 | 
						|
    '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
 | 
						|
    '_th_.*', '_thnn_.*',
 | 
						|
    'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
 | 
						|
    '_potrs.*', '_cholesky.*',
 | 
						|
    'slice', 'randint(_out)?',
 | 
						|
    'item', '_local_scalar_dense',
 | 
						|
    'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
 | 
						|
    'copy_sparse_to_sparse_',
 | 
						|
]
 | 
						|
 | 
						|
# These function signatures are not exposed to Python. Note that this signature
 | 
						|
# list does not support regex.
 | 
						|
SKIP_PYTHON_BINDINGS_SIGNATURES = [
 | 
						|
    'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
 | 
						|
    'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
 | 
						|
    'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
 | 
						|
    'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
 | 
						|
]
 | 
						|
 | 
						|
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
 | 
						|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
 | 
						|
{
 | 
						|
  HANDLE_TH_ERRORS
 | 
						|
  static PythonArgParser parser({
 | 
						|
    ${signatures}
 | 
						|
  }, /*traceable=*/${traceable});
 | 
						|
  ${unpack_self}
 | 
						|
  ParsedArgs<${max_args}> parsed_args;
 | 
						|
  auto r = parser.parse(args, kwargs, parsed_args);
 | 
						|
  ${declare_namedtuple_return_types}
 | 
						|
  ${dispatch}
 | 
						|
  Py_RETURN_NONE;
 | 
						|
  END_HANDLE_TH_ERRORS
 | 
						|
}
 | 
						|
""")
 | 
						|
 | 
						|
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
 | 
						|
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
 | 
						|
{
 | 
						|
  HANDLE_TH_ERRORS
 | 
						|
  ${declare_namedtuple_return_types}
 | 
						|
  ${unpack_self}
 | 
						|
  return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
 | 
						|
  END_HANDLE_TH_ERRORS
 | 
						|
}
 | 
						|
""")
 | 
						|
 | 
						|
PY_VARIABLE_CASE = CodeTemplate("""\
 | 
						|
${cond} (r.idx == ${i}) {
 | 
						|
  ${call_dispatch}
 | 
						|
""")
 | 
						|
 | 
						|
PY_VARIABLE_OUT = CodeTemplate("""\
 | 
						|
if (r.isNone(${out_idx})) {
 | 
						|
  ${call_dispatch}
 | 
						|
} else {
 | 
						|
  ${call_dispatch_out}
 | 
						|
}
 | 
						|
""")
 | 
						|
 | 
						|
PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
 | 
						|
if (r.isNone(${out_idx})) {
 | 
						|
  ${call_dispatch}
 | 
						|
} else {
 | 
						|
  check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
 | 
						|
                         r.layout(${layout_idx}), r.isNone(${layout_idx}),
 | 
						|
                         r.device(${device_idx}), r.isNone(${device_idx}));
 | 
						|
  ${call_dispatch_out}
 | 
						|
}
 | 
						|
""")
 | 
						|
 | 
						|
PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
 | 
						|
${dispatch_name}(${actuals})""")
 | 
						|
 | 
						|
PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
 | 
						|
${call_dispatch}.set_requires_grad(${requires_grad})""")
 | 
						|
 | 
						|
PY_VARIABLE_WRAP = CodeTemplate("""\
 | 
						|
return wrap(${namedtuple_return_type}${call_dispatch});""")
 | 
						|
 | 
						|
PY_VARIABLE_DISPATCH = CodeTemplate("""\
 | 
						|
inline ${simple_return_type} ${dispatch_name}(${formal_args}) {
 | 
						|
  ${initialize_cuda}
 | 
						|
  ${AutoNoGIL}
 | 
						|
  return ${dispatch_call}(${dispatch_args});
 | 
						|
}
 | 
						|
""")
 | 
						|
 | 
						|
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
 | 
						|
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
 | 
						|
 | 
						|
PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\
 | 
						|
static PyStructSequence_Field fields${namedtuple_type_index}[] = {
 | 
						|
  ${namedtuple_fields} {nullptr}
 | 
						|
};
 | 
						|
static PyStructSequence_Desc desc${namedtuple_type_index} = {
 | 
						|
  "torch.return_types.${name}", nullptr,
 | 
						|
  fields${namedtuple_type_index}, ${namedtuple_size}
 | 
						|
};
 | 
						|
static PyTypeObject type${namedtuple_type_index};
 | 
						|
static bool namedtuple_type_initialized${namedtuple_type_index} = false;
 | 
						|
if (!namedtuple_type_initialized${namedtuple_type_index}) {
 | 
						|
  PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
 | 
						|
  namedtuple_type_initialized${namedtuple_type_index} = true;
 | 
						|
}
 | 
						|
""")
 | 
						|
 | 
						|
UNPACK_SELF = "auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
 | 
						|
 | 
						|
PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
 | 
						|
${name}(${py_formal_args})""")
 | 
						|
 | 
						|
# XXX: if you got here because of an assertion failure, it doesn't mean
 | 
						|
# it's enough to just extend the list here. Before you do this, make sure
 | 
						|
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
 | 
						|
SUPPORTED_RETURN_TYPES = {
 | 
						|
    'Tensor',
 | 
						|
    'std::tuple<Tensor,Tensor>',
 | 
						|
    'std::tuple<Tensor,Tensor,Tensor>',
 | 
						|
    'std::tuple<Tensor,Tensor,Tensor,Tensor>',
 | 
						|
    'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
 | 
						|
    'std::tuple<Tensor,Tensor,Tensor,int64_t>',
 | 
						|
    'std::tuple<Tensor,Tensor,double,int64_t>',
 | 
						|
    'std::vector<Tensor>',
 | 
						|
    'Scalar', 'bool', 'int64_t', 'void*', 'void'
 | 
						|
}
 | 
						|
 | 
						|
TENSOR_OPTIONS = CodeTemplate("""\
 | 
						|
const auto options = TensorOptions()
 | 
						|
    .dtype(${dtype})
 | 
						|
    .device(${device})
 | 
						|
    .layout(${layout}.layout)
 | 
						|
    .requires_grad(${requires_grad});
 | 
						|
""")
 | 
						|
 | 
						|
 | 
						|
def should_generate_python_binding(declaration):
 | 
						|
    name = declaration['name']
 | 
						|
    for pattern in SKIP_PYTHON_BINDINGS:
 | 
						|
        if re.match('^' + pattern + '$', name):
 | 
						|
            return False
 | 
						|
 | 
						|
    simple_types = [arg['simple_type'] for arg in declaration['arguments']]
 | 
						|
    signature = '{}({})'.format(name, ', '.join(simple_types))
 | 
						|
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
 | 
						|
        if pattern == signature:
 | 
						|
            return False
 | 
						|
 | 
						|
    # TODO: fix handling of SparseTensor. We don't want to generate Python
 | 
						|
    # bindings to SparseTensor overloads, such as add(Tensor, SparseTensorRef),
 | 
						|
    # since the Tensor-based signature already dynamically dispatches correctly.
 | 
						|
    # However, sparse_mask only has a SparseTensor signature so we need to bind
 | 
						|
    # that function.
 | 
						|
    for arg in declaration['arguments']:
 | 
						|
        if arg['type'] == 'SparseTensorRef' and declaration['name'] != 'sparse_mask':
 | 
						|
            return False
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def get_py_variable_methods(declarations):
 | 
						|
    """
 | 
						|
    Get declarations (grouped by name) which should be generated
 | 
						|
    as methods on Tensor.
 | 
						|
    """
 | 
						|
    def should_bind(declaration):
 | 
						|
        return (should_generate_python_binding(declaration) and
 | 
						|
                declaration['mode'] != 'NN' and
 | 
						|
                declaration.get('python_module') != 'nn' and
 | 
						|
                'Tensor' in declaration['method_of'])
 | 
						|
 | 
						|
    return group_declarations_by_name(declarations, should_bind)
 | 
						|
 | 
						|
 | 
						|
def gen_py_variable_methods(out, declarations, template_path):
 | 
						|
    PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
 | 
						|
    PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
 | 
						|
 | 
						|
    py_variable_methods = get_py_variable_methods(declarations)
 | 
						|
 | 
						|
    env = create_python_bindings(py_variable_methods, True)
 | 
						|
    write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
 | 
						|
    write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
 | 
						|
 | 
						|
 | 
						|
def get_py_nn_functions(declarations):
 | 
						|
    """
 | 
						|
    Get declarations (grouped by name) which should be generated
 | 
						|
    as functions in the "nn" module.
 | 
						|
    """
 | 
						|
    def should_bind(declaration):
 | 
						|
        return (should_generate_python_binding(declaration) and
 | 
						|
                (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
 | 
						|
 | 
						|
    return group_declarations_by_name(declarations, should_bind)
 | 
						|
 | 
						|
 | 
						|
def gen_py_nn_functions(out, declarations, template_path):
 | 
						|
    PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
 | 
						|
    PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
 | 
						|
    PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
 | 
						|
 | 
						|
    py_nn_functions = get_py_nn_functions(declarations)
 | 
						|
 | 
						|
    env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
 | 
						|
    write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
 | 
						|
    write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
 | 
						|
    write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
 | 
						|
 | 
						|
 | 
						|
def get_py_torch_functions(declarations):
 | 
						|
    """
 | 
						|
    Get declarations (grouped by name) which should be generated
 | 
						|
    as functions in the "torch" module.
 | 
						|
    """
 | 
						|
    def should_bind(declaration):
 | 
						|
        return (should_generate_python_binding(declaration) and
 | 
						|
                declaration['mode'] != 'NN' and
 | 
						|
                declaration.get('python_module') != 'nn' and
 | 
						|
                'namespace' in declaration['method_of'])
 | 
						|
 | 
						|
    return group_declarations_by_name(declarations, should_bind)
 | 
						|
 | 
						|
 | 
						|
def gen_py_torch_functions(out, declarations, template_path):
 | 
						|
    PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
 | 
						|
    PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
 | 
						|
 | 
						|
    py_torch_functions = get_py_torch_functions(declarations)
 | 
						|
 | 
						|
    env = create_python_bindings(py_torch_functions, has_self=False)
 | 
						|
    write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
 | 
						|
    write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)
 | 
						|
 | 
						|
 | 
						|
def group_declarations_by_name(declarations, should_bind_fn):
 | 
						|
    """Group declarations by name ignoring _out suffix"""
 | 
						|
    groups = defaultdict(list)
 | 
						|
    for declaration in declarations:
 | 
						|
        name = declaration['name']
 | 
						|
        if should_bind_fn(declaration):
 | 
						|
            if name.endswith('_out'):
 | 
						|
                groups[name[:-4]].append(declaration)
 | 
						|
            else:
 | 
						|
                groups[name].append(declaration)
 | 
						|
    return groups
 | 
						|
 | 
						|
 | 
						|
def get_type_default(declaration):
 | 
						|
    if declaration['name'].startswith('randperm') or \
 | 
						|
            declaration['name'] == 'tril_indices' or \
 | 
						|
            declaration['name'] == 'triu_indices':
 | 
						|
        return 'torch.int64'
 | 
						|
    else:
 | 
						|
        return 'None'
 | 
						|
 | 
						|
 | 
						|
def create_python_bindings(python_functions, has_self, is_module=False):
 | 
						|
    """Generates Python bindings to ATen functions"""
 | 
						|
    py_methods = []
 | 
						|
    py_method_defs = []
 | 
						|
    py_method_dispatch = []
 | 
						|
 | 
						|
    unpack_methods = {
 | 
						|
        'const Tensor &': 'tensor',
 | 
						|
        'SparseTensorRef': 'tensor',
 | 
						|
        'Tensor &': 'tensor',
 | 
						|
        'Generator *': 'generator',
 | 
						|
        'Storage &': 'storage',
 | 
						|
        'const Type &': 'scalartype',
 | 
						|
        'const THPLayout &': 'layout',
 | 
						|
        'const Device &': 'device',
 | 
						|
        'c10::optional<ScalarType>': 'scalartypeOptional',
 | 
						|
        'c10::optional<Scalar>': 'scalarOptional',
 | 
						|
        'c10::optional<int64_t>': 'toInt64Optional',
 | 
						|
        'IntArrayRef': 'intlist',
 | 
						|
        'int64_t': 'toInt64',
 | 
						|
        'bool': 'toBool',
 | 
						|
        'double': 'toDouble',
 | 
						|
        'std::string': 'string',
 | 
						|
    }
 | 
						|
 | 
						|
    unpack_with_default_methods = {
 | 
						|
        'IntArrayRef': 'setDefaultIntlist',
 | 
						|
        'Scalar': 'scalarWithDefault',
 | 
						|
        'int64_t': 'toInt64WithDefault',
 | 
						|
        'bool': 'setDefaultBool',
 | 
						|
        'double': 'setDefaultDouble',
 | 
						|
        'const Type &': 'scalartypeWithDefault',
 | 
						|
        'const THPLayout &': 'layoutWithDefault',
 | 
						|
        'const Device &': 'deviceWithDefault',
 | 
						|
        'ScalarType': 'scalartypeWithDefault',
 | 
						|
    }
 | 
						|
 | 
						|
    def emit_single_dispatch(declaration, out_idx, base_env):
 | 
						|
        env = {}
 | 
						|
        simple_return_type = declaration['return_type'].replace(' &', '')
 | 
						|
        assert simple_return_type in SUPPORTED_RETURN_TYPES, \
 | 
						|
            declaration['name'] + ' returns unsupported type: ' + simple_return_type
 | 
						|
 | 
						|
        body = []
 | 
						|
        actuals = []
 | 
						|
        formal_args = []
 | 
						|
        arg_idx = 0
 | 
						|
 | 
						|
        def is_output(arg):
 | 
						|
            return arg.get('output', False)
 | 
						|
 | 
						|
        inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
 | 
						|
        outputs = [arg for arg in declaration['arguments'] if is_output(arg)]
 | 
						|
 | 
						|
        has_tensor_options = any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments'])
 | 
						|
 | 
						|
        def get_type_args(args):
 | 
						|
            return [arg for arg in args if arg['simple_type'] == 'Type']
 | 
						|
        type_actual_args = get_type_args(declaration['arguments'])
 | 
						|
        type_binding_args = get_type_args(declaration['python_binding_arguments'])
 | 
						|
        assert len(type_actual_args + type_binding_args) <= 1
 | 
						|
        if type_binding_args and len(outputs) == 0:
 | 
						|
            # out(s) determines the dtype if it is present, so only use this if there are no outputs.
 | 
						|
            type_args = type_binding_args
 | 
						|
        else:
 | 
						|
            type_args = type_actual_args
 | 
						|
 | 
						|
        if type_args and len(outputs) > 1:
 | 
						|
            raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")
 | 
						|
 | 
						|
        def parse_arg(arg, arg_index, unpack_args=False):
 | 
						|
            name = arg['name']
 | 
						|
            typename = arg['type']
 | 
						|
            if typename.startswith('IntArrayRef['):
 | 
						|
                typename = 'IntArrayRef'
 | 
						|
            if typename.startswith('LongTensor'):
 | 
						|
                typename = 'Tensor'
 | 
						|
 | 
						|
            if arg.get('python_default_init'):
 | 
						|
                assert typename in unpack_with_default_methods, \
 | 
						|
                    '`{}` type is not supported in python_default_init'.format(typename)
 | 
						|
                unpack_with_default = unpack_with_default_methods.get(typename)
 | 
						|
                default_expr = arg.get('python_default_init')
 | 
						|
                expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
 | 
						|
            else:
 | 
						|
                unpack = unpack_methods.get(typename, typename.lower())
 | 
						|
                expr = 'r.{}({})'.format(unpack, arg_index)
 | 
						|
 | 
						|
            if unpack_args:
 | 
						|
                body.append('auto {} = {};'.format(name, expr))
 | 
						|
                expr = name
 | 
						|
 | 
						|
            if typename == 'SparseTensorRef':
 | 
						|
                expr = 'SparseTensorRef({})'.format(expr)
 | 
						|
 | 
						|
            dispatch_type = typename
 | 
						|
            if dispatch_type == 'Tensor':
 | 
						|
                dispatch_type = 'const Tensor &'
 | 
						|
            elif dispatch_type == 'Tensor &':
 | 
						|
                dispatch_type = 'Tensor'
 | 
						|
            elif dispatch_type == 'const Device &':
 | 
						|
                dispatch_type = 'c10::optional<int32_t>'
 | 
						|
            formal = '{} {}'.format(dispatch_type, name)
 | 
						|
            return expr, formal
 | 
						|
 | 
						|
        def append_actuals_formals(actual, formal):
 | 
						|
            actuals.append(actual)
 | 
						|
            formal_args.append(formal)
 | 
						|
 | 
						|
        # We always want to unpack when we have TensorOptions.
 | 
						|
        unpack = has_tensor_options
 | 
						|
        for arg in inputs:
 | 
						|
            if arg['simple_type'] in ['Type', 'TensorOptions']:
 | 
						|
                continue
 | 
						|
            if has_self and arg['name'] == 'self':
 | 
						|
                formal_args.append('Tensor & self')
 | 
						|
                actuals.append('self')
 | 
						|
                continue
 | 
						|
            append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
 | 
						|
            arg_idx += 1
 | 
						|
 | 
						|
        if len(outputs) == 1:
 | 
						|
            append_actuals_formals(*parse_arg(outputs[0], arg_idx))
 | 
						|
        elif len(outputs) > 1:
 | 
						|
            N = len(outputs)
 | 
						|
            body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
 | 
						|
            for i, arg in enumerate(outputs):
 | 
						|
                formal_args.append('Tensor & {}'.format(arg['name']))
 | 
						|
                actuals.append('results[{}]'.format(i))
 | 
						|
 | 
						|
        layout = None
 | 
						|
        parsed_type_args = None
 | 
						|
        # type args go after the outputs to match the signature generation.
 | 
						|
        arg_idx = arg_idx if out_idx is None else out_idx + 1
 | 
						|
        for arg in type_args:
 | 
						|
            parsed_type_args = parse_arg(arg, arg_idx, unpack)
 | 
						|
            arg_idx += 1
 | 
						|
 | 
						|
        # check python_binding_arguments
 | 
						|
        has_device_bind = False
 | 
						|
        requires_grad = None
 | 
						|
        python_binding_arguments = declaration.get('python_binding_arguments', [])
 | 
						|
        if 'dtype' in (a['name'] for a in python_binding_arguments):
 | 
						|
            if not has_tensor_options:
 | 
						|
                arg_idx += 1
 | 
						|
 | 
						|
        if 'layout' in (a['name'] for a in python_binding_arguments):
 | 
						|
            layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
 | 
						|
        else:
 | 
						|
            device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
 | 
						|
 | 
						|
        device = None
 | 
						|
        for arg in python_binding_arguments:
 | 
						|
            if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
 | 
						|
                pass  # already handled by type_dispatched_args
 | 
						|
            elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
 | 
						|
                # out(s) determines the type and layout if it is present, so only use this if there are no outputs.
 | 
						|
                if len(outputs) == 0:
 | 
						|
                    layout = parse_arg(arg, layout_idx)[0]
 | 
						|
            elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
 | 
						|
                if len(outputs) == 0:
 | 
						|
                    assert parsed_type_args
 | 
						|
                    assert layout
 | 
						|
                    device, device_type = parse_arg(arg, device_idx, True)
 | 
						|
 | 
						|
                    if not has_tensor_options:
 | 
						|
                        # add type, device formals and corresponding actuals.
 | 
						|
                        # The type actual is the ATen type mapped from (ScalarType, Layout, Device)
 | 
						|
                        # The device actual is the corresponding AutoGPU index for the Device.
 | 
						|
                        formal_args.append(parsed_type_args[1])
 | 
						|
                        formal_args.append(device_type)
 | 
						|
                        actuals.append("torch::getVariableType({}, {}, {})".format(parsed_type_args[0], layout, device))
 | 
						|
                        actuals.append('{}.index()'.format(device))
 | 
						|
 | 
						|
                    has_device_bind = True
 | 
						|
            elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
 | 
						|
                requires_grad = parse_arg(arg, requires_grad_idx)[0]
 | 
						|
            else:
 | 
						|
                raise RuntimeError(("found {} in python_binding_arguments but only "
 | 
						|
                                    "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
 | 
						|
                                    "\"Device device\" are supported".format(arg)))
 | 
						|
 | 
						|
        dtype = parsed_type_args[0] if parsed_type_args else None
 | 
						|
        if has_tensor_options and all([dtype, device, layout, requires_grad]):
 | 
						|
            body.append(TENSOR_OPTIONS.substitute({
 | 
						|
                'dtype': dtype,
 | 
						|
                'layout': layout,
 | 
						|
                'device': device,
 | 
						|
                'requires_grad': requires_grad
 | 
						|
            }))
 | 
						|
            formal_args.append('const TensorOptions & options')
 | 
						|
            actuals.append('options')
 | 
						|
 | 
						|
        env['unpack_args'] = []
 | 
						|
        env['formal_args'] = formal_args
 | 
						|
        env['actuals'] = actuals
 | 
						|
 | 
						|
        if has_tensor_options:
 | 
						|
            env['initialize_cuda'] = 'maybe_initialize_cuda(options);'
 | 
						|
        else:
 | 
						|
            env['initialize_cuda'] = ''
 | 
						|
 | 
						|
        if 'call_args' in declaration:
 | 
						|
            env['dispatch_args'] = declaration['call_args']
 | 
						|
        else:
 | 
						|
            env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]
 | 
						|
 | 
						|
        if 'Tensor' in declaration['method_of']:
 | 
						|
            env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
 | 
						|
            env['dispatch_call'] = 'self.{}'.format(declaration['name'])
 | 
						|
        elif 'namespace' in declaration['method_of']:
 | 
						|
            namespace = 'torch' if (has_tensor_options or declaration['name'].endswith('_like')) else 'at'
 | 
						|
            env['dispatch_call'] = '{}::{}'.format(namespace, declaration['name'])
 | 
						|
        else:
 | 
						|
            raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')
 | 
						|
 | 
						|
        env['AutoNoGIL'] = 'AutoNoGIL no_gil;' if not declaration['with_gil'] else ''
 | 
						|
 | 
						|
        # Use the simple_return_type (Tensor) rather than the fancy return type
 | 
						|
        # (Tensor &).  This is important because the dispatch functions take
 | 
						|
        # mutable arguments *by value*, not by reference.  If you then return
 | 
						|
        # a a reference to such an argument, you will now have a pointer to a
 | 
						|
        # dangling stack entry.  Not good.
 | 
						|
        #
 | 
						|
        # You want:
 | 
						|
        #
 | 
						|
        #   Tensor dispatch_selu_(Tensor self) { return at::selu_(self); }
 | 
						|
        #
 | 
						|
        # *not*
 | 
						|
        #
 | 
						|
        #   Tensor& dispatch_selu_(Tensor self) { return at::selu_(self); }
 | 
						|
        #
 | 
						|
        # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
 | 
						|
        # codegen looks like dispatch_selu_(wrap(tensor)), and you can't take a
 | 
						|
        # mutable reference to temporary.  Maybe we could assign it to a
 | 
						|
        # variable itself.)
 | 
						|
        env['simple_return_type'] = simple_return_type
 | 
						|
 | 
						|
        env = nested_dict(env, nested_dict(base_env, declaration))
 | 
						|
        call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
 | 
						|
        if requires_grad and not has_tensor_options:
 | 
						|
            call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
 | 
						|
                                                                     requires_grad=requires_grad)
 | 
						|
        if simple_return_type == 'void':
 | 
						|
            body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
 | 
						|
            body.append('Py_RETURN_NONE;')
 | 
						|
        else:
 | 
						|
            body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
 | 
						|
        py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
 | 
						|
        return body
 | 
						|
 | 
						|
    def emit_dispatch(i, dictionary, base_env):
 | 
						|
        if 'out' in dictionary:
 | 
						|
            out_idx = len([arg for arg in dictionary['out']['arguments']
 | 
						|
                           if not arg.get('output', False)])
 | 
						|
            env = {}
 | 
						|
            env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
 | 
						|
            env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)
 | 
						|
 | 
						|
            has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
 | 
						|
            if has_dtype_bind:
 | 
						|
                body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
 | 
						|
                                                             layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
 | 
						|
            else:
 | 
						|
                body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
 | 
						|
        else:
 | 
						|
            body = emit_single_dispatch(dictionary['base'], None, base_env)
 | 
						|
 | 
						|
        cond = 'if' if i == 0 else '} else if'
 | 
						|
        return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)
 | 
						|
 | 
						|
    def get_python_binding_arguments(declaration):
 | 
						|
        python_binding_arguments = []
 | 
						|
        has_tensor_input_arg = False
 | 
						|
        has_type_input_arg = False
 | 
						|
        has_options_arg = False
 | 
						|
        for arg in declaration['arguments']:
 | 
						|
            if arg.get('output', False):
 | 
						|
                continue
 | 
						|
            typename = arg['simple_type']
 | 
						|
            if typename in ['Tensor', 'TensorList']:
 | 
						|
                has_tensor_input_arg = True
 | 
						|
            if arg['simple_type'] == 'Type':
 | 
						|
                has_type_input_arg = True
 | 
						|
            elif arg['simple_type'] == 'TensorOptions':
 | 
						|
                has_options_arg = True
 | 
						|
            if arg['name'] == 'requires_grad':
 | 
						|
                raise ValueError("argument named requires_grad not supported")
 | 
						|
 | 
						|
        has_tensor_return = False
 | 
						|
        for ret in declaration['returns']:
 | 
						|
            if ret['dynamic_type'] in ['Tensor', 'TensorList']:
 | 
						|
                # this probably won't work if one of the returns is not a tensor, but it will
 | 
						|
                # produce a compile-time error that is obvious
 | 
						|
                has_tensor_return = True
 | 
						|
 | 
						|
        is_like_function = name.endswith('_like')
 | 
						|
        is_like_function_with_options = is_like_function and has_options_arg
 | 
						|
        is_factory_function = has_tensor_return and not has_tensor_input_arg
 | 
						|
        is_factory_or_like_function = has_tensor_return and (not has_tensor_input_arg or is_like_function)
 | 
						|
 | 
						|
        if (is_factory_function and not has_type_input_arg) or has_options_arg:
 | 
						|
            default_type = get_type_default(declaration)
 | 
						|
            py_default_dtype = 'self.scalar_type()' if is_like_function_with_options else None
 | 
						|
            dtype_arg = {
 | 
						|
                'default': default_type,
 | 
						|
                'dynamic_type': 'Type',
 | 
						|
                'kwarg_only': True,
 | 
						|
                'name': 'dtype',
 | 
						|
                'type': 'const Type &',
 | 
						|
                'simple_type': 'Type',
 | 
						|
                'python_default_init': py_default_dtype,
 | 
						|
            }
 | 
						|
            python_binding_arguments.append(dtype_arg)
 | 
						|
        if is_factory_function or is_like_function_with_options:
 | 
						|
            py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_function_with_options else None
 | 
						|
            layout_arg = {
 | 
						|
                'default': 'torch.strided',
 | 
						|
                'dynamic_type': 'Layout',
 | 
						|
                'kwarg_only': True,
 | 
						|
                'name': 'layout',
 | 
						|
                'type': 'const THPLayout &',
 | 
						|
                'simple_type': 'Layout',
 | 
						|
                'python_default_init': py_default_layout,
 | 
						|
            }
 | 
						|
            python_binding_arguments.append(layout_arg)
 | 
						|
            py_default_device = 'self.device()' if is_like_function_with_options else None
 | 
						|
            device_arg = {
 | 
						|
                'default': 'None',
 | 
						|
                'default_init': 'None',
 | 
						|
                'dynamic_type': 'Device',
 | 
						|
                'kwarg_only': True,
 | 
						|
                'name': 'device',
 | 
						|
                'type': 'const Device &',
 | 
						|
                'simple_type': 'Device',
 | 
						|
                'python_default_init': py_default_device
 | 
						|
            }
 | 
						|
            python_binding_arguments.append(device_arg)
 | 
						|
        if is_factory_or_like_function:
 | 
						|
            requires_grad_arg = {
 | 
						|
                'default': False,
 | 
						|
                'dynamic_type': 'bool',
 | 
						|
                'kwarg_only': True,
 | 
						|
                'name': 'requires_grad',
 | 
						|
                'type': 'bool',
 | 
						|
                'simple_type': 'bool',
 | 
						|
            }
 | 
						|
            python_binding_arguments.append(requires_grad_arg)
 | 
						|
        return python_binding_arguments
 | 
						|
 | 
						|
    def emit_namedtuple_return_type_def(declaration, next_index):
 | 
						|
        returns = declaration['returns']
 | 
						|
        if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
 | 
						|
            declaration['namedtuple_return_type'] = ''
 | 
						|
            return '', next_index
 | 
						|
        declaration['namedtuple_type_index'] = next_index
 | 
						|
        declaration['namedtuple_fields'] = ''
 | 
						|
        for x in returns:
 | 
						|
            # See Note [field_name versus name]
 | 
						|
            if 'field_name' not in x:
 | 
						|
                # When building on Windows, `PyStructSequence_UnnamedField` could not be
 | 
						|
                # resolved by the linker for some reason, which cause error in building:
 | 
						|
                #
 | 
						|
                # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
 | 
						|
                # PyStructSequence_UnnamedField
 | 
						|
                #
 | 
						|
                # Thus, at this point in time, we do not support unnamed
 | 
						|
                # fields in namedtuple; you must either name all fields,
 | 
						|
                # or none of them.
 | 
						|
                raise ValueError("Unnamed field is not supported by codegen")
 | 
						|
            else:
 | 
						|
                declaration['namedtuple_fields'] += '{"' + x['field_name'] + '", ""}, '
 | 
						|
        declaration['namedtuple_size'] = len(returns)
 | 
						|
        declaration['namedtuple_return_type'] = '&type{}, '.format(next_index)
 | 
						|
        return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1
 | 
						|
 | 
						|
    def process_function(name, declarations):
 | 
						|
        for declaration in declarations:
 | 
						|
            declaration['python_binding_arguments'] = get_python_binding_arguments(declaration)
 | 
						|
 | 
						|
        env = {
 | 
						|
            'name': name,
 | 
						|
            'dispatch_name': 'dispatch_{}'.format(name),
 | 
						|
            'pycname': 'THPVariable_{}'.format(name),
 | 
						|
            'signatures': [],
 | 
						|
            'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
 | 
						|
            'unpack_self': [],
 | 
						|
            'dispatch': [],
 | 
						|
            'declare_namedtuple_return_types': '',
 | 
						|
        }
 | 
						|
 | 
						|
        if has_self:
 | 
						|
            env['unpack_self'] = [UNPACK_SELF]
 | 
						|
 | 
						|
        # generate namedtuple type declare
 | 
						|
        next_index = 0
 | 
						|
        for declaration in declarations:
 | 
						|
            typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index)
 | 
						|
            env['declare_namedtuple_return_types'] += typedef
 | 
						|
 | 
						|
        # emit dispatch
 | 
						|
        grouped = group_declarations(declarations)
 | 
						|
        for i, dictionary in enumerate(grouped):
 | 
						|
            signature = dictionary['signature']
 | 
						|
            if has_self:
 | 
						|
                signature = signature.replace('Tensor self, ', '')
 | 
						|
                signature = signature.replace('Tensor self', '')
 | 
						|
            if not has_self:
 | 
						|
                # Use 'input' instead of 'self' for NN functions
 | 
						|
                signature = signature.replace('Tensor self', 'Tensor input')
 | 
						|
            signature = signature.replace('SparseTensorRef', 'Tensor')
 | 
						|
            if dictionary['base'].get('deprecated', False):
 | 
						|
                signature += '|deprecated'
 | 
						|
            env['signatures'].append('"{}",'.format(signature))
 | 
						|
            env['dispatch'].append(emit_dispatch(i, dictionary, env))
 | 
						|
 | 
						|
        env['dispatch'].append('}')
 | 
						|
 | 
						|
        env['traceable'] = 'true' if all(should_trace(d) for d in declarations) else 'false'
 | 
						|
 | 
						|
        if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
 | 
						|
            tmpl = PY_VARIABLE_METHOD_NOARGS
 | 
						|
            env['actuals'] = ['self']
 | 
						|
            env['flags'] = 'METH_NOARGS'
 | 
						|
            env['namedtuple_return_type'] = declarations[0]['namedtuple_return_type']
 | 
						|
        else:
 | 
						|
            tmpl = PY_VARIABLE_METHOD_VARARGS
 | 
						|
            env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
 | 
						|
 | 
						|
        if not is_module and not has_self:
 | 
						|
            env['flags'] += ' | METH_STATIC'
 | 
						|
 | 
						|
        py_methods.append(tmpl.substitute(env))
 | 
						|
        py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
 | 
						|
 | 
						|
    for name in sorted(python_functions.keys()):
 | 
						|
        process_function(name, python_functions[name])
 | 
						|
 | 
						|
    return {
 | 
						|
        'py_methods': py_methods,
 | 
						|
        'py_method_defs': py_method_defs,
 | 
						|
        'py_method_dispatch': py_method_dispatch,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def group_declarations(declarations):
 | 
						|
    """Returns a list of dictionaries containing the optional keys:
 | 
						|
 | 
						|
       "base": the regular ATen declaration (e.g. conv2d)
 | 
						|
       "out": the out variant (e.g. conv2d_out)
 | 
						|
       "signature": the signature used for Python argument parsing
 | 
						|
    """
 | 
						|
    grouped = defaultdict(dict)
 | 
						|
 | 
						|
    # first group by signature ignoring out arguments
 | 
						|
    for declaration in declarations:
 | 
						|
        signature = get_python_signature(declaration, False)
 | 
						|
        v = grouped[signature]
 | 
						|
        if declaration['name'].endswith('_out'):
 | 
						|
            v['out'] = declaration
 | 
						|
            # prefer the signature with optional out=... arguments
 | 
						|
            v['signature'] = get_python_signature(declaration, True)
 | 
						|
        else:
 | 
						|
            v['base'] = declaration
 | 
						|
            if 'signature' not in v:
 | 
						|
                v['signature'] = signature
 | 
						|
 | 
						|
    result = []
 | 
						|
    for _, dictionary in sorted(grouped.items()):
 | 
						|
        if 'base' not in dictionary:
 | 
						|
            raise RuntimeError("'base' not in dictionary", dictionary)
 | 
						|
        result.append(dictionary)
 | 
						|
    return sort_declarations(result)
 | 
						|
 | 
						|
 | 
						|
# This function declares a partial order on declarations, and sorts them according
 | 
						|
# to its linear extension. This is necessary, because there's some ambiguity in the
 | 
						|
# choice of overload, and we want a different order.
 | 
						|
#
 | 
						|
# See Note[Order of overloads matters]
 | 
						|
def sort_declarations(grouped_decls):
 | 
						|
 | 
						|
    # TODO: This is a hack!
 | 
						|
    #
 | 
						|
    # For some reason, when you specify a Scalar argument in a native
 | 
						|
    # function, you get a Declarations.yaml entry that looks like this:
 | 
						|
    #
 | 
						|
    #   - default: 1
 | 
						|
    #     dynamic_type: Scalar
 | 
						|
    #     is_nullable: false
 | 
						|
    #     kwarg_only: true
 | 
						|
    #     name: alpha
 | 
						|
    #     type: Scalar
 | 
						|
    #
 | 
						|
    # This is contrast to when there is a 'real' argument in TH
 | 
						|
    # Declarations.cwrap; this gets (correctly?) translated into
 | 
						|
    # dynamic_type: real, and type: Scalar.  I would like to fix this
 | 
						|
    # at the source but I have never understood what dynamic_type is
 | 
						|
    # supposed to be.
 | 
						|
    def normalized_dynamic_type(arg):
 | 
						|
        if arg['dynamic_type'] == 'real':
 | 
						|
            return 'Scalar'
 | 
						|
        return arg['dynamic_type']
 | 
						|
 | 
						|
    def is_coord_smaller(arg1, arg2):
 | 
						|
        return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'
 | 
						|
 | 
						|
    def is_smaller(d1, d2):
 | 
						|
        """Returns True if d1 < d2 in the partial order."""
 | 
						|
        args1, args2 = d1['base']['arguments'], d2['base']['arguments']
 | 
						|
        if len(args1) != len(args2):
 | 
						|
            return False
 | 
						|
        any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
 | 
						|
        all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
 | 
						|
                                   is_coord_smaller(arg1, arg2)
 | 
						|
                                   for arg1, arg2 in zip(args1, args2))
 | 
						|
        return any_smaller and all_smaller_or_equal
 | 
						|
 | 
						|
    # Construct the relation graph
 | 
						|
    larger_than = defaultdict(set)
 | 
						|
    for i1, decl1 in enumerate(grouped_decls):
 | 
						|
        for i2, decl2 in enumerate(grouped_decls):
 | 
						|
            if is_smaller(decl1, decl2):
 | 
						|
                larger_than[i1].add(i2)
 | 
						|
 | 
						|
    if not larger_than:
 | 
						|
        return grouped_decls
 | 
						|
 | 
						|
    # Use a topological sort to sort decls according to the partial order.
 | 
						|
    sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
 | 
						|
                   if i not in larger_than]
 | 
						|
    for i, decl in sorted_deps:
 | 
						|
        for i2 in sorted(larger_than.keys()):
 | 
						|
            larger = larger_than[i2]
 | 
						|
            larger.discard(i)
 | 
						|
            if not larger:
 | 
						|
                del larger_than[i2]
 | 
						|
                sorted_deps.append((i2, grouped_decls[i2]))
 | 
						|
 | 
						|
    return [decl for i, decl in sorted_deps]
 | 
						|
 | 
						|
 | 
						|
def get_python_signature(declaration, include_out):
 | 
						|
    # Compute the Python function signature for argument parsing,
 | 
						|
    # as specified in torch/csrc/utils/python_arg_parser.h.  WARNING:
 | 
						|
    # this is NOT the same type signature as specified by PEP 484
 | 
						|
    # as understood by mypy; our format was independently developed
 | 
						|
    # and has some quirks to make it more suitable specifically
 | 
						|
    # for error parsing.
 | 
						|
    #
 | 
						|
    # For a translation to mypy-valid type signatures, see
 | 
						|
    # tools/gen_pyi.py.  If you change any logic here, please
 | 
						|
    # check that file too.
 | 
						|
    py_formal_args = []
 | 
						|
    output_args = []
 | 
						|
    type_args = []
 | 
						|
    positional = True
 | 
						|
 | 
						|
    def get_py_formal_arg(arg):
 | 
						|
        typename = arg['simple_type']
 | 
						|
        typename = typename if typename != 'Type' else 'ScalarType'
 | 
						|
 | 
						|
        # TODO: remove this and make optional types in simple_type to be consistent across
 | 
						|
        # tensor and other types after make Tensor? be optional instead of undefined
 | 
						|
        if arg.get('is_nullable') and '?' not in typename:
 | 
						|
            typename = '{}?'.format(typename)
 | 
						|
 | 
						|
        if arg.get('size') is not None:
 | 
						|
            typename = '{}[{}]'.format(typename, arg['size'])
 | 
						|
        param = typename + ' ' + arg['name']
 | 
						|
        default = None
 | 
						|
        if arg.get('default') is not None:
 | 
						|
            default = arg['default']
 | 
						|
            if default == 'nullptr' or default == 'nullopt' or default == '{}':
 | 
						|
                default = 'None'
 | 
						|
        if default is not None:
 | 
						|
            param += '=' + str(default)
 | 
						|
        return param
 | 
						|
 | 
						|
    for arg in declaration['arguments']:
 | 
						|
        if arg.get('output', False):
 | 
						|
            output_args.append(arg)
 | 
						|
            continue
 | 
						|
        if arg['simple_type'] == 'Type':
 | 
						|
            type_args.append(arg)
 | 
						|
            continue
 | 
						|
        # Skip `TensorOptions` in Python, as it is only used on the C++ side.
 | 
						|
        if arg['simple_type'] == 'TensorOptions':
 | 
						|
            continue
 | 
						|
        if arg.get('kwarg_only', False) and positional:
 | 
						|
            py_formal_args.append('*')
 | 
						|
            positional = False
 | 
						|
        param = get_py_formal_arg(arg)
 | 
						|
        py_formal_args.append(param)
 | 
						|
 | 
						|
    # add output arguments
 | 
						|
    name = declaration['name']
 | 
						|
    if name.endswith('_out'):
 | 
						|
        name = name[:-4]
 | 
						|
 | 
						|
    if len(output_args) > 0 and include_out:
 | 
						|
        assert declaration['name'].endswith('_out')
 | 
						|
        if positional:
 | 
						|
            py_formal_args.append('*')
 | 
						|
            positional = False
 | 
						|
        typenames = [arg['simple_type'] for arg in output_args]
 | 
						|
        if len(typenames) > 1:
 | 
						|
            typename = 'TensorList[{}]'.format(len(typenames))
 | 
						|
        else:
 | 
						|
            typename = typenames[0]
 | 
						|
        if len(output_args) == 1:
 | 
						|
            # The nn module bindings are often not exposed to the user directly
 | 
						|
            # but via torch.nn modules and functionals.
 | 
						|
            py_formal_args.append(typename + ' ' + output_args[0]['name'] + '=None')
 | 
						|
        else:
 | 
						|
            # NB: For more than 1 output args the type name is a TensorList
 | 
						|
            # and as such we don't (yet) need to consider the naming.
 | 
						|
            py_formal_args.append(typename + ' out=None')
 | 
						|
 | 
						|
    # we could put this in the loop above but we want to ensure both type dispatched args
 | 
						|
    # and python binding arguments are after the out argument; this matches the case
 | 
						|
    # where there is a python binding argument dtype, which is necessary to match
 | 
						|
    # the function signatures between the out and non-out variant.
 | 
						|
    assert len(type_args) <= 1
 | 
						|
    for arg in type_args:
 | 
						|
        if positional:  # assume type_args should be kwarg_only.
 | 
						|
            py_formal_args.append('*')
 | 
						|
            positional = False
 | 
						|
        py_formal_args.append(get_py_formal_arg(arg))
 | 
						|
 | 
						|
    if len(declaration['python_binding_arguments']) > 0:
 | 
						|
        for arg in declaration['python_binding_arguments']:
 | 
						|
            if arg.get('kwarg_only', False) and positional:
 | 
						|
                py_formal_args.append('*')
 | 
						|
                positional = False
 | 
						|
            py_formal_args.append(get_py_formal_arg(arg))
 | 
						|
 | 
						|
    # Python function signature.
 | 
						|
    # This is the string that we give to FunctionParameter, which is
 | 
						|
    # then parsed into the actual structure which we do parsing
 | 
						|
    # with.
 | 
						|
    return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)
 |