From 5ee3358a92d423c67dfeed37bc1398daa15c866c Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Thu, 12 May 2016 17:49:16 -0400 Subject: [PATCH] python 2 support --- .gitignore | 1 + setup.py | 20 +++++++- test/smoke.py | 24 ++++++++++ tools/__init__.py | 0 tools/cwrap.py | 13 ++++-- torch/Storage.py | 2 +- torch/Tensor.py | 2 +- torch/csrc/Module.cpp | 48 +++++++++++++------ torch/csrc/THP.h | 9 ++++ torch/csrc/generic/Storage.cpp | 34 +++++++++----- torch/csrc/generic/StorageMethods.cpp | 2 +- torch/csrc/generic/Tensor.cpp | 66 +++++++++++++++++---------- torch/csrc/generic/utils.cpp | 20 ++++++-- torch/csrc/utils.cpp | 16 +++++++ torch/csrc/utils.h | 2 + 15 files changed, 196 insertions(+), 63 deletions(-) create mode 100644 test/smoke.py create mode 100644 tools/__init__.py diff --git a/.gitignore b/.gitignore index affe5affe058..7246d25523cd 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ torch.egg-info/ */**/__pycache__ torch/__init__.py torch/csrc/generic/TensorMethods.cpp +*/**/*.pyc \ No newline at end of file diff --git a/setup.py b/setup.py index 4ecf365f47f4..dc8a298b6a22 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ from setuptools import setup, Extension +from os.path import expanduser from tools.cwrap import cwrap +import platform ################################################################################ # Generate __init__.py from templates @@ -49,6 +51,16 @@ for src in cwrap_src: ################################################################################ # Declare the package ################################################################################ +extra_link_args = [] + +# TODO: remove and properly submodule TH in the repo itself +th_path = expanduser("~/torch/install/") +th_header_path = th_path + "include" +th_lib_path = th_path + "lib" +if platform.system() == 'Darwin': + extra_link_args.append('-L' + th_lib_path) + extra_link_args.append('-Wl,-rpath,' + th_lib_path) + sources = [ "torch/csrc/Module.cpp", "torch/csrc/Tensor.cpp", @@ -59,9 +71,13 @@ C = Extension("torch.C", libraries=['TH'], sources=sources, language='c++', - include_dirs=["torch/csrc"]) + include_dirs=(["torch/csrc", th_header_path]), + extra_link_args = extra_link_args, +) + setup(name="torch", version="0.1", ext_modules=[C], - packages=['torch']) + packages=['torch'], +) diff --git a/test/smoke.py b/test/smoke.py new file mode 100644 index 000000000000..ca3005209a79 --- /dev/null +++ b/test/smoke.py @@ -0,0 +1,24 @@ +import torch + +a = torch.FloatTensor(4, 3) +b = torch.FloatTensor(3, 4) + +a.add(b) + +c = a.storage() + +d = a.select(0, 1) + +print(c) +print(a) +print(b) +print(d) + + +a.fill(0) + +print(a[1]) + +print(a.ge(long(0))) +print(a.ge(0)) + diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tools/cwrap.py b/tools/cwrap.py index 49b1e7d9aae1..2ad10a821769 100644 --- a/tools/cwrap.py +++ b/tools/cwrap.py @@ -151,7 +151,7 @@ RETURN_WRAPPER = { 'THStorage': Template('return THPStorage_(newObject)($expr)'), 'THLongStorage': Template('return THPLongStorage_newObject($expr)'), 'bool': Template('return PyBool_FromLong($expr)'), - 'long': Template('return PyLong_FromLong($expr)'), + 'long': Template('return PyInt_FromLong($expr)'), 'double': Template('return PyFloat_FromDouble($expr)'), 'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'), # TODO @@ -397,16 +397,19 @@ def argfilter(): CONSTANT arguments are literals. Repeated arguments do not need to be specified twice. """ - provided = set() + # use class rather than nonlocal to maintain 2.7 compat + # see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x + # TODO: check this works + class context: + provided = set() def is_already_provided(arg): - nonlocal provided ret = False ret |= arg.name == 'self' ret |= arg.name == '_res_new' ret |= arg.type == 'CONSTANT' ret |= arg.type == 'EXPRESSION' - ret |= arg.name in provided - provided.add(arg.name) + ret |= arg.name in context.provided + context.provided.add(arg.name) return ret return is_already_provided diff --git a/torch/Storage.py b/torch/Storage.py index c3d028559d91..6a868d349872 100644 --- a/torch/Storage.py +++ b/torch/Storage.py @@ -7,5 +7,5 @@ class RealStorage(RealStorageBase): return str(self) def __iter__(self): - return map(lambda i: self[i], range(self.size())) + return iter(map(lambda i: self[i], range(self.size()))) diff --git a/torch/Tensor.py b/torch/Tensor.py index 14364ee8745a..7d4dfca1fc97 100644 --- a/torch/Tensor.py +++ b/torch/Tensor.py @@ -42,4 +42,4 @@ class RealTensor(RealTensorBase): return _printing.printTensor(self) def __iter__(self): - return map(lambda i: self.select(0, i), range(self.size(0))) + return iter(map(lambda i: self.select(0, i), range(self.size(0)))) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index a4e583f4133c..461a144b48b1 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -5,7 +5,11 @@ #include "THP.h" +#if PY_MAJOR_VERSION == 2 +#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;} +#else #define ASSERT_TRUE(cmd) if (!(cmd)) return NULL +#endif static PyObject* module; static PyObject* tensor_classes; @@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self) PyObject *torch_module = PyImport_ImportModule("torch"); PyObject* module_dict = PyModule_GetDict(torch_module); - THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage"); - THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage"); - THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage"); - THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage"); - THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage"); - THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage"); - THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage"); + THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage"); + THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage"); + THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage"); + THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage"); + THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage"); + THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage"); + THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage"); - THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor"); - THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor"); - THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor"); - THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor"); - THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor"); - THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor"); - THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor"); + THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor"); + THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor"); + THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor"); + THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor"); + THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor"); + THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor"); + THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor"); PySet_Add(tensor_classes, THPDoubleTensorClass); PySet_Add(tensor_classes, THPFloatTensorClass); PySet_Add(tensor_classes, THPLongTensorClass); @@ -314,6 +318,7 @@ static PyMethodDef TorchMethods[] = { {NULL, NULL, 0, NULL} }; +#if PY_MAJOR_VERSION != 2 static struct PyModuleDef torchmodule = { PyModuleDef_HEAD_INIT, "torch.C", @@ -321,6 +326,7 @@ static struct PyModuleDef torchmodule = { -1, TorchMethods }; +#endif static void errorHandler(const char *msg, void *data) { @@ -338,10 +344,17 @@ static void updateErrorHandlers() THSetArgErrorHandler(errorHandlerArg, NULL); } +#if PY_MAJOR_VERSION == 2 +PyMODINIT_FUNC initC() +#else PyMODINIT_FUNC PyInit_C() +#endif { +#if PY_MAJOR_VERSION == 2 + ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods)); +#else ASSERT_TRUE(module = PyModule_Create(&torchmodule)); - +#endif ASSERT_TRUE(tensor_classes = PySet_New(NULL)); ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0); @@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C() updateErrorHandlers(); +#if PY_MAJOR_VERSION == 2 +#else return module; +#endif } + +#undef ASSERT_TRUE diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h index 02c914108ae8..9fac0f267983 100644 --- a/torch/csrc/THP.h +++ b/torch/csrc/THP.h @@ -1,6 +1,15 @@ #include #include +// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/ +// define PyInt_* macros for Python 3.x +#ifndef PyInt_Check +#define PyInt_Check PyLong_Check +#define PyInt_FromLong PyLong_FromLong +#define PyInt_AsLong PyLong_AsLong +#define PyInt_Type PyLong_Type +#endif + #include "Exceptions.h" #include "utils.h" diff --git a/torch/csrc/generic/Storage.cpp b/torch/csrc/generic/Storage.cpp index 77967adc6d43..ef1f77be3a56 100644 --- a/torch/csrc/generic/Storage.cpp +++ b/torch/csrc/generic/Storage.cpp @@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr) // TODO: error checking PyObject *args = PyTuple_New(0); PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr)); + PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); @@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec { HANDLE_TH_ERRORS static const char *keywords[] = {"cdata", NULL}; - PyObject *number_arg = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg)) + void* number_arg = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords, + THPUtils_getLong, &number_arg)) return NULL; - THPStorage *self = (THPStorage *)type->tp_alloc(type, 0); if (self != NULL) { if (kwargs) { - self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg); + self->cdata = (THStorage*)number_arg; THStorage_(retain)(self->cdata); } else if (/* !kwargs && */ number_arg) { - self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg)); + self->cdata = THStorage_(newWithSize)((long) number_arg); } else { self->cdata = THStorage_(new)(); } @@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index) { HANDLE_TH_ERRORS /* Integer index */ - if (PyLong_Check(index)) { - long nindex = PyLong_AsLong(index); + long nindex; + if ((PyLong_Check(index) || PyInt_Check(index)) + && THPUtils_getLong(index, &nindex) == 1 ) { if (nindex < 0) nindex += THStorage_(size)(self->cdata); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) @@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index) THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength); return THPStorage_(newObject)(new_storage); } - PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported"); + char err_string[512]; + snprintf (err_string, 512, + "%s %s", "Only indexing with integers and slices supported, but got type: ", + index->ob_type->tp_name); + PyErr_SetString(PyExc_RuntimeError, err_string); return NULL; END_HANDLE_TH_ERRORS } @@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value) if (!THPUtils_(parseReal)(value, &rvalue)) return -1; - if (PyLong_Check(index)) { - THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue); + long nindex; + if ((PyLong_Check(index) || PyInt_Check(index)) + && THPUtils_getLong(index, &nindex) == 1) { + THStorage_(set)(self->cdata, nindex, rvalue); return 0; } else if (PySlice_Check(index)) { Py_ssize_t start, stop, len; @@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value) THStorage_(set)(self->cdata, start, rvalue); return 0; } - PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment"); + char err_string[512]; + snprintf (err_string, 512, "%s %s", + "Only indexing with integers and slices supported, but got type: ", + index->ob_type->tp_name); + PyErr_SetString(PyExc_RuntimeError, err_string); return -1; END_HANDLE_TH_ERRORS_RET(-1) } diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index 9a6733e53dec..d356ea306ff4 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg) HANDLE_TH_ERRORS if (!PyLong_Check(number_arg)) return NULL; - size_t newsize = PyLong_AsSize_t(number_arg); + long newsize = PyLong_AsLong(number_arg); if (PyErr_Occurred()) return NULL; THStorage_(resize)(self->cdata, newsize); diff --git a/torch/csrc/generic/Tensor.cpp b/torch/csrc/generic/Tensor.cpp index 3fa87b6b0cd9..4b6a8cf8273a 100644 --- a/torch/csrc/generic/Tensor.cpp +++ b/torch/csrc/generic/Tensor.cpp @@ -7,7 +7,7 @@ PyObject * THPTensor_(newObject)(THTensor *ptr) { // TODO: error checking PyObject *args = PyTuple_New(0); - PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr)); + PyObject *kwargs = Py_BuildValue("{s:K}", "cdata", (unsigned long long) ptr); PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); @@ -41,9 +41,11 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject THTensor *cdata_ptr = NULL; // If not, try to parse integers #define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments" + // TODO: check that cdata_ptr is a keyword arg if (!storage_obj && - !PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLL$k" ERRMSG, (char**)keywords, + !PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLLL" ERRMSG, (char**)keywords, &sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr)) +#undef ERRMSG return NULL; THPTensor *self = (THPTensor *)type->tp_alloc(type, 0); @@ -64,21 +66,24 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject END_HANDLE_TH_ERRORS } -#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \ -long idx = PyLong_AsLong(IDX_VARIABLE); \ -long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \ -idx = (idx < 0) ? dimsize + idx : idx; \ - \ -THArgCheck(dimsize > 0, 1, "empty tensor"); \ -THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \ - \ -if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \ - CASE_1D; \ -} else { \ - CASE_MD; \ -} -#define GET_PTR_1D(t, idx) \ +#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \ + long idx; \ + THPUtils_getLong(IDX_VARIABLE, &idx); \ + long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \ + idx = (idx < 0) ? dimsize + idx : idx; \ + \ + THArgCheck(dimsize > 0, 1, "empty tensor"); \ + THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \ + \ + if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \ + CASE_1D; \ + } else { \ + CASE_MD; \ + } + +#define GET_PTR_1D(t, idx) \ t->storage->data + t->storageOffset + t->stride[0] * idx; + static bool THPTensor_(_index)(THPTensor *self, PyObject *index, THTensor * &tresult, real * &rresult) { @@ -86,7 +91,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index, rresult = NULL; try { // Indexing with an integer - if(PyLong_Check(index)) { + if(PyLong_Check(index) || PyInt_Check(index)) { THTensor *self_t = self->cdata; INDEX_LONG(0, index, self_t, // 1D tensor @@ -97,8 +102,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index, ) // Indexing with a single element tuple } else if (PyTuple_Check(index) && - PyTuple_Size(index) == 1 && - PyLong_Check(PyTuple_GET_ITEM(index, 0))) { + PyTuple_Size(index) == 1 && + (PyLong_Check(PyTuple_GET_ITEM(index, 0)) + || PyInt_Check(PyTuple_GET_ITEM(index, 0)))) { PyObject *index_obj = PyTuple_GET_ITEM(index, 0); tresult = THTensor_(newWithTensor)(self->cdata); INDEX_LONG(0, index_obj, tresult, @@ -121,7 +127,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index, for(int dim = 0; dim < PyTuple_Size(index); dim++) { PyObject *dimidx = PyTuple_GET_ITEM(index, dim); - if(PyLong_Check(dimidx)) { + if(PyLong_Check(dimidx) || PyInt_Check(dimidx)) { INDEX_LONG(t_dim, dimidx, tresult, // 1D tensor rresult = GET_PTR_1D(tresult, idx); @@ -132,7 +138,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index, THTensor_(select)(tresult, NULL, t_dim, idx) ) } else if (PyTuple_Check(dimidx)) { - if (PyTuple_Size(dimidx) != 1 || !PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))) { + if (PyTuple_Size(dimidx) != 1 + || !(PyLong_Check(PyTuple_GET_ITEM(dimidx, 0)) + || PyInt_Check(PyTuple_GET_ITEM(dimidx, 0)))) { PyErr_SetString(PyExc_RuntimeError, "Expected a single integer"); return false; } @@ -177,7 +185,11 @@ static PyObject * THPTensor_(get)(THPTensor *self, PyObject *index) return THPTensor_(newObject)(tresult); if (rresult) return THPUtils_(newReal)(*rresult); - PyErr_SetString(PyExc_RuntimeError, "Unknown exception"); + char err_string[512]; + snprintf (err_string, 512, + "%s %s", "Unknown exception in THPTensor_(get). Index type is: ", + index->ob_type->tp_name); + PyErr_SetString(PyExc_RuntimeError, err_string); return NULL; } END_HANDLE_TH_ERRORS @@ -312,7 +324,7 @@ PyTypeObject THPTensorStatelessType = { 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ - 0, /* tp_reserved */ + 0, /* tp_reserved / tp_compare */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ @@ -342,6 +354,13 @@ PyTypeObject THPTensorStatelessType = { 0, /* tp_init */ 0, /* tp_alloc */ 0, /* tp_new */ + 0, /* tp_free */ + 0, /* tp_is_gc */ + 0, /* tp_bases */ + 0, /* tp_mro */ + 0, /* tp_cache */ + 0, /* tp_subclasses */ + 0, /* tp_weaklist */ }; bool THPTensor_(init)(PyObject *module) @@ -353,6 +372,7 @@ bool THPTensor_(init)(PyObject *module) THPTensorStatelessType.tp_new = PyType_GenericNew; if (PyType_Ready(&THPTensorStatelessType) < 0) return false; + PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType); return true; } diff --git a/torch/csrc/generic/utils.cpp b/torch/csrc/generic/utils.cpp index 04923587d94d..39d731b8c80c 100644 --- a/torch/csrc/generic/utils.cpp +++ b/torch/csrc/generic/utils.cpp @@ -5,7 +5,14 @@ bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength) { Py_ssize_t start, stop, step, slicelength; - if (PySlice_GetIndicesEx(slice, len, &start, &stop, &step, &slicelength) < 0) { + if (PySlice_GetIndicesEx( +// https://bugsfiles.kde.org/attachment.cgi?id=61186 +#if PY_VERSION_HEX >= 0x03020000 + slice, +#else + (PySliceObject *)slice, +#endif + len, &start, &stop, &step, &slicelength) < 0) { PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice"); return false; } @@ -24,11 +31,16 @@ bool THPUtils_(parseReal)(PyObject *value, real *result) { if (PyLong_Check(value)) { *result = (real)PyLong_AsLongLong(value); + } else if (PyInt_Check(value)) { + *result = (real)PyInt_AsLong(value); } else if (PyFloat_Check(value)) { *result = (real)PyFloat_AsDouble(value); } else { - // TODO: meaningful error - PyErr_SetString(PyExc_RuntimeError, "Unrecognized object"); + char err_string[512]; + snprintf (err_string, 512, "%s %s", + "parseReal expected long or float, but got type: ", + value->ob_type->tp_name); + PyErr_SetString(PyExc_RuntimeError, err_string); return false; } return true; @@ -36,7 +48,7 @@ bool THPUtils_(parseReal)(PyObject *value, real *result) bool THPUtils_(checkReal)(PyObject *value) { - return PyFloat_Check(value) || PyLong_Check(value); + return PyFloat_Check(value) || PyLong_Check(value) || PyInt_Check(value); } PyObject * THPUtils_(newReal)(real value) diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index 0f57bd4a6ef4..625bfb3e57d2 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -3,3 +3,19 @@ #include "generic/utils.cpp" #include + + int THPUtils_getLong(PyObject *index, long *result) { + if (PyLong_Check(index)) { + *result = PyLong_AsLong(index); + } else if (PyInt_Check(index)) { + *result = PyInt_AsLong(index); + } else { + char err_string[512]; + snprintf (err_string, 512, "%s %s", + "getLong expected int or long, but got type: ", + index->ob_type->tp_name); + PyErr_SetString(PyExc_RuntimeError, err_string); + return 0; + } + return 1; +} diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index cebed4c75a4f..85d8c4eb7e5e 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -6,5 +6,7 @@ #include "generic/utils.h" #include +int THPUtils_getLong(PyObject *index, long *result); + #endif