mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
python 2 support
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ torch.egg-info/
|
||||
*/**/__pycache__
|
||||
torch/__init__.py
|
||||
torch/csrc/generic/TensorMethods.cpp
|
||||
*/**/*.pyc
|
20
setup.py
20
setup.py
@ -1,5 +1,7 @@
|
||||
from setuptools import setup, Extension
|
||||
from os.path import expanduser
|
||||
from tools.cwrap import cwrap
|
||||
import platform
|
||||
|
||||
################################################################################
|
||||
# Generate __init__.py from templates
|
||||
@ -49,6 +51,16 @@ for src in cwrap_src:
|
||||
################################################################################
|
||||
# Declare the package
|
||||
################################################################################
|
||||
extra_link_args = []
|
||||
|
||||
# TODO: remove and properly submodule TH in the repo itself
|
||||
th_path = expanduser("~/torch/install/")
|
||||
th_header_path = th_path + "include"
|
||||
th_lib_path = th_path + "lib"
|
||||
if platform.system() == 'Darwin':
|
||||
extra_link_args.append('-L' + th_lib_path)
|
||||
extra_link_args.append('-Wl,-rpath,' + th_lib_path)
|
||||
|
||||
sources = [
|
||||
"torch/csrc/Module.cpp",
|
||||
"torch/csrc/Tensor.cpp",
|
||||
@ -59,9 +71,13 @@ C = Extension("torch.C",
|
||||
libraries=['TH'],
|
||||
sources=sources,
|
||||
language='c++',
|
||||
include_dirs=["torch/csrc"])
|
||||
include_dirs=(["torch/csrc", th_header_path]),
|
||||
extra_link_args = extra_link_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
setup(name="torch", version="0.1",
|
||||
ext_modules=[C],
|
||||
packages=['torch'])
|
||||
packages=['torch'],
|
||||
)
|
||||
|
24
test/smoke.py
Normal file
24
test/smoke.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
a = torch.FloatTensor(4, 3)
|
||||
b = torch.FloatTensor(3, 4)
|
||||
|
||||
a.add(b)
|
||||
|
||||
c = a.storage()
|
||||
|
||||
d = a.select(0, 1)
|
||||
|
||||
print(c)
|
||||
print(a)
|
||||
print(b)
|
||||
print(d)
|
||||
|
||||
|
||||
a.fill(0)
|
||||
|
||||
print(a[1])
|
||||
|
||||
print(a.ge(long(0)))
|
||||
print(a.ge(0))
|
||||
|
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
@ -151,7 +151,7 @@ RETURN_WRAPPER = {
|
||||
'THStorage': Template('return THPStorage_(newObject)($expr)'),
|
||||
'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
|
||||
'bool': Template('return PyBool_FromLong($expr)'),
|
||||
'long': Template('return PyLong_FromLong($expr)'),
|
||||
'long': Template('return PyInt_FromLong($expr)'),
|
||||
'double': Template('return PyFloat_FromDouble($expr)'),
|
||||
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
|
||||
# TODO
|
||||
@ -397,16 +397,19 @@ def argfilter():
|
||||
CONSTANT arguments are literals.
|
||||
Repeated arguments do not need to be specified twice.
|
||||
"""
|
||||
provided = set()
|
||||
# use class rather than nonlocal to maintain 2.7 compat
|
||||
# see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x
|
||||
# TODO: check this works
|
||||
class context:
|
||||
provided = set()
|
||||
def is_already_provided(arg):
|
||||
nonlocal provided
|
||||
ret = False
|
||||
ret |= arg.name == 'self'
|
||||
ret |= arg.name == '_res_new'
|
||||
ret |= arg.type == 'CONSTANT'
|
||||
ret |= arg.type == 'EXPRESSION'
|
||||
ret |= arg.name in provided
|
||||
provided.add(arg.name)
|
||||
ret |= arg.name in context.provided
|
||||
context.provided.add(arg.name)
|
||||
return ret
|
||||
return is_already_provided
|
||||
|
||||
|
@ -7,5 +7,5 @@ class RealStorage(RealStorageBase):
|
||||
return str(self)
|
||||
|
||||
def __iter__(self):
|
||||
return map(lambda i: self[i], range(self.size()))
|
||||
return iter(map(lambda i: self[i], range(self.size())))
|
||||
|
||||
|
@ -42,4 +42,4 @@ class RealTensor(RealTensorBase):
|
||||
return _printing.printTensor(self)
|
||||
|
||||
def __iter__(self):
|
||||
return map(lambda i: self.select(0, i), range(self.size(0)))
|
||||
return iter(map(lambda i: self.select(0, i), range(self.size(0))))
|
||||
|
@ -5,7 +5,11 @@
|
||||
|
||||
#include "THP.h"
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;}
|
||||
#else
|
||||
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
|
||||
#endif
|
||||
|
||||
static PyObject* module;
|
||||
static PyObject* tensor_classes;
|
||||
@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self)
|
||||
PyObject *torch_module = PyImport_ImportModule("torch");
|
||||
PyObject* module_dict = PyModule_GetDict(torch_module);
|
||||
|
||||
THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage");
|
||||
THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage");
|
||||
THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage");
|
||||
THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage");
|
||||
THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage");
|
||||
THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage");
|
||||
THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage");
|
||||
THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage");
|
||||
THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage");
|
||||
THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage");
|
||||
THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage");
|
||||
THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage");
|
||||
THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage");
|
||||
THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage");
|
||||
|
||||
THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor");
|
||||
THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor");
|
||||
THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor");
|
||||
THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor");
|
||||
THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor");
|
||||
THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor");
|
||||
THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor");
|
||||
THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor");
|
||||
THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor");
|
||||
THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor");
|
||||
THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor");
|
||||
THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor");
|
||||
THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor");
|
||||
THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor");
|
||||
PySet_Add(tensor_classes, THPDoubleTensorClass);
|
||||
PySet_Add(tensor_classes, THPFloatTensorClass);
|
||||
PySet_Add(tensor_classes, THPLongTensorClass);
|
||||
@ -314,6 +318,7 @@ static PyMethodDef TorchMethods[] = {
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION != 2
|
||||
static struct PyModuleDef torchmodule = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch.C",
|
||||
@ -321,6 +326,7 @@ static struct PyModuleDef torchmodule = {
|
||||
-1,
|
||||
TorchMethods
|
||||
};
|
||||
#endif
|
||||
|
||||
static void errorHandler(const char *msg, void *data)
|
||||
{
|
||||
@ -338,10 +344,17 @@ static void updateErrorHandlers()
|
||||
THSetArgErrorHandler(errorHandlerArg, NULL);
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
PyMODINIT_FUNC initC()
|
||||
#else
|
||||
PyMODINIT_FUNC PyInit_C()
|
||||
#endif
|
||||
{
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods));
|
||||
#else
|
||||
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
|
||||
|
||||
#endif
|
||||
ASSERT_TRUE(tensor_classes = PySet_New(NULL));
|
||||
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);
|
||||
|
||||
@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C()
|
||||
|
||||
updateErrorHandlers();
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
#else
|
||||
return module;
|
||||
#endif
|
||||
}
|
||||
|
||||
#undef ASSERT_TRUE
|
||||
|
@ -1,6 +1,15 @@
|
||||
#include <stdbool.h>
|
||||
#include <TH/TH.h>
|
||||
|
||||
// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/
|
||||
// define PyInt_* macros for Python 3.x
|
||||
#ifndef PyInt_Check
|
||||
#define PyInt_Check PyLong_Check
|
||||
#define PyInt_FromLong PyLong_FromLong
|
||||
#define PyInt_AsLong PyLong_AsLong
|
||||
#define PyInt_Type PyLong_Type
|
||||
#endif
|
||||
|
||||
#include "Exceptions.h"
|
||||
#include "utils.h"
|
||||
|
||||
|
@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr)
|
||||
// TODO: error checking
|
||||
PyObject *args = PyTuple_New(0);
|
||||
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
|
||||
|
||||
PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
|
||||
Py_DECREF(args);
|
||||
Py_DECREF(kwargs);
|
||||
@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static const char *keywords[] = {"cdata", NULL};
|
||||
PyObject *number_arg = NULL;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg))
|
||||
void* number_arg = NULL;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords,
|
||||
THPUtils_getLong, &number_arg))
|
||||
return NULL;
|
||||
|
||||
THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
|
||||
if (self != NULL) {
|
||||
if (kwargs) {
|
||||
self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg);
|
||||
self->cdata = (THStorage*)number_arg;
|
||||
THStorage_(retain)(self->cdata);
|
||||
} else if (/* !kwargs && */ number_arg) {
|
||||
self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg));
|
||||
self->cdata = THStorage_(newWithSize)((long) number_arg);
|
||||
} else {
|
||||
self->cdata = THStorage_(new)();
|
||||
}
|
||||
@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
/* Integer index */
|
||||
if (PyLong_Check(index)) {
|
||||
long nindex = PyLong_AsLong(index);
|
||||
long nindex;
|
||||
if ((PyLong_Check(index) || PyInt_Check(index))
|
||||
&& THPUtils_getLong(index, &nindex) == 1 ) {
|
||||
if (nindex < 0)
|
||||
nindex += THStorage_(size)(self->cdata);
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
|
||||
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
|
||||
return THPStorage_(newObject)(new_storage);
|
||||
}
|
||||
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported");
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512,
|
||||
"%s %s", "Only indexing with integers and slices supported, but got type: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
|
||||
if (!THPUtils_(parseReal)(value, &rvalue))
|
||||
return -1;
|
||||
|
||||
if (PyLong_Check(index)) {
|
||||
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue);
|
||||
long nindex;
|
||||
if ((PyLong_Check(index) || PyInt_Check(index))
|
||||
&& THPUtils_getLong(index, &nindex) == 1) {
|
||||
THStorage_(set)(self->cdata, nindex, rvalue);
|
||||
return 0;
|
||||
} else if (PySlice_Check(index)) {
|
||||
Py_ssize_t start, stop, len;
|
||||
@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
|
||||
THStorage_(set)(self->cdata, start, rvalue);
|
||||
return 0;
|
||||
}
|
||||
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment");
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512, "%s %s",
|
||||
"Only indexing with integers and slices supported, but got type: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
return -1;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg)
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyLong_Check(number_arg))
|
||||
return NULL;
|
||||
size_t newsize = PyLong_AsSize_t(number_arg);
|
||||
long newsize = PyLong_AsLong(number_arg);
|
||||
if (PyErr_Occurred())
|
||||
return NULL;
|
||||
THStorage_(resize)(self->cdata, newsize);
|
||||
|
@ -7,7 +7,7 @@ PyObject * THPTensor_(newObject)(THTensor *ptr)
|
||||
{
|
||||
// TODO: error checking
|
||||
PyObject *args = PyTuple_New(0);
|
||||
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
|
||||
PyObject *kwargs = Py_BuildValue("{s:K}", "cdata", (unsigned long long) ptr);
|
||||
PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs);
|
||||
Py_DECREF(args);
|
||||
Py_DECREF(kwargs);
|
||||
@ -41,9 +41,11 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
|
||||
THTensor *cdata_ptr = NULL;
|
||||
// If not, try to parse integers
|
||||
#define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments"
|
||||
// TODO: check that cdata_ptr is a keyword arg
|
||||
if (!storage_obj &&
|
||||
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLL$k" ERRMSG, (char**)keywords,
|
||||
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLLL" ERRMSG, (char**)keywords,
|
||||
&sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr))
|
||||
#undef ERRMSG
|
||||
return NULL;
|
||||
|
||||
THPTensor *self = (THPTensor *)type->tp_alloc(type, 0);
|
||||
@ -64,21 +66,24 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
|
||||
long idx = PyLong_AsLong(IDX_VARIABLE); \
|
||||
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
|
||||
idx = (idx < 0) ? dimsize + idx : idx; \
|
||||
\
|
||||
THArgCheck(dimsize > 0, 1, "empty tensor"); \
|
||||
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
|
||||
\
|
||||
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
|
||||
CASE_1D; \
|
||||
} else { \
|
||||
CASE_MD; \
|
||||
}
|
||||
#define GET_PTR_1D(t, idx) \
|
||||
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
|
||||
long idx; \
|
||||
THPUtils_getLong(IDX_VARIABLE, &idx); \
|
||||
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
|
||||
idx = (idx < 0) ? dimsize + idx : idx; \
|
||||
\
|
||||
THArgCheck(dimsize > 0, 1, "empty tensor"); \
|
||||
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
|
||||
\
|
||||
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
|
||||
CASE_1D; \
|
||||
} else { \
|
||||
CASE_MD; \
|
||||
}
|
||||
|
||||
#define GET_PTR_1D(t, idx) \
|
||||
t->storage->data + t->storageOffset + t->stride[0] * idx;
|
||||
|
||||
static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
THTensor * &tresult, real * &rresult)
|
||||
{
|
||||
@ -86,7 +91,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
rresult = NULL;
|
||||
try {
|
||||
// Indexing with an integer
|
||||
if(PyLong_Check(index)) {
|
||||
if(PyLong_Check(index) || PyInt_Check(index)) {
|
||||
THTensor *self_t = self->cdata;
|
||||
INDEX_LONG(0, index, self_t,
|
||||
// 1D tensor
|
||||
@ -97,8 +102,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
)
|
||||
// Indexing with a single element tuple
|
||||
} else if (PyTuple_Check(index) &&
|
||||
PyTuple_Size(index) == 1 &&
|
||||
PyLong_Check(PyTuple_GET_ITEM(index, 0))) {
|
||||
PyTuple_Size(index) == 1 &&
|
||||
(PyLong_Check(PyTuple_GET_ITEM(index, 0))
|
||||
|| PyInt_Check(PyTuple_GET_ITEM(index, 0)))) {
|
||||
PyObject *index_obj = PyTuple_GET_ITEM(index, 0);
|
||||
tresult = THTensor_(newWithTensor)(self->cdata);
|
||||
INDEX_LONG(0, index_obj, tresult,
|
||||
@ -121,7 +127,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
|
||||
for(int dim = 0; dim < PyTuple_Size(index); dim++) {
|
||||
PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
|
||||
if(PyLong_Check(dimidx)) {
|
||||
if(PyLong_Check(dimidx) || PyInt_Check(dimidx)) {
|
||||
INDEX_LONG(t_dim, dimidx, tresult,
|
||||
// 1D tensor
|
||||
rresult = GET_PTR_1D(tresult, idx);
|
||||
@ -132,7 +138,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
THTensor_(select)(tresult, NULL, t_dim, idx)
|
||||
)
|
||||
} else if (PyTuple_Check(dimidx)) {
|
||||
if (PyTuple_Size(dimidx) != 1 || !PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))) {
|
||||
if (PyTuple_Size(dimidx) != 1
|
||||
|| !(PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))
|
||||
|| PyInt_Check(PyTuple_GET_ITEM(dimidx, 0)))) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Expected a single integer");
|
||||
return false;
|
||||
}
|
||||
@ -177,7 +185,11 @@ static PyObject * THPTensor_(get)(THPTensor *self, PyObject *index)
|
||||
return THPTensor_(newObject)(tresult);
|
||||
if (rresult)
|
||||
return THPUtils_(newReal)(*rresult);
|
||||
PyErr_SetString(PyExc_RuntimeError, "Unknown exception");
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512,
|
||||
"%s %s", "Unknown exception in THPTensor_(get). Index type is: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
return NULL;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -312,7 +324,7 @@ PyTypeObject THPTensorStatelessType = {
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_reserved */
|
||||
0, /* tp_reserved / tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
@ -342,6 +354,13 @@ PyTypeObject THPTensorStatelessType = {
|
||||
0, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0, /* tp_new */
|
||||
0, /* tp_free */
|
||||
0, /* tp_is_gc */
|
||||
0, /* tp_bases */
|
||||
0, /* tp_mro */
|
||||
0, /* tp_cache */
|
||||
0, /* tp_subclasses */
|
||||
0, /* tp_weaklist */
|
||||
};
|
||||
|
||||
bool THPTensor_(init)(PyObject *module)
|
||||
@ -353,6 +372,7 @@ bool THPTensor_(init)(PyObject *module)
|
||||
THPTensorStatelessType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&THPTensorStatelessType) < 0)
|
||||
return false;
|
||||
|
||||
PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType);
|
||||
return true;
|
||||
}
|
||||
|
@ -5,7 +5,14 @@
|
||||
bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength)
|
||||
{
|
||||
Py_ssize_t start, stop, step, slicelength;
|
||||
if (PySlice_GetIndicesEx(slice, len, &start, &stop, &step, &slicelength) < 0) {
|
||||
if (PySlice_GetIndicesEx(
|
||||
// https://bugsfiles.kde.org/attachment.cgi?id=61186
|
||||
#if PY_VERSION_HEX >= 0x03020000
|
||||
slice,
|
||||
#else
|
||||
(PySliceObject *)slice,
|
||||
#endif
|
||||
len, &start, &stop, &step, &slicelength) < 0) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice");
|
||||
return false;
|
||||
}
|
||||
@ -24,11 +31,16 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
|
||||
{
|
||||
if (PyLong_Check(value)) {
|
||||
*result = (real)PyLong_AsLongLong(value);
|
||||
} else if (PyInt_Check(value)) {
|
||||
*result = (real)PyInt_AsLong(value);
|
||||
} else if (PyFloat_Check(value)) {
|
||||
*result = (real)PyFloat_AsDouble(value);
|
||||
} else {
|
||||
// TODO: meaningful error
|
||||
PyErr_SetString(PyExc_RuntimeError, "Unrecognized object");
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512, "%s %s",
|
||||
"parseReal expected long or float, but got type: ",
|
||||
value->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -36,7 +48,7 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
|
||||
|
||||
bool THPUtils_(checkReal)(PyObject *value)
|
||||
{
|
||||
return PyFloat_Check(value) || PyLong_Check(value);
|
||||
return PyFloat_Check(value) || PyLong_Check(value) || PyInt_Check(value);
|
||||
}
|
||||
|
||||
PyObject * THPUtils_(newReal)(real value)
|
||||
|
@ -3,3 +3,19 @@
|
||||
|
||||
#include "generic/utils.cpp"
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
int THPUtils_getLong(PyObject *index, long *result) {
|
||||
if (PyLong_Check(index)) {
|
||||
*result = PyLong_AsLong(index);
|
||||
} else if (PyInt_Check(index)) {
|
||||
*result = PyInt_AsLong(index);
|
||||
} else {
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512, "%s %s",
|
||||
"getLong expected int or long, but got type: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
@ -6,5 +6,7 @@
|
||||
#include "generic/utils.h"
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
int THPUtils_getLong(PyObject *index, long *result);
|
||||
|
||||
#endif
|
||||
|
||||
|
Reference in New Issue
Block a user