mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Back out "Back out "Add support for generating ATen files during fbcode build"" Original commit changeset: 7b8de22d1613 I'm re-sending this diff exactly as it was approved and committed. Fixes to support @mode/opt will be sent separately for ease of review. * Enable building //caffe2:torch with @mode/opt In @mode/opt, python runs out of a PAR, which breaks a lot of assumptions in the code about where templates/ folders live relative to __file__. Rather than introduce hacks with parutil, I simply turn template_path into a parameter for all the relevant functions and thread it through from the top level.
157 lines
5.8 KiB
Python
157 lines
5.8 KiB
Python
import os
|
|
from string import Template
|
|
from . import CWrapPlugin
|
|
|
|
|
|
MODULE_HEAD = """
|
|
#include <Python.h>
|
|
#include <exception>
|
|
|
|
#include "THP.h"
|
|
#include "torch/csrc/utils/auto_gpu.h"
|
|
#include "torch/csrc/nn/type_checks.h"
|
|
|
|
"""
|
|
REGISTER_METHOD_TEMPLATE = Template(' {"$name", (PyCFunction)$name, METH_STATIC | METH_VARARGS, NULL},\n')
|
|
|
|
MODULE_METHODS_TEMPLATE = Template("""
|
|
static PyMethodDef module_methods[] = {
|
|
$METHODS
|
|
{NULL, NULL, 0, NULL}
|
|
};
|
|
""")
|
|
|
|
|
|
class NNExtension(CWrapPlugin):
|
|
|
|
TYPE_UNPACK = {
|
|
'THFloatTensor*': Template('THNN_FloatTensor_Unpack($arg)'),
|
|
'THDoubleTensor*': Template('THNN_DoubleTensor_Unpack($arg)'),
|
|
'THLongTensor*': Template('THNN_LongTensor_Unpack($arg)'),
|
|
'THIntTensor*': Template('THNN_IntTensor_Unpack($arg)'),
|
|
'THCudaHalfTensor*': Template('THNN_CudaHalfTensor_Unpack($arg)'),
|
|
'THCudaTensor*': Template('THNN_CudaFloatTensor_Unpack($arg)'),
|
|
'THCudaDoubleTensor*': Template('THNN_CudaDoubleTensor_Unpack($arg)'),
|
|
'THCudaLongTensor*': Template('THNN_CudaLongTensor_Unpack($arg)'),
|
|
'half': Template('THPHalfUtils_unpackReal($arg)'),
|
|
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
|
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
|
'bool': Template('($arg == Py_True ? true : false)'),
|
|
'int': Template('THPUtils_unpackLong($arg)'),
|
|
'long': Template('THPUtils_unpackLong($arg)'),
|
|
'int64_t': Template('THPUtils_unpackLong($arg)'),
|
|
'void*': Template('(void*)THPUtils_unpackLong($arg)'),
|
|
'THGenerator*': Template('THPGenerator_TH_CData((THPGenerator*)$arg)'),
|
|
}
|
|
|
|
TYPE_CHECK = {
|
|
'THFloatTensor*': Template('THNN_FloatTensor_Check($arg)'),
|
|
'THDoubleTensor*': Template('THNN_DoubleTensor_Check($arg)'),
|
|
'THLongTensor*': Template('THNN_LongTensor_Check($arg)'),
|
|
'THIntTensor*': Template('THNN_IntTensor_Check($arg)'),
|
|
'THCudaHalfTensor*': Template('THNN_CudaHalfTensor_Check($arg)'),
|
|
'THCudaTensor*': Template('THNN_CudaFloatTensor_Check($arg)'),
|
|
'THCudaDoubleTensor*': Template('THNN_CudaDoubleTensor_Check($arg)'),
|
|
'THCudaLongTensor*': Template('THNN_CudaLongTensor_Check($arg)'),
|
|
'half': Template('THPHalfUtils_checkReal($arg)'),
|
|
'float': Template('THPFloatUtils_checkReal($arg)'),
|
|
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
|
'bool': Template('PyBool_Check($arg)'),
|
|
'int': Template('THPUtils_checkLong($arg)'),
|
|
'long': Template('THPUtils_checkLong($arg)'),
|
|
'int64_t': Template('THPUtils_checkLong($arg)'),
|
|
'void*': Template('THPUtils_checkLong($arg)'),
|
|
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
|
}
|
|
|
|
WRAPPER_TEMPLATE = Template("""
|
|
PyObject * $name(PyObject *_unused, PyObject *args)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
int __argcount = args ? PyTuple_Size(args) : 0;
|
|
$options
|
|
} else {
|
|
THPUtils_invalidArguments(args, NULL, "$name", 1, $expected_args);
|
|
return NULL;
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
""")
|
|
|
|
TYPE_NAMES = {
|
|
'THGenerator*': 'Generator',
|
|
'THCudaHalfTensor*': 'torch.cuda.HalfTensor',
|
|
'THCudaTensor*': 'torch.cuda.FloatTensor',
|
|
'THCudaDoubleTensor*': 'torch.cuda.DoubleTensor',
|
|
'THCudaLongTensor*': 'torch.cuda.LongTensor',
|
|
'THDoubleTensor*': 'torch.DoubleTensor',
|
|
'THFloatTensor*': 'torch.FloatTensor',
|
|
'THBoolTensor*': 'torch.ByteTensor',
|
|
'THLongTensor*': 'torch.LongTensor',
|
|
'THIndexTensor*': 'torch.LongTensor',
|
|
'THIntTensor*': 'torch.IntTensor',
|
|
'THLongStorage*': 'torch.LongStorage',
|
|
'long': 'int',
|
|
'int64_t': 'int',
|
|
'int': 'int',
|
|
'real': 'float',
|
|
'half': 'float',
|
|
'double': 'float',
|
|
'float': 'float',
|
|
'accreal': 'float',
|
|
'bool': 'bool',
|
|
'void*': 'int',
|
|
}
|
|
|
|
def __init__(self, module_name):
|
|
self.module_name = module_name
|
|
self.declarations = []
|
|
|
|
def process_full_file(self, code, template_path):
|
|
with open(os.path.join(template_path, 'nn_tail.cpp'), 'r') as f:
|
|
MODULE_TAIL = Template(f.read())
|
|
|
|
short_name = self.module_name.split('.')[-1]
|
|
new_code = MODULE_HEAD
|
|
new_code += code
|
|
new_code += self.declare_module_methods()
|
|
new_code += MODULE_TAIL.substitute(full_name=self.module_name, short_name=short_name)
|
|
return new_code
|
|
|
|
def process_wrapper(self, code, declaration):
|
|
self.declarations.append(declaration)
|
|
return code
|
|
|
|
def declare_module_methods(self):
|
|
module_methods = ''
|
|
for declaration in self.declarations:
|
|
module_methods += REGISTER_METHOD_TEMPLATE.substitute(name=declaration['name'])
|
|
return MODULE_METHODS_TEMPLATE.substitute(METHODS=module_methods)
|
|
|
|
def get_type_unpack(self, arg, option):
|
|
return self.TYPE_UNPACK.get(arg['type'], None)
|
|
|
|
def get_type_check(self, arg, option):
|
|
return self.TYPE_CHECK.get(arg['type'], None)
|
|
|
|
def get_wrapper_template(self, declaration):
|
|
arg_desc = []
|
|
|
|
def describe_arg(arg):
|
|
desc = self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
|
|
if arg.get('nullable'):
|
|
return '[{} or None]'.format(desc)
|
|
return desc
|
|
for option in declaration['options']:
|
|
option_desc = [describe_arg(arg)
|
|
for arg in option['arguments']
|
|
if not arg.get('ignore_check', False)]
|
|
if option_desc:
|
|
arg_desc.append('({})'.format(', '.join(option_desc)))
|
|
else:
|
|
arg_desc.append('no arguments')
|
|
arg_desc.sort(key=len)
|
|
arg_desc = ['"' + desc + '"' for desc in arg_desc]
|
|
arg_str = ', '.join(arg_desc)
|
|
return Template(self.WRAPPER_TEMPLATE.safe_substitute(expected_args=arg_str))
|