mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve error messages in storage and tensor C functions
This commit is contained in:
69
test/error_messages/storage.py
Normal file
69
test/error_messages/storage.py
Normal file
@ -0,0 +1,69 @@
|
||||
import torch
|
||||
|
||||
def check_error(desc, fn, *required_substrings):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
error_message = e.args[0]
|
||||
print('=' * 80)
|
||||
print(desc)
|
||||
print('-' * 80)
|
||||
print(error_message)
|
||||
print('')
|
||||
for sub in required_substrings:
|
||||
assert sub in error_message
|
||||
return
|
||||
assert False, "given function ({}) didn't raise an error".format(desc)
|
||||
|
||||
check_error(
|
||||
'Wrong argument types',
|
||||
lambda: torch.FloatStorage(object()),
|
||||
'object')
|
||||
|
||||
check_error('Unknown keyword argument',
|
||||
lambda: torch.FloatStorage(content=1234.),
|
||||
'keyword')
|
||||
|
||||
check_error('Invalid types inside a sequence',
|
||||
lambda: torch.FloatStorage(['a', 'b']),
|
||||
'list', 'str')
|
||||
|
||||
check_error('Invalid size type',
|
||||
lambda: torch.FloatStorage(1.5),
|
||||
'float')
|
||||
|
||||
check_error('Invalid offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
|
||||
'2', '4')
|
||||
|
||||
check_error('Negative offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
|
||||
'2', '-1')
|
||||
|
||||
check_error('Invalid size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
'2', '1', '5')
|
||||
|
||||
check_error('Negative size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
'2', '1', '-5')
|
||||
|
||||
check_error('Invalid index type',
|
||||
lambda: torch.FloatStorage(10)['first item'],
|
||||
'str')
|
||||
|
||||
def assign():
|
||||
torch.FloatStorage(10)[1:-1] = '1'
|
||||
check_error('Invalid value type',
|
||||
assign,
|
||||
'str')
|
||||
|
||||
check_error('resize_ with invalid type',
|
||||
lambda: torch.FloatStorage(10).resize_(1.5),
|
||||
'float')
|
||||
|
||||
check_error('fill_ with invalid type',
|
||||
lambda: torch.IntStorage(10).fill_('asdf'),
|
||||
'str')
|
||||
|
||||
# TODO: frombuffer
|
@ -893,7 +893,7 @@ class TestNN(NNTestCase):
|
||||
gradInputConcat = concat.backward(input, gradOutput)
|
||||
# the spatial dims are the largest, the nFilters is the sum
|
||||
output = torch.Tensor(2, int(outputSize.sum()), 12, 12).zero_() # zero for padding
|
||||
narrows = ( (slice(None), (0, 5), slice(None), slice(None)), (slice(None), (5, 11), (1, 11), (1, 11)), (slice(None), (11, 18), (1, 10), (1, 10)), (slice(None), (18, 26), (2, 10), (2, 10)) )
|
||||
narrows = ( (slice(None), slice(0, 5), slice(None), slice(None)), (slice(None), slice(5, 11), slice(1, 11), slice(1, 11)), (slice(None), slice(11, 18), slice(1, 10), slice(1, 10)), (slice(None), slice(18, 26), slice(2, 10), slice(2, 10)) )
|
||||
gradInput = input.clone().zero_()
|
||||
for i in range(4):
|
||||
conv = concat.get(i)
|
||||
|
@ -106,15 +106,18 @@ class TestMultiprocessing(TestCase):
|
||||
def _test_preserve_sharing(self):
|
||||
def do_test():
|
||||
x = torch.randn(5, 5)
|
||||
data = [x.storage(), x, x[2], x[:,1]]
|
||||
data = [x.storage(), x.storage()[1:4], x, x[2], x[:,1]]
|
||||
q = mp.Queue()
|
||||
q.put(data)
|
||||
new_data = q.get()
|
||||
self.assertEqual(new_data, data, 0)
|
||||
storage_cdata = data[0]._cdata
|
||||
self.assertEqual(new_data[0]._cdata, storage_cdata)
|
||||
for t in new_data[1:]:
|
||||
for t in new_data[2:]:
|
||||
self.assertEqual(t.storage()._cdata, storage_cdata)
|
||||
# TODO: enable after fixing #46
|
||||
# new_data[0].fill_(10)
|
||||
# self.assertEqual(new_data[1], new_data[0][1:4], 0)
|
||||
|
||||
with leak_checker(self):
|
||||
do_test()
|
||||
|
@ -135,7 +135,7 @@ class TestTorch(TestCase):
|
||||
# with indices
|
||||
m1 = torch.randn(100,100)
|
||||
res1val, res1ind = torchfn(m1, 1)
|
||||
res2val = m1[:,(0,)].clone()
|
||||
res2val = m1[:,0:1].clone()
|
||||
res2ind = res1ind.clone().fill_(0)
|
||||
for i, j in iter_indices(m1):
|
||||
if mathfn(res2val[i,0], m1[i,j]) != res2val[i,0]:
|
||||
@ -1690,9 +1690,6 @@ class TestTorch(TestCase):
|
||||
self.assertEqual(reference[0], self._consecutive((3, 3)), 0)
|
||||
self.assertEqual(reference[1], self._consecutive((3, 3), 10), 0)
|
||||
self.assertEqual(reference[2], self._consecutive((3, 3), 19), 0)
|
||||
self.assertEqual(reference[(0,)], self._consecutive((1, 3, 3)), 0)
|
||||
self.assertEqual(reference[(1,)], self._consecutive((1, 3, 3), 10), 0)
|
||||
self.assertEqual(reference[(2,)], self._consecutive((1, 3, 3), 19), 0)
|
||||
self.assertEqual(reference[0, 1], self._consecutive((3,), 4), 0)
|
||||
self.assertEqual(reference[0:2], self._consecutive((2, 3, 3)), 0)
|
||||
self.assertEqual(reference[2, 2, 2], 27, 0)
|
||||
@ -1772,7 +1769,7 @@ class TestTorch(TestCase):
|
||||
for j in range(1 if dim == 1 else n):
|
||||
for k in range(1 if dim == 2 else o):
|
||||
ii = [i, j, k]
|
||||
ii[dim] = (0, idx.size(dim))
|
||||
ii[dim] = slice(0, idx.size(dim)+1)
|
||||
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
|
||||
|
||||
def test_gather(self):
|
||||
@ -2221,7 +2218,9 @@ class TestTorch(TestCase):
|
||||
|
||||
def test_serialization(self):
|
||||
a = [torch.randn(5, 5).float() for i in range(2)]
|
||||
b = [a[i % 2] for i in range(4)] + [a[0].storage()]
|
||||
b = [a[i % 2] for i in range(4)]
|
||||
b += [a[0].storage()]
|
||||
b += [a[0].storage()[1:4]]
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(b, f)
|
||||
f.seek(0)
|
||||
@ -2237,6 +2236,8 @@ class TestTorch(TestCase):
|
||||
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
|
||||
c[1].fill_(20)
|
||||
self.assertEqual(c[1], c[3], 0)
|
||||
# TODO: enable after fixing #46
|
||||
# self.assertEqual(c[4], c[5][1:4], 0)
|
||||
|
||||
def test_from_buffer(self):
|
||||
a = bytearray([1, 2, 3, 4])
|
||||
|
@ -3,7 +3,12 @@ from . import CWrapPlugin
|
||||
|
||||
class THPLongArgsPlugin(CWrapPlugin):
|
||||
PARSE_LONG_ARGS = Template("""\
|
||||
THLongStoragePtr __long_args_guard = THPUtils_getLongStorage(args, $num_checked);
|
||||
THLongStoragePtr __long_args_guard;
|
||||
try {
|
||||
__long_args_guard = THPUtils_getLongStorage(args, $num_checked);
|
||||
} catch (std::exception &e) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
THLongStorage* __long_args = __long_args_guard.get();
|
||||
""")
|
||||
|
||||
@ -30,6 +35,13 @@ class THPLongArgsPlugin(CWrapPlugin):
|
||||
code = code.replace('__argcount ==', '__argcount >')
|
||||
return code
|
||||
|
||||
def process_wrapper(self, code, declaration):
|
||||
if any(map(lambda opt: opt.get('long_args'), declaration['options'])):
|
||||
invalid_arguments_idx = code.find('THPUtils_invalidArguments')
|
||||
newline_idx = code.rfind('\n', 0, invalid_arguments_idx)
|
||||
code = code[:newline_idx] + '\ninvalid_arguments:' + code[newline_idx:]
|
||||
return code
|
||||
|
||||
def process_option_code(self, code, option):
|
||||
if 'long_args' in option and option['long_args']:
|
||||
lines = code.split('\n')
|
||||
|
@ -73,10 +73,10 @@ PyObject * $name(PyObject *self, PyObject *args)
|
||||
HANDLE_TH_ERRORS
|
||||
int __argcount = args ? PyTuple_Size(args) : 0;
|
||||
$options
|
||||
} else {
|
||||
THPUtils_invalidArguments(args, "$readable_name", $num_options, $expected_args);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THPUtils_invalidArguments(args, "$readable_name", $num_options, $expected_args);
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
""")
|
||||
@ -164,6 +164,9 @@ PyObject * $name(PyObject *self, PyObject *args)
|
||||
option_desc = [self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
|
||||
for arg in option['arguments']
|
||||
if not arg.get('ignore_check', False)]
|
||||
# TODO: this should probably go to THPLongArgsPlugin
|
||||
if option.get('long_args'):
|
||||
option_desc.append('int ...')
|
||||
if option_desc:
|
||||
arg_desc.append('({})'.format(', '.join(option_desc)))
|
||||
else:
|
||||
|
@ -74,14 +74,16 @@ bool TH_CONCAT_3(THPModule_,name,Copy)(PyObject *dst, PyObject *src) \
|
||||
static PyObject * TH_CONCAT_3(THPModule_,name,CopyWrapper)(PyObject *unused, PyObject *args)\
|
||||
{ \
|
||||
HANDLE_TH_ERRORS \
|
||||
/* TODO: check args */ \
|
||||
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0; \
|
||||
THPUtils_assert(num_args == 2, #name "Copy expected exactly two arguments, " \
|
||||
"but got %ld", (long)num_args); \
|
||||
PyObject *dst = PyTuple_GET_ITEM(args, 0); \
|
||||
PyObject *src = PyTuple_GET_ITEM(args, 1); \
|
||||
if (!TH_CONCAT_3(THPModule_,name,Copy)(dst, src)) { \
|
||||
return NULL; \
|
||||
} \
|
||||
/* TODO: return dst? */ \
|
||||
Py_RETURN_NONE; \
|
||||
Py_INCREF(dst); \
|
||||
return dst; \
|
||||
END_HANDLE_TH_ERRORS \
|
||||
}
|
||||
|
||||
@ -145,81 +147,55 @@ static PyObject * THPModule_getNumThreads(PyObject *module)
|
||||
|
||||
static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
|
||||
{
|
||||
if (!THPUtils_checkLong(arg)) {
|
||||
THPUtils_setError("set_num_threads expects a single int as argument");
|
||||
return NULL;
|
||||
}
|
||||
// TODO: maybe throw an error to let people know it's a noop? or a warning?
|
||||
THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, "
|
||||
"but got %s", THPUtils_typename(arg));
|
||||
#ifdef _OPENMP
|
||||
omp_set_num_threads(THPUtils_unpackLong(arg));
|
||||
#else
|
||||
PyErr_WarnEx(PyExc_RuntimeWarning, "set_num_threads is a no-op - torch was "
|
||||
"compiled without OpenMP support", 1);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject * THPModule_getRNGState(PyObject *module, PyObject *args)
|
||||
static PyObject * THPModule_getRNGState(PyObject *module)
|
||||
{
|
||||
THGenerator *generator = THPDefaultGenerator->cdata;
|
||||
if (args && PyTuple_Size(args) == 1 && THPGenerator_Check(PyTuple_GET_ITEM(args, 0))) {
|
||||
generator = ((THPGenerator*)PyTuple_GET_ITEM(args, 0))->cdata;
|
||||
} else if (args && PyTuple_Size(args) > 0) {
|
||||
// TODO: better error message
|
||||
THPUtils_setError("invalid arguments");
|
||||
return NULL;
|
||||
}
|
||||
HANDLE_TH_ERRORS
|
||||
THPGenerator *self = THPDefaultGenerator;
|
||||
THGenerator *generator = self->cdata;
|
||||
THByteTensorPtr _t = THByteTensor_new();
|
||||
THByteTensor_getRNGState(generator, _t.get());
|
||||
PyObject *_ret = THPByteTensor_New(_t);
|
||||
_t.release();
|
||||
return _ret;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPModule_setRNGState(PyObject *module, PyObject *args)
|
||||
static PyObject * THPModule_setRNGState(PyObject *_unused, PyObject *_new_state)
|
||||
{
|
||||
THGenerator *generator = THPDefaultGenerator->cdata;
|
||||
THByteTensor *new_state = NULL;
|
||||
bool args_ok = false;
|
||||
if (args && PyTuple_Size(args) > 0) {
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
|
||||
|
||||
if (THPGenerator_Check(first_arg)) {
|
||||
PyObject *second_arg = PyTuple_GET_ITEM(args, 1);
|
||||
if (THPByteTensor_Check(second_arg)) {
|
||||
new_state = ((THPByteTensor*)second_arg)->cdata;
|
||||
args_ok = PyTuple_Size(args) == 2;
|
||||
}
|
||||
} else if (THPByteTensor_Check(first_arg)) {
|
||||
new_state = ((THPByteTensor*)first_arg)->cdata;
|
||||
args_ok = PyTuple_Size(args) == 1;
|
||||
}
|
||||
}
|
||||
if (!args_ok) {
|
||||
THPUtils_setError("invalid arguments");
|
||||
return NULL;
|
||||
}
|
||||
HANDLE_TH_ERRORS
|
||||
THPGenerator *self = THPDefaultGenerator;
|
||||
THGenerator *generator = self->cdata;
|
||||
THPUtils_assert(THPByteTensor_Check(_new_state), "set_rng_state expects a "
|
||||
"torch.ByteTensor, but got %s", THPUtils_typename(_new_state));
|
||||
THByteTensor *new_state = ((THPByteTensor*)_new_state)->cdata;;
|
||||
THByteTensor_setRNGState(generator, new_state);
|
||||
Py_RETURN_NONE;
|
||||
Py_INCREF(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPModule_manualSeed(PyObject *module, PyObject *args)
|
||||
static PyObject * THPModule_manualSeed(PyObject *_unused, PyObject *seed)
|
||||
{
|
||||
THGenerator *generator = THPDefaultGenerator->cdata;
|
||||
long new_seed;
|
||||
bool args_ok = false;
|
||||
if (args && PyTuple_Size(args) > 0) {
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
|
||||
if (THPUtils_checkLong(first_arg)) {
|
||||
new_seed = THPUtils_unpackLong(first_arg);
|
||||
args_ok = PyTuple_Size(args) == 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!args_ok) {
|
||||
// TODO: better error message
|
||||
THPUtils_setError("invalid arguments");
|
||||
return NULL;
|
||||
}
|
||||
THRandom_manualSeed(generator, new_seed);
|
||||
Py_RETURN_NONE;
|
||||
HANDLE_TH_ERRORS
|
||||
THPGenerator *self = THPDefaultGenerator;
|
||||
THGenerator *generator = self->cdata;
|
||||
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
|
||||
"but got %s", THPUtils_typename(seed));
|
||||
THRandom_manualSeed(generator, THPUtils_unpackLong(seed));
|
||||
Py_INCREF(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
bool THPModule_isTensor(PyObject *obj)
|
||||
@ -572,9 +548,9 @@ static PyMethodDef TorchMethods[] = {
|
||||
{"_storageCopy", (PyCFunction)THPModule_storageCopyWrapper, METH_VARARGS, NULL},
|
||||
{"get_num_threads", (PyCFunction)THPModule_getNumThreads, METH_NOARGS, NULL},
|
||||
{"set_num_threads", (PyCFunction)THPModule_setNumThreads, METH_O, NULL},
|
||||
{"get_rng_state", (PyCFunction)THPModule_getRNGState, METH_VARARGS, NULL},
|
||||
{"set_rng_state", (PyCFunction)THPModule_setRNGState, METH_VARARGS, NULL},
|
||||
{"manual_seed", (PyCFunction)THPModule_manualSeed, METH_VARARGS, NULL},
|
||||
{"get_rng_state", (PyCFunction)THPModule_getRNGState, METH_NOARGS, NULL},
|
||||
{"set_rng_state", (PyCFunction)THPModule_setRNGState, METH_O, NULL},
|
||||
{"manual_seed", (PyCFunction)THPModule_manualSeed, METH_O, NULL},
|
||||
|
||||
{"sigmoid", (PyCFunction)THPModule_sigmoid, METH_VARARGS, NULL},
|
||||
{"log", (PyCFunction)THPModule_log, METH_VARARGS, NULL},
|
||||
|
@ -6,25 +6,16 @@ PyObject *THPStorageClass = NULL;
|
||||
|
||||
PyObject * THPStorage_(New)(THStorage *ptr)
|
||||
{
|
||||
PyObject *args = PyTuple_New(0);
|
||||
PyObject *kwargs = NULL;
|
||||
if (!args) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Could not create a new storage object - "
|
||||
"failed to allocate argument tuple");
|
||||
return NULL;
|
||||
}
|
||||
THPObjectPtr args = PyTuple_New(0);
|
||||
THPObjectPtr kwargs;
|
||||
THPUtils_assert(args, "Could not create a new storage object - failed to"
|
||||
"allocate argument tuple");
|
||||
if (ptr) {
|
||||
kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
|
||||
if (!kwargs) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Could not create a new storage object - "
|
||||
THPUtils_assert(kwargs, "Could not create a new storage object - "
|
||||
"failed to allocate keyword argument dictionary");
|
||||
Py_DECREF(args);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
PyObject *result = PyObject_Call(THPStorageClass, args, kwargs);
|
||||
Py_DECREF(args);
|
||||
Py_XDECREF(kwargs);
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -56,116 +47,116 @@ static void THPStorage_(dealloc)(THPStorage* self)
|
||||
static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject *cdata_ptr = NULL; // keyword-only arg - cdata pointer value
|
||||
THPStorage *storage_arg = NULL; // storage to be viewed on
|
||||
long storage_arg_offset = 0; // offset for storage view
|
||||
long storage_arg_size = -1; // size for storage view
|
||||
THPObjectPtr iterator; // not null iff got a single iterable
|
||||
long size = -1; // non-negative iff got a number - new storage size
|
||||
bool args_ok = true;
|
||||
|
||||
if (kwargs != NULL && PyDict_Size(kwargs) == 1) {
|
||||
cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
|
||||
args_ok = cdata_ptr != NULL;
|
||||
} else if (args != NULL && PyTuple_Size(args) == 1) {
|
||||
PyObject *arg = PyTuple_GET_ITEM(args, 0);
|
||||
if (THPUtils_checkLong(arg)) {
|
||||
size = THPUtils_unpackLong(arg);
|
||||
} else {
|
||||
iterator = PyObject_GetIter(arg);
|
||||
args_ok = iterator != nullptr;
|
||||
if (args_ok) {
|
||||
size = PyObject_Length(arg);
|
||||
args_ok = size != -1;
|
||||
}
|
||||
}
|
||||
// Storage view
|
||||
} else if (args != NULL && PyTuple_Size(args) >= 1 && THPStorage_(Check)(PyTuple_GET_ITEM(args, 0))) {
|
||||
storage_arg = (THPStorage *)PyTuple_GET_ITEM(args, 0);
|
||||
if (PyTuple_Size(args) >= 2) {
|
||||
PyObject *second_arg = PyTuple_GET_ITEM(args, 1);
|
||||
THPUtils_assert(THPUtils_checkLong(second_arg), "Invalid arguments");
|
||||
storage_arg_offset = THPUtils_unpackLong(second_arg);
|
||||
}
|
||||
storage_arg_size = storage_arg->cdata->size - storage_arg_offset;
|
||||
if (PyTuple_Size(args) >= 3) {
|
||||
PyObject *third_arg = PyTuple_GET_ITEM(args, 2);
|
||||
THPUtils_assert(THPUtils_checkLong(third_arg), "Invalid arguments");
|
||||
storage_arg_size = THPUtils_unpackLong(third_arg);
|
||||
}
|
||||
if (storage_arg_offset < 0 || storage_arg_offset >= storage_arg->cdata->size) {
|
||||
THPUtils_setError("Invalid storage offset (%ld)!\n", storage_arg_offset);
|
||||
return NULL;
|
||||
}
|
||||
if (storage_arg_size < 1 || storage_arg_size > storage_arg->cdata->size - storage_arg_offset) {
|
||||
THPUtils_setError("Invalid storage size (got %ld, but should be between 0 and %ld)!\n",
|
||||
storage_arg_size);
|
||||
return NULL;
|
||||
}
|
||||
if (PyTuple_Size(args) >= 4)
|
||||
args_ok = false;
|
||||
} else if (args && PyTuple_Size(args) != 0) {
|
||||
args_ok = false;
|
||||
}
|
||||
if (!args_ok) {
|
||||
// TODO: better error message
|
||||
THPUtils_setError("invalid arguments");
|
||||
return NULL;
|
||||
}
|
||||
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
|
||||
|
||||
THPStoragePtr self = (THPStorage *)type->tp_alloc(type, 0);
|
||||
if (self != nullptr) {
|
||||
if (cdata_ptr) {
|
||||
THStorage *ptr = (THStorage*)PyLong_AsVoidPtr(cdata_ptr);
|
||||
self->cdata = ptr;
|
||||
} else if (storage_arg) {
|
||||
real *data_ptr = storage_arg->cdata->data + storage_arg_offset;
|
||||
THStoragePtr storage = THStorage_(newWithData)(LIBRARY_STATE data_ptr, storage_arg_size);
|
||||
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
|
||||
storage->view = storage_arg->cdata;
|
||||
THStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
|
||||
self->cdata = storage.release();
|
||||
} else if (iterator == nullptr && size >= 0) {
|
||||
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE size);
|
||||
} else if (iterator != nullptr) {
|
||||
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE size);
|
||||
long items_processed = 0;
|
||||
THPObjectPtr item;
|
||||
real v;
|
||||
while ((item = PyIter_Next(iterator))) {
|
||||
if (!THPUtils_(checkReal)(item)) {
|
||||
THPUtils_setError("expected a numeric type, but got %s", Py_TYPE(item)->tp_name);
|
||||
return NULL;
|
||||
}
|
||||
v = THPUtils_(unpackReal)(item);
|
||||
if (items_processed == size) {
|
||||
// TODO: error - iterator has too many items
|
||||
return NULL;
|
||||
}
|
||||
#ifndef THC_GENERIC_FILE
|
||||
self->cdata->data[items_processed++] = v;
|
||||
#else
|
||||
// TODO: this might be slow - consider batched updates?
|
||||
THCStorage_(set)(LIBRARY_STATE self->cdata, items_processed++, v);
|
||||
#endif
|
||||
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
|
||||
|
||||
// Internally we allow constructing with a keywoard only argument cdata
|
||||
if (kwargs != NULL) {
|
||||
Py_ssize_t num_kwargs = PyDict_Size(kwargs);
|
||||
if (num_args == 0) {
|
||||
PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
|
||||
if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
|
||||
THStorage *ptr = (THStorage*)PyLong_AsVoidPtr(cdata_ptr);
|
||||
self->cdata = ptr;
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
// Iterator raised an exception
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
// Iterator was too short
|
||||
if (items_processed < size) {
|
||||
// TODO; error message
|
||||
return NULL;
|
||||
}
|
||||
} else {
|
||||
self->cdata = THStorage_(new)(LIBRARY_STATE_NOARGS);
|
||||
}
|
||||
// This is an internal option, so we don't want to advertise it.
|
||||
THPUtils_assert(num_kwargs == 0, THPStorageStr " constructor doesn't "
|
||||
"accept any keyword arguments");
|
||||
}
|
||||
|
||||
// torch.Storage()
|
||||
if (num_args == 0) {
|
||||
self->cdata = THStorage_(new)(LIBRARY_STATE_NOARGS);
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
|
||||
|
||||
// torch.Storage(size)
|
||||
if (num_args == 1 && THPUtils_checkLong(first_arg)) {
|
||||
long size = THPUtils_unpackLong(first_arg);
|
||||
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE size);
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
// torch.Storage(view_source, [offset, [size]])
|
||||
if (num_args < 4 && THPStorage_(Check)(first_arg)) {
|
||||
THPStorage *storage_arg = (THPStorage *)first_arg;
|
||||
long numel = storage_arg->cdata->size;
|
||||
long offset = 0;
|
||||
|
||||
if (num_args >= 2) {
|
||||
PyObject *second_arg = PyTuple_GET_ITEM(args, 1);
|
||||
if (!THPUtils_checkLong(second_arg))
|
||||
goto invalid_arguments;
|
||||
offset = THPUtils_unpackLong(second_arg);
|
||||
}
|
||||
|
||||
if (self->cdata == NULL)
|
||||
return NULL;
|
||||
long size = numel - offset;
|
||||
if (num_args >= 3) {
|
||||
PyObject *third_arg = PyTuple_GET_ITEM(args, 2);
|
||||
if (!THPUtils_checkLong(third_arg))
|
||||
goto invalid_arguments;
|
||||
size = THPUtils_unpackLong(third_arg);
|
||||
}
|
||||
|
||||
THPUtils_assert(offset >= 0 && offset <= numel, "specified an offset of "
|
||||
"%ld, but the viewed storage has only %ld element(s)", offset, numel);
|
||||
THPUtils_assert(size >= 1 && size <= numel - offset, "specified a size of "
|
||||
"%d, but the viewed storage has only %ld element(s) after offset %ld",
|
||||
size, numel - offset, offset);
|
||||
|
||||
real *data_ptr = storage_arg->cdata->data + offset;
|
||||
THStoragePtr storage = THStorage_(newWithData)(LIBRARY_STATE data_ptr, size);
|
||||
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
|
||||
storage->view = storage_arg->cdata;
|
||||
THStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
|
||||
self->cdata = storage.release();
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
return (PyObject *)self.release();
|
||||
|
||||
// torch.Storage(sequence)
|
||||
if (num_args == 1 && PySequence_Check(first_arg)) {
|
||||
Py_ssize_t length = PySequence_Length(first_arg);
|
||||
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(first_arg));
|
||||
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE length);
|
||||
THPObjectPtr item;
|
||||
try {
|
||||
for (Py_ssize_t i = 0; i < length; i++) {
|
||||
item = PySequence_GetItem(first_arg, i);
|
||||
real value = THPUtils_(unpackReal)(item.get());
|
||||
#ifndef THC_GENERIC_FILE
|
||||
self->cdata->data[i] = value;
|
||||
#else
|
||||
// TODO: this might be slow - consider batched updates?
|
||||
THCStorage_(set)(LIBRARY_STATE self->cdata, i, value);
|
||||
#endif
|
||||
}
|
||||
} catch (std::runtime_error &e) {
|
||||
THPUtils_setError("tried to construct a storage from a sequence (%s), "
|
||||
"but one of the items was of type %s instead of %s",
|
||||
THPUtils_typename(first_arg),
|
||||
THPUtils_typename(item.get()),
|
||||
THPUtils_typeTraits<real>::python_type_str);
|
||||
return NULL;
|
||||
}
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, THPStorageStr " constructor", 6,
|
||||
"no arguments",
|
||||
"(int size)",
|
||||
"(Sequence data)",
|
||||
"(" THPStorageStr " view_source)",
|
||||
"(" THPStorageStr " view_source, int offset)",
|
||||
"(" THPStorageStr " view_source, int offset, int size)");
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -180,47 +171,31 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
/* Integer index */
|
||||
long nindex;
|
||||
if (THPUtils_checkLong(index)) {
|
||||
nindex = THPUtils_unpackLong(index);
|
||||
long nindex = THPUtils_unpackLong(index);
|
||||
if (nindex < 0)
|
||||
nindex += THStorage_(size)(LIBRARY_STATE self->cdata);
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || \
|
||||
defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
return PyFloat_FromDouble(THStorage_(get)(LIBRARY_STATE self->cdata, nindex));
|
||||
#elif defined(THC_REAL_IS_HALF)
|
||||
return PyFloat_FromDouble(THC_half2float(THStorage_(get)(LIBRARY_STATE self->cdata, nindex)));
|
||||
#else
|
||||
return PyLong_FromLong(THStorage_(get)(LIBRARY_STATE self->cdata, nindex));
|
||||
#endif
|
||||
real value = THStorage_(get)(LIBRARY_STATE self->cdata, nindex);
|
||||
return THPUtils_(newReal)(value);
|
||||
/* Slice index */
|
||||
} else if (PySlice_Check(index)) {
|
||||
Py_ssize_t start, stop, slicelength, len;
|
||||
len = THStorage_(size)(LIBRARY_STATE self->cdata);
|
||||
Py_ssize_t start, stop, slicelength;
|
||||
long len = THStorage_(size)(LIBRARY_STATE self->cdata);
|
||||
if (!THPUtils_parseSlice(index, len, &start, &stop, &slicelength))
|
||||
return NULL;
|
||||
|
||||
real *data = THStorage_(data)(LIBRARY_STATE self->cdata);
|
||||
#ifndef THC_GENERIC_FILE
|
||||
// TODO: this can leak memory if newWithData fails
|
||||
real *new_data = (real*)THAlloc(slicelength * sizeof(real));
|
||||
memcpy(new_data, data + start, slicelength * sizeof(real));
|
||||
THStoragePtr new_storage = THStorage_(newWithData)(LIBRARY_STATE new_data, slicelength);
|
||||
#else
|
||||
THStoragePtr new_storage = THStorage_(newWithSize)(LIBRARY_STATE slicelength);
|
||||
THStoragePtr view = THStorage_(newWithData)(LIBRARY_STATE data + start, slicelength);
|
||||
THStorage_(clearFlag)(LIBRARY_STATE view, TH_STORAGE_FREEMEM);
|
||||
THStorage_(copy)(LIBRARY_STATE new_storage, view);
|
||||
#endif
|
||||
THStoragePtr new_storage = THStorage_(newWithData)(LIBRARY_STATE data + start, slicelength);
|
||||
new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
|
||||
new_storage->view = self->cdata;
|
||||
THStorage_(retain)(LIBRARY_STATE self->cdata);
|
||||
|
||||
PyObject *_ret = THPStorage_(New)(new_storage);
|
||||
new_storage.release();
|
||||
return _ret;
|
||||
}
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512,
|
||||
"%s %s", "Only indexing with integers and slices supported, but got type: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
THPUtils_setError("can't index a " THPStorageStr " with %s",
|
||||
THPUtils_typename(index));
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -229,32 +204,30 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPUtils_(checkReal)(value)) {
|
||||
// TODO: error
|
||||
THPUtils_setError("TODO");
|
||||
THPUtils_setError("can only set storage content with a %s, but got "
|
||||
"%s instead", THPUtils_typeTraits<real>::python_type_str,
|
||||
THPUtils_typename(value));
|
||||
return -1;
|
||||
}
|
||||
real rvalue = THPUtils_(unpackReal)(value);
|
||||
|
||||
long nindex;
|
||||
real rvalue = THPUtils_(unpackReal)(value);
|
||||
if (THPUtils_checkLong(index)) {
|
||||
nindex = THPUtils_unpackLong(index);
|
||||
long nindex = THPUtils_unpackLong(index);
|
||||
THStorage_(set)(LIBRARY_STATE self->cdata, nindex, rvalue);
|
||||
return 0;
|
||||
} else if (PySlice_Check(index)) {
|
||||
Py_ssize_t start, stop, len;
|
||||
len = THStorage_(size)(LIBRARY_STATE self->cdata);
|
||||
Py_ssize_t start, stop;
|
||||
long len = THStorage_(size)(LIBRARY_STATE self->cdata);
|
||||
if (!THPUtils_parseSlice(index, len, &start, &stop, NULL))
|
||||
return -1;
|
||||
// TODO: check the bounds only once
|
||||
// TODO: fill?
|
||||
for (;start < stop; start++)
|
||||
THStorage_(set)(LIBRARY_STATE self->cdata, start, rvalue);
|
||||
return 0;
|
||||
}
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512, "%s %s",
|
||||
"Only indexing with integers and slices supported, but got type: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
THPUtils_setError("can't index a " THPStorageStr " with %s",
|
||||
THPUtils_typename(index));
|
||||
return -1;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
@ -44,7 +44,8 @@ static PyObject * THPStorage_(new)(THPStorage *self)
|
||||
static PyObject * THPStorage_(resize_)(THPStorage *self, PyObject *number_arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THPUtils_checkLong(number_arg), "invalid arguments");
|
||||
THPUtils_assert(THPUtils_checkLong(number_arg), "resize_ expects an int, "
|
||||
"but got %s", THPUtils_typename(number_arg));
|
||||
long newsize = THPUtils_unpackLong(number_arg);
|
||||
THStorage_(resize)(LIBRARY_STATE self->cdata, newsize);
|
||||
Py_INCREF(self);
|
||||
@ -55,11 +56,9 @@ static PyObject * THPStorage_(resize_)(THPStorage *self, PyObject *number_arg)
|
||||
static PyObject * THPStorage_(fill_)(THPStorage *self, PyObject *number_arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPUtils_(checkReal)(number_arg)) {
|
||||
// TODO: better error message
|
||||
THPUtils_setError("TODO");
|
||||
return NULL;
|
||||
}
|
||||
THPUtils_assert(THPUtils_(checkReal)(number_arg), "fill_ expects %s, "
|
||||
"but got %s", THPUtils_typeTraits<real>::python_type_str,
|
||||
THPUtils_typename(number_arg));
|
||||
THStorage_(fill)(LIBRARY_STATE self->cdata, THPUtils_(unpackReal)(number_arg));
|
||||
Py_INCREF(self);
|
||||
return (PyObject*)self;
|
||||
@ -108,15 +107,16 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
|
||||
|
||||
if (offset < 0 || offset > buffer.len) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"offset must be non-negative and no greater than buffer length (%ld)",
|
||||
(long) buffer.len);
|
||||
"offset must be non-negative and no greater than buffer length (%ld), "
|
||||
"but got %ld", (long)offset, (long)buffer.len);
|
||||
PyBuffer_Release(&buffer);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (count < 0) {
|
||||
if ((buffer.len - offset) % sizeof(real) != 0) {
|
||||
PyErr_Format(PyExc_ValueError, "buffer size must be a multiple of element size");
|
||||
PyErr_Format(PyExc_ValueError, "buffer size (%ld) must be a multiple "
|
||||
"of element size (%ld)", (long)buffer.len, (long)sizeof(real));
|
||||
PyBuffer_Release(&buffer);
|
||||
return NULL;
|
||||
}
|
||||
@ -124,7 +124,9 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
|
||||
}
|
||||
|
||||
if (offset + (count * (Py_ssize_t)sizeof(real)) > buffer.len) {
|
||||
PyErr_Format(PyExc_ValueError, "buffer is smaller than requested size");
|
||||
PyErr_Format(PyExc_ValueError, "buffer has only %ld elements after offset "
|
||||
"%ld, but specified a size of %ld", (long)(buffer.len - offset),
|
||||
(long)offset, (long)count);
|
||||
PyBuffer_Release(&buffer);
|
||||
return NULL;
|
||||
}
|
||||
@ -159,10 +161,8 @@ PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *file)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
int fd = PyObject_AsFileDescriptor(file);
|
||||
if (fd == -1) {
|
||||
THPUtils_setError("_write_file couln't retrieve file descriptor from given object");
|
||||
return NULL;
|
||||
}
|
||||
THPUtils_assert(fd != -1, "_write_file couldn't retrieve a file descriptor "
|
||||
"from given object");
|
||||
THPStorage_(writeFileRaw)(self->cdata, fd);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -172,10 +172,8 @@ PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *file)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
int fd = PyObject_AsFileDescriptor(file);
|
||||
if (fd == -1) {
|
||||
THPUtils_setError("_new_with_file couln't retrieve file descriptor from given object");
|
||||
return NULL;
|
||||
}
|
||||
THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file "
|
||||
"descriptor from given object");
|
||||
THStoragePtr storage = THPStorage_(readFileRaw)(fd);
|
||||
PyObject *result = THPStorage_(New)(storage);
|
||||
storage.release();
|
||||
@ -398,10 +396,7 @@ PyObject * THPStorage_(_sharedFd)(THPStorage *self)
|
||||
}
|
||||
}
|
||||
|
||||
if (!ctx) {
|
||||
THPUtils_setError("can't retrieve shared file descriptor");
|
||||
return NULL;
|
||||
}
|
||||
THPUtils_assert(ctx, "couldn't retrieve a shared file descriptor");
|
||||
return PyLong_FromLong(THMapAllocatorContext_fd(ctx));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -419,10 +414,9 @@ PyObject * THPStorage_(getDevice)(THPStorage *self)
|
||||
PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPUtils_checkLong(new_cdata)) {
|
||||
THPUtils_setError("invalid argument to _set_cdata - expected an int or long");
|
||||
return NULL;
|
||||
}
|
||||
THPUtils_assert(THPUtils_checkLong(new_cdata), "given an invalid argument to "
|
||||
"_set_cdata - expected an int or long, but got %s",
|
||||
THPUtils_typename(new_cdata));
|
||||
THStorage *ptr = (THStorage*)PyLong_AsVoidPtr(new_cdata);
|
||||
THStorage_(retain)(LIBRARY_STATE ptr);
|
||||
THStorage_(free)(LIBRARY_STATE self->cdata);
|
||||
|
@ -56,197 +56,239 @@ static void THPTensor_(dealloc)(THPTensor* self)
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
static std::string THPTensor_(indicesToString)(std::vector<size_t> &indices,
|
||||
size_t depth)
|
||||
{
|
||||
std::string index = "(";
|
||||
for (size_t i = 0; i <= depth; ++i) {
|
||||
index += std::to_string(indices[i]);
|
||||
index += ", ";
|
||||
}
|
||||
index.erase(index.length()-2); // Remove trailing ", "
|
||||
index += ")";
|
||||
return index;
|
||||
}
|
||||
|
||||
static void THPTensor_(setInconsistentDepthError)(std::vector<size_t> &sizes,
|
||||
std::vector<size_t> &indices, size_t depth, size_t length)
|
||||
{
|
||||
std::string error = "inconsistent sequence length at index ";
|
||||
error += THPTensor_(indicesToString)(indices, depth);
|
||||
error += " - expected ";
|
||||
error += std::to_string(sizes[depth]);
|
||||
error += " but got ";
|
||||
error += std::to_string(length);
|
||||
THPUtils_setError(error.c_str());
|
||||
}
|
||||
|
||||
static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject *cdata_arg = NULL; // keyword-only arg - cdata pointer value
|
||||
THLongStorage *sizes_arg = NULL; // a storage with sizes for a new tensor
|
||||
THTensor *tensor_arg = NULL; // a tensor to be viewed on
|
||||
// TODO: constructor from storage
|
||||
PyObject *iterable_arg = NULL; // an iterable, with new tensor contents
|
||||
std::vector<size_t> iterator_lengths; // a queue storing lengths of iterables at each depth
|
||||
bool args_ok = true;
|
||||
#ifdef NUMPY_TYPE_ENUM
|
||||
THPObjectPtr numpy_array = NULL;
|
||||
#endif
|
||||
|
||||
if (kwargs && PyDict_Size(kwargs) == 1) {
|
||||
cdata_arg = PyDict_GetItemString(kwargs, "cdata");
|
||||
args_ok = cdata_arg != NULL;
|
||||
} else if (args && PyTuple_Size(args) == 1) {
|
||||
PyObject *arg = PyTuple_GET_ITEM(args, 0);
|
||||
if (THPTensor_(Check)(arg)) {
|
||||
tensor_arg = ((THPTensor*)arg)->cdata;
|
||||
} else if (THPLongStorage_Check(arg)) {
|
||||
sizes_arg = ((THPLongStorage*)arg)->cdata;
|
||||
} else if (THPUtils_checkLong(arg)) {
|
||||
sizes_arg = THPUtils_getLongStorage(args);
|
||||
args_ok = sizes_arg != nullptr;
|
||||
#ifdef NUMPY_TYPE_ENUM
|
||||
} else if (PyArray_Check(arg) && PyArray_TYPE((PyArrayObject*)arg) == NUMPY_TYPE_ENUM) {
|
||||
numpy_array = PyArray_FromArray((PyArrayObject*)arg, nullptr, NPY_ARRAY_BEHAVED);
|
||||
args_ok = numpy_array != nullptr;
|
||||
#endif
|
||||
} else {
|
||||
iterable_arg = arg;
|
||||
Py_INCREF(arg);
|
||||
THPObjectPtr item = arg;
|
||||
THPObjectPtr iter;
|
||||
while ((iter = PyObject_GetIter(item)) != nullptr) {
|
||||
Py_ssize_t length = PyObject_Length(item);
|
||||
iterator_lengths.push_back(length);
|
||||
if (iterator_lengths.size() > 1000000) {
|
||||
THPUtils_setError("Counted more than 1,000,000 dimensions in a given iterable. "
|
||||
"Most likely your items are also iterable, and there's no "
|
||||
"way to infer how many dimensions should the tensor have.");
|
||||
return NULL;
|
||||
}
|
||||
// TODO length == 0 is an error too
|
||||
if (length == -1) {
|
||||
// TODO: error
|
||||
return NULL;
|
||||
}
|
||||
if (length > 0) {
|
||||
item = PyIter_Next(iter);
|
||||
if (item == nullptr) {
|
||||
// TODO: set error
|
||||
return NULL;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (iterator_lengths.size() > 1) {
|
||||
for (auto length: iterator_lengths) {
|
||||
if (length <= 0) {
|
||||
// TODO: error message
|
||||
THPUtils_setError("invalid size");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
args_ok = iterator_lengths.size() > 0;
|
||||
// We have accumulated some errors along the way.
|
||||
// Since we did all checking and ignored only the non-important
|
||||
// ones it's safe to clear them here.
|
||||
PyErr_Clear();
|
||||
}
|
||||
} else if (args && PyTuple_Size(args) > 0) {
|
||||
sizes_arg = THPUtils_getLongStorage(args);
|
||||
args_ok = sizes_arg != nullptr;
|
||||
}
|
||||
|
||||
if (!args_ok) {
|
||||
// TODO: nice error mossage
|
||||
THPUtils_setError("invalid arguments");
|
||||
return NULL;
|
||||
}
|
||||
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
|
||||
|
||||
THPTensorPtr self = (THPTensor *)type->tp_alloc(type, 0);
|
||||
if (self != nullptr) {
|
||||
if (cdata_arg) {
|
||||
self->cdata = (THTensor*)PyLong_AsVoidPtr(cdata_arg);
|
||||
} else if (sizes_arg) {
|
||||
self->cdata = THTensor_(newWithSize)(LIBRARY_STATE sizes_arg, nullptr);
|
||||
} else if (tensor_arg) {
|
||||
self->cdata = THTensor_(newWithTensor)(LIBRARY_STATE tensor_arg);
|
||||
THPUtils_assert(self, "failed to allocate a " THPTensorStr " object");
|
||||
self->cdata = NULL;
|
||||
|
||||
// Internally we allow constructing with a keywoard only argument cdata
|
||||
if (kwargs != NULL) {
|
||||
Py_ssize_t num_kwargs = PyDict_Size(kwargs);
|
||||
if (num_args == 0) {
|
||||
PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
|
||||
if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
|
||||
THTensor *ptr = (THTensor*)PyLong_AsVoidPtr(cdata_ptr);
|
||||
self->cdata = ptr;
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
}
|
||||
// This is an internal option, so we don't want to advertise it.
|
||||
THPUtils_assert(num_kwargs == 0, THPTensorStr " constructor doesn't "
|
||||
"accept any keyword arguments");
|
||||
}
|
||||
|
||||
// torch.Tensor()
|
||||
if (num_args == 0) {
|
||||
self->cdata = THTensor_(new)(LIBRARY_STATE_NOARGS);
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
|
||||
|
||||
// torch.Tensor(torch.Tensor tensor)
|
||||
if (num_args == 1 && THPTensor_(Check)(first_arg)) {
|
||||
THTensor *tensor = ((THPTensor*)first_arg)->cdata;
|
||||
self->cdata = THTensor_(newWithTensor)(LIBRARY_STATE tensor);
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
// torch.Tensor(torch.LongStorage sizes)
|
||||
if (num_args == 1 && THPLongStorage_Check(first_arg)) {
|
||||
THLongStorage *sizes = ((THPLongStorage*)first_arg)->cdata;
|
||||
self->cdata = THTensor_(newWithSize)(LIBRARY_STATE sizes, nullptr);
|
||||
return (PyObject *)self.release();
|
||||
}
|
||||
|
||||
// TODO: implement storageOffset, sizes and strides
|
||||
// torch.Tensor(torch.Storage data)
|
||||
if (num_args == 1 && THPStorage_(Check)(first_arg)) {
|
||||
THStorage *storage = ((THPStorage*)first_arg)->cdata;
|
||||
self->cdata = THTensor_(newWithStorage1d)(LIBRARY_STATE storage, 0, storage->size, -1);
|
||||
return (PyObject *)self.release();
|
||||
}
|
||||
|
||||
#ifdef NUMPY_TYPE_ENUM
|
||||
} else if (numpy_array) {
|
||||
self->cdata = THPTensor_(fromNumpy)(numpy_array.get());
|
||||
// torch.Tensor(np.ndarray array)
|
||||
if (num_args == 1 && PyArray_Check(first_arg) &&
|
||||
PyArray_TYPE((PyArrayObject*)first_arg) == NUMPY_TYPE_ENUM) {
|
||||
THPObjectPtr numpy_array =
|
||||
PyArray_FromArray((PyArrayObject*)first_arg, nullptr, NPY_ARRAY_BEHAVED);
|
||||
self->cdata = THPTensor_(fromNumpy)(numpy_array.get());
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
#endif
|
||||
} else if (iterable_arg && iterator_lengths.size() == 1 && iterator_lengths[0] == 0) {
|
||||
self->cdata = THTensor_(new)(LIBRARY_STATE_NOARGS);
|
||||
} else if (iterable_arg) {
|
||||
size_t iter_depth = iterator_lengths.size();
|
||||
std::stack<THPObjectPtr> iterator_stack;
|
||||
std::vector<size_t> items_processed(iter_depth);
|
||||
Py_INCREF(iterable_arg);
|
||||
THPObjectPtr item = iterable_arg;
|
||||
PyObject *iter;
|
||||
while (iterator_stack.size() != iter_depth) {
|
||||
iter = PyObject_GetIter(item);
|
||||
if (!iter) {
|
||||
THPUtils_setError("inconsistent iterator depth");
|
||||
return NULL;
|
||||
}
|
||||
iterator_stack.emplace(iter);
|
||||
item = PyIter_Next(iter);
|
||||
if (item == nullptr) {
|
||||
THPUtils_setError("error or empty iter");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
THLongStoragePtr sizes = THLongStorage_newWithSize(iter_depth);
|
||||
long *sizes_data = sizes->data;
|
||||
for (size_t s: iterator_lengths) {
|
||||
*sizes_data++ = s;
|
||||
}
|
||||
THTensorPtr tensor = THTensor_(newWithSize)(LIBRARY_STATE sizes, NULL);
|
||||
|
||||
// TODO: do cuda in one transfer
|
||||
#ifndef THC_GENERIC_FILE
|
||||
#define SET_ITEM *data++ = THPUtils_(unpackReal)(item)
|
||||
real *data = tensor->storage->data;
|
||||
#else
|
||||
#define SET_ITEM item_value = THPUtils_(unpackReal)(item); THStorage_(set)(LIBRARY_STATE storage, item_nr++, item_value)
|
||||
real item_value;
|
||||
size_t item_nr = 0;
|
||||
THStorage *storage = tensor->storage;
|
||||
#endif
|
||||
try {
|
||||
SET_ITEM;
|
||||
items_processed[iter_depth-1]++;
|
||||
|
||||
while (!iterator_stack.empty()) {
|
||||
PyObject *iter = iterator_stack.top().get();
|
||||
// Parse items
|
||||
if (iterator_stack.size() == iter_depth) {
|
||||
while ((item = PyIter_Next(iter))) {
|
||||
SET_ITEM;
|
||||
items_processed[iter_depth-1]++;
|
||||
}
|
||||
if (items_processed[iter_depth-1] != iterator_lengths[iter_depth-1]) {
|
||||
THPUtils_setError("inconsistent size");
|
||||
return NULL;
|
||||
}
|
||||
iterator_stack.pop(); // this deallocates the iter
|
||||
// Iterate on lower depths
|
||||
} else {
|
||||
item = PyIter_Next(iter);
|
||||
if (item == nullptr) {
|
||||
if (PyErr_Occurred())
|
||||
return NULL;
|
||||
if (items_processed[iterator_stack.size()-1]) {
|
||||
THPUtils_setError("inconsistent size");
|
||||
return NULL;
|
||||
}
|
||||
iterator_stack.pop(); // this deallocates the iter
|
||||
} else {
|
||||
PyObject *new_iter = PyObject_GetIter(item);
|
||||
if (!new_iter) {
|
||||
THPUtils_setError("non-iterable item");
|
||||
return NULL;
|
||||
}
|
||||
items_processed[iterator_stack.size()] = 0;
|
||||
iterator_stack.emplace(new_iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
if (std::string(e.what()).find("Could not parse real") != std::string::npos) {
|
||||
// TODO: error message
|
||||
THPUtils_setError("Expected an iterable of numbers!");
|
||||
}
|
||||
}
|
||||
self->cdata = tensor.release();
|
||||
} else {
|
||||
// torch.Tensor(Sequence data)
|
||||
if (num_args == 1 && PySequence_Check(first_arg)) {
|
||||
Py_ssize_t length = PySequence_Length(first_arg);
|
||||
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(first_arg));
|
||||
if (length == 0) {
|
||||
self->cdata = THTensor_(new)(LIBRARY_STATE_NOARGS);
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
if (self->cdata == NULL)
|
||||
return NULL;
|
||||
Py_INCREF(first_arg);
|
||||
THPObjectPtr item = first_arg;
|
||||
std::vector<size_t> sizes;
|
||||
while ((length = PySequence_Length(item)) >= 0) {
|
||||
sizes.push_back(length);
|
||||
// TODO: check for string in this case
|
||||
THPUtils_assert(sizes.size() < 1000000, "already counted a million "
|
||||
"dimensions in a given sequence. Most likely your items are also "
|
||||
"sequences and there's no way to infer how many dimension should "
|
||||
"the tensor have");
|
||||
THPUtils_assert(length > 0, "given sequence has an invalid size of "
|
||||
"dimension %ld: %ld", (long)sizes.size(), (long)length);
|
||||
item = PySequence_GetItem(item, 0);
|
||||
if (!item)
|
||||
return NULL;
|
||||
}
|
||||
// Last length check has set an error flag, so we need to clear it.
|
||||
PyErr_Clear();
|
||||
|
||||
THLongStoragePtr sizes_storage = THLongStorage_newWithSize(sizes.size());
|
||||
long *sizes_data = sizes_storage->data;
|
||||
for (auto size: sizes)
|
||||
*sizes_data++ = size;
|
||||
THTensorPtr tensor = THTensor_(newWithSize)(LIBRARY_STATE sizes_storage, NULL);
|
||||
|
||||
int ndims = sizes.size();
|
||||
std::vector<size_t> indices(ndims);
|
||||
std::vector<THPObjectPtr> sequences(ndims);
|
||||
Py_INCREF(first_arg);
|
||||
item = first_arg;
|
||||
for (size_t i = 0; i < sequences.size(); i++) {
|
||||
PyObject *item_ptr = item.get();
|
||||
sequences[i] = std::move(item);
|
||||
if (i < sequences.size()-1) {
|
||||
item = PySequence_ITEM(item_ptr, 0);
|
||||
if (!item)
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef THC_GENERIC_FILE
|
||||
#define SET_ITEM *data++ = THPUtils_(unpackReal)(item)
|
||||
real *data = tensor->storage->data;
|
||||
#else
|
||||
#define SET_ITEM item_value = THPUtils_(unpackReal)(item); THStorage_(set)(LIBRARY_STATE storage, item_nr++, item_value)
|
||||
real item_value;
|
||||
size_t item_nr = 0;
|
||||
THStorage *storage = tensor->storage;
|
||||
#endif
|
||||
THPObjectPtr final_sequence;
|
||||
while (true) {
|
||||
final_sequence = std::move(sequences[ndims-1]);
|
||||
try {
|
||||
// We're taking a fast-track over the last dimension
|
||||
for (size_t i = 0; i < sizes[ndims-1]; i++) {
|
||||
indices[ndims-1] = i;
|
||||
item = PySequence_ITEM(final_sequence, i);
|
||||
// We've checked the length earlier, so it must have been an error
|
||||
if (!item)
|
||||
return NULL;
|
||||
SET_ITEM;
|
||||
}
|
||||
} catch(std::runtime_error &e) {
|
||||
std::string index = THPTensor_(indicesToString)(indices, ndims-1);
|
||||
THPUtils_setError("tried to construct a tensor from a %s%s sequence, "
|
||||
"but found an item of type %s at index %s",
|
||||
(ndims > 1 ? "nested " : ""),
|
||||
THPUtils_typeTraits<real>::python_type_str,
|
||||
THPUtils_typename(item.get()),
|
||||
index.c_str());
|
||||
return NULL;
|
||||
}
|
||||
#undef SET_ITEM
|
||||
|
||||
// Update the counters
|
||||
int dim = ndims-2;
|
||||
size_t last_updated_dim = dim;
|
||||
while (dim >= 0) {
|
||||
last_updated_dim = dim;
|
||||
if (++indices[dim] == sizes[dim])
|
||||
indices[dim--] = 0;
|
||||
else
|
||||
break;
|
||||
}
|
||||
// Check if we've just made a full cycle
|
||||
if ((last_updated_dim == 0 && indices[0] == 0) || ndims == 1)
|
||||
break;
|
||||
// Update sequences
|
||||
for (int i = last_updated_dim+1; i < ndims; i++) {
|
||||
sequences[i] = PySequence_ITEM(sequences[i-1], indices[i-1]);
|
||||
if (!sequences[i]) {
|
||||
THPTensor_(setInconsistentDepthError)(sizes, indices, i, indices[i]);
|
||||
return NULL;
|
||||
}
|
||||
if (!PySequence_Check(sequences[i])) {
|
||||
std::string index_str = THPTensor_(indicesToString)(indices, i);
|
||||
THPUtils_setError("an item of time %s at index %s doesn't implement "
|
||||
"a sequence protocol");
|
||||
return NULL;
|
||||
}
|
||||
Py_ssize_t length = PySequence_Length(sequences[i]);
|
||||
if (length < 0) {
|
||||
std::string index_str = THPTensor_(indicesToString)(indices, i);
|
||||
THPUtils_setError("could not obtain a length of %s at index %s",
|
||||
THPUtils_typename(sequences[i].get()), index_str.c_str());
|
||||
return NULL;
|
||||
}
|
||||
if ((size_t)length != sizes[i]) {
|
||||
THPTensor_(setInconsistentDepthError)(sizes, indices, i, length);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
self->cdata = tensor.release();
|
||||
return (PyObject *)self.release();
|
||||
}
|
||||
return (PyObject *)self.release();
|
||||
|
||||
// torch.Tensor(int ...)
|
||||
try {
|
||||
THLongStoragePtr sizes = THPUtils_getLongStorage(args);
|
||||
self->cdata = THTensor_(newWithSize)(LIBRARY_STATE sizes, nullptr);
|
||||
return (PyObject *)self.release();
|
||||
} catch(std::exception &e) {};
|
||||
|
||||
THPUtils_invalidArguments(args, THPTensorStr " constructor", 6,
|
||||
"no arguments",
|
||||
"(int ...)",
|
||||
"(" THPTensorStr " viewed_tensor)",
|
||||
"(torch.LongStorage sizes)",
|
||||
"(" THPStorageStr " data)",
|
||||
"(Sequence data)");
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -255,8 +297,9 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
|
||||
long dimsize = THTensor_(size)(LIBRARY_STATE TENSOR_VARIABLE, DIM); \
|
||||
idx = (idx < 0) ? dimsize + idx : idx; \
|
||||
\
|
||||
THArgCheck(dimsize > 0, 1, "empty tensor"); \
|
||||
THArgCheck(idx >= 0 && idx < dimsize, 2, "out of range"); \
|
||||
THPUtils_assert(dimsize > 0, "indexing an empty tensor"); \
|
||||
THPUtils_assert(idx >= 0 && idx < dimsize, "index %ld is out of range for " \
|
||||
"dimension %ld (of size %ld)", idx, DIM, dimsize); \
|
||||
\
|
||||
if(THTensor_(nDimension)(LIBRARY_STATE TENSOR_VARIABLE) == 1) { \
|
||||
CASE_1D; \
|
||||
@ -284,17 +327,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
tresult = THTensor_(newWithTensor)(LIBRARY_STATE self_t);
|
||||
THTensor_(select)(LIBRARY_STATE tresult, NULL, 0, idx)
|
||||
)
|
||||
// Indexing with a single element tuple
|
||||
} else if (PyTuple_Check(index) &&
|
||||
PyTuple_Size(index) == 1 &&
|
||||
(PyLong_Check(PyTuple_GET_ITEM(index, 0))
|
||||
|| PyInt_Check(PyTuple_GET_ITEM(index, 0)))) {
|
||||
PyObject *index_obj = PyTuple_GET_ITEM(index, 0);
|
||||
tresult = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata);
|
||||
INDEX_LONG(0, index_obj, tresult,
|
||||
THTensor_(narrow)(LIBRARY_STATE tresult, NULL, 0, idx, 1),
|
||||
THTensor_(narrow)(LIBRARY_STATE tresult, NULL, 0, idx, 1)
|
||||
)
|
||||
return true;
|
||||
// Indexing with a slice
|
||||
} else if (PySlice_Check(index)) {
|
||||
tresult = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata);
|
||||
@ -302,14 +335,18 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
if (!THPUtils_parseSlice(index, THTensor_(size)(LIBRARY_STATE tresult, 0), &start, &end, &length))
|
||||
return false;
|
||||
THTensor_(narrow)(LIBRARY_STATE tresult, NULL, 0, start, length);
|
||||
return true;
|
||||
// Indexing multiple dimensions
|
||||
} else if(PyTuple_Check(index)) {
|
||||
THArgCheck(PyTuple_Size(index) <= THTensor_(nDimension)(LIBRARY_STATE self->cdata), 2,
|
||||
"Indexing too many dimensions");
|
||||
long num_index_dim = (long)PyTuple_Size(index);
|
||||
long num_tensor_dim = THTensor_(nDimension)(LIBRARY_STATE self->cdata);
|
||||
THPUtils_assert(num_index_dim <= num_tensor_dim, "trying to index %ld "
|
||||
"dimensions of a %ld dimensional tensor", num_index_dim,
|
||||
num_tensor_dim);
|
||||
|
||||
tresult = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata);
|
||||
int t_dim = 0;
|
||||
|
||||
for(int dim = 0; dim < PyTuple_Size(index); dim++) {
|
||||
for(int dim = 0; dim < num_index_dim; dim++) {
|
||||
PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
|
||||
if(THPUtils_checkLong(dimidx)) {
|
||||
INDEX_LONG(t_dim, dimidx, tresult,
|
||||
@ -322,46 +359,39 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
|
||||
// >1D tensor
|
||||
THTensor_(select)(LIBRARY_STATE tresult, NULL, t_dim, idx)
|
||||
)
|
||||
} else if (PyTuple_Check(dimidx)) {
|
||||
long length = 1;
|
||||
if (PyTuple_Size(dimidx) == 0 || PyTuple_Size(dimidx) > 2 || !THPUtils_checkLong(PyTuple_GET_ITEM(dimidx, 0))) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Expected one or two integers");
|
||||
return false;
|
||||
}
|
||||
PyObject *index_obj = PyTuple_GET_ITEM(dimidx, 0);
|
||||
if (PyTuple_Size(dimidx) == 2) {
|
||||
long idx;
|
||||
if (!THPUtils_checkLong(PyTuple_GET_ITEM(dimidx, 1))) {
|
||||
THPUtils_setError("Expected one or two intetegers");
|
||||
return false;
|
||||
}
|
||||
idx = THPUtils_unpackLong(index_obj);
|
||||
length = THPUtils_unpackLong(PyTuple_GET_ITEM(dimidx, 1));
|
||||
length -= idx;
|
||||
}
|
||||
INDEX_LONG(t_dim, index_obj, tresult,
|
||||
THTensor_(narrow)(LIBRARY_STATE tresult, NULL, t_dim++, idx, length),
|
||||
THTensor_(narrow)(LIBRARY_STATE tresult, NULL, t_dim++, idx, length)
|
||||
)
|
||||
} else if (PySlice_Check(dimidx)) {
|
||||
Py_ssize_t start, end, length;
|
||||
if (!THPUtils_parseSlice(dimidx, THTensor_(size)(LIBRARY_STATE tresult, t_dim), &start, &end, &length))
|
||||
return false;
|
||||
THTensor_(narrow)(LIBRARY_STATE tresult, NULL, t_dim++, start, length);
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Slicing with an unsupported type");
|
||||
return false;
|
||||
THTensor_(free)(LIBRARY_STATE tresult);
|
||||
goto invalid_index_type;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
} catch(...) {
|
||||
THTensor_(free)(LIBRARY_STATE tresult);
|
||||
if (tresult) {
|
||||
THTensor_(free)(LIBRARY_STATE tresult);
|
||||
tresult = NULL;
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
invalid_index_type:
|
||||
THPUtils_setError("indexing a tensor with an object of type %s. The only "
|
||||
"supported types are integers, slices and "
|
||||
#ifndef THC_GENERIC_FILE
|
||||
"torch.ByteTensor.",
|
||||
#else
|
||||
"torch.cuda.ByteTensor.",
|
||||
#endif
|
||||
THPUtils_typename(index));
|
||||
return false;
|
||||
}
|
||||
#undef INDEX_LONG
|
||||
#undef GET_PTR_1D
|
||||
#undef GET_OFFSET
|
||||
|
||||
static PyObject * THPTensor_(getValue)(THPTensor *self, PyObject *index)
|
||||
{
|
||||
@ -371,99 +401,96 @@ static PyObject * THPTensor_(getValue)(THPTensor *self, PyObject *index)
|
||||
THTensor *t = THTensor_(new)(LIBRARY_STATE_NOARGS);
|
||||
THTensor_(maskedSelect)(LIBRARY_STATE t, self->cdata, ((THPByteTensor*)index)->cdata);
|
||||
return THPTensor_(New)(t);
|
||||
#elif defined(THC_REAL_IS_FLOAT)
|
||||
}
|
||||
#else
|
||||
if(THCPByteTensor_Check(index)) {
|
||||
THTensor *t = THTensor_(new)(LIBRARY_STATE_NOARGS);
|
||||
THTensor_(maskedSelect)(LIBRARY_STATE t, self->cdata, ((THCPByteTensor*)index)->cdata);
|
||||
return THPTensor_(New)(t);
|
||||
#else
|
||||
if (false) {
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
THTensor *tresult; // TODO: free on error
|
||||
THStorage *sresult;
|
||||
long storage_offset;
|
||||
if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset))
|
||||
return NULL;
|
||||
|
||||
THTensor *tresult;
|
||||
THStorage *sresult;
|
||||
long storage_offset;
|
||||
if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset))
|
||||
return NULL;
|
||||
try {
|
||||
if (tresult)
|
||||
return THPTensor_(New)(tresult);
|
||||
if (sresult)
|
||||
return THPUtils_(newReal)(THStorage_(get)(LIBRARY_STATE sresult, storage_offset));
|
||||
char err_string[512];
|
||||
snprintf (err_string, 512,
|
||||
"%s %s", "Unknown exception in THPTensor_(getValue). Index type is: ",
|
||||
index->ob_type->tp_name);
|
||||
PyErr_SetString(PyExc_RuntimeError, err_string);
|
||||
return NULL;
|
||||
} catch (...) {
|
||||
if (tresult) {
|
||||
THTensor_(free)(LIBRARY_STATE tresult);
|
||||
tresult = NULL;
|
||||
}
|
||||
throw;
|
||||
}
|
||||
THPUtils_setError("An unknown error has occured when indexing a tensor "
|
||||
"in THPTensor_(getValue). Please report this in a github issue at: "
|
||||
"https://github.com/pytorch/pytorch");
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
//extern PyObject * THPTensor_(copy)(THPTensor *self, PyObject *other);
|
||||
int THPTensor_(setValue)(THPTensor *self, PyObject *index, PyObject *value)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#if !defined(THC_GENERIC_FILE) || defined(THC_REAL_IS_FLOAT)
|
||||
#ifdef THC_REAL_IS_FLOAT
|
||||
if (THCPByteTensor_Check(index)) {
|
||||
THCPByteTensor *mask = (THCPByteTensor*)index;
|
||||
#else
|
||||
#ifndef THC_GENERIC_FILE
|
||||
if (THPByteTensor_Check(index)) {
|
||||
THPByteTensor *mask = (THPByteTensor*)index;
|
||||
#else
|
||||
if (THCPByteTensor_Check(index)) {
|
||||
THCPByteTensor *mask = (THCPByteTensor*)index;
|
||||
#endif
|
||||
if (THPUtils_(checkReal)(value)) {
|
||||
if (!THPUtils_(checkReal)(value)) {
|
||||
// TODO: better error message
|
||||
THPUtils_setError("TODO");
|
||||
return -1;
|
||||
}
|
||||
real v = THPUtils_(unpackReal)(value);
|
||||
THTensor_(maskedFill)(LIBRARY_STATE self->cdata, mask->cdata, v);
|
||||
} else if (THPTensor_(Check)(value)) {
|
||||
THTensor_(maskedCopy)(LIBRARY_STATE self->cdata, mask->cdata, ((THPTensor*)value)->cdata);
|
||||
} else {
|
||||
THError("number or Tensor expected");
|
||||
}
|
||||
#else
|
||||
if (false) {
|
||||
#endif
|
||||
} else {
|
||||
THTensor *tresult;
|
||||
THStorage *sresult;
|
||||
long storage_offset;
|
||||
real v;
|
||||
if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset))
|
||||
return -1;
|
||||
|
||||
THTensorPtr tresult_ptr = tresult;
|
||||
if (sresult) {
|
||||
if (!THPUtils_(checkReal)(value)) {
|
||||
// TODO: better error message
|
||||
THPUtils_setError("TODO");
|
||||
return -1;
|
||||
}
|
||||
v = THPUtils_(unpackReal)(value);
|
||||
THStorage_(set)(LIBRARY_STATE sresult, storage_offset, v);
|
||||
} else if (tresult) {
|
||||
if (THPUtils_(checkReal)(value)) {
|
||||
v = THPUtils_(unpackReal)(value);
|
||||
THTensor_(fill)(LIBRARY_STATE tresult, v);
|
||||
} else {
|
||||
// TODO: try to do this without creating a temporary object
|
||||
THPTensorPtr tmp = (THPTensor*)THPTensor_(New)(tresult_ptr.get());
|
||||
if (!tmp)
|
||||
return -1;
|
||||
tresult_ptr.release();
|
||||
if (!THPModule_tensorCopy((PyObject*)tmp.get(), value))
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
// TODO: error message
|
||||
THPUtils_setError("error");
|
||||
return -1;
|
||||
THPUtils_setError("can't assign %s to a " THPTensorStr " using a mask "
|
||||
"(only " THPTensorStr " or %s are supported)",
|
||||
THPUtils_typename(value), THPUtils_typeTraits<real>::python_type_str);
|
||||
// TODO
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
|
||||
THTensor *tresult;
|
||||
THStorage *sresult;
|
||||
long storage_offset;
|
||||
if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset))
|
||||
return -1;
|
||||
|
||||
THTensorPtr tresult_ptr = tresult;
|
||||
if (sresult) {
|
||||
if (!THPUtils_(checkReal)(value)) {
|
||||
THPUtils_setError("can't assign a %s to a scalar value of type %s",
|
||||
THPUtils_typename(value), THPUtils_typeTraits<real>::python_type_str);
|
||||
return -1;
|
||||
}
|
||||
THStorage_(set)(LIBRARY_STATE sresult, storage_offset, THPUtils_(unpackReal)(value));
|
||||
return 0;
|
||||
} else if (tresult) {
|
||||
if (THPUtils_(checkReal)(value)) {
|
||||
THTensor_(fill)(LIBRARY_STATE tresult, THPUtils_(unpackReal)(value));
|
||||
} else {
|
||||
// TODO: try to do this without creating a temporary object
|
||||
THPTensorPtr tmp = (THPTensor*)THPTensor_(New)(tresult_ptr.get());
|
||||
if (!tmp)
|
||||
return -1;
|
||||
tresult_ptr.release();
|
||||
if (!THPModule_tensorCopy((PyObject*)tmp.get(), value))
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
THPUtils_setError("An unknown error has occured when indexing a tensor "
|
||||
"in THPTensor_(setValue). Please report this in a github issue at: "
|
||||
"https://github.com/pytorch/pytorch");
|
||||
return -1;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
|
@ -10,4 +10,17 @@ typedef class THPPointer<THTensor> THTensorPtr;
|
||||
typedef class THPPointer<THPStorage> THPStoragePtr;
|
||||
typedef class THPPointer<THPTensor> THPTensorPtr;
|
||||
|
||||
#if !defined(THC_GENERIC_FILE) || defined(THC_REAL_IS_HALF)
|
||||
template<>
|
||||
struct THPUtils_typeTraits<real> {
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || \
|
||||
defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || \
|
||||
defined(THC_REAL_IS_HALF)
|
||||
static constexpr char *python_type_str = "float";
|
||||
#else
|
||||
static constexpr char *python_type_str = "int";
|
||||
#endif
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -16,12 +16,10 @@ int THPUtils_getCallable(PyObject *arg, PyObject **result) {
|
||||
|
||||
|
||||
THLongStorage * THPUtils_getLongStorage(PyObject *args, int ignore_first) {
|
||||
// TODO: error messages
|
||||
long value;
|
||||
|
||||
Py_ssize_t length = PyTuple_Size(args);
|
||||
if (length < ignore_first+1)
|
||||
throw std::logic_error("Provided too few arguments");
|
||||
throw std::runtime_error("Provided " + std::to_string(length) +
|
||||
" arguments, but expected at least " + std::to_string(ignore_first+1));
|
||||
|
||||
// Maybe there's a LongStorage
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, ignore_first);
|
||||
@ -36,9 +34,10 @@ THLongStorage * THPUtils_getLongStorage(PyObject *args, int ignore_first) {
|
||||
for (Py_ssize_t i = ignore_first; i < length; ++i) {
|
||||
PyObject *arg = PyTuple_GET_ITEM(args, i);
|
||||
if (!THPUtils_checkLong(arg))
|
||||
throw std::invalid_argument("Expected a numeric argument, but got " + std::string(Py_TYPE(arg)->tp_name));
|
||||
value = THPUtils_unpackLong(arg);
|
||||
result->data[i-ignore_first] = value;
|
||||
throw std::invalid_argument("Expected an int argument, but got " +
|
||||
std::string(THPUtils_typename(arg)) + "at position " +
|
||||
std::to_string(i));
|
||||
result->data[i-ignore_first] = THPUtils_unpackLong(arg);
|
||||
}
|
||||
return result.release();
|
||||
}
|
||||
@ -134,11 +133,11 @@ bool THPUtils_parseSlice(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py
|
||||
(PySliceObject *)slice,
|
||||
#endif
|
||||
len, &start, &stop, &step, &slicelength) < 0) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Got an invalid slice");
|
||||
return false;
|
||||
}
|
||||
if (step != 1) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Only step == 1 supported");
|
||||
THPUtils_setError("Trying to slice with a step of %ld, but only a step of "
|
||||
"1 is supported", (long)step);
|
||||
return false;
|
||||
}
|
||||
*ostart = start;
|
||||
|
@ -6,6 +6,8 @@
|
||||
|
||||
#define THPUtils_(NAME) TH_CONCAT_4(THP,Real,Utils_,NAME)
|
||||
|
||||
#define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name)
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
#define THPUtils_checkLong(obj) (PyLong_Check(obj) || PyInt_Check(obj))
|
||||
@ -151,6 +153,7 @@ public:
|
||||
T * release() { T *tmp = ptr; ptr = NULL; return tmp; }
|
||||
operator T*() { return ptr; }
|
||||
THPPointer& operator =(T *new_ptr) { free(); ptr = new_ptr; return *this; }
|
||||
THPPointer& operator =(THPPointer &&p) { free(); ptr = p.ptr; p.ptr = nullptr; return *this; }
|
||||
T * operator ->() { return ptr; }
|
||||
operator bool() { return ptr != nullptr; }
|
||||
|
||||
@ -162,6 +165,9 @@ private:
|
||||
typedef THPPointer<PyObject> THPObjectPtr;
|
||||
typedef THPPointer<THPGenerator> THPGeneratorPtr;
|
||||
|
||||
template <typename T>
|
||||
struct THPUtils_typeTraits {};
|
||||
|
||||
#include "generic/utils.h"
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
|
@ -51,11 +51,11 @@ class ClassSimplexCriterion(MSECriterion):
|
||||
if k == 0:
|
||||
a[k][k] = 1
|
||||
else:
|
||||
a[k][k] = math.sqrt(1 - a[(k,), (0, k)].norm()**2)
|
||||
a[k][k] = math.sqrt(1 - a[k:k+1, 0:k+1].norm()**2)
|
||||
|
||||
# fill_ the k-th coordinates for the vectors of the remaining vertices
|
||||
c = (a[k][k]**2 - 1 - 1/n) / a[k][k]
|
||||
a[(k+1, n+1), (k,)].fill_(c)
|
||||
a[k+1:n+2, k:k+1].fill_(c)
|
||||
|
||||
return a
|
||||
|
||||
|
Reference in New Issue
Block a user