Files
pytorch/tools/autograd/gen_python_functions.py
Sam Gross 0509f26d41 Speed-up nn.Linear for the 3d input case (#5279)
This adds at::_unsafe_view and uses it in matmul. The _unsafe_view
function is identical to view except that the output is not treated
like a view by the automatic differentiation code. This avoids in-place
modifications triggering the more expensive CopySlices/AsStridedBackward
behavior.

The _unsafe_view function is only safe to use on temporaries that will
be immediately discarded and that do not alias other tensors. Otherwise,
in-place modificatiions may trigger incorrect gradients. The funciton is
not exposed to Python.

See #5169
2018-02-19 19:47:20 -05:00

524 lines
20 KiB
Python

# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn object.
#
from collections import defaultdict
import re
from .nested_dict import nested_dict
from tools.shared.module_loader import import_module
from .gen_autograd import template_path
from .utils import write
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
# These functions require manual Python bindings or are not exposed to Python
SKIP_PYTHON_BINDINGS = [
'alias', 'contiguous', 'clamp.*', 'is_cuda', 'is_sparse', 'size', 'stride',
'.*_backward', '.*_backward_out', '.*_forward', '.*_forward_out',
'sparse_raw_resize_', '_unsafe_view',
]
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
${signatures}
});
${unpack_self}
PyObject* parsed_args[${max_args}];
auto r = parser.parse(args, kwargs, parsed_args);
${dispatch}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
${unpack_self}
return wrap(${dispatch_name}(${actuals}));
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_CASE = CodeTemplate("""\
${cond} (r.idx == ${i}) {
${call_dispatch}
""")
PY_VARIABLE_OUT = CodeTemplate("""\
if (r.isNone(${out_idx})) {
${call_dispatch}
} else {
${call_dispatch_out}
}
""")
PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
${dispatch_name}(${actuals})""")
PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
set_requires_grad(${call_dispatch}, ${requires_grad})""")
PY_VARIABLE_WRAP = CodeTemplate("""\
return wrap(${call_dispatch});""")
PY_VARIABLE_DISPATCH = CodeTemplate("""\
inline ${return_type} ${dispatch_name}(${formal_args}) {
${AutoNoGIL}
${AutoGPU}
return ${dispatch_call}(${dispatch_args});
}
""")
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;"
PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
${name}(${typed_args})""")
# XXX: if you got here because of an assertion failure, it doesn't mean
# it's enough to just extend the list here. Before you do this, make sure
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
'Tensor', 'std::tuple<Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
'std::vector<Tensor>',
'Scalar', 'bool', 'int64_t', 'void*'
}
def should_generate_python_binding(declaration):
name = declaration['name']
for pattern in SKIP_PYTHON_BINDINGS:
if re.match('^' + pattern + '$', name):
return False
# TODO: fix handling of SparseTensor. We don't want to generate Python
# bindings to SparseTensor overloads, such as add(Tensor, SparseTensor),
# since the Tensor-based signature already dynamically dispatches correctly.
# However, _sparse_mask only has a SparseTensor signature so we need to bind
# that function.
for arg in declaration['arguments']:
if arg['type'] == 'SparseTensor' and declaration['name'] != '_sparse_mask':
return False
return True
def gen_py_variable_methods(out, declarations):
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
'Tensor' in declaration['method_of'])
py_variable_methods = group_declarations_by_name(declarations, should_bind)
env = create_python_bindings(py_variable_methods, True)
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
def gen_py_nn_functions(out, declarations):
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] == 'NN')
py_nn_functions = group_declarations_by_name(declarations, should_bind)
env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
def gen_py_torch_functions(out, declarations):
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
('namespace' in declaration['method_of'] or
'Type' in declaration['method_of']))
py_torch_functions = group_declarations_by_name(declarations, should_bind)
env = create_python_bindings(py_torch_functions, has_self=False)
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)
def group_declarations_by_name(declarations, should_bind_fn):
"""Group declarations by name ignoring _out suffix"""
groups = defaultdict(list)
for declaration in declarations:
name = declaration['name']
if should_bind_fn(declaration):
if name.endswith('_out'):
groups[name[:-4]].append(declaration)
else:
groups[name].append(declaration)
return groups
def create_python_bindings(python_functions, has_self, is_module=False):
"""Generates Python bindings to ATen functions"""
py_methods = []
py_method_defs = []
py_method_dispatch = []
unpack_methods = {
'const Tensor &': 'tensor',
'SparseTensor': 'tensor',
'Tensor &': 'tensor',
'Generator *': 'generator',
'Storage &': 'storage',
'int64_t': 'toInt64',
'bool': 'toBool',
'double': 'toDouble',
}
unpack_with_default_methods = {
'IntList': 'setDefaultIntlist',
'Scalar': 'scalarWithDefault',
'int64_t': 'toInt64WithDefault',
'bool': 'setDefaultBool',
'double': 'setDefaultDouble',
}
def first_tensor_arg(arguments):
for arg in arguments:
if arg['simple_type'] in {'Tensor', 'TensorList'}:
return arg['name']
return None
def auto_gpu(option):
tensor_arg = first_tensor_arg(option['arguments'])
if tensor_arg is None:
return ''
return 'AutoGPU auto_gpu({});'.format(tensor_arg)
def emit_single_dispatch(declaration, out_idx, base_env):
env = {}
simple_return_type = declaration['return_type'].replace(' &', '')
assert simple_return_type in SUPPORTED_RETURN_TYPES, \
declaration['name'] + ' returns unsupported type: ' + simple_return_type
body = []
actuals = []
formal_args = []
arg_idx = 0
def is_output(arg):
return arg.get('output', False)
inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
outputs = [arg for arg in declaration['arguments'] if is_output(arg)]
def parse_arg(arg, arg_index, unpack_args=False):
name = arg['name']
typename = arg['type']
if typename.startswith('IntList['):
typename = 'IntList'
if typename.startswith('LongTensor'):
typename = 'Tensor'
if arg.get('python_default_init'):
assert typename in unpack_with_default_methods, \
'`{}` type is not supported in python_default_init'.format(typename)
unpack_with_default = unpack_with_default_methods.get(typename)
default_expr = arg.get('python_default_init')
expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
else:
unpack = unpack_methods.get(typename, typename.lower())
expr = 'r.{}({})'.format(unpack, arg_index)
if unpack_args:
body.append('auto {} = {};'.format(name, expr))
expr = name
if typename == 'Storage &':
expr = '*' + expr
if typename == 'SparseTensor':
expr = 'SparseTensor({})'.format(expr)
dispatch_type = typename
if dispatch_type == 'Tensor':
dispatch_type = 'const Tensor &'
elif dispatch_type == 'Tensor &':
dispatch_type = 'Tensor'
formal = '{} {}'.format(dispatch_type, name)
return expr, formal
def append_actuals_formals(actual, formal):
actuals.append(actual)
formal_args.append(formal)
unpack = any(arg.get('python_default_init') for arg in inputs)
for arg in inputs:
if has_self and arg['name'] == 'self':
formal_args.append('Tensor & self')
actuals.append('self_')
continue
append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
arg_idx += 1
if len(outputs) == 1:
append_actuals_formals(*parse_arg(outputs[0], arg_idx))
elif len(outputs) > 1:
N = len(outputs)
body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
for i, arg in enumerate(outputs):
formal_args.append('Tensor & {}'.format(arg['name']))
actuals.append('results[{}]'.format(i))
env['unpack_args'] = []
env['formal_args'] = formal_args
env['actuals'] = actuals
if 'call_args' in declaration:
env['dispatch_args'] = declaration['call_args']
else:
env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]
if 'Tensor' in declaration['method_of']:
env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
env['dispatch_call'] = 'self.{}'.format(declaration['name'])
elif 'namespace' in declaration['method_of']:
env['dispatch_call'] = 'at::{}'.format(declaration['name'])
else:
env['dispatch_call'] = 'default_type().{}'.format(declaration['name'])
env['AutoNoGIL'] = 'AutoNoGIL no_gil;'
env['AutoGPU'] = auto_gpu(declaration)
requires_grad = None
if len(declaration.get('python_binding_arguments', [])) > 1:
raise RuntimeError("found more than 1 entry in python_binding_arguments")
for arg in declaration.get('python_binding_arguments', []):
if not arg['name'] == 'requires_grad' or not arg['type'] == 'bool':
raise RuntimeError(("found {} in python_binding_arguments but only "
"bool requires_grad is supported".format(arg)))
# we have to use out_idx if there is an out variant because the base variant
# won't have the full arg_idx count
requires_grad_idx = arg_idx if out_idx is None else out_idx + 1
requires_grad = parse_arg(arg, requires_grad_idx)[0]
env = nested_dict(env, nested_dict(base_env, declaration))
call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
if requires_grad:
call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
requires_grad=requires_grad)
body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
return body
def emit_dispatch(i, dictionary, base_env):
if 'out' in dictionary:
out_idx = len([arg for arg in dictionary['out']['arguments']
if not arg.get('output', False)])
env = {}
env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)
body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
else:
body = emit_single_dispatch(dictionary['base'], None, base_env)
cond = 'if' if i == 0 else '} else if'
return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)
def get_requires_grad_argument(declaration):
requires_grad_arg = []
has_tensor_input_arg = False
for arg in declaration['arguments']:
if arg.get('output', False):
continue
typename = arg['simple_type']
if typename in ['Tensor', 'TensorList']:
has_tensor_input_arg = True
if arg['name'] == 'requires_grad':
raise ValueError("argument named requires_grad not supported")
has_tensor_return = False
for ret in declaration['returns']:
if ret['dynamic_type'] in ['Tensor', 'TensorList']:
# this probably won't work if one of the returns is not a tensor, but it will
# produce a compile-time error that is obvious
has_tensor_return = True
if (not has_tensor_input_arg or name.endswith('_like')) and has_tensor_return:
arg = {
'default': False,
'default_init': False,
'dynamic_type': 'bool',
'kwarg_only': True,
'name': 'requires_grad',
'type': 'bool',
'simple_type': 'bool',
}
requires_grad_arg.append(arg),
return requires_grad_arg
def process_function(name, declarations):
for declaration in declarations:
declaration['python_binding_arguments'] = get_requires_grad_argument(declaration)
env = {
'name': name,
'dispatch_name': 'dispatch_{}'.format(name),
'pycname': 'THPVariable_{}'.format(name),
'signatures': [],
'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
'unpack_self': [],
'dispatch': [],
}
if has_self:
env['unpack_self'] = [UNPACK_SELF]
grouped = group_declarations(declarations)
for i, dictionary in enumerate(grouped):
signature = dictionary['signature']
if has_self:
signature = signature.replace('Tensor self, ', '')
signature = signature.replace('Tensor self', '')
if not has_self:
# Use 'input' instead of 'self' for NN functions
signature = signature.replace('Tensor self', 'Tensor input')
signature = signature.replace('SparseTensor', 'Tensor')
if dictionary['base'].get('deprecated', False):
signature += '|deprecated'
env['signatures'].append('"{}",'.format(signature))
env['dispatch'].append(emit_dispatch(i, dictionary, env))
env['dispatch'].append('}')
if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
tmpl = PY_VARIABLE_METHOD_NOARGS
env['actuals'] = ['self_']
env['flags'] = 'METH_NOARGS'
else:
tmpl = PY_VARIABLE_METHOD_VARARGS
env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
if not is_module and not has_self:
env['flags'] += ' | METH_STATIC'
py_methods.append(tmpl.substitute(env))
py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
for name in sorted(python_functions.keys()):
process_function(name, python_functions[name])
return {
'py_methods': py_methods,
'py_method_defs': py_method_defs,
'py_method_dispatch': py_method_dispatch,
}
def group_declarations(declarations):
"""Returns a list of dictionaries containing the optional keys:
"base": the regular ATen declaration (e.g. conv2d)
"out": the out variant (e.g. conv2d_out)
"signature": the signature used for Python argument parsing
"""
grouped = defaultdict(dict)
# first group by signature ignoring out arguments
for declaration in declarations:
signature = get_python_signature(declaration, False)
v = grouped[signature]
if declaration['name'].endswith('_out'):
v['out'] = declaration
# prefer the signature with optional out=... arguments
v['signature'] = get_python_signature(declaration, True)
else:
v['base'] = declaration
if 'signature' not in v:
v['signature'] = signature
result = []
for _, dictionary in sorted(grouped.items()):
assert 'base' in dictionary
result.append(dictionary)
return result
def get_python_signature(declaration, include_out):
# Compute the Python function signature for argument parsing
typed_args = []
output_args = []
positional = True
def get_typed_arg(arg):
typename = arg['simple_type']
if arg.get('is_nullable'):
typename = '{}?'.format(typename)
if arg.get('size') is not None:
typename = '{}[{}]'.format(typename, arg['size'])
param = typename + ' ' + arg['name']
default = None
if arg.get('default') is not None:
default = arg['default']
if default == 'nullptr' or default == '{}':
default = 'None'
if arg.get('python_default_init') is not None:
default = 'None'
if default is not None:
param += '=' + str(default)
return param
for arg in declaration['arguments']:
if arg.get('output', False):
output_args.append(arg)
continue
if arg.get('kwarg_only', False) and positional:
typed_args.append('*')
positional = False
param = get_typed_arg(arg)
typed_args.append(param)
# add output arguments
name = declaration['name']
if name.endswith('_out'):
name = name[:-4]
if len(output_args) > 0 and include_out:
assert declaration['name'].endswith('_out')
if positional:
typed_args.append('*')
positional = False
typenames = [arg['simple_type'] for arg in output_args]
if len(typenames) > 1:
typename = 'TensorList[{}]'.format(len(typenames))
else:
typename = typenames[0]
typed_args.append(typename + ' out=None')
# we could put this in the loop above but we want to ensure it is after the out argument
if len(declaration['python_binding_arguments']) > 0:
for arg in declaration['python_binding_arguments']:
if arg.get('kwarg_only', False) and positional:
typed_args.append('*')
positional = False
typed_args.append(get_typed_arg(arg))
# Python function signature.
# This is the string that we give to FunctionParameter, which is
# then parsed into the actual structure which we do parsing
# with.
return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, typed_args=typed_args)