python 2 support

This commit is contained in:
Soumith Chintala
2016-05-12 17:49:16 -04:00
parent 6954783d9d
commit 5ee3358a92
15 changed files with 196 additions and 63 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ torch.egg-info/
*/**/__pycache__
torch/__init__.py
torch/csrc/generic/TensorMethods.cpp
*/**/*.pyc

View File

@ -1,5 +1,7 @@
from setuptools import setup, Extension
from os.path import expanduser
from tools.cwrap import cwrap
import platform
################################################################################
# Generate __init__.py from templates
@ -49,6 +51,16 @@ for src in cwrap_src:
################################################################################
# Declare the package
################################################################################
extra_link_args = []
# TODO: remove and properly submodule TH in the repo itself
th_path = expanduser("~/torch/install/")
th_header_path = th_path + "include"
th_lib_path = th_path + "lib"
if platform.system() == 'Darwin':
extra_link_args.append('-L' + th_lib_path)
extra_link_args.append('-Wl,-rpath,' + th_lib_path)
sources = [
"torch/csrc/Module.cpp",
"torch/csrc/Tensor.cpp",
@ -59,9 +71,13 @@ C = Extension("torch.C",
libraries=['TH'],
sources=sources,
language='c++',
include_dirs=["torch/csrc"])
include_dirs=(["torch/csrc", th_header_path]),
extra_link_args = extra_link_args,
)
setup(name="torch", version="0.1",
ext_modules=[C],
packages=['torch'])
packages=['torch'],
)

24
test/smoke.py Normal file
View File

@ -0,0 +1,24 @@
import torch
a = torch.FloatTensor(4, 3)
b = torch.FloatTensor(3, 4)
a.add(b)
c = a.storage()
d = a.select(0, 1)
print(c)
print(a)
print(b)
print(d)
a.fill(0)
print(a[1])
print(a.ge(long(0)))
print(a.ge(0))

0
tools/__init__.py Normal file
View File

View File

@ -151,7 +151,7 @@ RETURN_WRAPPER = {
'THStorage': Template('return THPStorage_(newObject)($expr)'),
'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
'bool': Template('return PyBool_FromLong($expr)'),
'long': Template('return PyLong_FromLong($expr)'),
'long': Template('return PyInt_FromLong($expr)'),
'double': Template('return PyFloat_FromDouble($expr)'),
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
# TODO
@ -397,16 +397,19 @@ def argfilter():
CONSTANT arguments are literals.
Repeated arguments do not need to be specified twice.
"""
provided = set()
# use class rather than nonlocal to maintain 2.7 compat
# see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x
# TODO: check this works
class context:
provided = set()
def is_already_provided(arg):
nonlocal provided
ret = False
ret |= arg.name == 'self'
ret |= arg.name == '_res_new'
ret |= arg.type == 'CONSTANT'
ret |= arg.type == 'EXPRESSION'
ret |= arg.name in provided
provided.add(arg.name)
ret |= arg.name in context.provided
context.provided.add(arg.name)
return ret
return is_already_provided

View File

@ -7,5 +7,5 @@ class RealStorage(RealStorageBase):
return str(self)
def __iter__(self):
return map(lambda i: self[i], range(self.size()))
return iter(map(lambda i: self[i], range(self.size())))

View File

@ -42,4 +42,4 @@ class RealTensor(RealTensorBase):
return _printing.printTensor(self)
def __iter__(self):
return map(lambda i: self.select(0, i), range(self.size(0)))
return iter(map(lambda i: self.select(0, i), range(self.size(0))))

View File

@ -5,7 +5,11 @@
#include "THP.h"
#if PY_MAJOR_VERSION == 2
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;}
#else
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
#endif
static PyObject* module;
static PyObject* tensor_classes;
@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self)
PyObject *torch_module = PyImport_ImportModule("torch");
PyObject* module_dict = PyModule_GetDict(torch_module);
THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage");
THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage");
THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage");
THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage");
THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage");
THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage");
THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage");
THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage");
THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage");
THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage");
THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage");
THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage");
THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage");
THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage");
THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor");
THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor");
THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor");
THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor");
THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor");
THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor");
THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor");
THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor");
THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor");
THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor");
THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor");
THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor");
THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor");
THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor");
PySet_Add(tensor_classes, THPDoubleTensorClass);
PySet_Add(tensor_classes, THPFloatTensorClass);
PySet_Add(tensor_classes, THPLongTensorClass);
@ -314,6 +318,7 @@ static PyMethodDef TorchMethods[] = {
{NULL, NULL, 0, NULL}
};
#if PY_MAJOR_VERSION != 2
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch.C",
@ -321,6 +326,7 @@ static struct PyModuleDef torchmodule = {
-1,
TorchMethods
};
#endif
static void errorHandler(const char *msg, void *data)
{
@ -338,10 +344,17 @@ static void updateErrorHandlers()
THSetArgErrorHandler(errorHandlerArg, NULL);
}
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC initC()
#else
PyMODINIT_FUNC PyInit_C()
#endif
{
#if PY_MAJOR_VERSION == 2
ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods));
#else
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
#endif
ASSERT_TRUE(tensor_classes = PySet_New(NULL));
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);
@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C()
updateErrorHandlers();
#if PY_MAJOR_VERSION == 2
#else
return module;
#endif
}
#undef ASSERT_TRUE

View File

@ -1,6 +1,15 @@
#include <stdbool.h>
#include <TH/TH.h>
// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/
// define PyInt_* macros for Python 3.x
#ifndef PyInt_Check
#define PyInt_Check PyLong_Check
#define PyInt_FromLong PyLong_FromLong
#define PyInt_AsLong PyLong_AsLong
#define PyInt_Type PyLong_Type
#endif
#include "Exceptions.h"
#include "utils.h"

View File

@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr)
// TODO: error checking
PyObject *args = PyTuple_New(0);
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);
@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
{
HANDLE_TH_ERRORS
static const char *keywords[] = {"cdata", NULL};
PyObject *number_arg = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg))
void* number_arg = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords,
THPUtils_getLong, &number_arg))
return NULL;
THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
if (self != NULL) {
if (kwargs) {
self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg);
self->cdata = (THStorage*)number_arg;
THStorage_(retain)(self->cdata);
} else if (/* !kwargs && */ number_arg) {
self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg));
self->cdata = THStorage_(newWithSize)((long) number_arg);
} else {
self->cdata = THStorage_(new)();
}
@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
{
HANDLE_TH_ERRORS
/* Integer index */
if (PyLong_Check(index)) {
long nindex = PyLong_AsLong(index);
long nindex;
if ((PyLong_Check(index) || PyInt_Check(index))
&& THPUtils_getLong(index, &nindex) == 1 ) {
if (nindex < 0)
nindex += THStorage_(size)(self->cdata);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
return THPStorage_(newObject)(new_storage);
}
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported");
char err_string[512];
snprintf (err_string, 512,
"%s %s", "Only indexing with integers and slices supported, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return NULL;
END_HANDLE_TH_ERRORS
}
@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
if (!THPUtils_(parseReal)(value, &rvalue))
return -1;
if (PyLong_Check(index)) {
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue);
long nindex;
if ((PyLong_Check(index) || PyInt_Check(index))
&& THPUtils_getLong(index, &nindex) == 1) {
THStorage_(set)(self->cdata, nindex, rvalue);
return 0;
} else if (PySlice_Check(index)) {
Py_ssize_t start, stop, len;
@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
THStorage_(set)(self->cdata, start, rvalue);
return 0;
}
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment");
char err_string[512];
snprintf (err_string, 512, "%s %s",
"Only indexing with integers and slices supported, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return -1;
END_HANDLE_TH_ERRORS_RET(-1)
}

View File

@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg)
HANDLE_TH_ERRORS
if (!PyLong_Check(number_arg))
return NULL;
size_t newsize = PyLong_AsSize_t(number_arg);
long newsize = PyLong_AsLong(number_arg);
if (PyErr_Occurred())
return NULL;
THStorage_(resize)(self->cdata, newsize);

View File

@ -7,7 +7,7 @@ PyObject * THPTensor_(newObject)(THTensor *ptr)
{
// TODO: error checking
PyObject *args = PyTuple_New(0);
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
PyObject *kwargs = Py_BuildValue("{s:K}", "cdata", (unsigned long long) ptr);
PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);
@ -41,9 +41,11 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
THTensor *cdata_ptr = NULL;
// If not, try to parse integers
#define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments"
// TODO: check that cdata_ptr is a keyword arg
if (!storage_obj &&
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLL$k" ERRMSG, (char**)keywords,
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLLL" ERRMSG, (char**)keywords,
&sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr))
#undef ERRMSG
return NULL;
THPTensor *self = (THPTensor *)type->tp_alloc(type, 0);
@ -64,21 +66,24 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
END_HANDLE_TH_ERRORS
}
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
long idx = PyLong_AsLong(IDX_VARIABLE); \
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
idx = (idx < 0) ? dimsize + idx : idx; \
\
THArgCheck(dimsize > 0, 1, "empty tensor"); \
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
\
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
CASE_1D; \
} else { \
CASE_MD; \
}
#define GET_PTR_1D(t, idx) \
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
long idx; \
THPUtils_getLong(IDX_VARIABLE, &idx); \
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
idx = (idx < 0) ? dimsize + idx : idx; \
\
THArgCheck(dimsize > 0, 1, "empty tensor"); \
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
\
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
CASE_1D; \
} else { \
CASE_MD; \
}
#define GET_PTR_1D(t, idx) \
t->storage->data + t->storageOffset + t->stride[0] * idx;
static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
THTensor * &tresult, real * &rresult)
{
@ -86,7 +91,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
rresult = NULL;
try {
// Indexing with an integer
if(PyLong_Check(index)) {
if(PyLong_Check(index) || PyInt_Check(index)) {
THTensor *self_t = self->cdata;
INDEX_LONG(0, index, self_t,
// 1D tensor
@ -97,8 +102,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
)
// Indexing with a single element tuple
} else if (PyTuple_Check(index) &&
PyTuple_Size(index) == 1 &&
PyLong_Check(PyTuple_GET_ITEM(index, 0))) {
PyTuple_Size(index) == 1 &&
(PyLong_Check(PyTuple_GET_ITEM(index, 0))
|| PyInt_Check(PyTuple_GET_ITEM(index, 0)))) {
PyObject *index_obj = PyTuple_GET_ITEM(index, 0);
tresult = THTensor_(newWithTensor)(self->cdata);
INDEX_LONG(0, index_obj, tresult,
@ -121,7 +127,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
for(int dim = 0; dim < PyTuple_Size(index); dim++) {
PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
if(PyLong_Check(dimidx)) {
if(PyLong_Check(dimidx) || PyInt_Check(dimidx)) {
INDEX_LONG(t_dim, dimidx, tresult,
// 1D tensor
rresult = GET_PTR_1D(tresult, idx);
@ -132,7 +138,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
THTensor_(select)(tresult, NULL, t_dim, idx)
)
} else if (PyTuple_Check(dimidx)) {
if (PyTuple_Size(dimidx) != 1 || !PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))) {
if (PyTuple_Size(dimidx) != 1
|| !(PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))
|| PyInt_Check(PyTuple_GET_ITEM(dimidx, 0)))) {
PyErr_SetString(PyExc_RuntimeError, "Expected a single integer");
return false;
}
@ -177,7 +185,11 @@ static PyObject * THPTensor_(get)(THPTensor *self, PyObject *index)
return THPTensor_(newObject)(tresult);
if (rresult)
return THPUtils_(newReal)(*rresult);
PyErr_SetString(PyExc_RuntimeError, "Unknown exception");
char err_string[512];
snprintf (err_string, 512,
"%s %s", "Unknown exception in THPTensor_(get). Index type is: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return NULL;
}
END_HANDLE_TH_ERRORS
@ -312,7 +324,7 @@ PyTypeObject THPTensorStatelessType = {
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_reserved / tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
@ -342,6 +354,13 @@ PyTypeObject THPTensorStatelessType = {
0, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
0, /* tp_free */
0, /* tp_is_gc */
0, /* tp_bases */
0, /* tp_mro */
0, /* tp_cache */
0, /* tp_subclasses */
0, /* tp_weaklist */
};
bool THPTensor_(init)(PyObject *module)
@ -353,6 +372,7 @@ bool THPTensor_(init)(PyObject *module)
THPTensorStatelessType.tp_new = PyType_GenericNew;
if (PyType_Ready(&THPTensorStatelessType) < 0)
return false;
PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType);
return true;
}

View File

@ -5,7 +5,14 @@
bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength)
{
Py_ssize_t start, stop, step, slicelength;
if (PySlice_GetIndicesEx(slice, len, &start, &stop, &step, &slicelength) < 0) {
if (PySlice_GetIndicesEx(
// https://bugsfiles.kde.org/attachment.cgi?id=61186
#if PY_VERSION_HEX >= 0x03020000
slice,
#else
(PySliceObject *)slice,
#endif
len, &start, &stop, &step, &slicelength) < 0) {
PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice");
return false;
}
@ -24,11 +31,16 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
{
if (PyLong_Check(value)) {
*result = (real)PyLong_AsLongLong(value);
} else if (PyInt_Check(value)) {
*result = (real)PyInt_AsLong(value);
} else if (PyFloat_Check(value)) {
*result = (real)PyFloat_AsDouble(value);
} else {
// TODO: meaningful error
PyErr_SetString(PyExc_RuntimeError, "Unrecognized object");
char err_string[512];
snprintf (err_string, 512, "%s %s",
"parseReal expected long or float, but got type: ",
value->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return false;
}
return true;
@ -36,7 +48,7 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
bool THPUtils_(checkReal)(PyObject *value)
{
return PyFloat_Check(value) || PyLong_Check(value);
return PyFloat_Check(value) || PyLong_Check(value) || PyInt_Check(value);
}
PyObject * THPUtils_(newReal)(real value)

View File

@ -3,3 +3,19 @@
#include "generic/utils.cpp"
#include <TH/THGenerateAllTypes.h>
int THPUtils_getLong(PyObject *index, long *result) {
if (PyLong_Check(index)) {
*result = PyLong_AsLong(index);
} else if (PyInt_Check(index)) {
*result = PyInt_AsLong(index);
} else {
char err_string[512];
snprintf (err_string, 512, "%s %s",
"getLong expected int or long, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return 0;
}
return 1;
}

View File

@ -6,5 +6,7 @@
#include "generic/utils.h"
#include <TH/THGenerateAllTypes.h>
int THPUtils_getLong(PyObject *index, long *result);
#endif