mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add generic bindings to THNN and THCUNN (#645)
Adds bindings using thpp::Tensor to THNN and THCUNN. This allows calling into those APIs without knowing the concrete types of the tensor arguments.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,6 +15,9 @@ torch/csrc/nn/THNN.cwrap
|
||||
torch/csrc/nn/THNN.cpp
|
||||
torch/csrc/nn/THCUNN.cwrap
|
||||
torch/csrc/nn/THCUNN.cpp
|
||||
torch/csrc/nn/THNN_generic.cwrap
|
||||
torch/csrc/nn/THNN_generic.cpp
|
||||
torch/csrc/nn/THNN_generic.h
|
||||
docs/src/**/*
|
||||
test/data/legacy_modules.t7
|
||||
test/htmlcov
|
||||
|
7
setup.py
7
setup.py
@ -185,6 +185,7 @@ include_dirs += [
|
||||
tmp_install_path + "/include",
|
||||
tmp_install_path + "/include/TH",
|
||||
tmp_install_path + "/include/THPP",
|
||||
tmp_install_path + "/include/THNN",
|
||||
]
|
||||
|
||||
extra_link_args.append('-L' + lib_path)
|
||||
@ -210,7 +211,7 @@ if platform.system() == 'Darwin':
|
||||
|
||||
main_compile_args = ['-D_THP_CORE']
|
||||
main_libraries = ['shm']
|
||||
main_link_args = [TH_LIB, THS_LIB, THPP_LIB]
|
||||
main_link_args = [TH_LIB, THS_LIB, THPP_LIB, THNN_LIB]
|
||||
main_sources = [
|
||||
"torch/csrc/PtrWrapper.cpp",
|
||||
"torch/csrc/Module.cpp",
|
||||
@ -227,6 +228,7 @@ main_sources = [
|
||||
"torch/csrc/autograd/variable.cpp",
|
||||
"torch/csrc/autograd/function.cpp",
|
||||
"torch/csrc/autograd/engine.cpp",
|
||||
"torch/csrc/nn/THNN_generic.cpp",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -259,11 +261,12 @@ if WITH_CUDA:
|
||||
if os.path.exists(cuda_lib_path):
|
||||
break
|
||||
include_dirs.append(cuda_include_path)
|
||||
include_dirs.append(tmp_install_path + "/include/THCUNN")
|
||||
extra_link_args.append('-L' + cuda_lib_path)
|
||||
extra_link_args.append('-Wl,-rpath,' + cuda_lib_path)
|
||||
extra_compile_args += ['-DWITH_CUDA']
|
||||
extra_compile_args += ['-DCUDA_LIB_PATH=' + cuda_lib_path]
|
||||
main_link_args += [THC_LIB, THCS_LIB]
|
||||
main_link_args += [THC_LIB, THCS_LIB, THCUNN_LIB]
|
||||
main_sources += [
|
||||
"torch/csrc/cuda/Module.cpp",
|
||||
"torch/csrc/cuda/Storage.cpp",
|
||||
|
@ -229,6 +229,8 @@ class cwrap(object):
|
||||
depth -= line.count('}') * 2
|
||||
code += ' ' * depth + line + '\n'
|
||||
depth += line.count('{') * 2
|
||||
depth += line.count('(') * 4
|
||||
depth -= line.count(')') * 4
|
||||
|
||||
# Put everything together
|
||||
return self.OPTION_TEMPLATE.substitute(
|
||||
|
211
tools/cwrap/plugins/GenericNN.py
Normal file
211
tools/cwrap/plugins/GenericNN.py
Normal file
@ -0,0 +1,211 @@
|
||||
import copy
|
||||
from string import Template
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class GenericNN(CWrapPlugin):
|
||||
INPUT_TYPE_CHECK = Template("checkTypes(is_cuda, $type, $tensor_args);")
|
||||
|
||||
HEADER_TEMPLATE = Template("void $name($args);")
|
||||
|
||||
WRAPPER_TEMPLATE = Template("""\
|
||||
void $name($args)
|
||||
{
|
||||
bool is_cuda = $input->isCuda();
|
||||
auto type = $input->type();
|
||||
$type_check
|
||||
$options
|
||||
} else {
|
||||
throw std::runtime_error("invalid arguments");
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
THNN_TEMPLATE = Template("""\
|
||||
if (type == thpp::Type::FLOAT) {
|
||||
THNN_Float$name(
|
||||
NULL,
|
||||
$float_args);
|
||||
} else if (type == thpp::Type::DOUBLE) {
|
||||
THNN_Double$name(
|
||||
NULL,
|
||||
$double_args);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported tensor type");
|
||||
}""")
|
||||
|
||||
THCUNN_TEMPLATE = Template("""\
|
||||
#ifdef WITH_CUDA
|
||||
if (type == thpp::Type::FLOAT) {
|
||||
THNN_Cuda$name(
|
||||
state,
|
||||
$float_args);
|
||||
} else if (type == thpp::Type::DOUBLE) {
|
||||
THNN_CudaDouble$name(
|
||||
state,
|
||||
$double_args);
|
||||
} else if (type == thpp::Type::HALF) {
|
||||
THNN_CudaHalf$name(
|
||||
state,
|
||||
$half_args);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported tensor type");
|
||||
}
|
||||
#endif
|
||||
""")
|
||||
|
||||
INDEX_TENSOR_TYPES = {'THIndexTensor*', 'THCIndexTensor*'}
|
||||
|
||||
REAL_TENSOR_TYPES = {'THTensor*', 'THCTensor*'}
|
||||
|
||||
INPUT_ARGUMENT_MAP = {
|
||||
'THNNState*': 'void*',
|
||||
'THCState*': 'void*',
|
||||
'THTensor*': 'thpp::Tensor*',
|
||||
'THCTensor*': 'thpp::Tensor*',
|
||||
'THIndexTensor*': 'thpp::Tensor*',
|
||||
'THIndex_t': 'long',
|
||||
'real': 'double',
|
||||
}
|
||||
|
||||
def __init__(self, header=False):
|
||||
self.header = header
|
||||
self.declarations = []
|
||||
|
||||
def process_full_file(self, base_wrapper):
|
||||
if self.header:
|
||||
wrapper = '#pragma once\n\n'
|
||||
wrapper += '#include <THPP/Tensor.hpp>\n\n'
|
||||
else:
|
||||
wrapper = '#include "THNN_generic.h"\n'
|
||||
wrapper = '#include "THNN_generic.inc.h"\n\n'
|
||||
wrapper += 'namespace torch { namespace nn {\n\n'
|
||||
wrapper += base_wrapper
|
||||
wrapper += '}} // namespace torch::nn\n'
|
||||
return wrapper
|
||||
|
||||
def process_declarations(self, declarations):
|
||||
for declaration in declarations:
|
||||
base_args = declaration['options'][0]['arguments']
|
||||
for option in declaration['options']:
|
||||
for idx, arg in enumerate(option['arguments']):
|
||||
arg['formal_name'] = base_args[idx]['name']
|
||||
arg['formal_type'] = base_args[idx]['type']
|
||||
if idx != 1:
|
||||
arg['ignore_check'] = True
|
||||
return declarations
|
||||
|
||||
def get_arg_accessor(self, arg, option):
|
||||
return self.get_type_unpack(arg, option)
|
||||
|
||||
def process_option_code_template(self, template, option):
|
||||
code = '// fill me in'
|
||||
|
||||
def base_cast(arg, CReal, real):
|
||||
name = arg['formal_name']
|
||||
type = arg['type']
|
||||
if type in self.REAL_TENSOR_TYPES:
|
||||
return ('(TH{CReal}Tensor*){name}->cdata()'
|
||||
.format(CReal=CReal, name=name))
|
||||
elif type in self.INDEX_TENSOR_TYPES:
|
||||
return '({type}){name}->cdata()'.format(type=type, name=name)
|
||||
elif type == 'THCState*':
|
||||
return '({}){}'.format(type, name)
|
||||
elif type == 'real':
|
||||
if real == 'half':
|
||||
return 'THC_float2half({})'.format(name)
|
||||
return '({real}){name}'.format(real=real, name=name)
|
||||
return name
|
||||
|
||||
def cast(arg, CReal, real):
|
||||
expr = base_cast(arg, CReal, real)
|
||||
if arg.get('optional', False):
|
||||
name = arg['formal_name']
|
||||
return '{name} ? {expr} : NULL'.format(name=name, expr=expr)
|
||||
return expr
|
||||
|
||||
if option['backend'] == 'nn':
|
||||
float_args = []
|
||||
double_args = []
|
||||
for idx, arg in enumerate(option['arguments']):
|
||||
float_args.append(cast(arg, 'Float', 'float'))
|
||||
double_args.append(cast(arg, 'Double', 'double'))
|
||||
|
||||
code = self.THNN_TEMPLATE.substitute(
|
||||
name=option['cname'],
|
||||
float_args=',\n'.join(float_args),
|
||||
double_args=',\n'.join(double_args))
|
||||
|
||||
elif option['backend'] == 'cunn':
|
||||
float_args = []
|
||||
double_args = []
|
||||
half_args = []
|
||||
for idx, arg in enumerate(option['arguments']):
|
||||
float_args.append(cast(arg, 'Cuda', 'float'))
|
||||
double_args.append(cast(arg, 'CudaDouble', 'double'))
|
||||
half_args.append(cast(arg, 'CudaHalf', 'half'))
|
||||
|
||||
code = self.THCUNN_TEMPLATE.substitute(
|
||||
name=option['cname'],
|
||||
float_args=',\n'.join(float_args),
|
||||
double_args=',\n'.join(double_args),
|
||||
half_args=',\n'.join(half_args))
|
||||
|
||||
return [code, '']
|
||||
|
||||
def get_type_unpack(self, arg, option):
|
||||
return Template(arg['name'])
|
||||
|
||||
def get_type_check(self, arg, option):
|
||||
if option['backend'] == 'cunn':
|
||||
return Template('is_cuda')
|
||||
else:
|
||||
return Template('!is_cuda')
|
||||
|
||||
def get_formal_args(self, arguments):
|
||||
formal_args = []
|
||||
for arg in arguments:
|
||||
arg = copy.copy(arg)
|
||||
new_type = self.INPUT_ARGUMENT_MAP.get(arg['type'])
|
||||
if new_type is not None:
|
||||
arg['type'] = new_type
|
||||
formal_args.append(arg)
|
||||
return formal_args
|
||||
|
||||
def get_wrapper_template(self, declaration):
|
||||
# get formal arguments string
|
||||
base_arguments = declaration['options'][0]['arguments']
|
||||
args = self.get_formal_args(base_arguments)
|
||||
arg_str = ', '.join([arg['type'] + ' ' + arg['name'] for arg in args])
|
||||
|
||||
if self.header:
|
||||
return Template(self.HEADER_TEMPLATE.safe_substitute(args=arg_str))
|
||||
|
||||
def get_checked_args(tensor_types):
|
||||
checked_args = []
|
||||
for arg in base_arguments:
|
||||
if arg['type'] in tensor_types:
|
||||
name = arg.get('formal_name', arg['name'])
|
||||
name_str = name
|
||||
if arg.get('optional', False):
|
||||
name_str = '?' + name_str
|
||||
checked_args += ['"' + name_str + '"', name]
|
||||
checked_args += ['NULL']
|
||||
return checked_args
|
||||
|
||||
real_args = get_checked_args(self.REAL_TENSOR_TYPES)
|
||||
long_args = get_checked_args(self.INDEX_TENSOR_TYPES)
|
||||
|
||||
# check input types
|
||||
types_checks = []
|
||||
if len(real_args) > 1:
|
||||
types_checks.append(self.INPUT_TYPE_CHECK.substitute(
|
||||
type='type', tensor_args=', '.join(real_args)))
|
||||
if len(long_args) > 1:
|
||||
types_checks.append(self.INPUT_TYPE_CHECK.substitute(
|
||||
type='thpp::Type::LONG', tensor_args=', '.join(long_args)))
|
||||
|
||||
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
|
||||
input=args[0]['name'],
|
||||
args=arg_str,
|
||||
type_check='\n '.join(types_checks)))
|
@ -58,3 +58,4 @@ from .ReturnArguments import ReturnArguments
|
||||
from .GILRelease import GILRelease
|
||||
from .AutoGPU import AutoGPU
|
||||
from .CuDNNPlugin import CuDNNPlugin
|
||||
from .GenericNN import GenericNN
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import sys
|
||||
from string import Template, ascii_lowercase
|
||||
from ..cwrap import cwrap
|
||||
from ..cwrap.plugins import StandaloneExtension, NullableArguments, AutoGPU
|
||||
from ..cwrap.plugins import StandaloneExtension, GenericNN, NullableArguments, AutoGPU
|
||||
|
||||
BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..'))
|
||||
WRAPPER_PATH = os.path.join(BASE_PATH, 'torch', 'csrc', 'nn')
|
||||
@ -103,6 +103,7 @@ def wrap_function(name, type, arguments):
|
||||
def generate_wrappers():
|
||||
wrap_nn()
|
||||
wrap_cunn()
|
||||
wrap_generic()
|
||||
|
||||
|
||||
def wrap_nn():
|
||||
@ -133,3 +134,66 @@ def wrap_cunn():
|
||||
NullableArguments(),
|
||||
AutoGPU(has_self=False),
|
||||
])
|
||||
|
||||
GENERIC_FUNCTION_TEMPLATE = Template("""\
|
||||
[[
|
||||
name: $name
|
||||
return: void
|
||||
options:
|
||||
""")
|
||||
|
||||
|
||||
def wrap_generic_function(name, backends):
|
||||
declaration = ''
|
||||
declaration += GENERIC_FUNCTION_TEMPLATE.substitute(name=name)
|
||||
for backend in backends:
|
||||
declaration += ' - cname: ' + name + '\n'
|
||||
declaration += ' backend: ' + backend['name'] + '\n'
|
||||
declaration += ' arguments:\n'
|
||||
for arg in backend['arguments']:
|
||||
declaration += ' - arg: ' + arg.type + ' ' + arg.name + '\n'
|
||||
if arg.is_optional:
|
||||
declaration += ' optional: True\n'
|
||||
declaration += ']]\n\n\n'
|
||||
return declaration
|
||||
|
||||
|
||||
def wrap_generic():
|
||||
from collections import OrderedDict
|
||||
defs = OrderedDict()
|
||||
|
||||
def should_wrap_function(name):
|
||||
if name.startswith('LookupTable'):
|
||||
return False
|
||||
return (name.endswith('updateOutput') or
|
||||
name.endswith('updateGradInput') or
|
||||
name.endswith('accGradParameters') or
|
||||
name.endswith('backward'))
|
||||
|
||||
def add_functions(name, functions):
|
||||
for fn in functions:
|
||||
if not should_wrap_function(fn.name):
|
||||
continue
|
||||
if fn.name not in defs:
|
||||
defs[fn.name] = []
|
||||
defs[fn.name] += [{
|
||||
'name': name,
|
||||
'arguments': fn.arguments[1:],
|
||||
}]
|
||||
|
||||
add_functions('nn', thnn_utils.parse_header(thnn_utils.THNN_H_PATH))
|
||||
add_functions('cunn', thnn_utils.parse_header(thnn_utils.THCUNN_H_PATH))
|
||||
|
||||
wrapper = ''
|
||||
for name, backends in defs.items():
|
||||
wrapper += wrap_generic_function(name, backends)
|
||||
with open('torch/csrc/nn/THNN_generic.cwrap', 'w') as f:
|
||||
f.write(wrapper)
|
||||
|
||||
cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[
|
||||
GenericNN(header=True),
|
||||
], default_plugins=False, destination='torch/csrc/nn/THNN_generic.h')
|
||||
|
||||
cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[
|
||||
GenericNN(),
|
||||
], default_plugins=False)
|
||||
|
56
torch/csrc/nn/THNN_generic.inc.h
Normal file
56
torch/csrc/nn/THNN_generic.inc.h
Normal file
@ -0,0 +1,56 @@
|
||||
#include "THNN_generic.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <TH/TH.h>
|
||||
#include <THNN/THNN.h>
|
||||
#ifdef THNN_
|
||||
#undef THNN_
|
||||
#endif
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <THC/THC.h>
|
||||
#include <THCUNN/THCUNN.h>
|
||||
#ifdef THNN_
|
||||
#undef THNN_
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
extern THCState* state;
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
static std::runtime_error invalid_tensor(const char* expected, const char* got) {
|
||||
std::stringstream ss;
|
||||
ss << "expected " << expected << " tensor (got " << got << " tensor)";
|
||||
return std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
void checkTypes(bool isCuda, thpp::Type type, ...) {
|
||||
va_list args;
|
||||
va_start(args, type);
|
||||
|
||||
const char* name;
|
||||
while ((name = va_arg(args, const char*))) {
|
||||
bool optional = false;
|
||||
if (name[0] == '?') {
|
||||
name++;
|
||||
optional = true;
|
||||
}
|
||||
thpp::Tensor* tensor = va_arg(args, thpp::Tensor*);
|
||||
if (!tensor) {
|
||||
if (optional) {
|
||||
continue;
|
||||
}
|
||||
throw std::runtime_error(std::string("missing required argument '") + name + "'");
|
||||
}
|
||||
if (tensor->isCuda() != isCuda) {
|
||||
throw invalid_tensor(isCuda ? "CUDA" : "CPU", tensor->isCuda() ? "CUDA" : "CPU");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Reference in New Issue
Block a user