mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
940 lines
28 KiB
C
940 lines
28 KiB
C
#ifndef TH_GENERIC_FILE
|
|
#define TH_GENERIC_FILE "generic/Tensor.c"
|
|
#else
|
|
|
|
static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride,
|
|
THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_);
|
|
|
|
static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_);
|
|
|
|
static int torch_Tensor_(size)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
if(lua_isnumber(L,2))
|
|
{
|
|
int dim = luaL_checkint(L, 2)-1;
|
|
luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range");
|
|
lua_pushnumber(L, tensor->size[dim]);
|
|
}
|
|
else
|
|
{
|
|
THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension);
|
|
memmove(storage->data, tensor->size, sizeof(long)*tensor->nDimension);
|
|
luaT_pushudata(L, storage, torch_LongStorage_id);
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(stride)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
if(lua_isnumber(L,2))
|
|
{
|
|
int dim = luaL_checkint(L, 2)-1;
|
|
luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range");
|
|
lua_pushnumber(L, tensor->stride[dim]);
|
|
}
|
|
else
|
|
{
|
|
THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension);
|
|
memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension);
|
|
luaT_pushudata(L, storage, torch_LongStorage_id);
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(nDimension)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
lua_pushnumber(L, tensor->nDimension);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(storage)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
|
|
if(tensor->storage)
|
|
{
|
|
THStorage_(retain)(tensor->storage);
|
|
luaT_pushudata(L, tensor->storage, torch_Storage_id);
|
|
}
|
|
else
|
|
lua_pushnil(L);
|
|
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(storageOffset)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
lua_pushnumber(L, tensor->storageOffset+1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(new)(lua_State *L)
|
|
{
|
|
THTensor *tensor;
|
|
long storageOffset;
|
|
THLongStorage *size, *stride;
|
|
|
|
if(lua_type(L, 1) == LUA_TTABLE)
|
|
{
|
|
long i, j;
|
|
THLongStorage *counter;
|
|
long si = 0;
|
|
int dimension = 0;
|
|
int is_finished = 0;
|
|
|
|
lua_settop(L, 1);
|
|
size = THLongStorage_new();
|
|
|
|
while( (lua_type(L, -1) == LUA_TTABLE) && (lua_objlen(L, -1) > 0) )
|
|
{
|
|
THLongStorage_resize(size, dimension+1);
|
|
size->data[dimension] = lua_objlen(L, -1);
|
|
dimension++;
|
|
lua_rawgeti(L, -1, 1);
|
|
}
|
|
lua_pop(L, 1);
|
|
|
|
counter = THLongStorage_newWithSize(size->size);
|
|
THLongStorage_fill(counter, 0);
|
|
|
|
tensor = THTensor_(newWithSize)(size, NULL);
|
|
|
|
if(size->size == 0)
|
|
is_finished = 1;
|
|
|
|
while(!is_finished)
|
|
{
|
|
if(!lua_istable(L, -1))
|
|
{
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(counter);
|
|
THTensor_(free)(tensor);
|
|
luaL_error(L, "invalid tensor definition");
|
|
}
|
|
|
|
if(lua_objlen(L, -1) != size->data[size->size-1])
|
|
{
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(counter);
|
|
THTensor_(free)(tensor);
|
|
luaL_error(L, "invalid tensor sizes");
|
|
}
|
|
|
|
for(i = 0; i < size->data[size->size-1]; i++)
|
|
{
|
|
lua_rawgeti(L, -1, i+1);
|
|
if(!lua_isnumber(L, -1))
|
|
{
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(counter);
|
|
THTensor_(free)(tensor);
|
|
luaL_error(L, "invalid element (not a number)");
|
|
}
|
|
THStorage_(set)(THTensor_(storage)(tensor), si++, (real)lua_tonumber(L, -1));
|
|
lua_pop(L, 1);
|
|
}
|
|
|
|
if(size->size == 1)
|
|
break;
|
|
|
|
for(i = size->size-2; i >= 0; i--)
|
|
{
|
|
if(++counter->data[i] == size->data[i])
|
|
{
|
|
if(i == 0)
|
|
{
|
|
is_finished = 1;
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
counter->data[i] = 0;
|
|
lua_pop(L, 1);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
lua_pop(L, 1);
|
|
for(j = i; j < size->size-1; j++)
|
|
{
|
|
if(!lua_istable(L, -1))
|
|
{
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(counter);
|
|
THTensor_(free)(tensor);
|
|
luaL_error(L, "invalid tensor definition");
|
|
}
|
|
if(lua_objlen(L, -1) != size->data[j])
|
|
{
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(counter);
|
|
THTensor_(free)(tensor);
|
|
luaL_error(L, "invalid tensor sizes");
|
|
}
|
|
lua_rawgeti(L, -1, counter->data[j]+1);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(counter);
|
|
}
|
|
else
|
|
{
|
|
THStorage *storage;
|
|
|
|
torch_Tensor_(c_readTensorStorageSizeStride)(L, 1, 1, 1, 1, 1,
|
|
&storage, &storageOffset, &size, &stride);
|
|
|
|
tensor = THTensor_(newWithStorage)(storage, storageOffset, size, stride);
|
|
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(stride);
|
|
}
|
|
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(set)(lua_State *L)
|
|
{
|
|
THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THStorage *storage;
|
|
long storageOffset;
|
|
THLongStorage *size, *stride;
|
|
|
|
torch_Tensor_(c_readTensorStorageSizeStride)(L, 2, 1, 1, 1, 1,
|
|
&storage, &storageOffset, &size, &stride);
|
|
|
|
THTensor_(setStorage)(self, storage, storageOffset, size, stride);
|
|
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(stride);
|
|
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(clone)(lua_State *L)
|
|
{
|
|
THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
self = THTensor_(newClone)(self);
|
|
luaT_pushudata(L, self, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(contiguous)(lua_State *L)
|
|
{
|
|
THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
self = THTensor_(newContiguous)(self);
|
|
luaT_pushudata(L, self, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
/* Resize */
|
|
static int torch_Tensor_(resizeAs)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id);
|
|
THTensor_(resizeAs)(tensor, src);
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(resize)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THLongStorage *size, *stride;
|
|
|
|
torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride);
|
|
|
|
THTensor_(resize)(tensor, size, stride);
|
|
|
|
THLongStorage_free(size);
|
|
THLongStorage_free(stride);
|
|
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(narrow)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
int dimension = luaL_checkint(L, 2)-1;
|
|
long firstIndex = luaL_checklong(L, 3)-1;
|
|
long size = luaL_checklong(L, 4);
|
|
|
|
/* THArgCheck( (dimension >= 0) && (dimension < tensor->nDimension), 2, "out of range");
|
|
THArgCheck( (firstIndex >= 0) && (firstIndex < tensor->size[dimension]), 3, "out of range");
|
|
THArgCheck( (size > 0) && (firstIndex+size <= tensor->size[dimension]), 4, "out of range");
|
|
*/
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(sub)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1;
|
|
|
|
d0s = luaL_checklong(L, 2)-1;
|
|
d0e = luaL_checklong(L, 3)-1;
|
|
if(d0s < 0)
|
|
d0s += tensor->size[0]+1;
|
|
if(d0e < 0)
|
|
d0e += tensor->size[0]+1;
|
|
luaL_argcheck(L, tensor->nDimension > 0, 2, "invalid dimension");
|
|
luaL_argcheck(L, d0s >= 0 && d0s < tensor->size[0], 2, "out of range");
|
|
luaL_argcheck(L, d0e >= 0 && d0e < tensor->size[0], 3, "out of range");
|
|
luaL_argcheck(L, d0e >= d0s, 3, "end smaller than beginning");
|
|
|
|
if(!lua_isnone(L, 4))
|
|
{
|
|
d1s = luaL_checklong(L, 4)-1;
|
|
d1e = luaL_checklong(L, 5)-1;
|
|
if(d1s < 0)
|
|
d1s += tensor->size[1]+1;
|
|
if(d1e < 0)
|
|
d1e += tensor->size[1]+1;
|
|
luaL_argcheck(L, tensor->nDimension > 1, 4, "invalid dimension");
|
|
luaL_argcheck(L, d1s >= 0 && d1s < tensor->size[1], 4, "out of range");
|
|
luaL_argcheck(L, d1e >= 0 && d1e < tensor->size[1], 5, "out of range");
|
|
luaL_argcheck(L, d1e >= d1s, 5, "end smaller than beginning");
|
|
|
|
if(!lua_isnone(L, 6))
|
|
{
|
|
d2s = luaL_checklong(L, 6)-1;
|
|
d2e = luaL_checklong(L, 7)-1;
|
|
if(d2s < 0)
|
|
d2s += tensor->size[2]+1;
|
|
if(d2e < 0)
|
|
d2e += tensor->size[2]+1;
|
|
luaL_argcheck(L, tensor->nDimension > 2, 6, "invalid dimension");
|
|
luaL_argcheck(L, d2s >= 0 && d2s < tensor->size[2], 6, "out of range");
|
|
luaL_argcheck(L, d2e >= 0 && d2e < tensor->size[2], 7, "out of range");
|
|
luaL_argcheck(L, d2e >= d2s, 7, "end smaller than beginning");
|
|
|
|
if(!lua_isnone(L, 8))
|
|
{
|
|
d3s = luaL_checklong(L, 8)-1;
|
|
d3e = luaL_checklong(L, 9)-1;
|
|
if(d3s < 0)
|
|
d3s += tensor->size[3]+1;
|
|
if(d3e < 0)
|
|
d3e += tensor->size[3]+1;
|
|
luaL_argcheck(L, tensor->nDimension > 3, 8, "invalid dimension");
|
|
luaL_argcheck(L, d3s >= 0 && d3s < tensor->size[3], 8, "out of range");
|
|
luaL_argcheck(L, d3e >= 0 && d3e < tensor->size[3], 9, "out of range");
|
|
luaL_argcheck(L, d3e >= d3s, 9, "end smaller than beginning");
|
|
}
|
|
}
|
|
}
|
|
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, d0s, d0e-d0s+1);
|
|
if(d1s >= 0)
|
|
THTensor_(narrow)(tensor, NULL, 1, d1s, d1e-d1s+1);
|
|
if(d2s >= 0)
|
|
THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1);
|
|
if(d3s >= 0)
|
|
THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(select)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
int dimension = luaL_checkint(L, 2)-1;
|
|
long sliceIndex = luaL_checklong(L, 3)-1;
|
|
|
|
/* THArgCheck(src->nDimension > 1, 1, "cannot select on a vector");
|
|
THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range");
|
|
THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 3, "out of range");
|
|
*/
|
|
|
|
if(tensor->nDimension > 1)
|
|
{
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(select)(tensor, NULL, dimension, sliceIndex);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
}
|
|
else
|
|
{
|
|
THArgCheck(tensor->nDimension == 1, 1, "empty Tensor");
|
|
lua_pushnumber(L, THTensor_(get1d)(tensor, sliceIndex));
|
|
}
|
|
|
|
return 1;
|
|
}
|
|
|
|
|
|
static int torch_Tensor_(transpose)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
int dimension1 = luaL_checkint(L, 2)-1;
|
|
int dimension2 = luaL_checkint(L, 3)-1;
|
|
|
|
/*
|
|
THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 2, "out of range");
|
|
THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 3, "out of range");
|
|
*/
|
|
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(transpose)(tensor, NULL, dimension1, dimension2);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(t)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
|
|
luaL_argcheck(L, tensor->nDimension == 2, 1, "Tensor must have 2 dimensions");
|
|
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(transpose)(tensor, NULL, 0, 1);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(unfold)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
int dimension = luaL_checkint(L, 2)-1;
|
|
long size = luaL_checklong(L, 3);
|
|
long step = luaL_checklong(L, 4);
|
|
|
|
/*
|
|
THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor");
|
|
THArgCheck(dimension < src->nDimension, 2, "out of range");
|
|
THArgCheck(size <= src->size[dimension], 3, "out of range");
|
|
*/
|
|
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(unfold)(tensor, NULL, dimension, size, step);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
/* is contiguous? [a bit like in TnXIterator] */
|
|
static int torch_Tensor_(isContiguous)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
lua_pushboolean(L, THTensor_(isContiguous)(tensor));
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(nElement)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
lua_pushnumber(L, THTensor_(nElement)(tensor));
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(copy)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
void *src;
|
|
if( (src = luaT_toudata(L, 2, torch_Tensor_id)) )
|
|
THTensor_(copy)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_ByteTensor_id)) )
|
|
THTensor_(copyByte)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_CharTensor_id)) )
|
|
THTensor_(copyChar)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_ShortTensor_id)) )
|
|
THTensor_(copyShort)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_IntTensor_id)) )
|
|
THTensor_(copyInt)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_LongTensor_id)) )
|
|
THTensor_(copyLong)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_FloatTensor_id)) )
|
|
THTensor_(copyFloat)(tensor, src);
|
|
else if( (src = luaT_toudata(L, 2, torch_DoubleTensor_id)) )
|
|
THTensor_(copyDouble)(tensor, src);
|
|
else
|
|
luaL_typerror(L, 2, "torch.*Tensor");
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(__newindex__)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THLongStorage *idx = NULL;
|
|
|
|
if(lua_isnumber(L, 2))
|
|
{
|
|
long index = luaL_checklong(L,2)-1;
|
|
void *src;
|
|
if (lua_isnumber(L,3)) {
|
|
real value = (real)luaL_checknumber(L,3);
|
|
luaL_argcheck(L, tensor->nDimension == 1, 1, "must be a one dimensional tensor");
|
|
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
|
|
THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value);
|
|
} else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copy)(tensor, src);
|
|
} else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copyByte)(tensor, src);
|
|
} else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copyChar)(tensor, src);
|
|
} else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copyShort)(tensor, src);
|
|
} else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copyInt)(tensor, src);
|
|
} else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copyLong)(tensor, src);
|
|
} else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) {
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
|
THTensor_(copyFloat)(tensor, src);
|
|
} else {
|
|
luaL_typerror(L, 3, "torch.*Tensor");
|
|
}
|
|
lua_pushboolean(L, 1);
|
|
}
|
|
else if((idx = luaT_toudata(L, 2, torch_LongStorage_id)))
|
|
{
|
|
long index = THTensor_(storageOffset)(tensor);
|
|
real value = (real)luaL_checknumber(L,3);
|
|
int dim;
|
|
|
|
luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size");
|
|
|
|
for(dim = 0; dim < idx->size; dim++)
|
|
{
|
|
long z = idx->data[dim]-1;
|
|
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
|
index += z*tensor->stride[dim];
|
|
}
|
|
|
|
THStorage_(set)(tensor->storage, index, value);
|
|
lua_pushboolean(L, 1);
|
|
}
|
|
else if(lua_istable(L, 2))
|
|
{
|
|
long index = THTensor_(storageOffset)(tensor);
|
|
real value = (real)luaL_checknumber(L,3);
|
|
int dim;
|
|
|
|
luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size");
|
|
|
|
for(dim = 0; dim < tensor->nDimension; dim++)
|
|
{
|
|
long z;
|
|
|
|
lua_rawgeti(L, 2, dim+1);
|
|
if(!lua_isnumber(L, -1))
|
|
luaL_error(L, "number expected for each dimension");
|
|
|
|
z = lua_tonumber(L, -1)-1;
|
|
lua_pop(L, 1);
|
|
|
|
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
|
index += z*tensor->stride[dim];
|
|
}
|
|
THStorage_(set)(tensor->storage, index, value);
|
|
lua_pushboolean(L, 1);
|
|
}
|
|
else
|
|
lua_pushboolean(L, 0);
|
|
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(__index__)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THLongStorage *idx = NULL;
|
|
|
|
if(lua_isnumber(L, 2))
|
|
{
|
|
long index = luaL_checklong(L,2)-1;
|
|
|
|
luaL_argcheck(L, tensor->nDimension > 0, 1, "empty tensor");
|
|
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
|
|
|
|
if(tensor->nDimension == 1)
|
|
{
|
|
lua_pushnumber(L, THStorage_(get)(tensor->storage, tensor->storageOffset+index*tensor->stride[0]));
|
|
}
|
|
else
|
|
{
|
|
tensor = THTensor_(newWithTensor)(tensor);
|
|
THTensor_(select)(tensor, NULL, 0, index);
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
}
|
|
lua_pushboolean(L, 1);
|
|
return 2;
|
|
}
|
|
else if((idx = luaT_toudata(L, 2, torch_LongStorage_id)))
|
|
{
|
|
long index = THTensor_(storageOffset)(tensor);
|
|
int dim;
|
|
|
|
luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size");
|
|
|
|
for(dim = 0; dim < idx->size; dim++)
|
|
{
|
|
long z = idx->data[dim]-1;
|
|
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
|
index += z*tensor->stride[dim];
|
|
}
|
|
lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index));
|
|
lua_pushboolean(L, 1);
|
|
return 2;
|
|
}
|
|
else if(lua_istable(L, 2))
|
|
{
|
|
long index = THTensor_(storageOffset)(tensor);
|
|
int dim;
|
|
|
|
luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size");
|
|
|
|
for(dim = 0; dim < tensor->nDimension; dim++)
|
|
{
|
|
long z;
|
|
|
|
lua_rawgeti(L, 2, dim+1);
|
|
if(!lua_isnumber(L, -1))
|
|
luaL_error(L, "number expected for each dimension");
|
|
|
|
z = lua_tonumber(L, -1)-1;
|
|
lua_pop(L, 1);
|
|
|
|
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
|
index += z*tensor->stride[dim];
|
|
}
|
|
lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index));
|
|
lua_pushboolean(L, 1);
|
|
return 2;
|
|
}
|
|
else
|
|
{
|
|
lua_pushboolean(L, 0);
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
static int torch_Tensor_(free)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THTensor_(free)(tensor);
|
|
return 0;
|
|
}
|
|
|
|
/* helpful functions */
|
|
static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_)
|
|
{
|
|
THLongStorage *size = NULL;
|
|
THLongStorage *stride = NULL;
|
|
|
|
if( (size = luaT_toudata(L, index, torch_LongStorage_id)) )
|
|
{
|
|
if(!lua_isnoneornil(L, index+1))
|
|
{
|
|
if( (stride = luaT_toudata(L, index+1, torch_LongStorage_id)) )
|
|
luaL_argcheck(L, stride->size == size->size, index+1, "provided stride and size are inconsistent");
|
|
else
|
|
luaL_argcheck(L, 0, index+1, "torch.LongStorage expected");
|
|
}
|
|
THLongStorage_retain(size);
|
|
if(stride)
|
|
THLongStorage_retain(stride);
|
|
}
|
|
else
|
|
{
|
|
int i;
|
|
|
|
size = THLongStorage_newWithSize(8);
|
|
stride = THLongStorage_newWithSize(8);
|
|
THLongStorage_fill(size, -1);
|
|
THLongStorage_fill(stride, -1);
|
|
|
|
if(allowStride)
|
|
{
|
|
for(i = 0; i < 8; i++)
|
|
{
|
|
if(lua_isnone(L, index+2*i))
|
|
break;
|
|
size->data[i] = luaL_checklong(L, index+2*i);
|
|
|
|
if(lua_isnone(L, index+2*i+1))
|
|
break;
|
|
stride->data[i] = luaL_checklong(L, index+2*i+1);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for(i = 0; i < 8; i++)
|
|
{
|
|
if(lua_isnone(L, index+i))
|
|
break;
|
|
size->data[i] = luaL_checklong(L, index+i);
|
|
}
|
|
}
|
|
}
|
|
|
|
*size_ = size;
|
|
*stride_ = stride;
|
|
}
|
|
|
|
static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride,
|
|
THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_)
|
|
{
|
|
static char errMsg[64];
|
|
THTensor *src = NULL;
|
|
THStorage *storage = NULL;
|
|
|
|
int arg1Type = lua_type(L, index);
|
|
|
|
if( allowNone && (arg1Type == LUA_TNONE) )
|
|
{
|
|
*storage_ = NULL;
|
|
*storageOffset_ = 0;
|
|
*size_ = NULL;
|
|
*stride_ = NULL;
|
|
return;
|
|
}
|
|
else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor_id)) )
|
|
{
|
|
*storage_ = src->storage;
|
|
*storageOffset_ = src->storageOffset;
|
|
*size_ = THTensor_(newSizeOf)(src);
|
|
*stride_ = THTensor_(newStrideOf)(src);
|
|
return;
|
|
}
|
|
else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage_id)) )
|
|
{
|
|
*storage_ = storage;
|
|
if(lua_isnone(L, index+1))
|
|
{
|
|
*storageOffset_ = 0;
|
|
*size_ = THLongStorage_newWithSize1(storage->size);
|
|
*stride_ = THLongStorage_newWithSize1(1);
|
|
}
|
|
else
|
|
{
|
|
*storageOffset_ = luaL_checklong(L, index+1)-1;
|
|
torch_Tensor_(c_readSizeStride)(L, index+2, allowStride, size_, stride_);
|
|
}
|
|
return;
|
|
}
|
|
else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, torch_LongStorage_id)) )
|
|
{
|
|
*storage_ = NULL;
|
|
*storageOffset_ = 0;
|
|
torch_Tensor_(c_readSizeStride)(L, index, 0, size_, stride_);
|
|
|
|
return;
|
|
}
|
|
|
|
*storage_ = NULL;
|
|
*storageOffset_ = 0;
|
|
|
|
sprintf(errMsg, "expecting number%s%s", (allowTensor ? " or Tensor" : ""), (allowStorage ? " or Storage" : ""));
|
|
luaL_argcheck(L, 0, index, errMsg);
|
|
}
|
|
|
|
static int torch_Tensor_(apply)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
luaL_checktype(L, 2, LUA_TFUNCTION);
|
|
lua_settop(L, 2);
|
|
|
|
TH_TENSOR_APPLY(real, tensor,
|
|
lua_pushvalue(L, 2);
|
|
lua_pushnumber(L, *tensor_data);
|
|
lua_call(L, 1, 1);
|
|
if(lua_isnumber(L, 3))
|
|
{
|
|
*tensor_data = (real)lua_tonumber(L, 3);
|
|
lua_pop(L, 1);
|
|
}
|
|
else if(lua_isnil(L, 3))
|
|
lua_pop(L, 1);
|
|
else
|
|
luaL_error(L, "given function should return a number or nil"););
|
|
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(map)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id);
|
|
luaL_checktype(L, 3, LUA_TFUNCTION);
|
|
lua_settop(L, 3);
|
|
|
|
TH_TENSOR_APPLY2(real, tensor, real, src,
|
|
lua_pushvalue(L, 3);
|
|
lua_pushnumber(L, *tensor_data);
|
|
lua_pushnumber(L, *src_data);
|
|
lua_call(L, 2, 1);
|
|
if(lua_isnumber(L, 4))
|
|
{
|
|
*tensor_data = (real)lua_tonumber(L, 4);
|
|
lua_pop(L, 1);
|
|
}
|
|
else if(lua_isnil(L, 4))
|
|
lua_pop(L, 1);
|
|
else
|
|
luaL_error(L, "given function should return a number or nil"););
|
|
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(map2)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor_id);
|
|
THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor_id);
|
|
luaL_checktype(L, 4, LUA_TFUNCTION);
|
|
lua_settop(L, 4);
|
|
|
|
TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2,
|
|
lua_pushvalue(L, 4);
|
|
lua_pushnumber(L, *tensor_data);
|
|
lua_pushnumber(L, *src1_data);
|
|
lua_pushnumber(L, *src2_data);
|
|
lua_call(L, 3, 1);
|
|
if(lua_isnumber(L, 5))
|
|
{
|
|
*tensor_data = (real)lua_tonumber(L, 5);
|
|
lua_pop(L, 1);
|
|
}
|
|
else if(lua_isnil(L, 5))
|
|
lua_pop(L, 1);
|
|
else
|
|
luaL_error(L, "given function should return a number or nothing"););
|
|
|
|
lua_settop(L, 1);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(factory)(lua_State *L)
|
|
{
|
|
THTensor *tensor = THTensor_(new)();
|
|
luaT_pushudata(L, tensor, torch_Tensor_id);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_Tensor_(write)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THFile *file = luaT_checkudata(L, 2, torch_File_id);
|
|
|
|
THFile_writeIntScalar(file, tensor->nDimension);
|
|
THFile_writeLongRaw(file, tensor->size, tensor->nDimension);
|
|
THFile_writeLongRaw(file, tensor->stride, tensor->nDimension);
|
|
THFile_writeLongScalar(file, tensor->storageOffset+1); /* to respect Lua convention */
|
|
|
|
lua_getfield(L, 2, "writeObject"); /* the method */
|
|
lua_pushvalue(L, 2); /* the file */
|
|
/* the storage */
|
|
if(tensor->storage)
|
|
{
|
|
THStorage_(retain)(tensor->storage);
|
|
luaT_pushudata(L, tensor->storage, torch_Storage_id);
|
|
}
|
|
else
|
|
lua_pushnil(L);
|
|
|
|
lua_call(L, 2, 0); /* call the method */
|
|
|
|
return 0;
|
|
}
|
|
|
|
static int torch_Tensor_(read)(lua_State *L)
|
|
{
|
|
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
|
THFile *file = luaT_checkudata(L, 2, torch_File_id);
|
|
|
|
tensor->nDimension = THFile_readIntScalar(file);
|
|
tensor->size = THAlloc(sizeof(long)*tensor->nDimension);
|
|
tensor->stride = THAlloc(sizeof(long)*tensor->nDimension);
|
|
THFile_readLongRaw(file, tensor->size, tensor->nDimension);
|
|
THFile_readLongRaw(file, tensor->stride, tensor->nDimension);
|
|
tensor->storageOffset = THFile_readLongScalar(file);
|
|
tensor->storageOffset--; /* to respect Lua convention */
|
|
|
|
lua_getfield(L, 2, "readObject"); /* the method */
|
|
lua_pushvalue(L, 2); /* the file */
|
|
lua_call(L, 1, 1); /* call the method */
|
|
|
|
tensor->storage = luaT_toudata(L, -1, torch_Storage_id);
|
|
if(tensor->storage)
|
|
THStorage_(retain)(tensor->storage);
|
|
|
|
return 0;
|
|
}
|
|
|
|
static const struct luaL_Reg torch_Tensor_(_) [] = {
|
|
{"contiguous", torch_Tensor_(contiguous)},
|
|
{"size", torch_Tensor_(size)},
|
|
{"__len__", torch_Tensor_(size)},
|
|
{"stride", torch_Tensor_(stride)},
|
|
{"dim", torch_Tensor_(nDimension)},
|
|
{"nDimension", torch_Tensor_(nDimension)},
|
|
{"set", torch_Tensor_(set)},
|
|
{"storage", torch_Tensor_(storage)},
|
|
{"storageOffset", torch_Tensor_(storageOffset)},
|
|
{"clone", torch_Tensor_(clone)},
|
|
{"contiguous", torch_Tensor_(contiguous)},
|
|
{"resizeAs", torch_Tensor_(resizeAs)},
|
|
{"resize", torch_Tensor_(resize)},
|
|
{"narrow", torch_Tensor_(narrow)},
|
|
{"sub", torch_Tensor_(sub)},
|
|
{"select", torch_Tensor_(select)},
|
|
{"transpose", torch_Tensor_(transpose)},
|
|
{"t", torch_Tensor_(t)},
|
|
{"unfold", torch_Tensor_(unfold)},
|
|
{"isContiguous", torch_Tensor_(isContiguous)},
|
|
{"nElement", torch_Tensor_(nElement)},
|
|
{"copy", torch_Tensor_(copy)},
|
|
{"apply", torch_Tensor_(apply)},
|
|
{"map", torch_Tensor_(map)},
|
|
{"map2", torch_Tensor_(map2)},
|
|
{"read", torch_Tensor_(read)},
|
|
{"write", torch_Tensor_(write)},
|
|
{"__index__", torch_Tensor_(__index__)},
|
|
{"__newindex__", torch_Tensor_(__newindex__)},
|
|
{NULL, NULL}
|
|
};
|
|
|
|
void torch_Tensor_(init)(lua_State *L)
|
|
{
|
|
torch_File_id = luaT_checktypename2id(L, "torch.File");
|
|
torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
|
torch_Storage_id = luaT_checktypename2id(L, STRING_torchStorage);
|
|
|
|
torch_Tensor_id = luaT_newmetatable(L, STRING_torchTensor, NULL,
|
|
torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
|
|
luaL_register(L, NULL, torch_Tensor_(_));
|
|
lua_pop(L, 1);
|
|
}
|
|
|
|
#endif
|