mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
changes to support ATen code generation inside fbcode (#8397)
* Back out "Back out "Add support for generating ATen files during fbcode build"" Original commit changeset: 7b8de22d1613 I'm re-sending this diff exactly as it was approved and committed. Fixes to support @mode/opt will be sent separately for ease of review. * Enable building //caffe2:torch with @mode/opt In @mode/opt, python runs out of a PAR, which breaks a lot of assumptions in the code about where templates/ folders live relative to __file__. Rather than introduce hacks with parutil, I simply turn template_path into a parameter for all the relevant functions and thread it through from the top level.
This commit is contained in:
@ -99,7 +99,6 @@ class FileManager(object):
|
||||
raise Exception("Outputs declared with 'will_write' were " +
|
||||
"never written: {}".format(self.filenames))
|
||||
|
||||
|
||||
TEMPLATE_PATH = options.source_path + "/templates"
|
||||
GENERATOR_DERIVED = CodeTemplate.from_file(
|
||||
TEMPLATE_PATH + "/GeneratorDerived.h")
|
||||
|
@ -14,10 +14,6 @@ import yaml
|
||||
from collections import defaultdict
|
||||
from .utils import YamlLoader, split_name_params
|
||||
|
||||
template_path = os.path.join(os.path.dirname(__file__), 'templates')
|
||||
derivatives_path = os.path.join(os.path.dirname(__file__), 'derivatives.yaml')
|
||||
deprecated_path = os.path.join(os.path.dirname(__file__), 'deprecated.yaml')
|
||||
|
||||
VIEW_FUNCTIONS = {
|
||||
'alias', 'as_strided', 'diagonal', 'expand', 'narrow', 'permute', 'select', 'slice',
|
||||
'squeeze', 't', 'transpose', 'unfold', 'unsqueeze', 'view',
|
||||
@ -75,7 +71,7 @@ def load_aten_declarations(path):
|
||||
return declarations
|
||||
|
||||
|
||||
def load_deprecated_signatures(aten_decls):
|
||||
def load_deprecated_signatures(aten_decls, deprecated_path):
|
||||
def group_declarations_by_signature():
|
||||
d = defaultdict(list)
|
||||
for declaration in aten_decls:
|
||||
@ -137,29 +133,35 @@ def load_deprecated_signatures(aten_decls):
|
||||
return declarations
|
||||
|
||||
|
||||
def gen_autograd(aten_path, out):
|
||||
def gen_autograd(aten_path, out, autograd_dir):
|
||||
aten_decls = load_aten_declarations(aten_path)
|
||||
|
||||
# Parse and load derivatives.yaml
|
||||
from .load_derivatives import load_derivatives
|
||||
autograd_functions = load_derivatives(derivatives_path, aten_decls)
|
||||
autograd_functions = load_derivatives(
|
||||
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
|
||||
|
||||
# Generate VariableType.h/cpp
|
||||
from .gen_variable_type import gen_variable_type
|
||||
gen_variable_type(out, aten_decls)
|
||||
gen_variable_type(out, aten_decls, os.path.join(autograd_dir, 'templates'))
|
||||
|
||||
# Generate Functions.h/cpp
|
||||
from .gen_autograd_functions import gen_autograd_functions
|
||||
gen_autograd_functions(out, autograd_functions)
|
||||
gen_autograd_functions(
|
||||
out, autograd_functions, os.path.join(autograd_dir, 'templates'))
|
||||
|
||||
# Load deprecated signatures
|
||||
deprecated = load_deprecated_signatures(aten_decls)
|
||||
deprecated = load_deprecated_signatures(
|
||||
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
|
||||
|
||||
# Genereate Python bindings
|
||||
from . import gen_python_functions
|
||||
gen_python_functions.gen_py_variable_methods(out, aten_decls + deprecated)
|
||||
gen_python_functions.gen_py_torch_functions(out, aten_decls + deprecated)
|
||||
gen_python_functions.gen_py_nn_functions(out, aten_decls)
|
||||
gen_python_functions.gen_py_variable_methods(
|
||||
out, aten_decls + deprecated, os.path.join(autograd_dir, 'templates'))
|
||||
gen_python_functions.gen_py_torch_functions(
|
||||
out, aten_decls + deprecated, os.path.join(autograd_dir, 'templates'))
|
||||
gen_python_functions.gen_py_nn_functions(
|
||||
out, aten_decls, os.path.join(autograd_dir, 'templates'))
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -6,14 +6,9 @@
|
||||
#
|
||||
import re
|
||||
from .utils import nested_dict, CodeTemplate, write
|
||||
from .gen_autograd import VIEW_FUNCTIONS, template_path
|
||||
from .gen_autograd import VIEW_FUNCTIONS
|
||||
from .utils import IDENT_REGEX
|
||||
|
||||
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
|
||||
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
|
||||
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
|
||||
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate("""\
|
||||
struct ${op} : public ${superclass} {
|
||||
using ${superclass}::${superclass};
|
||||
@ -86,12 +81,18 @@ if (should_compute_output({ ${idx_ranges} })) {
|
||||
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
||||
|
||||
|
||||
def gen_autograd_functions(out, autograd_functions):
|
||||
def gen_autograd_functions(out, autograd_functions, template_path):
|
||||
"""Functions.h and Functions.cpp body
|
||||
|
||||
These contain the auto-generated subclasses of torch::autograd::Function
|
||||
for each every differentiable torch function.
|
||||
"""
|
||||
|
||||
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
|
||||
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
|
||||
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
|
||||
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
|
||||
|
||||
function_definitions = []
|
||||
function_declarations = []
|
||||
py_function_initializers = []
|
||||
|
@ -6,12 +6,14 @@
|
||||
from collections import defaultdict
|
||||
import re
|
||||
from .nested_dict import nested_dict
|
||||
from tools.shared.module_loader import import_module
|
||||
from .gen_autograd import template_path
|
||||
from .gen_variable_type import should_trace
|
||||
from .utils import write
|
||||
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
try:
|
||||
from src.ATen.code_template import CodeTemplate
|
||||
except ImportError:
|
||||
from tools.shared.module_loader import import_module
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
|
||||
# These functions require manual Python bindings or are not exposed to Python
|
||||
SKIP_PYTHON_BINDINGS = [
|
||||
@ -25,14 +27,6 @@ SKIP_PYTHON_BINDINGS = [
|
||||
'arange.*', 'range.*', '_gesv.*', 'slice',
|
||||
]
|
||||
|
||||
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
||||
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
|
||||
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
||||
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
|
||||
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
|
||||
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
|
||||
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
|
||||
|
||||
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
|
||||
static PyObject * ${pycname}(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
@ -140,7 +134,10 @@ def should_generate_python_binding(declaration):
|
||||
return True
|
||||
|
||||
|
||||
def gen_py_variable_methods(out, declarations):
|
||||
def gen_py_variable_methods(out, declarations, template_path):
|
||||
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
||||
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
|
||||
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
declaration['mode'] != 'NN' and
|
||||
@ -153,7 +150,11 @@ def gen_py_variable_methods(out, declarations):
|
||||
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
|
||||
|
||||
|
||||
def gen_py_nn_functions(out, declarations):
|
||||
def gen_py_nn_functions(out, declarations, template_path):
|
||||
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
|
||||
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
|
||||
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
|
||||
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
declaration['mode'] == 'NN')
|
||||
@ -166,7 +167,10 @@ def gen_py_nn_functions(out, declarations):
|
||||
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
|
||||
|
||||
|
||||
def gen_py_torch_functions(out, declarations):
|
||||
def gen_py_torch_functions(out, declarations, template_path):
|
||||
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
||||
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
|
||||
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
declaration['mode'] != 'NN' and
|
||||
|
@ -26,13 +26,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
|
||||
from .gen_autograd import VIEW_FUNCTIONS, template_path, \
|
||||
HARDCODED_DIFFERENTIABLE_OUTPUTS
|
||||
from .gen_autograd import VIEW_FUNCTIONS, HARDCODED_DIFFERENTIABLE_OUTPUTS
|
||||
from .gen_autograd_functions import uses_single_grad
|
||||
|
||||
VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h')
|
||||
VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp')
|
||||
|
||||
# These functions are written manually in templates/VariableType.cpp
|
||||
MANUAL_IMPLEMENTATIONS = {
|
||||
'contiguous', 'resize_', 'resize_as_'
|
||||
@ -166,7 +162,7 @@ def should_trace(declaration):
|
||||
return True
|
||||
|
||||
|
||||
def gen_variable_type(out, aten_declarations):
|
||||
def gen_variable_type(out, aten_declarations, template_path):
|
||||
"""VariableType.h and VariableType.cpp body
|
||||
|
||||
This is the at::Type subclass for differentiable tensors. The
|
||||
@ -174,6 +170,9 @@ def gen_variable_type(out, aten_declarations):
|
||||
compute the output. The grad_fn is attached to differentiable functions.
|
||||
"""
|
||||
|
||||
VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h')
|
||||
VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp')
|
||||
|
||||
type_declarations = []
|
||||
type_definitions = []
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import re
|
||||
import os
|
||||
from tools.shared.module_loader import import_module
|
||||
from .nested_dict import nested_dict
|
||||
|
||||
|
||||
@ -9,8 +8,11 @@ __all__ = [
|
||||
'split_name_params', 'write',
|
||||
]
|
||||
|
||||
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
try:
|
||||
from src.ATen.code_template import CodeTemplate
|
||||
except ImportError:
|
||||
from tools.shared.module_loader import import_module
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
|
@ -37,7 +37,7 @@ class cwrap(object):
|
||||
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments,
|
||||
ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
|
||||
|
||||
def __init__(self, source, destination=None, plugins=None, default_plugins=True):
|
||||
def __init__(self, source, destination=None, plugins=None, default_plugins=True, template_path=None):
|
||||
if destination is None:
|
||||
destination = source.replace('.cwrap', '.cpp')
|
||||
|
||||
@ -58,7 +58,7 @@ class cwrap(object):
|
||||
|
||||
# let each plugin do any post-processing of the wrapped file
|
||||
for plugin in self.plugins:
|
||||
wrapper = plugin.process_full_file(wrapper)
|
||||
wrapper = plugin.process_full_file(wrapper, template_path)
|
||||
|
||||
# See Note [Unchanging results for ninja]
|
||||
try:
|
||||
|
@ -175,5 +175,5 @@ static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
|
||||
methods += entry
|
||||
return self.METHODS_DECLARATION.substitute(methods=methods)
|
||||
|
||||
def process_full_file(self, code):
|
||||
def process_full_file(self, code, template_path):
|
||||
return code + self.declare_methods()
|
||||
|
@ -12,9 +12,6 @@ MODULE_HEAD = """
|
||||
#include "torch/csrc/nn/type_checks.h"
|
||||
|
||||
"""
|
||||
with open(os.path.join(os.path.dirname(__file__), 'templates', 'nn_tail.cpp'), 'r') as f:
|
||||
MODULE_TAIL = Template(f.read())
|
||||
|
||||
REGISTER_METHOD_TEMPLATE = Template(' {"$name", (PyCFunction)$name, METH_STATIC | METH_VARARGS, NULL},\n')
|
||||
|
||||
MODULE_METHODS_TEMPLATE = Template("""
|
||||
@ -110,7 +107,10 @@ PyObject * $name(PyObject *_unused, PyObject *args)
|
||||
self.module_name = module_name
|
||||
self.declarations = []
|
||||
|
||||
def process_full_file(self, code):
|
||||
def process_full_file(self, code, template_path):
|
||||
with open(os.path.join(template_path, 'nn_tail.cpp'), 'r') as f:
|
||||
MODULE_TAIL = Template(f.read())
|
||||
|
||||
short_name = self.module_name.split('.')[-1]
|
||||
new_code = MODULE_HEAD
|
||||
new_code += code
|
||||
|
@ -540,7 +540,7 @@ ${cpu}
|
||||
generated = '#if !defined(TH_REAL_IS_HALF) && !IS_DISTRIBUTED\n' + generated + '\n#endif\n\n'
|
||||
return generated
|
||||
|
||||
def process_full_file(self, code):
|
||||
def process_full_file(self, code, template_path):
|
||||
# We have to find a place before all undefs
|
||||
idx = code.find('// PUT DEFINITIONS IN HERE PLEASE')
|
||||
return (code[:idx] +
|
||||
|
@ -186,7 +186,7 @@ class CWrapPlugin(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_full_file(self, code):
|
||||
def process_full_file(self, code, template_path):
|
||||
"""Used to modify the code for the entire output file.
|
||||
|
||||
The last thing any plugin can do. Code contains the results of wrapping
|
||||
|
@ -5,12 +5,6 @@ from ..autograd.utils import CodeTemplate, write, uninplace_api_name
|
||||
from ..autograd.gen_autograd import load_aten_declarations
|
||||
from collections import OrderedDict
|
||||
|
||||
template_path = os.path.join(os.path.dirname(__file__), 'templates')
|
||||
|
||||
ATEN_DISPATCH_CPP = CodeTemplate.from_file(template_path + '/aten_dispatch.cpp')
|
||||
ATEN_INTERNED_STRINGS_H = CodeTemplate.from_file(template_path + '/aten_interned_strings.h')
|
||||
ATEN_SCHEMA_CPP = CodeTemplate.from_file(template_path + '/aten_schema.cpp')
|
||||
|
||||
ATTR_METHOD_MAP = {
|
||||
'int64_t': 'i',
|
||||
'IntList': 'is',
|
||||
@ -127,7 +121,10 @@ def is_sized_intlist_arg(arg):
|
||||
return (arg['simple_type'] == 'IntList') and ('size' in arg)
|
||||
|
||||
|
||||
def gen_jit_dispatch(declarations, out):
|
||||
def gen_jit_dispatch(declarations, out, template_path):
|
||||
ATEN_DISPATCH_CPP = CodeTemplate.from_file(template_path + '/aten_dispatch.cpp')
|
||||
ATEN_INTERNED_STRINGS_H = CodeTemplate.from_file(template_path + '/aten_interned_strings.h')
|
||||
|
||||
ops = {}
|
||||
|
||||
def get_invocation(decl, args, num_dynamic_inputs):
|
||||
@ -310,7 +307,7 @@ def gen_jit_dispatch(declarations, out):
|
||||
}
|
||||
write(out, 'aten_dispatch.cpp', ATEN_DISPATCH_CPP, env)
|
||||
|
||||
emit_schema(jit_decls, out)
|
||||
emit_schema(jit_decls, out, template_path)
|
||||
|
||||
# NB: Operate on aten_decls, not jit_decls, because VariableType is
|
||||
# a client for these symbols as well
|
||||
@ -331,7 +328,8 @@ def gen_jit_dispatch(declarations, out):
|
||||
write(out, 'aten_interned_strings.h', ATEN_INTERNED_STRINGS_H, strings_env)
|
||||
|
||||
|
||||
def emit_schema(jit_decls, out):
|
||||
def emit_schema(jit_decls, out, template_path):
|
||||
ATEN_SCHEMA_CPP = CodeTemplate.from_file(template_path + '/aten_schema.cpp')
|
||||
|
||||
# see [aten_schema encoding] for how this gets translated to C++ object
|
||||
|
||||
@ -414,8 +412,10 @@ def main():
|
||||
help='path to Declarations.yaml')
|
||||
parser.add_argument('out', metavar='OUT',
|
||||
help='path to output directory')
|
||||
parser.add_argument('template-path', metavar='TEMPLATE_PATH',
|
||||
help='path to templates directory')
|
||||
args = parser.parse_args()
|
||||
gen_jit_dispatch(args.declarations, args.out)
|
||||
gen_jit_dispatch(args.declarations, args.out, args.template_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,2 +1,5 @@
|
||||
from .generate_wrappers import generate_wrappers, wrap_function, \
|
||||
import_module
|
||||
from .generate_wrappers import generate_wrappers, wrap_function
|
||||
try:
|
||||
from .generate_wrappers import import_module
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -3,14 +3,17 @@ import sys
|
||||
from string import Template, ascii_lowercase
|
||||
from ..cwrap import cwrap
|
||||
from ..cwrap.plugins import NNExtension, NullableArguments, AutoGPU
|
||||
from ..shared import import_module
|
||||
|
||||
BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..'))
|
||||
WRAPPER_PATH = os.path.join(BASE_PATH, 'torch', 'csrc', 'nn')
|
||||
THNN_UTILS_PATH = os.path.join(BASE_PATH, 'torch', '_thnn', 'utils.py')
|
||||
|
||||
|
||||
thnn_utils = import_module('torch._thnn.utils', THNN_UTILS_PATH)
|
||||
try:
|
||||
from torch._thnn import utils as thnn_utils
|
||||
except ImportError:
|
||||
from ..shared import import_module
|
||||
thnn_utils = import_module('torch._thnn.utils', THNN_UTILS_PATH)
|
||||
|
||||
FUNCTION_TEMPLATE = Template("""\
|
||||
[[
|
||||
@ -95,36 +98,39 @@ def wrap_function(name, type, arguments):
|
||||
return declaration
|
||||
|
||||
|
||||
def generate_wrappers(nn_root=None):
|
||||
wrap_nn(os.path.join(nn_root, 'THNN', 'generic', 'THNN.h') if nn_root else None)
|
||||
wrap_cunn(os.path.join(nn_root, 'THCUNN', 'generic', 'THCUNN.h') if nn_root else None)
|
||||
def generate_wrappers(nn_root=None, install_dir=None, template_path=None):
|
||||
wrap_nn(os.path.join(nn_root, 'THNN', 'generic', 'THNN.h') if nn_root else None, install_dir, template_path)
|
||||
wrap_cunn(os.path.join(nn_root, 'THCUNN', 'generic', 'THCUNN.h') if nn_root else None, install_dir, template_path)
|
||||
|
||||
|
||||
def wrap_nn(thnn_h_path):
|
||||
def wrap_nn(thnn_h_path, install_dir, template_path):
|
||||
wrapper = '#include <TH/TH.h>\n\n\n'
|
||||
nn_functions = thnn_utils.parse_header(thnn_h_path or thnn_utils.THNN_H_PATH)
|
||||
for fn in nn_functions:
|
||||
for t in ['Float', 'Double']:
|
||||
wrapper += wrap_function(fn.name, t, fn.arguments)
|
||||
with open('torch/csrc/nn/THNN.cwrap', 'w') as f:
|
||||
install_dir = install_dir or 'torch/csrc/nn'
|
||||
try:
|
||||
os.makedirs(install_dir)
|
||||
except OSError:
|
||||
pass
|
||||
with open(os.path.join(install_dir, 'THNN.cwrap'), 'w') as f:
|
||||
f.write(wrapper)
|
||||
cwrap('torch/csrc/nn/THNN.cwrap', plugins=[
|
||||
NNExtension('torch._C._THNN'),
|
||||
NullableArguments(),
|
||||
])
|
||||
cwrap(os.path.join(install_dir, 'THNN.cwrap'),
|
||||
plugins=[NNExtension('torch._C._THNN'), NullableArguments()],
|
||||
template_path=template_path)
|
||||
|
||||
|
||||
def wrap_cunn(thcunn_h_path=None):
|
||||
def wrap_cunn(thcunn_h_path, install_dir, template_path):
|
||||
wrapper = '#include <TH/TH.h>\n'
|
||||
wrapper += '#include <THC/THC.h>\n\n\n'
|
||||
cunn_functions = thnn_utils.parse_header(thcunn_h_path or thnn_utils.THCUNN_H_PATH)
|
||||
for fn in cunn_functions:
|
||||
for t in ['CudaHalf', 'Cuda', 'CudaDouble']:
|
||||
wrapper += wrap_function(fn.name, t, fn.arguments)
|
||||
with open('torch/csrc/nn/THCUNN.cwrap', 'w') as f:
|
||||
install_dir = install_dir or 'torch/csrc/nn'
|
||||
with open(os.path.join(install_dir, 'THCUNN.cwrap'), 'w') as f:
|
||||
f.write(wrapper)
|
||||
cwrap('torch/csrc/nn/THCUNN.cwrap', plugins=[
|
||||
NNExtension('torch._C._THCUNN'),
|
||||
NullableArguments(),
|
||||
AutoGPU(has_self=False),
|
||||
])
|
||||
cwrap(os.path.join(install_dir, 'THCUNN.cwrap'),
|
||||
plugins=[NNExtension('torch._C._THCUNN'), NullableArguments(), AutoGPU(has_self=False)],
|
||||
template_path=template_path)
|
||||
|
@ -65,7 +65,8 @@ def generate_code_ninja(w):
|
||||
|
||||
def generate_code(ninja_global=None,
|
||||
declarations_path=None,
|
||||
nn_path=None):
|
||||
nn_path=None,
|
||||
install_dir=None):
|
||||
# if ninja is enabled, we just register this file as something
|
||||
# ninja will need to call if needed
|
||||
if ninja_global is not None:
|
||||
@ -80,16 +81,16 @@ def generate_code(ninja_global=None,
|
||||
|
||||
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
|
||||
# used by the legacy NN bindings.
|
||||
generate_nn_wrappers(nn_path)
|
||||
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
|
||||
|
||||
# Build ATen based Variable classes
|
||||
autograd_gen_dir = 'torch/csrc/autograd/generated'
|
||||
jit_gen_dir = 'torch/csrc/jit/generated'
|
||||
autograd_gen_dir = install_dir or 'torch/csrc/autograd/generated'
|
||||
jit_gen_dir = install_dir or 'torch/csrc/jit/generated'
|
||||
for d in (autograd_gen_dir, jit_gen_dir):
|
||||
if not os.path.exists(d):
|
||||
os.mkdir(d)
|
||||
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir)
|
||||
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir)
|
||||
os.makedirs(d)
|
||||
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
||||
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates')
|
||||
|
||||
|
||||
def main():
|
||||
@ -97,10 +98,12 @@ def main():
|
||||
parser.add_argument('--declarations-path')
|
||||
parser.add_argument('--nn-path')
|
||||
parser.add_argument('--ninja-global')
|
||||
parser.add_argument('--install_dir')
|
||||
options = parser.parse_args()
|
||||
generate_code(options.ninja_global,
|
||||
options.declarations_path,
|
||||
options.nn_path)
|
||||
options.nn_path,
|
||||
options.install_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,8 +2,12 @@ import os
|
||||
import itertools
|
||||
import importlib
|
||||
|
||||
THNN_H_PATH = os.path.join(os.path.dirname(__file__), '..', 'lib', 'THNN.h')
|
||||
THCUNN_H_PATH = os.path.join(os.path.dirname(__file__), '..', 'lib', 'THCUNN.h')
|
||||
# in fbcode, this fails in some cases, but we don't need it, therefore the try-catch
|
||||
try:
|
||||
THNN_H_PATH = os.path.join(os.path.dirname(__file__), '..', 'lib', 'THNN.h')
|
||||
THCUNN_H_PATH = os.path.join(os.path.dirname(__file__), '..', 'lib', 'THCUNN.h')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _unpickle_backend(backend_name):
|
||||
|
Reference in New Issue
Block a user