mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: We have: - This is an initial stab at creating a type stub `torch/__init__.pyi` . - This is only tested on Python 3, since that's the only Python version mypy works on. - So far, we only aim at doing this for torch functions and torch.Tensor. - Quite a few methods and functions have to be typed manually. These are done in `torch/__init__.pyi.in` For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out. An example of a generated PYI is at [this gist](https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1). Pull Request resolved: https://github.com/pytorch/pytorch/pull/12500 Differential Revision: D13695553 Pulled By: ezyang fbshipit-source-id: 4566c71913ede4e4c23ebc4a72c17151f94e8e21
919 lines
37 KiB
Python
919 lines
37 KiB
Python
# Generates Python bindings for ATen functions
|
|
#
|
|
# The bindings are generated as methods on python_variable or functions on the
|
|
# torch._C._nn object.
|
|
#
|
|
from collections import defaultdict
|
|
import re
|
|
from .nested_dict import nested_dict
|
|
from .gen_variable_type import should_trace
|
|
from .utils import write
|
|
|
|
try:
|
|
from src.ATen.code_template import CodeTemplate
|
|
except ImportError:
|
|
from tools.shared.module_loader import import_module
|
|
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
|
|
|
# These functions require manual Python bindings or are not exposed to Python
|
|
SKIP_PYTHON_BINDINGS = [
|
|
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
|
|
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
|
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
|
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
|
|
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
|
|
'index',
|
|
'_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
|
|
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
|
|
'_th_.*', '_thnn_.*',
|
|
'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
|
|
'_potrs.*', '_cholesky.*',
|
|
'slice', 'randint(_out)?',
|
|
'item', '_local_scalar_dense',
|
|
'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
|
|
'copy_sparse_to_sparse_',
|
|
]
|
|
|
|
# These function signatures are not exposed to Python. Note that this signature
|
|
# list does not support regex.
|
|
SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
|
'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
|
|
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
|
|
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
|
|
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
|
|
]
|
|
|
|
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
${unpack_self}
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto r = parser.parse(args, kwargs, parsed_args);
|
|
${declare_namedtuple_return_types}
|
|
${dispatch}
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
""")
|
|
|
|
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
${declare_namedtuple_return_types}
|
|
${unpack_self}
|
|
return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
""")
|
|
|
|
PY_VARIABLE_CASE = CodeTemplate("""\
|
|
${cond} (r.idx == ${i}) {
|
|
${call_dispatch}
|
|
""")
|
|
|
|
PY_VARIABLE_OUT = CodeTemplate("""\
|
|
if (r.isNone(${out_idx})) {
|
|
${call_dispatch}
|
|
} else {
|
|
${call_dispatch_out}
|
|
}
|
|
""")
|
|
|
|
PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
|
|
if (r.isNone(${out_idx})) {
|
|
${call_dispatch}
|
|
} else {
|
|
check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
|
|
r.layout(${layout_idx}), r.isNone(${layout_idx}),
|
|
r.device(${device_idx}), r.isNone(${device_idx}));
|
|
${call_dispatch_out}
|
|
}
|
|
""")
|
|
|
|
PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
|
|
${dispatch_name}(${actuals})""")
|
|
|
|
PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
|
|
${call_dispatch}.set_requires_grad(${requires_grad})""")
|
|
|
|
PY_VARIABLE_WRAP = CodeTemplate("""\
|
|
return wrap(${namedtuple_return_type}${call_dispatch});""")
|
|
|
|
PY_VARIABLE_DISPATCH = CodeTemplate("""\
|
|
inline ${simple_return_type} ${dispatch_name}(${formal_args}) {
|
|
${initialize_cuda}
|
|
${AutoNoGIL}
|
|
return ${dispatch_call}(${dispatch_args});
|
|
}
|
|
""")
|
|
|
|
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
|
|
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
|
|
|
|
PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\
|
|
static PyStructSequence_Field fields${namedtuple_type_index}[] = {
|
|
${namedtuple_fields} {nullptr}
|
|
};
|
|
static PyStructSequence_Desc desc${namedtuple_type_index} = {
|
|
"torch.return_types.${name}", nullptr,
|
|
fields${namedtuple_type_index}, ${namedtuple_size}
|
|
};
|
|
static PyTypeObject type${namedtuple_type_index};
|
|
static bool namedtuple_type_initialized${namedtuple_type_index} = false;
|
|
if (!namedtuple_type_initialized${namedtuple_type_index}) {
|
|
PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
|
|
namedtuple_type_initialized${namedtuple_type_index} = true;
|
|
}
|
|
""")
|
|
|
|
UNPACK_SELF = "auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
|
|
|
|
PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
|
|
${name}(${py_formal_args})""")
|
|
|
|
# XXX: if you got here because of an assertion failure, it doesn't mean
|
|
# it's enough to just extend the list here. Before you do this, make sure
|
|
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
|
|
SUPPORTED_RETURN_TYPES = {
|
|
'Tensor', 'std::tuple<Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,double,int64_t>',
|
|
'std::tuple<Tensor,Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
|
|
'std::vector<Tensor>',
|
|
'Scalar', 'bool', 'int64_t', 'void*', 'void'
|
|
}
|
|
|
|
TENSOR_OPTIONS = CodeTemplate("""\
|
|
const auto options = TensorOptions()
|
|
.dtype(${dtype})
|
|
.device(${device})
|
|
.layout(${layout}.layout)
|
|
.requires_grad(${requires_grad});
|
|
""")
|
|
|
|
|
|
def should_generate_python_binding(declaration):
|
|
name = declaration['name']
|
|
for pattern in SKIP_PYTHON_BINDINGS:
|
|
if re.match('^' + pattern + '$', name):
|
|
return False
|
|
|
|
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
|
|
signature = '{}({})'.format(name, ', '.join(simple_types))
|
|
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
|
if pattern == signature:
|
|
return False
|
|
|
|
# TODO: fix handling of SparseTensor. We don't want to generate Python
|
|
# bindings to SparseTensor overloads, such as add(Tensor, SparseTensorRef),
|
|
# since the Tensor-based signature already dynamically dispatches correctly.
|
|
# However, sparse_mask only has a SparseTensor signature so we need to bind
|
|
# that function.
|
|
for arg in declaration['arguments']:
|
|
if arg['type'] == 'SparseTensorRef' and declaration['name'] != 'sparse_mask':
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_py_variable_methods(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as methods on Tensor.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
declaration['mode'] != 'NN' and
|
|
declaration.get('python_module') != 'nn' and
|
|
'Tensor' in declaration['method_of'])
|
|
|
|
return group_declarations_by_name(declarations, should_bind)
|
|
|
|
|
|
def gen_py_variable_methods(out, declarations, template_path):
|
|
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
|
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
|
|
|
|
py_variable_methods = get_py_variable_methods(declarations)
|
|
|
|
env = create_python_bindings(py_variable_methods, True)
|
|
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
|
|
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
|
|
|
|
|
|
def get_py_nn_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "nn" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
(declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
|
|
|
|
return group_declarations_by_name(declarations, should_bind)
|
|
|
|
|
|
def gen_py_nn_functions(out, declarations, template_path):
|
|
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
|
|
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
|
|
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
|
|
|
|
py_nn_functions = get_py_nn_functions(declarations)
|
|
|
|
env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
|
|
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
|
|
write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
|
|
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
|
|
|
|
|
|
def get_py_torch_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "torch" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
declaration['mode'] != 'NN' and
|
|
declaration.get('python_module') != 'nn' and
|
|
'namespace' in declaration['method_of'])
|
|
|
|
return group_declarations_by_name(declarations, should_bind)
|
|
|
|
|
|
def gen_py_torch_functions(out, declarations, template_path):
|
|
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
|
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
|
|
|
|
py_torch_functions = get_py_torch_functions(declarations)
|
|
|
|
env = create_python_bindings(py_torch_functions, has_self=False)
|
|
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
|
|
write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)
|
|
|
|
|
|
def group_declarations_by_name(declarations, should_bind_fn):
|
|
"""Group declarations by name ignoring _out suffix"""
|
|
groups = defaultdict(list)
|
|
for declaration in declarations:
|
|
name = declaration['name']
|
|
if should_bind_fn(declaration):
|
|
if name.endswith('_out'):
|
|
groups[name[:-4]].append(declaration)
|
|
else:
|
|
groups[name].append(declaration)
|
|
return groups
|
|
|
|
|
|
def get_type_default(declaration):
|
|
if declaration['name'].startswith('randperm') or \
|
|
declaration['name'] == 'tril_indices' or \
|
|
declaration['name'] == 'triu_indices':
|
|
return 'torch.int64'
|
|
else:
|
|
return 'None'
|
|
|
|
|
|
def create_python_bindings(python_functions, has_self, is_module=False):
|
|
"""Generates Python bindings to ATen functions"""
|
|
py_methods = []
|
|
py_method_defs = []
|
|
py_method_dispatch = []
|
|
|
|
unpack_methods = {
|
|
'const Tensor &': 'tensor',
|
|
'SparseTensorRef': 'tensor',
|
|
'Tensor &': 'tensor',
|
|
'Generator *': 'generator',
|
|
'Storage &': 'storage',
|
|
'const Type &': 'scalartype',
|
|
'const THPLayout &': 'layout',
|
|
'const Device &': 'device',
|
|
'c10::optional<ScalarType>': 'scalartypeOptional',
|
|
'c10::optional<Scalar>': 'scalarOptional',
|
|
'c10::optional<int64_t>': 'toInt64Optional',
|
|
'int64_t': 'toInt64',
|
|
'bool': 'toBool',
|
|
'double': 'toDouble',
|
|
'std::string': 'string',
|
|
}
|
|
|
|
unpack_with_default_methods = {
|
|
'IntList': 'setDefaultIntlist',
|
|
'Scalar': 'scalarWithDefault',
|
|
'int64_t': 'toInt64WithDefault',
|
|
'bool': 'setDefaultBool',
|
|
'double': 'setDefaultDouble',
|
|
'const Type &': 'scalartypeWithDefault',
|
|
'const THPLayout &': 'layoutWithDefault',
|
|
'const Device &': 'deviceWithDefault',
|
|
'ScalarType': 'scalartypeWithDefault',
|
|
}
|
|
|
|
def emit_single_dispatch(declaration, out_idx, base_env):
|
|
env = {}
|
|
simple_return_type = declaration['return_type'].replace(' &', '')
|
|
assert simple_return_type in SUPPORTED_RETURN_TYPES, \
|
|
declaration['name'] + ' returns unsupported type: ' + simple_return_type
|
|
|
|
body = []
|
|
actuals = []
|
|
formal_args = []
|
|
arg_idx = 0
|
|
|
|
def is_output(arg):
|
|
return arg.get('output', False)
|
|
|
|
inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
|
|
outputs = [arg for arg in declaration['arguments'] if is_output(arg)]
|
|
|
|
has_tensor_options = any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments'])
|
|
|
|
def get_type_args(args):
|
|
return [arg for arg in args if arg['simple_type'] == 'Type']
|
|
type_actual_args = get_type_args(declaration['arguments'])
|
|
type_binding_args = get_type_args(declaration['python_binding_arguments'])
|
|
assert len(type_actual_args + type_binding_args) <= 1
|
|
if type_binding_args and len(outputs) == 0:
|
|
# out(s) determines the dtype if it is present, so only use this if there are no outputs.
|
|
type_args = type_binding_args
|
|
else:
|
|
type_args = type_actual_args
|
|
|
|
if type_args and len(outputs) > 1:
|
|
raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")
|
|
|
|
def parse_arg(arg, arg_index, unpack_args=False):
|
|
name = arg['name']
|
|
typename = arg['type']
|
|
if typename.startswith('IntList['):
|
|
typename = 'IntList'
|
|
if typename.startswith('LongTensor'):
|
|
typename = 'Tensor'
|
|
|
|
if arg.get('python_default_init'):
|
|
assert typename in unpack_with_default_methods, \
|
|
'`{}` type is not supported in python_default_init'.format(typename)
|
|
unpack_with_default = unpack_with_default_methods.get(typename)
|
|
default_expr = arg.get('python_default_init')
|
|
# TODO: Type currently maps to ScalarType, figure out a cleaner solution
|
|
if typename == 'const Type &':
|
|
default_expr += '.scalarType()'
|
|
expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
|
|
else:
|
|
unpack = unpack_methods.get(typename, typename.lower())
|
|
expr = 'r.{}({})'.format(unpack, arg_index)
|
|
|
|
if unpack_args:
|
|
body.append('auto {} = {};'.format(name, expr))
|
|
expr = name
|
|
|
|
if typename == 'SparseTensorRef':
|
|
expr = 'SparseTensorRef({})'.format(expr)
|
|
|
|
dispatch_type = typename
|
|
if dispatch_type == 'Tensor':
|
|
dispatch_type = 'const Tensor &'
|
|
elif dispatch_type == 'Tensor &':
|
|
dispatch_type = 'Tensor'
|
|
elif dispatch_type == 'const Device &':
|
|
dispatch_type = 'c10::optional<int32_t>'
|
|
formal = '{} {}'.format(dispatch_type, name)
|
|
return expr, formal
|
|
|
|
def append_actuals_formals(actual, formal):
|
|
actuals.append(actual)
|
|
formal_args.append(formal)
|
|
|
|
# We always want to unpack when we have TensorOptions.
|
|
unpack = has_tensor_options
|
|
for arg in inputs:
|
|
if arg['simple_type'] in ['Type', 'TensorOptions']:
|
|
continue
|
|
if has_self and arg['name'] == 'self':
|
|
formal_args.append('Tensor & self')
|
|
actuals.append('self')
|
|
continue
|
|
append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
|
|
arg_idx += 1
|
|
|
|
if len(outputs) == 1:
|
|
append_actuals_formals(*parse_arg(outputs[0], arg_idx))
|
|
elif len(outputs) > 1:
|
|
N = len(outputs)
|
|
body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
|
|
for i, arg in enumerate(outputs):
|
|
formal_args.append('Tensor & {}'.format(arg['name']))
|
|
actuals.append('results[{}]'.format(i))
|
|
|
|
layout = None
|
|
parsed_type_args = None
|
|
# type args go after the outputs to match the signature generation.
|
|
arg_idx = arg_idx if out_idx is None else out_idx + 1
|
|
for arg in type_args:
|
|
parsed_type_args = parse_arg(arg, arg_idx, unpack)
|
|
arg_idx += 1
|
|
|
|
# check python_binding_arguments
|
|
has_device_bind = False
|
|
requires_grad = None
|
|
python_binding_arguments = declaration.get('python_binding_arguments', [])
|
|
if 'dtype' in (a['name'] for a in python_binding_arguments):
|
|
if not has_tensor_options:
|
|
arg_idx += 1
|
|
|
|
if 'layout' in (a['name'] for a in python_binding_arguments):
|
|
layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
|
|
else:
|
|
device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
|
|
|
|
device = None
|
|
for arg in python_binding_arguments:
|
|
if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
|
|
pass # already handled by type_dispatched_args
|
|
elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
|
|
# out(s) determines the type and layout if it is present, so only use this if there are no outputs.
|
|
if len(outputs) == 0:
|
|
layout = parse_arg(arg, layout_idx)[0]
|
|
elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
|
|
if len(outputs) == 0:
|
|
assert parsed_type_args
|
|
assert layout
|
|
device, device_type = parse_arg(arg, device_idx, True)
|
|
|
|
if not has_tensor_options:
|
|
# add type, device formals and corresponding actuals.
|
|
# The type actual is the ATen type mapped from (ScalarType, Layout, Device)
|
|
# The device actual is the corresponding AutoGPU index for the Device.
|
|
formal_args.append(parsed_type_args[1])
|
|
formal_args.append(device_type)
|
|
actuals.append("torch::getVariableType({}, {}, {})".format(parsed_type_args[0], layout, device))
|
|
actuals.append('{}.index()'.format(device))
|
|
|
|
has_device_bind = True
|
|
elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
|
|
requires_grad = parse_arg(arg, requires_grad_idx)[0]
|
|
else:
|
|
raise RuntimeError(("found {} in python_binding_arguments but only "
|
|
"\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
|
|
"\"Device device\" are supported".format(arg)))
|
|
|
|
dtype = parsed_type_args[0] if parsed_type_args else None
|
|
if has_tensor_options and all([dtype, device, layout, requires_grad]):
|
|
body.append(TENSOR_OPTIONS.substitute({
|
|
'dtype': dtype,
|
|
'layout': layout,
|
|
'device': device,
|
|
'requires_grad': requires_grad
|
|
}))
|
|
formal_args.append('const TensorOptions & options')
|
|
actuals.append('options')
|
|
|
|
env['unpack_args'] = []
|
|
env['formal_args'] = formal_args
|
|
env['actuals'] = actuals
|
|
|
|
if has_tensor_options:
|
|
env['initialize_cuda'] = 'maybe_initialize_cuda(options);'
|
|
else:
|
|
env['initialize_cuda'] = ''
|
|
|
|
if 'call_args' in declaration:
|
|
env['dispatch_args'] = declaration['call_args']
|
|
else:
|
|
env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]
|
|
|
|
if 'Tensor' in declaration['method_of']:
|
|
env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
|
|
env['dispatch_call'] = 'self.{}'.format(declaration['name'])
|
|
elif 'namespace' in declaration['method_of']:
|
|
namespace = 'torch' if (has_tensor_options or declaration['name'].endswith('_like')) else 'at'
|
|
env['dispatch_call'] = '{}::{}'.format(namespace, declaration['name'])
|
|
else:
|
|
raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')
|
|
|
|
env['AutoNoGIL'] = 'AutoNoGIL no_gil;' if not declaration['with_gil'] else ''
|
|
|
|
# Use the simple_return_type (Tensor) rather than the fancy return type
|
|
# (Tensor &). This is important because the dispatch functions take
|
|
# mutable arguments *by value*, not by reference. If you then return
|
|
# a a reference to such an argument, you will now have a pointer to a
|
|
# dangling stack entry. Not good.
|
|
#
|
|
# You want:
|
|
#
|
|
# Tensor dispatch_selu_(Tensor self) { return at::selu_(self); }
|
|
#
|
|
# *not*
|
|
#
|
|
# Tensor& dispatch_selu_(Tensor self) { return at::selu_(self); }
|
|
#
|
|
# (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
|
|
# codegen looks like dispatch_selu_(wrap(tensor)), and you can't take a
|
|
# mutable reference to temporary. Maybe we could assign it to a
|
|
# variable itself.)
|
|
env['simple_return_type'] = simple_return_type
|
|
|
|
env = nested_dict(env, nested_dict(base_env, declaration))
|
|
call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
|
|
if requires_grad and not has_tensor_options:
|
|
call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
|
|
requires_grad=requires_grad)
|
|
if simple_return_type == 'void':
|
|
body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
|
|
body.append('Py_RETURN_NONE;')
|
|
else:
|
|
body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
|
|
py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
|
|
return body
|
|
|
|
def emit_dispatch(i, dictionary, base_env):
|
|
if 'out' in dictionary:
|
|
out_idx = len([arg for arg in dictionary['out']['arguments']
|
|
if not arg.get('output', False)])
|
|
env = {}
|
|
env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
|
|
env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)
|
|
|
|
has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
|
|
if has_dtype_bind:
|
|
body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
|
|
layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
|
|
else:
|
|
body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
|
|
else:
|
|
body = emit_single_dispatch(dictionary['base'], None, base_env)
|
|
|
|
cond = 'if' if i == 0 else '} else if'
|
|
return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)
|
|
|
|
def get_python_binding_arguments(declaration):
|
|
python_binding_arguments = []
|
|
has_tensor_input_arg = False
|
|
has_type_input_arg = False
|
|
has_options_arg = False
|
|
for arg in declaration['arguments']:
|
|
if arg.get('output', False):
|
|
continue
|
|
typename = arg['simple_type']
|
|
if typename in ['Tensor', 'TensorList']:
|
|
has_tensor_input_arg = True
|
|
if arg['simple_type'] == 'Type':
|
|
has_type_input_arg = True
|
|
elif arg['simple_type'] == 'TensorOptions':
|
|
has_options_arg = True
|
|
if arg['name'] == 'requires_grad':
|
|
raise ValueError("argument named requires_grad not supported")
|
|
|
|
has_tensor_return = False
|
|
for ret in declaration['returns']:
|
|
if ret['dynamic_type'] in ['Tensor', 'TensorList']:
|
|
# this probably won't work if one of the returns is not a tensor, but it will
|
|
# produce a compile-time error that is obvious
|
|
has_tensor_return = True
|
|
|
|
is_like_function = name.endswith('_like')
|
|
is_like_function_with_options = is_like_function and has_options_arg
|
|
is_factory_function = has_tensor_return and not has_tensor_input_arg
|
|
is_factory_or_like_function = has_tensor_return and (not has_tensor_input_arg or is_like_function)
|
|
|
|
if (is_factory_function and not has_type_input_arg) or has_options_arg:
|
|
default_type = get_type_default(declaration)
|
|
py_default_dtype = 'self.type()' if is_like_function_with_options else None
|
|
dtype_arg = {
|
|
'default': default_type,
|
|
'dynamic_type': 'Type',
|
|
'kwarg_only': True,
|
|
'name': 'dtype',
|
|
'type': 'const Type &',
|
|
'simple_type': 'Type',
|
|
'python_default_init': py_default_dtype,
|
|
}
|
|
python_binding_arguments.append(dtype_arg)
|
|
if is_factory_function or is_like_function_with_options:
|
|
py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_function_with_options else None
|
|
layout_arg = {
|
|
'default': 'torch.strided',
|
|
'dynamic_type': 'Layout',
|
|
'kwarg_only': True,
|
|
'name': 'layout',
|
|
'type': 'const THPLayout &',
|
|
'simple_type': 'Layout',
|
|
'python_default_init': py_default_layout,
|
|
}
|
|
python_binding_arguments.append(layout_arg)
|
|
py_default_device = 'self.device()' if is_like_function_with_options else None
|
|
device_arg = {
|
|
'default': 'None',
|
|
'default_init': 'None',
|
|
'dynamic_type': 'Device',
|
|
'kwarg_only': True,
|
|
'name': 'device',
|
|
'type': 'const Device &',
|
|
'simple_type': 'Device',
|
|
'python_default_init': py_default_device
|
|
}
|
|
python_binding_arguments.append(device_arg)
|
|
if is_factory_or_like_function:
|
|
requires_grad_arg = {
|
|
'default': False,
|
|
'dynamic_type': 'bool',
|
|
'kwarg_only': True,
|
|
'name': 'requires_grad',
|
|
'type': 'bool',
|
|
'simple_type': 'bool',
|
|
}
|
|
python_binding_arguments.append(requires_grad_arg)
|
|
return python_binding_arguments
|
|
|
|
def emit_namedtuple_return_type_def(declaration, next_index):
|
|
returns = declaration['returns']
|
|
if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
|
|
declaration['namedtuple_return_type'] = ''
|
|
return '', next_index
|
|
declaration['namedtuple_type_index'] = next_index
|
|
declaration['namedtuple_fields'] = ''
|
|
for x in returns:
|
|
# See Note [field_name versus name]
|
|
if 'field_name' not in x:
|
|
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
|
# resolved by the linker for some reason, which cause error in building:
|
|
#
|
|
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
|
|
# PyStructSequence_UnnamedField
|
|
#
|
|
# Thus, at this point in time, we do not support unnamed
|
|
# fields in namedtuple; you must either name all fields,
|
|
# or none of them.
|
|
raise ValueError("Unnamed field is not supported by codegen")
|
|
else:
|
|
declaration['namedtuple_fields'] += '{"' + x['field_name'] + '", ""}, '
|
|
declaration['namedtuple_size'] = len(returns)
|
|
declaration['namedtuple_return_type'] = '&type{}, '.format(next_index)
|
|
return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1
|
|
|
|
def process_function(name, declarations):
|
|
for declaration in declarations:
|
|
declaration['python_binding_arguments'] = get_python_binding_arguments(declaration)
|
|
|
|
env = {
|
|
'name': name,
|
|
'dispatch_name': 'dispatch_{}'.format(name),
|
|
'pycname': 'THPVariable_{}'.format(name),
|
|
'signatures': [],
|
|
'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
|
|
'unpack_self': [],
|
|
'dispatch': [],
|
|
'declare_namedtuple_return_types': '',
|
|
}
|
|
|
|
if has_self:
|
|
env['unpack_self'] = [UNPACK_SELF]
|
|
|
|
# generate namedtuple type declare
|
|
next_index = 0
|
|
for declaration in declarations:
|
|
typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index)
|
|
env['declare_namedtuple_return_types'] += typedef
|
|
|
|
# emit dispatch
|
|
grouped = group_declarations(declarations)
|
|
for i, dictionary in enumerate(grouped):
|
|
signature = dictionary['signature']
|
|
if has_self:
|
|
signature = signature.replace('Tensor self, ', '')
|
|
signature = signature.replace('Tensor self', '')
|
|
if not has_self:
|
|
# Use 'input' instead of 'self' for NN functions
|
|
signature = signature.replace('Tensor self', 'Tensor input')
|
|
signature = signature.replace('SparseTensorRef', 'Tensor')
|
|
if dictionary['base'].get('deprecated', False):
|
|
signature += '|deprecated'
|
|
env['signatures'].append('"{}",'.format(signature))
|
|
env['dispatch'].append(emit_dispatch(i, dictionary, env))
|
|
|
|
env['dispatch'].append('}')
|
|
|
|
env['traceable'] = 'true' if all(should_trace(d) for d in declarations) else 'false'
|
|
|
|
if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
|
|
tmpl = PY_VARIABLE_METHOD_NOARGS
|
|
env['actuals'] = ['self']
|
|
env['flags'] = 'METH_NOARGS'
|
|
env['namedtuple_return_type'] = declarations[0]['namedtuple_return_type']
|
|
else:
|
|
tmpl = PY_VARIABLE_METHOD_VARARGS
|
|
env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
|
|
|
|
if not is_module and not has_self:
|
|
env['flags'] += ' | METH_STATIC'
|
|
|
|
py_methods.append(tmpl.substitute(env))
|
|
py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
|
|
|
|
for name in sorted(python_functions.keys()):
|
|
process_function(name, python_functions[name])
|
|
|
|
return {
|
|
'py_methods': py_methods,
|
|
'py_method_defs': py_method_defs,
|
|
'py_method_dispatch': py_method_dispatch,
|
|
}
|
|
|
|
|
|
def group_declarations(declarations):
|
|
"""Returns a list of dictionaries containing the optional keys:
|
|
|
|
"base": the regular ATen declaration (e.g. conv2d)
|
|
"out": the out variant (e.g. conv2d_out)
|
|
"signature": the signature used for Python argument parsing
|
|
"""
|
|
grouped = defaultdict(dict)
|
|
|
|
# first group by signature ignoring out arguments
|
|
for declaration in declarations:
|
|
signature = get_python_signature(declaration, False)
|
|
v = grouped[signature]
|
|
if declaration['name'].endswith('_out'):
|
|
v['out'] = declaration
|
|
# prefer the signature with optional out=... arguments
|
|
v['signature'] = get_python_signature(declaration, True)
|
|
else:
|
|
v['base'] = declaration
|
|
if 'signature' not in v:
|
|
v['signature'] = signature
|
|
|
|
result = []
|
|
for _, dictionary in sorted(grouped.items()):
|
|
if 'base' not in dictionary:
|
|
raise RuntimeError("'base' not in dictionary", dictionary)
|
|
result.append(dictionary)
|
|
return sort_declarations(result)
|
|
|
|
|
|
# This function declares a partial order on declarations, and sorts them according
|
|
# to its linear extension. This is necessary, because there's some ambiguity in the
|
|
# choice of overload, and we want a different order.
|
|
#
|
|
# See Note[Order of overloads matters]
|
|
def sort_declarations(grouped_decls):
|
|
|
|
# TODO: This is a hack!
|
|
#
|
|
# For some reason, when you specify a Scalar argument in a native
|
|
# function, you get a Declarations.yaml entry that looks like this:
|
|
#
|
|
# - default: 1
|
|
# dynamic_type: Scalar
|
|
# is_nullable: false
|
|
# kwarg_only: true
|
|
# name: alpha
|
|
# type: Scalar
|
|
#
|
|
# This is contrast to when there is a 'real' argument in TH
|
|
# Declarations.cwrap; this gets (correctly?) translated into
|
|
# dynamic_type: real, and type: Scalar. I would like to fix this
|
|
# at the source but I have never understood what dynamic_type is
|
|
# supposed to be.
|
|
def normalized_dynamic_type(arg):
|
|
if arg['dynamic_type'] == 'real':
|
|
return 'Scalar'
|
|
return arg['dynamic_type']
|
|
|
|
def is_coord_smaller(arg1, arg2):
|
|
return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'
|
|
|
|
def is_smaller(d1, d2):
|
|
"""Returns True if d1 < d2 in the partial order."""
|
|
args1, args2 = d1['base']['arguments'], d2['base']['arguments']
|
|
if len(args1) != len(args2):
|
|
return False
|
|
any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
|
|
all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
|
|
is_coord_smaller(arg1, arg2)
|
|
for arg1, arg2 in zip(args1, args2))
|
|
return any_smaller and all_smaller_or_equal
|
|
|
|
# Construct the relation graph
|
|
larger_than = defaultdict(set)
|
|
for i1, decl1 in enumerate(grouped_decls):
|
|
for i2, decl2 in enumerate(grouped_decls):
|
|
if is_smaller(decl1, decl2):
|
|
larger_than[i1].add(i2)
|
|
|
|
if not larger_than:
|
|
return grouped_decls
|
|
|
|
# Use a topological sort to sort decls according to the partial order.
|
|
sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
|
|
if i not in larger_than]
|
|
for i, decl in sorted_deps:
|
|
for i2 in sorted(larger_than.keys()):
|
|
larger = larger_than[i2]
|
|
larger.discard(i)
|
|
if not larger:
|
|
del larger_than[i2]
|
|
sorted_deps.append((i2, grouped_decls[i2]))
|
|
|
|
return [decl for i, decl in sorted_deps]
|
|
|
|
|
|
def get_python_signature(declaration, include_out):
|
|
# Compute the Python function signature for argument parsing,
|
|
# as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
|
|
# this is NOT the same type signature as specified by PEP 484
|
|
# as understood by mypy; our format was independently developed
|
|
# and has some quirks to make it more suitable specifically
|
|
# for error parsing.
|
|
#
|
|
# For a translation to mypy-valid type signatures, see
|
|
# tools/gen_pyi.py. If you change any logic here, please
|
|
# check that file too.
|
|
py_formal_args = []
|
|
output_args = []
|
|
type_args = []
|
|
positional = True
|
|
|
|
def get_py_formal_arg(arg):
|
|
typename = arg['simple_type']
|
|
typename = typename if typename != 'Type' else 'ScalarType'
|
|
|
|
# TODO: remove this and make optional types in simple_type to be consistent across
|
|
# tensor and other types after make Tensor? be optional instead of undefined
|
|
if arg.get('is_nullable') and '?' not in typename:
|
|
typename = '{}?'.format(typename)
|
|
|
|
if arg.get('size') is not None:
|
|
typename = '{}[{}]'.format(typename, arg['size'])
|
|
param = typename + ' ' + arg['name']
|
|
default = None
|
|
if arg.get('default') is not None:
|
|
default = arg['default']
|
|
if default == 'nullptr' or default == 'nullopt' or default == '{}':
|
|
default = 'None'
|
|
if default is not None:
|
|
param += '=' + str(default)
|
|
return param
|
|
|
|
for arg in declaration['arguments']:
|
|
if arg.get('output', False):
|
|
output_args.append(arg)
|
|
continue
|
|
if arg['simple_type'] == 'Type':
|
|
type_args.append(arg)
|
|
continue
|
|
# Skip `TensorOptions` in Python, as it is only used on the C++ side.
|
|
if arg['simple_type'] == 'TensorOptions':
|
|
continue
|
|
if arg.get('kwarg_only', False) and positional:
|
|
py_formal_args.append('*')
|
|
positional = False
|
|
param = get_py_formal_arg(arg)
|
|
py_formal_args.append(param)
|
|
|
|
# add output arguments
|
|
name = declaration['name']
|
|
if name.endswith('_out'):
|
|
name = name[:-4]
|
|
|
|
if len(output_args) > 0 and include_out:
|
|
assert declaration['name'].endswith('_out')
|
|
if positional:
|
|
py_formal_args.append('*')
|
|
positional = False
|
|
typenames = [arg['simple_type'] for arg in output_args]
|
|
if len(typenames) > 1:
|
|
typename = 'TensorList[{}]'.format(len(typenames))
|
|
else:
|
|
typename = typenames[0]
|
|
py_formal_args.append(typename + ' out=None')
|
|
|
|
# we could put this in the loop above but we want to ensure both type dispatched args
|
|
# and python binding arguments are after the out argument; this matches the case
|
|
# where there is a python binding argument dtype, which is necessary to match
|
|
# the function signatures between the out and non-out variant.
|
|
assert len(type_args) <= 1
|
|
for arg in type_args:
|
|
if positional: # assume type_args should be kwarg_only.
|
|
py_formal_args.append('*')
|
|
positional = False
|
|
py_formal_args.append(get_py_formal_arg(arg))
|
|
|
|
if len(declaration['python_binding_arguments']) > 0:
|
|
for arg in declaration['python_binding_arguments']:
|
|
if arg.get('kwarg_only', False) and positional:
|
|
py_formal_args.append('*')
|
|
positional = False
|
|
py_formal_args.append(get_py_formal_arg(arg))
|
|
|
|
# Python function signature.
|
|
# This is the string that we give to FunctionParameter, which is
|
|
# then parsed into the actual structure which we do parsing
|
|
# with.
|
|
return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)
|