mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
136 lines
3.1 KiB
C
136 lines
3.1 KiB
C
#include "general.h"
|
|
#include "utils.h"
|
|
|
|
#include <sys/time.h>
|
|
|
|
static const void* torch_LongStorage_id = NULL;
|
|
static const void* torch_default_tensor_id = NULL;
|
|
|
|
THLongStorage* torch_checklongargs(lua_State *L, int index)
|
|
{
|
|
THLongStorage *storage;
|
|
int i;
|
|
int narg = lua_gettop(L)-index+1;
|
|
|
|
if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id))
|
|
{
|
|
THLongStorage *storagesrc = luaT_toudata(L, index, torch_LongStorage_id);
|
|
storage = THLongStorage_newWithSize(storagesrc->size);
|
|
THLongStorage_copy(storage, storagesrc);
|
|
}
|
|
else
|
|
{
|
|
storage = THLongStorage_newWithSize(narg);
|
|
for(i = index; i < index+narg; i++)
|
|
{
|
|
if(!lua_isnumber(L, i))
|
|
{
|
|
THLongStorage_free(storage);
|
|
luaL_argerror(L, i, "number expected");
|
|
}
|
|
THLongStorage_set(storage, i-index, lua_tonumber(L, i));
|
|
}
|
|
}
|
|
return storage;
|
|
}
|
|
|
|
int torch_islongargs(lua_State *L, int index)
|
|
{
|
|
int narg = lua_gettop(L)-index+1;
|
|
|
|
if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id))
|
|
{
|
|
return 1;
|
|
}
|
|
else
|
|
{
|
|
int i;
|
|
|
|
for(i = index; i < index+narg; i++)
|
|
{
|
|
if(!lua_isnumber(L, i))
|
|
return 0;
|
|
}
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
|
|
|
|
static int torch_lua_tic(lua_State* L)
|
|
{
|
|
struct timeval tv;
|
|
gettimeofday(&tv,NULL);
|
|
double ttime = (double)tv.tv_sec + (double)(tv.tv_usec)/1000000.0;
|
|
lua_pushnumber(L,ttime);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_lua_toc(lua_State* L)
|
|
{
|
|
struct timeval tv;
|
|
gettimeofday(&tv,NULL);
|
|
double toctime = (double)tv.tv_sec + (double)(tv.tv_usec)/1000000.0;
|
|
lua_Number tictime = luaL_checknumber(L,1);
|
|
lua_pushnumber(L,toctime-tictime);
|
|
return 1;
|
|
}
|
|
|
|
static int torch_lua_setdefaulttensortype(lua_State *L)
|
|
{
|
|
const void *id;
|
|
|
|
luaL_checkstring(L, 1);
|
|
|
|
if(!(id = luaT_typename2id(L, lua_tostring(L, 1)))) \
|
|
return luaL_error(L, "<%s> is not a string describing a torch object", lua_tostring(L, 1)); \
|
|
|
|
torch_default_tensor_id = id;
|
|
|
|
return 0;
|
|
}
|
|
|
|
static int torch_lua_getdefaulttensortype(lua_State *L)
|
|
{
|
|
lua_pushstring(L, luaT_id2typename(L, torch_default_tensor_id));
|
|
return 1;
|
|
}
|
|
|
|
void torch_setdefaulttensorid(const void* id)
|
|
{
|
|
torch_default_tensor_id = id;
|
|
}
|
|
|
|
const void* torch_getdefaulttensorid()
|
|
{
|
|
return torch_default_tensor_id;
|
|
}
|
|
|
|
static const struct luaL_Reg torch_utils__ [] = {
|
|
{"__setdefaulttensortype", torch_lua_setdefaulttensortype},
|
|
{"getdefaulttensortype", torch_lua_getdefaulttensortype},
|
|
{"tic", torch_lua_tic},
|
|
{"toc", torch_lua_toc},
|
|
{"factory", luaT_lua_factory},
|
|
{"getconstructortable", luaT_lua_getconstructortable},
|
|
{"id", luaT_lua_id},
|
|
{"typename", luaT_lua_typename},
|
|
{"typename2id", luaT_lua_typename2id},
|
|
{"isequal", luaT_lua_isequal},
|
|
{"getenv", luaT_lua_getenv},
|
|
{"setenv", luaT_lua_setenv},
|
|
{"newmetatable", luaT_lua_newmetatable},
|
|
{"setmetatable", luaT_lua_setmetatable},
|
|
{"getmetatable", luaT_lua_getmetatable},
|
|
{"version", luaT_lua_version},
|
|
{"pointer", luaT_lua_pointer},
|
|
{NULL, NULL}
|
|
};
|
|
|
|
void torch_utils_init(lua_State *L)
|
|
{
|
|
torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
|
luaL_register(L, NULL, torch_utils__);
|
|
}
|