mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use Declarations.yaml to generate python bindings
This commit is contained in:
committed by
Soumith Chintala
parent
558d26a69e
commit
f29bcab67e
@ -50,38 +50,41 @@ UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;"
|
||||
|
||||
|
||||
def create_python_bindings(
|
||||
python_functions, py_methods, py_method_defs, py_method_dispatch,
|
||||
is_static):
|
||||
python_functions, py_methods, py_method_defs, py_method_dispatch):
|
||||
"""python_variable_methods.cpp
|
||||
|
||||
Generates Python bindings to Variable methods
|
||||
"""
|
||||
|
||||
unpack_methods = {
|
||||
'const Tensor &': 'tensor',
|
||||
'Generator *': 'generator',
|
||||
'Storage &': 'storage',
|
||||
'int64_t': 'toInt64',
|
||||
'bool': 'toBool'
|
||||
'bool': 'toBool',
|
||||
'double': 'toDouble',
|
||||
}
|
||||
|
||||
def first_tensor_arg(arguments):
|
||||
for arg in arguments:
|
||||
if arg['type'] in {'Tensor', 'TensorList'}:
|
||||
if arg['simple_type'] in {'Tensor', 'TensorList'}:
|
||||
return arg['name']
|
||||
return None
|
||||
|
||||
def auto_gpu(option):
|
||||
tensor_arg = first_tensor_arg(option['python_arguments'])
|
||||
tensor_arg = first_tensor_arg(option['arguments'])
|
||||
if tensor_arg is None:
|
||||
return ''
|
||||
return 'AutoGPU auto_gpu({});'.format(tensor_arg)
|
||||
|
||||
def emit_dispatch(i, function, has_self):
|
||||
def emit_dispatch(i, function):
|
||||
env = {}
|
||||
|
||||
actuals = []
|
||||
formal_args = []
|
||||
arg_idx = 0
|
||||
for arg in function['python_arguments']:
|
||||
if arg['name'] == 'self' and not is_static:
|
||||
for arg in function['arguments']:
|
||||
if 'Tensor' in function['method_of'] and arg['name'] == 'self':
|
||||
formal_args.append('Tensor & {}'.format(arg['name']))
|
||||
actuals.append('self_')
|
||||
continue
|
||||
@ -102,8 +105,11 @@ def create_python_bindings(
|
||||
env['i'] = i
|
||||
env['actuals'] = actuals
|
||||
env['formal_args'] = formal_args
|
||||
env['dispatch_args'] = function['call_args']
|
||||
if not is_static and has_self:
|
||||
if 'call_args' in function:
|
||||
env['dispatch_args'] = function['call_args']
|
||||
else:
|
||||
env['dispatch_args'] = [arg['name'] for arg in function['arguments']]
|
||||
if 'Tensor' in function['method_of']:
|
||||
env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
|
||||
env['dispatch_call'] = 'self.{}'.format(function['name'])
|
||||
else:
|
||||
@ -121,20 +127,18 @@ def create_python_bindings(
|
||||
'dispatch_name': 'dispatch_{}'.format(name),
|
||||
'pycname': 'THPVariable_{}'.format(name),
|
||||
'prototypes': [],
|
||||
'max_args': max(len(o['python_arguments']) for o in functions),
|
||||
'max_args': max(len(o['arguments']) for o in functions),
|
||||
'unpack_self': [],
|
||||
'dispatch': [],
|
||||
}
|
||||
|
||||
has_self = 'self' in functions[0]['args']
|
||||
if has_self:
|
||||
is_method = 'Tensor' in functions[0]['method_of']
|
||||
if is_method:
|
||||
env['unpack_self'] = [UNPACK_SELF]
|
||||
|
||||
for o in functions:
|
||||
prototype = o['prototype']
|
||||
if o['inplace']:
|
||||
prototype = prototype.replace('(', '_(')
|
||||
if not is_static:
|
||||
if is_method:
|
||||
prototype = prototype.replace('Tensor self, ', '')
|
||||
prototype = prototype.replace('Tensor self', '')
|
||||
if 'deprecated' in o:
|
||||
@ -142,10 +146,10 @@ def create_python_bindings(
|
||||
env['prototypes'].append('"{}",'.format(prototype))
|
||||
|
||||
for i, option in enumerate(functions):
|
||||
env['dispatch'].append(emit_dispatch(i, nested_dict(env, option), has_self))
|
||||
env['dispatch'].append(emit_dispatch(i, nested_dict(env, option)))
|
||||
env['dispatch'].append('}')
|
||||
|
||||
if len(functions) == 1 and len(functions[0]['args']) == 1 and not is_static:
|
||||
if len(functions) == 1 and len(functions[0]['args']) == 1 and is_method:
|
||||
tmpl = PY_VARIABLE_METHOD_NOARGS
|
||||
env['actuals'] = ['self_']
|
||||
env['flags'] = 'METH_NOARGS'
|
||||
@ -153,6 +157,9 @@ def create_python_bindings(
|
||||
tmpl = PY_VARIABLE_METHOD_VARARGS
|
||||
env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
|
||||
|
||||
if not is_method:
|
||||
env['flags'] += ' | METH_STATIC'
|
||||
|
||||
py_methods.append(tmpl.substitute(env))
|
||||
py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
|
||||
|
||||
|
Reference in New Issue
Block a user