mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
python 2 support
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ torch.egg-info/
|
|||||||
*/**/__pycache__
|
*/**/__pycache__
|
||||||
torch/__init__.py
|
torch/__init__.py
|
||||||
torch/csrc/generic/TensorMethods.cpp
|
torch/csrc/generic/TensorMethods.cpp
|
||||||
|
*/**/*.pyc
|
20
setup.py
20
setup.py
@ -1,5 +1,7 @@
|
|||||||
from setuptools import setup, Extension
|
from setuptools import setup, Extension
|
||||||
|
from os.path import expanduser
|
||||||
from tools.cwrap import cwrap
|
from tools.cwrap import cwrap
|
||||||
|
import platform
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Generate __init__.py from templates
|
# Generate __init__.py from templates
|
||||||
@ -49,6 +51,16 @@ for src in cwrap_src:
|
|||||||
################################################################################
|
################################################################################
|
||||||
# Declare the package
|
# Declare the package
|
||||||
################################################################################
|
################################################################################
|
||||||
|
extra_link_args = []
|
||||||
|
|
||||||
|
# TODO: remove and properly submodule TH in the repo itself
|
||||||
|
th_path = expanduser("~/torch/install/")
|
||||||
|
th_header_path = th_path + "include"
|
||||||
|
th_lib_path = th_path + "lib"
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
extra_link_args.append('-L' + th_lib_path)
|
||||||
|
extra_link_args.append('-Wl,-rpath,' + th_lib_path)
|
||||||
|
|
||||||
sources = [
|
sources = [
|
||||||
"torch/csrc/Module.cpp",
|
"torch/csrc/Module.cpp",
|
||||||
"torch/csrc/Tensor.cpp",
|
"torch/csrc/Tensor.cpp",
|
||||||
@ -59,9 +71,13 @@ C = Extension("torch.C",
|
|||||||
libraries=['TH'],
|
libraries=['TH'],
|
||||||
sources=sources,
|
sources=sources,
|
||||||
language='c++',
|
language='c++',
|
||||||
include_dirs=["torch/csrc"])
|
include_dirs=(["torch/csrc", th_header_path]),
|
||||||
|
extra_link_args = extra_link_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
setup(name="torch", version="0.1",
|
setup(name="torch", version="0.1",
|
||||||
ext_modules=[C],
|
ext_modules=[C],
|
||||||
packages=['torch'])
|
packages=['torch'],
|
||||||
|
)
|
||||||
|
24
test/smoke.py
Normal file
24
test/smoke.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
a = torch.FloatTensor(4, 3)
|
||||||
|
b = torch.FloatTensor(3, 4)
|
||||||
|
|
||||||
|
a.add(b)
|
||||||
|
|
||||||
|
c = a.storage()
|
||||||
|
|
||||||
|
d = a.select(0, 1)
|
||||||
|
|
||||||
|
print(c)
|
||||||
|
print(a)
|
||||||
|
print(b)
|
||||||
|
print(d)
|
||||||
|
|
||||||
|
|
||||||
|
a.fill(0)
|
||||||
|
|
||||||
|
print(a[1])
|
||||||
|
|
||||||
|
print(a.ge(long(0)))
|
||||||
|
print(a.ge(0))
|
||||||
|
|
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
@ -151,7 +151,7 @@ RETURN_WRAPPER = {
|
|||||||
'THStorage': Template('return THPStorage_(newObject)($expr)'),
|
'THStorage': Template('return THPStorage_(newObject)($expr)'),
|
||||||
'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
|
'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
|
||||||
'bool': Template('return PyBool_FromLong($expr)'),
|
'bool': Template('return PyBool_FromLong($expr)'),
|
||||||
'long': Template('return PyLong_FromLong($expr)'),
|
'long': Template('return PyInt_FromLong($expr)'),
|
||||||
'double': Template('return PyFloat_FromDouble($expr)'),
|
'double': Template('return PyFloat_FromDouble($expr)'),
|
||||||
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
|
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
|
||||||
# TODO
|
# TODO
|
||||||
@ -397,16 +397,19 @@ def argfilter():
|
|||||||
CONSTANT arguments are literals.
|
CONSTANT arguments are literals.
|
||||||
Repeated arguments do not need to be specified twice.
|
Repeated arguments do not need to be specified twice.
|
||||||
"""
|
"""
|
||||||
provided = set()
|
# use class rather than nonlocal to maintain 2.7 compat
|
||||||
|
# see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x
|
||||||
|
# TODO: check this works
|
||||||
|
class context:
|
||||||
|
provided = set()
|
||||||
def is_already_provided(arg):
|
def is_already_provided(arg):
|
||||||
nonlocal provided
|
|
||||||
ret = False
|
ret = False
|
||||||
ret |= arg.name == 'self'
|
ret |= arg.name == 'self'
|
||||||
ret |= arg.name == '_res_new'
|
ret |= arg.name == '_res_new'
|
||||||
ret |= arg.type == 'CONSTANT'
|
ret |= arg.type == 'CONSTANT'
|
||||||
ret |= arg.type == 'EXPRESSION'
|
ret |= arg.type == 'EXPRESSION'
|
||||||
ret |= arg.name in provided
|
ret |= arg.name in context.provided
|
||||||
provided.add(arg.name)
|
context.provided.add(arg.name)
|
||||||
return ret
|
return ret
|
||||||
return is_already_provided
|
return is_already_provided
|
||||||
|
|
||||||
|
@ -7,5 +7,5 @@ class RealStorage(RealStorageBase):
|
|||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return map(lambda i: self[i], range(self.size()))
|
return iter(map(lambda i: self[i], range(self.size())))
|
||||||
|
|
||||||
|
@ -42,4 +42,4 @@ class RealTensor(RealTensorBase):
|
|||||||
return _printing.printTensor(self)
|
return _printing.printTensor(self)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return map(lambda i: self.select(0, i), range(self.size(0)))
|
return iter(map(lambda i: self.select(0, i), range(self.size(0))))
|
||||||
|
@ -5,7 +5,11 @@
|
|||||||
|
|
||||||
#include "THP.h"
|
#include "THP.h"
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION == 2
|
||||||
|
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;}
|
||||||
|
#else
|
||||||
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
|
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
|
||||||
|
#endif
|
||||||
|
|
||||||
static PyObject* module;
|
static PyObject* module;
|
||||||
static PyObject* tensor_classes;
|
static PyObject* tensor_classes;
|
||||||
@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self)
|
|||||||
PyObject *torch_module = PyImport_ImportModule("torch");
|
PyObject *torch_module = PyImport_ImportModule("torch");
|
||||||
PyObject* module_dict = PyModule_GetDict(torch_module);
|
PyObject* module_dict = PyModule_GetDict(torch_module);
|
||||||
|
|
||||||
THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage");
|
THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage");
|
||||||
THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage");
|
THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage");
|
||||||
THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage");
|
THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage");
|
||||||
THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage");
|
THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage");
|
||||||
THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage");
|
THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage");
|
||||||
THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage");
|
THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage");
|
||||||
THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage");
|
THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage");
|
||||||
|
|
||||||
THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor");
|
THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor");
|
||||||
THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor");
|
THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor");
|
||||||
THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor");
|
THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor");
|
||||||
THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor");
|
THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor");
|
||||||
THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor");
|
THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor");
|
||||||
THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor");
|
THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor");
|
||||||
THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor");
|
THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor");
|
||||||
PySet_Add(tensor_classes, THPDoubleTensorClass);
|
PySet_Add(tensor_classes, THPDoubleTensorClass);
|
||||||
PySet_Add(tensor_classes, THPFloatTensorClass);
|
PySet_Add(tensor_classes, THPFloatTensorClass);
|
||||||
PySet_Add(tensor_classes, THPLongTensorClass);
|
PySet_Add(tensor_classes, THPLongTensorClass);
|
||||||
@ -314,6 +318,7 @@ static PyMethodDef TorchMethods[] = {
|
|||||||
{NULL, NULL, 0, NULL}
|
{NULL, NULL, 0, NULL}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION != 2
|
||||||
static struct PyModuleDef torchmodule = {
|
static struct PyModuleDef torchmodule = {
|
||||||
PyModuleDef_HEAD_INIT,
|
PyModuleDef_HEAD_INIT,
|
||||||
"torch.C",
|
"torch.C",
|
||||||
@ -321,6 +326,7 @@ static struct PyModuleDef torchmodule = {
|
|||||||
-1,
|
-1,
|
||||||
TorchMethods
|
TorchMethods
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
static void errorHandler(const char *msg, void *data)
|
static void errorHandler(const char *msg, void *data)
|
||||||
{
|
{
|
||||||
@ -338,10 +344,17 @@ static void updateErrorHandlers()
|
|||||||
THSetArgErrorHandler(errorHandlerArg, NULL);
|
THSetArgErrorHandler(errorHandlerArg, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION == 2
|
||||||
|
PyMODINIT_FUNC initC()
|
||||||
|
#else
|
||||||
PyMODINIT_FUNC PyInit_C()
|
PyMODINIT_FUNC PyInit_C()
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
|
#if PY_MAJOR_VERSION == 2
|
||||||
|
ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods));
|
||||||
|
#else
|
||||||
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
|
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
|
||||||
|
#endif
|
||||||
ASSERT_TRUE(tensor_classes = PySet_New(NULL));
|
ASSERT_TRUE(tensor_classes = PySet_New(NULL));
|
||||||
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);
|
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);
|
||||||
|
|
||||||
@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C()
|
|||||||
|
|
||||||
updateErrorHandlers();
|
updateErrorHandlers();
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION == 2
|
||||||
|
#else
|
||||||
return module;
|
return module;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef ASSERT_TRUE
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#include <TH/TH.h>
|
#include <TH/TH.h>
|
||||||
|
|
||||||
|
// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/
|
||||||
|
// define PyInt_* macros for Python 3.x
|
||||||
|
#ifndef PyInt_Check
|
||||||
|
#define PyInt_Check PyLong_Check
|
||||||
|
#define PyInt_FromLong PyLong_FromLong
|
||||||
|
#define PyInt_AsLong PyLong_AsLong
|
||||||
|
#define PyInt_Type PyLong_Type
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "Exceptions.h"
|
#include "Exceptions.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr)
|
|||||||
// TODO: error checking
|
// TODO: error checking
|
||||||
PyObject *args = PyTuple_New(0);
|
PyObject *args = PyTuple_New(0);
|
||||||
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
|
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
|
||||||
|
|
||||||
PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
|
PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
|
||||||
Py_DECREF(args);
|
Py_DECREF(args);
|
||||||
Py_DECREF(kwargs);
|
Py_DECREF(kwargs);
|
||||||
@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
|
|||||||
{
|
{
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
static const char *keywords[] = {"cdata", NULL};
|
static const char *keywords[] = {"cdata", NULL};
|
||||||
PyObject *number_arg = NULL;
|
void* number_arg = NULL;
|
||||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg))
|
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords,
|
||||||
|
THPUtils_getLong, &number_arg))
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
|
THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
|
||||||
if (self != NULL) {
|
if (self != NULL) {
|
||||||
if (kwargs) {
|
if (kwargs) {
|
||||||
self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg);
|
self->cdata = (THStorage*)number_arg;
|
||||||
THStorage_(retain)(self->cdata);
|
THStorage_(retain)(self->cdata);
|
||||||
} else if (/* !kwargs && */ number_arg) {
|
} else if (/* !kwargs && */ number_arg) {
|
||||||
self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg));
|
self->cdata = THStorage_(newWithSize)((long) number_arg);
|
||||||
} else {
|
} else {
|
||||||
self->cdata = THStorage_(new)();
|
self->cdata = THStorage_(new)();
|
||||||
}
|
}
|
||||||
@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
|
|||||||
{
|
{
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
/* Integer index */
|
/* Integer index */
|
||||||
if (PyLong_Check(index)) {
|
long nindex;
|
||||||
long nindex = PyLong_AsLong(index);
|
if ((PyLong_Check(index) || PyInt_Check(index))
|
||||||
|
&& THPUtils_getLong(index, &nindex) == 1 ) {
|
||||||
if (nindex < 0)
|
if (nindex < 0)
|
||||||
nindex += THStorage_(size)(self->cdata);
|
nindex += THStorage_(size)(self->cdata);
|
||||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||||
@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
|
|||||||
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
|
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
|
||||||
return THPStorage_(newObject)(new_storage);
|
return THPStorage_(newObject)(new_storage);
|
||||||
}
|
}
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported");
|
char err_string[512];
|
||||||
|
snprintf (err_string, 512,
|
||||||
|
"%s %s", "Only indexing with integers and slices supported, but got type: ",
|
||||||
|
index->ob_type->tp_name);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||||
return NULL;
|
return NULL;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
|
|||||||
if (!THPUtils_(parseReal)(value, &rvalue))
|
if (!THPUtils_(parseReal)(value, &rvalue))
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (PyLong_Check(index)) {
|
long nindex;
|
||||||
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue);
|
if ((PyLong_Check(index) || PyInt_Check(index))
|
||||||
|
&& THPUtils_getLong(index, &nindex) == 1) {
|
||||||
|
THStorage_(set)(self->cdata, nindex, rvalue);
|
||||||
return 0;
|
return 0;
|
||||||
} else if (PySlice_Check(index)) {
|
} else if (PySlice_Check(index)) {
|
||||||
Py_ssize_t start, stop, len;
|
Py_ssize_t start, stop, len;
|
||||||
@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
|
|||||||
THStorage_(set)(self->cdata, start, rvalue);
|
THStorage_(set)(self->cdata, start, rvalue);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment");
|
char err_string[512];
|
||||||
|
snprintf (err_string, 512, "%s %s",
|
||||||
|
"Only indexing with integers and slices supported, but got type: ",
|
||||||
|
index->ob_type->tp_name);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||||
return -1;
|
return -1;
|
||||||
END_HANDLE_TH_ERRORS_RET(-1)
|
END_HANDLE_TH_ERRORS_RET(-1)
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg)
|
|||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
if (!PyLong_Check(number_arg))
|
if (!PyLong_Check(number_arg))
|
||||||
return NULL;
|
return NULL;
|
||||||
size_t newsize = PyLong_AsSize_t(number_arg);
|
long newsize = PyLong_AsLong(number_arg);
|
||||||
if (PyErr_Occurred())
|
if (PyErr_Occurred())
|
||||||
return NULL;
|
return NULL;
|
||||||
THStorage_(resize)(self->cdata, newsize);
|
THStorage_(resize)(self->cdata, newsize);
|
||||||
|
@ -7,7 +7,7 @@ PyObject * THPTensor_(newObject)(THTensor *ptr)
|
|||||||
{
|
{
|
||||||
// TODO: error checking
|
// TODO: error checking
|
||||||
PyObject *args = PyTuple_New(0);
|
PyObject *args = PyTuple_New(0);
|
||||||
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
|
PyObject *kwargs = Py_BuildValue("{s:K}", "cdata", (unsigned long long) ptr);
|
||||||
PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs);
|
PyObject *instance = PyObject_Call(THPTensorClass, args, kwargs);
|
||||||
Py_DECREF(args);
|
Py_DECREF(args);
|
||||||
Py_DECREF(kwargs);
|
Py_DECREF(kwargs);
|
||||||
@ -41,9 +41,11 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
|
|||||||
THTensor *cdata_ptr = NULL;
|
THTensor *cdata_ptr = NULL;
|
||||||
// If not, try to parse integers
|
// If not, try to parse integers
|
||||||
#define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments"
|
#define ERRMSG ";Expected torch.LongStorage or up to 4 integers as arguments"
|
||||||
|
// TODO: check that cdata_ptr is a keyword arg
|
||||||
if (!storage_obj &&
|
if (!storage_obj &&
|
||||||
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLL$k" ERRMSG, (char**)keywords,
|
!PyArg_ParseTupleAndKeywords(args, kwargs, "|LLLLL" ERRMSG, (char**)keywords,
|
||||||
&sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr))
|
&sizes[0], &sizes[1], &sizes[2], &sizes[3], &cdata_ptr))
|
||||||
|
#undef ERRMSG
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
THPTensor *self = (THPTensor *)type->tp_alloc(type, 0);
|
THPTensor *self = (THPTensor *)type->tp_alloc(type, 0);
|
||||||
@ -64,21 +66,24 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
|
#define INDEX_LONG(DIM, IDX_VARIABLE, TENSOR_VARIABLE, CASE_1D, CASE_MD) \
|
||||||
long idx = PyLong_AsLong(IDX_VARIABLE); \
|
long idx; \
|
||||||
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
|
THPUtils_getLong(IDX_VARIABLE, &idx); \
|
||||||
idx = (idx < 0) ? dimsize + idx : idx; \
|
long dimsize = THTensor_(size)(TENSOR_VARIABLE, DIM); \
|
||||||
\
|
idx = (idx < 0) ? dimsize + idx : idx; \
|
||||||
THArgCheck(dimsize > 0, 1, "empty tensor"); \
|
\
|
||||||
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
|
THArgCheck(dimsize > 0, 1, "empty tensor"); \
|
||||||
\
|
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
|
||||||
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
|
\
|
||||||
CASE_1D; \
|
if(THTensor_(nDimension)(TENSOR_VARIABLE) == 1) { \
|
||||||
} else { \
|
CASE_1D; \
|
||||||
CASE_MD; \
|
} else { \
|
||||||
}
|
CASE_MD; \
|
||||||
#define GET_PTR_1D(t, idx) \
|
}
|
||||||
|
|
||||||
|
#define GET_PTR_1D(t, idx) \
|
||||||
t->storage->data + t->storageOffset + t->stride[0] * idx;
|
t->storage->data + t->storageOffset + t->stride[0] * idx;
|
||||||
|
|
||||||
static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||||
THTensor * &tresult, real * &rresult)
|
THTensor * &tresult, real * &rresult)
|
||||||
{
|
{
|
||||||
@ -86,7 +91,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
|||||||
rresult = NULL;
|
rresult = NULL;
|
||||||
try {
|
try {
|
||||||
// Indexing with an integer
|
// Indexing with an integer
|
||||||
if(PyLong_Check(index)) {
|
if(PyLong_Check(index) || PyInt_Check(index)) {
|
||||||
THTensor *self_t = self->cdata;
|
THTensor *self_t = self->cdata;
|
||||||
INDEX_LONG(0, index, self_t,
|
INDEX_LONG(0, index, self_t,
|
||||||
// 1D tensor
|
// 1D tensor
|
||||||
@ -97,8 +102,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
|||||||
)
|
)
|
||||||
// Indexing with a single element tuple
|
// Indexing with a single element tuple
|
||||||
} else if (PyTuple_Check(index) &&
|
} else if (PyTuple_Check(index) &&
|
||||||
PyTuple_Size(index) == 1 &&
|
PyTuple_Size(index) == 1 &&
|
||||||
PyLong_Check(PyTuple_GET_ITEM(index, 0))) {
|
(PyLong_Check(PyTuple_GET_ITEM(index, 0))
|
||||||
|
|| PyInt_Check(PyTuple_GET_ITEM(index, 0)))) {
|
||||||
PyObject *index_obj = PyTuple_GET_ITEM(index, 0);
|
PyObject *index_obj = PyTuple_GET_ITEM(index, 0);
|
||||||
tresult = THTensor_(newWithTensor)(self->cdata);
|
tresult = THTensor_(newWithTensor)(self->cdata);
|
||||||
INDEX_LONG(0, index_obj, tresult,
|
INDEX_LONG(0, index_obj, tresult,
|
||||||
@ -121,7 +127,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
|||||||
|
|
||||||
for(int dim = 0; dim < PyTuple_Size(index); dim++) {
|
for(int dim = 0; dim < PyTuple_Size(index); dim++) {
|
||||||
PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
|
PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
|
||||||
if(PyLong_Check(dimidx)) {
|
if(PyLong_Check(dimidx) || PyInt_Check(dimidx)) {
|
||||||
INDEX_LONG(t_dim, dimidx, tresult,
|
INDEX_LONG(t_dim, dimidx, tresult,
|
||||||
// 1D tensor
|
// 1D tensor
|
||||||
rresult = GET_PTR_1D(tresult, idx);
|
rresult = GET_PTR_1D(tresult, idx);
|
||||||
@ -132,7 +138,9 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
|||||||
THTensor_(select)(tresult, NULL, t_dim, idx)
|
THTensor_(select)(tresult, NULL, t_dim, idx)
|
||||||
)
|
)
|
||||||
} else if (PyTuple_Check(dimidx)) {
|
} else if (PyTuple_Check(dimidx)) {
|
||||||
if (PyTuple_Size(dimidx) != 1 || !PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))) {
|
if (PyTuple_Size(dimidx) != 1
|
||||||
|
|| !(PyLong_Check(PyTuple_GET_ITEM(dimidx, 0))
|
||||||
|
|| PyInt_Check(PyTuple_GET_ITEM(dimidx, 0)))) {
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Expected a single integer");
|
PyErr_SetString(PyExc_RuntimeError, "Expected a single integer");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -177,7 +185,11 @@ static PyObject * THPTensor_(get)(THPTensor *self, PyObject *index)
|
|||||||
return THPTensor_(newObject)(tresult);
|
return THPTensor_(newObject)(tresult);
|
||||||
if (rresult)
|
if (rresult)
|
||||||
return THPUtils_(newReal)(*rresult);
|
return THPUtils_(newReal)(*rresult);
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Unknown exception");
|
char err_string[512];
|
||||||
|
snprintf (err_string, 512,
|
||||||
|
"%s %s", "Unknown exception in THPTensor_(get). Index type is: ",
|
||||||
|
index->ob_type->tp_name);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
@ -312,7 +324,7 @@ PyTypeObject THPTensorStatelessType = {
|
|||||||
0, /* tp_print */
|
0, /* tp_print */
|
||||||
0, /* tp_getattr */
|
0, /* tp_getattr */
|
||||||
0, /* tp_setattr */
|
0, /* tp_setattr */
|
||||||
0, /* tp_reserved */
|
0, /* tp_reserved / tp_compare */
|
||||||
0, /* tp_repr */
|
0, /* tp_repr */
|
||||||
0, /* tp_as_number */
|
0, /* tp_as_number */
|
||||||
0, /* tp_as_sequence */
|
0, /* tp_as_sequence */
|
||||||
@ -342,6 +354,13 @@ PyTypeObject THPTensorStatelessType = {
|
|||||||
0, /* tp_init */
|
0, /* tp_init */
|
||||||
0, /* tp_alloc */
|
0, /* tp_alloc */
|
||||||
0, /* tp_new */
|
0, /* tp_new */
|
||||||
|
0, /* tp_free */
|
||||||
|
0, /* tp_is_gc */
|
||||||
|
0, /* tp_bases */
|
||||||
|
0, /* tp_mro */
|
||||||
|
0, /* tp_cache */
|
||||||
|
0, /* tp_subclasses */
|
||||||
|
0, /* tp_weaklist */
|
||||||
};
|
};
|
||||||
|
|
||||||
bool THPTensor_(init)(PyObject *module)
|
bool THPTensor_(init)(PyObject *module)
|
||||||
@ -353,6 +372,7 @@ bool THPTensor_(init)(PyObject *module)
|
|||||||
THPTensorStatelessType.tp_new = PyType_GenericNew;
|
THPTensorStatelessType.tp_new = PyType_GenericNew;
|
||||||
if (PyType_Ready(&THPTensorStatelessType) < 0)
|
if (PyType_Ready(&THPTensorStatelessType) < 0)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType);
|
PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,14 @@
|
|||||||
bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength)
|
bool THPUtils_(parseSlice)(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength)
|
||||||
{
|
{
|
||||||
Py_ssize_t start, stop, step, slicelength;
|
Py_ssize_t start, stop, step, slicelength;
|
||||||
if (PySlice_GetIndicesEx(slice, len, &start, &stop, &step, &slicelength) < 0) {
|
if (PySlice_GetIndicesEx(
|
||||||
|
// https://bugsfiles.kde.org/attachment.cgi?id=61186
|
||||||
|
#if PY_VERSION_HEX >= 0x03020000
|
||||||
|
slice,
|
||||||
|
#else
|
||||||
|
(PySliceObject *)slice,
|
||||||
|
#endif
|
||||||
|
len, &start, &stop, &step, &slicelength) < 0) {
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice");
|
PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -24,11 +31,16 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
|
|||||||
{
|
{
|
||||||
if (PyLong_Check(value)) {
|
if (PyLong_Check(value)) {
|
||||||
*result = (real)PyLong_AsLongLong(value);
|
*result = (real)PyLong_AsLongLong(value);
|
||||||
|
} else if (PyInt_Check(value)) {
|
||||||
|
*result = (real)PyInt_AsLong(value);
|
||||||
} else if (PyFloat_Check(value)) {
|
} else if (PyFloat_Check(value)) {
|
||||||
*result = (real)PyFloat_AsDouble(value);
|
*result = (real)PyFloat_AsDouble(value);
|
||||||
} else {
|
} else {
|
||||||
// TODO: meaningful error
|
char err_string[512];
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Unrecognized object");
|
snprintf (err_string, 512, "%s %s",
|
||||||
|
"parseReal expected long or float, but got type: ",
|
||||||
|
value->ob_type->tp_name);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -36,7 +48,7 @@ bool THPUtils_(parseReal)(PyObject *value, real *result)
|
|||||||
|
|
||||||
bool THPUtils_(checkReal)(PyObject *value)
|
bool THPUtils_(checkReal)(PyObject *value)
|
||||||
{
|
{
|
||||||
return PyFloat_Check(value) || PyLong_Check(value);
|
return PyFloat_Check(value) || PyLong_Check(value) || PyInt_Check(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject * THPUtils_(newReal)(real value)
|
PyObject * THPUtils_(newReal)(real value)
|
||||||
|
@ -3,3 +3,19 @@
|
|||||||
|
|
||||||
#include "generic/utils.cpp"
|
#include "generic/utils.cpp"
|
||||||
#include <TH/THGenerateAllTypes.h>
|
#include <TH/THGenerateAllTypes.h>
|
||||||
|
|
||||||
|
int THPUtils_getLong(PyObject *index, long *result) {
|
||||||
|
if (PyLong_Check(index)) {
|
||||||
|
*result = PyLong_AsLong(index);
|
||||||
|
} else if (PyInt_Check(index)) {
|
||||||
|
*result = PyInt_AsLong(index);
|
||||||
|
} else {
|
||||||
|
char err_string[512];
|
||||||
|
snprintf (err_string, 512, "%s %s",
|
||||||
|
"getLong expected int or long, but got type: ",
|
||||||
|
index->ob_type->tp_name);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
@ -6,5 +6,7 @@
|
|||||||
#include "generic/utils.h"
|
#include "generic/utils.h"
|
||||||
#include <TH/THGenerateAllTypes.h>
|
#include <TH/THGenerateAllTypes.h>
|
||||||
|
|
||||||
|
int THPUtils_getLong(PyObject *index, long *result);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user