python 2 support

This commit is contained in:
Soumith Chintala
2016-05-12 17:49:16 -04:00
parent 6954783d9d
commit 5ee3358a92
15 changed files with 196 additions and 63 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ torch.egg-info/
*/**/__pycache__ */**/__pycache__
torch/__init__.py torch/__init__.py
torch/csrc/generic/TensorMethods.cpp torch/csrc/generic/TensorMethods.cpp
*/**/*.pyc

View File

@ -1,5 +1,7 @@
from setuptools import setup, Extension from setuptools import setup, Extension
from os.path import expanduser
from tools.cwrap import cwrap from tools.cwrap import cwrap
import platform
################################################################################ ################################################################################
# Generate __init__.py from templates # Generate __init__.py from templates
@ -49,6 +51,16 @@ for src in cwrap_src:
################################################################################ ################################################################################
# Declare the package # Declare the package
################################################################################ ################################################################################
extra_link_args = []
# TODO: remove and properly submodule TH in the repo itself
th_path = expanduser("~/torch/install/")
th_header_path = th_path + "include"
th_lib_path = th_path + "lib"
if platform.system() == 'Darwin':
extra_link_args.append('-L' + th_lib_path)
extra_link_args.append('-Wl,-rpath,' + th_lib_path)
sources = [ sources = [
"torch/csrc/Module.cpp", "torch/csrc/Module.cpp",
"torch/csrc/Tensor.cpp", "torch/csrc/Tensor.cpp",
@ -59,9 +71,13 @@ C = Extension("torch.C",
libraries=['TH'], libraries=['TH'],
sources=sources, sources=sources,
language='c++', language='c++',
include_dirs=["torch/csrc"]) include_dirs=(["torch/csrc", th_header_path]),
extra_link_args = extra_link_args,
)
setup(name="torch", version="0.1", setup(name="torch", version="0.1",
ext_modules=[C], ext_modules=[C],
packages=['torch']) packages=['torch'],
)

24
test/smoke.py Normal file
View File

@ -0,0 +1,24 @@
import torch
a = torch.FloatTensor(4, 3)
b = torch.FloatTensor(3, 4)
a.add(b)
c = a.storage()
d = a.select(0, 1)
print(c)
print(a)
print(b)
print(d)
a.fill(0)
print(a[1])
print(a.ge(long(0)))
print(a.ge(0))

0
tools/__init__.py Normal file
View File

View File

@ -151,7 +151,7 @@ RETURN_WRAPPER = {
'THStorage': Template('return THPStorage_(newObject)($expr)'), 'THStorage': Template('return THPStorage_(newObject)($expr)'),
'THLongStorage': Template('return THPLongStorage_newObject($expr)'), 'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
'bool': Template('return PyBool_FromLong($expr)'), 'bool': Template('return PyBool_FromLong($expr)'),
'long': Template('return PyLong_FromLong($expr)'), 'long': Template('return PyInt_FromLong($expr)'),
'double': Template('return PyFloat_FromDouble($expr)'), 'double': Template('return PyFloat_FromDouble($expr)'),
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'), 'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
# TODO # TODO
@ -397,16 +397,19 @@ def argfilter():
CONSTANT arguments are literals. CONSTANT arguments are literals.
Repeated arguments do not need to be specified twice. Repeated arguments do not need to be specified twice.
""" """
provided = set() # use class rather than nonlocal to maintain 2.7 compat
# see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x
# TODO: check this works
class context:
provided = set()
def is_already_provided(arg): def is_already_provided(arg):
nonlocal provided
ret = False ret = False
ret |= arg.name == 'self' ret |= arg.name == 'self'
ret |= arg.name == '_res_new' ret |= arg.name == '_res_new'
ret |= arg.type == 'CONSTANT' ret |= arg.type == 'CONSTANT'
ret |= arg.type == 'EXPRESSION' ret |= arg.type == 'EXPRESSION'
ret |= arg.name in provided ret |= arg.name in context.provided
provided.add(arg.name) context.provided.add(arg.name)
return ret return ret
return is_already_provided return is_already_provided

View File

@ -7,5 +7,5 @@ class RealStorage(RealStorageBase):
return str(self) return str(self)
def __iter__(self): def __iter__(self):
return map(lambda i: self[i], range(self.size())) return iter(map(lambda i: self[i], range(self.size())))

View File

@ -42,4 +42,4 @@ class RealTensor(RealTensorBase):
return _printing.printTensor(self) return _printing.printTensor(self)
def __iter__(self): def __iter__(self):
return map(lambda i: self.select(0, i), range(self.size(0))) return iter(map(lambda i: self.select(0, i), range(self.size(0))))

View File

@ -5,7 +5,11 @@
#include "THP.h" #include "THP.h"
#if PY_MAJOR_VERSION == 2
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;}
#else
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL #define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
#endif
static PyObject* module; static PyObject* module;
static PyObject* tensor_classes; static PyObject* tensor_classes;
@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self)
PyObject *torch_module = PyImport_ImportModule("torch"); PyObject *torch_module = PyImport_ImportModule("torch");
PyObject* module_dict = PyModule_GetDict(torch_module); PyObject* module_dict = PyModule_GetDict(torch_module);
THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage"); THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage");
THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage"); THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage");
THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage"); THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage");
THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage"); THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage");
THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage"); THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage");
THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage"); THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage");
THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage"); THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage");
THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor"); THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor");
THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor"); THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor");
THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor"); THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor");
THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor"); THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor");
THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor"); THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor");
THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor"); THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor");
THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor"); THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor");
PySet_Add(tensor_classes, THPDoubleTensorClass); PySet_Add(tensor_classes, THPDoubleTensorClass);
PySet_Add(tensor_classes, THPFloatTensorClass); PySet_Add(tensor_classes, THPFloatTensorClass);
PySet_Add(tensor_classes, THPLongTensorClass); PySet_Add(tensor_classes, THPLongTensorClass);
@ -314,6 +318,7 @@ static PyMethodDef TorchMethods[] = {
{NULL, NULL, 0, NULL} {NULL, NULL, 0, NULL}
}; };
#if PY_MAJOR_VERSION != 2
static struct PyModuleDef torchmodule = { static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT, PyModuleDef_HEAD_INIT,
"torch.C", "torch.C",
@ -321,6 +326,7 @@ static struct PyModuleDef torchmodule = {
-1, -1,
TorchMethods TorchMethods
}; };
#endif
static void errorHandler(const char *msg, void *data) static void errorHandler(const char *msg, void *data)
{ {
@ -338,10 +344,17 @@ static void updateErrorHandlers()
THSetArgErrorHandler(errorHandlerArg, NULL); THSetArgErrorHandler(errorHandlerArg, NULL);
} }
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC initC()
#else
PyMODINIT_FUNC PyInit_C() PyMODINIT_FUNC PyInit_C()
#endif
{ {
#if PY_MAJOR_VERSION == 2
ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods));
#else
ASSERT_TRUE(module = PyModule_Create(&torchmodule)); ASSERT_TRUE(module = PyModule_Create(&torchmodule));
#endif
ASSERT_TRUE(tensor_classes = PySet_New(NULL)); ASSERT_TRUE(tensor_classes = PySet_New(NULL));
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0); ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);
@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C()
updateErrorHandlers(); updateErrorHandlers();
#if PY_MAJOR_VERSION == 2
#else
return module; return module;
#endif
} }
#undef ASSERT_TRUE

View File

@ -1,6 +1,15 @@
#include <stdbool.h> #include <stdbool.h>
#include <TH/TH.h> #include <TH/TH.h>
// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/
// define PyInt_* macros for Python 3.x
#ifndef PyInt_Check
#define PyInt_Check PyLong_Check
#define PyInt_FromLong PyLong_FromLong
#define PyInt_AsLong PyLong_AsLong
#define PyInt_Type PyLong_Type
#endif
#include "Exceptions.h" #include "Exceptions.h"
#include "utils.h" #include "utils.h"

View File

@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr)
// TODO: error checking // TODO: error checking
PyObject *args = PyTuple_New(0); PyObject *args = PyTuple_New(0);
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr)); PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs); PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
Py_DECREF(args); Py_DECREF(args);
Py_DECREF(kwargs); Py_DECREF(kwargs);
@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
{ {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
static const char *keywords[] = {"cdata", NULL}; static const char *keywords[] = {"cdata", NULL};
PyObject *number_arg = NULL; void* number_arg = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg)) if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords,
THPUtils_getLong, &number_arg))
return NULL; return NULL;
THPStorage *self = (THPStorage *)type->tp_alloc(type, 0); THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
if (self != NULL) { if (self != NULL) {
if (kwargs) { if (kwargs) {
self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg); self->cdata = (THStorage*)number_arg;
THStorage_(retain)(self->cdata); THStorage_(retain)(self->cdata);
} else if (/* !kwargs && */ number_arg) { } else if (/* !kwargs && */ number_arg) {
self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg)); self->cdata = THStorage_(newWithSize)((long) number_arg);
} else { } else {
self->cdata = THStorage_(new)(); self->cdata = THStorage_(new)();
} }
@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
{ {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
/* Integer index */ /* Integer index */
if (PyLong_Check(index)) { long nindex;
long nindex = PyLong_AsLong(index); if ((PyLong_Check(index) || PyInt_Check(index))
&& THPUtils_getLong(index, &nindex) == 1 ) {
if (nindex < 0) if (nindex < 0)
nindex += THStorage_(size)(self->cdata); nindex += THStorage_(size)(self->cdata);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength); THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
return THPStorage_(newObject)(new_storage); return THPStorage_(newObject)(new_storage);
} }
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported"); char err_string[512];
snprintf (err_string, 512,
"%s %s", "Only indexing with integers and slices supported, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return NULL; return NULL;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
if (!THPUtils_(parseReal)(value, &rvalue)) if (!THPUtils_(parseReal)(value, &rvalue))
return -1; return -1;
if (PyLong_Check(index)) { long nindex;
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue); if ((PyLong_Check(index) || PyInt_Check(index))
&& THPUtils_getLong(index, &nindex) == 1) {
THStorage_(set)(self->cdata, nindex, rvalue);
return 0; return 0;
} else if (PySlice_Check(index)) { } else if (PySlice_Check(index)) {
Py_ssize_t start, stop, len; Py_ssize_t start, stop, len;
@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
THStorage_(set)(self->cdata, start, rvalue); THStorage_(set)(self->cdata, start, rvalue);
return 0; return 0;
} }
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment"); char err_string[512];
snprintf (err_string, 512, "%s %s",
"Only indexing with integers and slices supported, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return -1; return -1;
END_HANDLE_TH_ERRORS_RET(-1) END_HANDLE_TH_ERRORS_RET(-1)
} }

View File

@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg)
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
if (!PyLong_Check(number_arg)) if (!PyLong_Check(number_arg))
return NULL; return NULL;
size_t newsize = PyLong_AsSize_t(number_arg); long newsize = PyLong_AsLong(number_arg);
if (PyErr_Occurred()) if (PyErr_Occurred())
return NULL; return NULL;
THStorage_(resize)(self->cdata, newsize); THStorage_(resize)(self->cdata, newsize);

View File

@ -7,7 +7,7 @@ PyObject * THPTensor_(newObject)(THTensor *ptr)
{ {
// TODO: error checking // TODO: error checking
PyObject *args = PyTuple_New(0); PyObject *args = PyTuple_New(0);
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr)); PyObject *kwargs = Py_BuildValue("{s:K}", "cdata", (unsigned long long) ptr);
PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs); PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs);
Py_DECREF(args); Py_DECREF(args);
Py_DECREF(kwargs); Py_DECREF(kwargs);
@ -41,9 +41,11 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
THTensor *cdata_ptr = NULL; THTensor *cdata_ptr = NULL;
// If not, try to parse integers // If not, try to parse integers
#define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments" #define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments"
// TODO: check that cdata_ptr is a keyword arg
if (!storage_obj && if (!storage_obj &&
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLL$k" ERRMSG, (char**)keywords, !PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLLL" ERRMSG, (char**)keywords,
&sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr)) &sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr))
#undef ERRMSG
return NULL; return NULL;
THPTensor *self = (THPTensor *)type->tp_alloc(type, 0); THPTensor *self = (THPTensor *)type->tp_alloc(type, 0);
@ -64,21 +66,24 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \ #define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
long idx = PyLong_AsLong(IDX_VARIABLE); \ long idx; \
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \ THPUtils_getLong(IDX_VARIABLE, &idx); \
idx = (idx < 0) ? dimsize + idx : idx; \ long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
\ idx = (idx < 0) ? dimsize + idx : idx; \
THArgCheck(dimsize > 0, 1, "empty tensor"); \ \
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \ THArgCheck(dimsize > 0, 1, "empty tensor"); \
\ THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \ \
CASE_1D; \ if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
} else { \ CASE_1D; \
CASE_MD; \ } else { \
} CASE_MD; \
#define GET_PTR_1D(t, idx) \ }
#define GET_PTR_1D(t, idx) \
t->storage->data + t->storageOffset + t->stride[0] * idx; t->storage->data + t->storageOffset + t->stride[0] * idx;
static bool THPTensor_(_index)(THPTensor *self, PyObject *index, static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
THTensor * &tresult, real * &rresult) THTensor * &tresult, real * &rresult)
{ {
@ -86,7 +91,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
rresult = NULL; rresult = NULL;
try { try {
// Indexing with an integer // Indexing with an integer
if(PyLong_Check(index)) { if(PyLong_Check(index) || PyInt_Check(index)) {
THTensor *self_t = self->cdata; THTensor *self_t = self->cdata;
INDEX_LONG(0, index, self_t, INDEX_LONG(0, index, self_t,
// 1D tensor // 1D tensor
@ -97,8 +102,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
) )
// Indexing with a single element tuple // Indexing with a single element tuple
} else if (PyTuple_Check(index) && } else if (PyTuple_Check(index) &&
PyTuple_Size(index) == 1 && PyTuple_Size(index) == 1 &&
PyLong_Check(PyTuple_GET_ITEM(index, 0))) { (PyLong_Check(PyTuple_GET_ITEM(index, 0))
|| PyInt_Check(PyTuple_GET_ITEM(index, 0)))) {
PyObject *index_obj = PyTuple_GET_ITEM(index, 0); PyObject *index_obj = PyTuple_GET_ITEM(index, 0);
tresult = THTensor_(newWithTensor)(self->cdata); tresult = THTensor_(newWithTensor)(self->cdata);
INDEX_LONG(0, index_obj, tresult, INDEX_LONG(0, index_obj, tresult,
@ -121,7 +127,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
for(int dim = 0; dim < PyTuple_Size(index); dim++) { for(int dim = 0; dim < PyTuple_Size(index); dim++) {
PyObject *dimidx = PyTuple_GET_ITEM(index, dim); PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
if(PyLong_Check(dimidx)) { if(PyLong_Check(dimidx) || PyInt_Check(dimidx)) {
INDEX_LONG(t_dim, dimidx, tresult, INDEX_LONG(t_dim, dimidx, tresult,
// 1D tensor // 1D tensor
rresult = GET_PTR_1D(tresult, idx); rresult = GET_PTR_1D(tresult, idx);
@ -132,7 +138,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
THTensor_(select)(tresult, NULL, t_dim, idx) THTensor_(select)(tresult, NULL, t_dim, idx)
) )
} else if (PyTuple_Check(dimidx)) { } else if (PyTuple_Check(dimidx)) {
if (PyTuple_Size(dimidx) != 1 || !PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))) { if (PyTuple_Size(dimidx) != 1
|| !(PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))
|| PyInt_Check(PyTuple_GET_ITEM(dimidx, 0)))) {
PyErr_SetString(PyExc_RuntimeError, "Expected a single integer"); PyErr_SetString(PyExc_RuntimeError, "Expected a single integer");
return false; return false;
} }
@ -177,7 +185,11 @@ static PyObject * THPTensor_(get)(THPTensor *self, PyObject *index)
return THPTensor_(newObject)(tresult); return THPTensor_(newObject)(tresult);
if (rresult) if (rresult)
return THPUtils_(newReal)(*rresult); return THPUtils_(newReal)(*rresult);
PyErr_SetString(PyExc_RuntimeError, "Unknown exception"); char err_string[512];
snprintf (err_string, 512,
"%s %s", "Unknown exception in THPTensor_(get). Index type is: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return NULL; return NULL;
} }
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
@ -312,7 +324,7 @@ PyTypeObject THPTensorStatelessType = {
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
0, /* tp_reserved */ 0, /* tp_reserved / tp_compare */
0, /* tp_repr */ 0, /* tp_repr */
0, /* tp_as_number */ 0, /* tp_as_number */
0, /* tp_as_sequence */ 0, /* tp_as_sequence */
@ -342,6 +354,13 @@ PyTypeObject THPTensorStatelessType = {
0, /* tp_init */ 0, /* tp_init */
0, /* tp_alloc */ 0, /* tp_alloc */
0, /* tp_new */ 0, /* tp_new */
0, /* tp_free */
0, /* tp_is_gc */
0, /* tp_bases */
0, /* tp_mro */
0, /* tp_cache */
0, /* tp_subclasses */
0, /* tp_weaklist */
}; };
bool THPTensor_(init)(PyObject *module) bool THPTensor_(init)(PyObject *module)
@ -353,6 +372,7 @@ bool THPTensor_(init)(PyObject *module)
THPTensorStatelessType.tp_new = PyType_GenericNew; THPTensorStatelessType.tp_new = PyType_GenericNew;
if (PyType_Ready(&THPTensorStatelessType) < 0) if (PyType_Ready(&THPTensorStatelessType) < 0)
return false; return false;
PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType); PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType);
return true; return true;
} }

View File

@ -5,7 +5,14 @@
bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength) bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength)
{ {
Py_ssize_t start, stop, step, slicelength; Py_ssize_t start, stop, step, slicelength;
if (PySlice_GetIndicesEx(slice, len, &start, &stop, &step, &slicelength) < 0) { if (PySlice_GetIndicesEx(
// https://bugsfiles.kde.org/attachment.cgi?id=61186
#if PY_VERSION_HEX >= 0x03020000
slice,
#else
(PySliceObject *)slice,
#endif
len, &start, &stop, &step, &slicelength) < 0) {
PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice"); PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice");
return false; return false;
} }
@ -24,11 +31,16 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
{ {
if (PyLong_Check(value)) { if (PyLong_Check(value)) {
*result = (real)PyLong_AsLongLong(value); *result = (real)PyLong_AsLongLong(value);
} else if (PyInt_Check(value)) {
*result = (real)PyInt_AsLong(value);
} else if (PyFloat_Check(value)) { } else if (PyFloat_Check(value)) {
*result = (real)PyFloat_AsDouble(value); *result = (real)PyFloat_AsDouble(value);
} else { } else {
// TODO: meaningful error char err_string[512];
PyErr_SetString(PyExc_RuntimeError, "Unrecognized object"); snprintf (err_string, 512, "%s %s",
"parseReal expected long or float, but got type: ",
value->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return false; return false;
} }
return true; return true;
@ -36,7 +48,7 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
bool THPUtils_(checkReal)(PyObject *value) bool THPUtils_(checkReal)(PyObject *value)
{ {
return PyFloat_Check(value) || PyLong_Check(value); return PyFloat_Check(value) || PyLong_Check(value) || PyInt_Check(value);
} }
PyObject * THPUtils_(newReal)(real value) PyObject * THPUtils_(newReal)(real value)

View File

@ -3,3 +3,19 @@
#include "generic/utils.cpp" #include "generic/utils.cpp"
#include <TH/THGenerateAllTypes.h> #include <TH/THGenerateAllTypes.h>
int THPUtils_getLong(PyObject *index, long *result) {
if (PyLong_Check(index)) {
*result = PyLong_AsLong(index);
} else if (PyInt_Check(index)) {
*result = PyInt_AsLong(index);
} else {
char err_string[512];
snprintf (err_string, 512, "%s %s",
"getLong expected int or long, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return 0;
}
return 1;
}

View File

@ -6,5 +6,7 @@
#include "generic/utils.h" #include "generic/utils.h"
#include <TH/THGenerateAllTypes.h> #include <TH/THGenerateAllTypes.h>
int THPUtils_getLong(PyObject *index, long *result);
#endif #endif