mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	initial revamp of torch7 tree
This commit is contained in:
		
							
								
								
									
										18
									
								
								CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| SET(src DiskFile.c File.c MemoryFile.c PipeFile.c Storage.c Tensor.c Timer.c utils.c init.c TensorOperator.c TensorMathWrap.c random.c) | ||||
| SET(luasrc init.lua File.lua Tensor.lua TensorMath.lua CmdLine.lua Tester.lua torch.lua test/test.lua) | ||||
|    | ||||
| # Necessary do generate wrapper | ||||
| ADD_TORCH_WRAP(tensormathwrap TensorMathWrap.lua) | ||||
| ADD_TORCH_WRAP(randomwrap random.lua) | ||||
|  | ||||
| ADD_TORCH_PACKAGE(torch "${src}" "${luasrc}" "Basics") | ||||
| ADD_TORCH_DOK(dok torch "Fundamentals" "Torch package" 1.1) | ||||
|  | ||||
| TARGET_LINK_LIBRARIES(torch luaT TH) | ||||
|  | ||||
| CONFIGURE_FILE(torch.in "${Torch_BINARY_DIR}/torch") | ||||
| INSTALL(FILES "${Torch_BINARY_DIR}/torch" | ||||
|         DESTINATION "${Torch_INSTALL_BIN_SUBDIR}" | ||||
|         PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ | ||||
|                     GROUP_EXECUTE GROUP_READ | ||||
|                     WORLD_EXECUTE WORLD_READ) | ||||
							
								
								
									
										244
									
								
								CmdLine.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										244
									
								
								CmdLine.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,244 @@ | ||||
| local CmdLine = torch.class('torch.CmdLine') | ||||
|  | ||||
| local function strip(str) | ||||
|    return string.match(str, '%-*(.*)') | ||||
| end | ||||
|  | ||||
| local function pad(str, sz) | ||||
|    return str .. string.rep(' ', sz-#str) | ||||
| end | ||||
|  | ||||
| function CmdLine:error(msg) | ||||
|    print('') | ||||
|    print(msg) | ||||
|    print('') | ||||
|    self:help() | ||||
|    os.exit(0) | ||||
| end | ||||
|  | ||||
| function CmdLine:__readArgument__(params, arg, i, nArgument) | ||||
|    local argument = self.arguments[nArgument] | ||||
|    local value = arg[i] | ||||
|  | ||||
|    if nArgument > #self.arguments then | ||||
|       self:error('invalid argument: ' .. value) | ||||
|    end | ||||
|    if argument.type and type(value) ~= argument.type then | ||||
|       self:error('invalid argument type for argument ' .. argument.key .. ' (should be ' .. argument.type .. ')') | ||||
|    end | ||||
|    params[strip(argument.key)] = value | ||||
|    return 1 | ||||
| end | ||||
|  | ||||
| function CmdLine:__readOption__(params, arg, i) | ||||
|    local key = arg[i] | ||||
|    local option = self.options[key] | ||||
|    if not option then | ||||
|       self:error('unknown option ' .. key) | ||||
|    end | ||||
|  | ||||
|    if option.type and option.type == 'boolean' then | ||||
|       params[strip(key)] = not option.default | ||||
|       return 1 | ||||
|    else | ||||
|       local value = arg[i+1] | ||||
|       if not value then | ||||
|          self:error('missing argument for option ' .. key) | ||||
|       end | ||||
|       if not option.type or option.type == 'string' then | ||||
|       elseif option.type == 'number' then | ||||
|          value = tonumber(value) | ||||
|       else | ||||
|          self:error('unknown required option type ' .. option.type) | ||||
|       end | ||||
|       if not value then | ||||
|          self:error('invalid type for option ' .. key .. ' (should be ' .. option.type .. ')') | ||||
|       end | ||||
|       params[strip(key)] = value | ||||
|       return 2 | ||||
|    end | ||||
| end | ||||
|  | ||||
| function CmdLine:__init(argseparator_,keyseparator_) | ||||
|    self.argseparator = argseparator_ or ',' | ||||
|    self.keyseparator = keyseparator_ or '=' | ||||
|    self.options = {} | ||||
|    self.arguments = {} | ||||
|    self.helplines = {} | ||||
| end | ||||
|  | ||||
| function CmdLine:argument(key, help, _type_) | ||||
|    table.insert(self.arguments, {key=key, help=help, type=_type_}) | ||||
|    table.insert(self.helplines, self.arguments[#self.arguments]) | ||||
| end | ||||
|  | ||||
| function CmdLine:option(key, default, help, _type_) | ||||
|    if default == nil then | ||||
|       error('option ' .. key .. ' has no default value') | ||||
|    end | ||||
|    _type_ = _type_ or type(default) | ||||
|    if type(default) ~= _type_ then | ||||
|       error('option ' .. key .. ' has wrong default type value') | ||||
|    end | ||||
|    self.options[key] = {key=key, default=default, help=help, type=_type_} | ||||
|    table.insert(self.helplines, self.options[key]) | ||||
| end | ||||
|  | ||||
| function CmdLine:default() | ||||
|    local params = {} | ||||
|    for option,v in pairs(self.options) do | ||||
|       params[strip(option)] = v.default | ||||
|    end | ||||
|    return params | ||||
| end | ||||
|  | ||||
| function CmdLine:parse(arg) | ||||
|    local i = 1 | ||||
|    local params = self:default() | ||||
|  | ||||
|    local nArgument = 0 | ||||
|  | ||||
|    while i <= #arg do | ||||
|       if arg[i] == '-help' or arg[i] == '-h' or arg[i] == '--help' then | ||||
|          self:help(arg) | ||||
|          os.exit(0) | ||||
|       end | ||||
|  | ||||
|       if self.options[arg[i]] then | ||||
|          i = i + self:__readOption__(params, arg, i) | ||||
|       else | ||||
|          nArgument = nArgument + 1 | ||||
|          i = i + self:__readArgument__(params, arg, i, nArgument) | ||||
|       end | ||||
|    end | ||||
|  | ||||
|    if nArgument ~= #self.arguments then | ||||
|       self:error('not enough arguments') | ||||
|    end | ||||
|  | ||||
|    return params | ||||
| end | ||||
|  | ||||
| function CmdLine:string(prefix, params, ignore) | ||||
|    local arguments = {} | ||||
|    local options = {} | ||||
|    prefix = prefix or '' | ||||
|  | ||||
|    for k,v in pairs(params) do | ||||
|       if ignore[k] then | ||||
|          print('-- ignore option ' .. k) | ||||
|       elseif self.options['-' .. k] then | ||||
|          if v ~= self.options['-' .. k].default then | ||||
|             if type(v) == 'boolean' then | ||||
|                if v then | ||||
|                   v = 't' | ||||
|                else | ||||
|                   v = 'f' | ||||
|                end | ||||
|             end | ||||
|             table.insert(options, k .. self.keyseparator .. v) | ||||
|             print(k,v,self.options['-' .. k].default) | ||||
|         end | ||||
|        else | ||||
|          local narg | ||||
|          for i=1,#self.arguments do | ||||
|             if strip(self.arguments[i].key) == k then | ||||
|                narg = i | ||||
|             end | ||||
|          end | ||||
|          if narg then | ||||
|             arguments[narg] = k .. self.keyseparator .. v | ||||
|          else | ||||
|             print('WARNING: unknown option/argument: ' .. k .. ' IGNORING for DIRECTORY NAME') | ||||
|          end | ||||
|       end | ||||
|    end | ||||
|    table.sort(options) | ||||
|    local str = table.concat(arguments, self.argseparator) | ||||
|    if str == '' then | ||||
|       str = table.concat(options, self.argseparator) | ||||
|    else | ||||
|       str = str .. self.argseparator .. table.concat(options, self.argseparator) | ||||
|    end | ||||
|    if str == '' then | ||||
|       return prefix | ||||
|    else | ||||
|       return prefix .. self.argseparator .. str | ||||
|    end | ||||
| end | ||||
|  | ||||
| function CmdLine:log(file, params)    | ||||
|    local f = io.open(file, 'w') | ||||
|    local oprint = print | ||||
|    function print(...) | ||||
|       local n = select("#", ...) | ||||
|       oprint(...) | ||||
|       for i=1,n do | ||||
|          f:write(tostring(select(i, ...))) | ||||
|          if i ~= n then | ||||
|             f:write(' ') | ||||
|          else | ||||
|             f:write('\n') | ||||
|          end | ||||
|       end | ||||
|       f:flush() | ||||
|    end | ||||
|    print('[program started on ' .. os.date() .. ']') | ||||
|    print('[command line arguments]') | ||||
|    if params then | ||||
|       for k,v in pairs(params) do | ||||
|          print(k,v) | ||||
|       end | ||||
|    end | ||||
|    print('[----------------------]')    | ||||
| end | ||||
|  | ||||
| function CmdLine:text(txt) | ||||
|    txt = txt or '' | ||||
|    assert(type(txt) == 'string') | ||||
|    table.insert(self.helplines, txt) | ||||
| end | ||||
|  | ||||
| function CmdLine:help(arg) | ||||
|    io.write('Usage: ') | ||||
|    if arg then io.write(arg[0] .. ' ') end | ||||
|    io.write('[options] ') | ||||
|    for i=1,#self.arguments do | ||||
|       io.write('<' .. strip(self.arguments[i].key) .. '>') | ||||
|    end | ||||
|    io.write('\n') | ||||
|  | ||||
|    -- first pass to compute max length | ||||
|    local optsz = 0 | ||||
|    for _,option in ipairs(self.helplines) do | ||||
|       if type(option) == 'table' then | ||||
|          if option.default ~= nil then -- it is an option | ||||
|             if #option.key > optsz then | ||||
|                optsz = #option.key | ||||
|             end | ||||
|          else -- it is an argument | ||||
|             if #strip(option.key)+2 > optsz then | ||||
|                optsz = #strip(option.key)+2 | ||||
|             end | ||||
|          end | ||||
|       end | ||||
|    end | ||||
|  | ||||
|    -- second pass to print | ||||
|    for _,option in ipairs(self.helplines) do | ||||
|       if type(option) == 'table' then | ||||
|          io.write('  ') | ||||
|          if option.default ~= nil then -- it is an option | ||||
|             io.write(pad(option.key, optsz)) | ||||
|             if option.help then io.write(' ' .. option.help) end | ||||
|             io.write(' [' .. tostring(option.default) .. ']') | ||||
|          else -- it is an argument | ||||
|             io.write(pad('<' .. strip(option.key) .. '>', optsz)) | ||||
|             if option.help then io.write(' ' .. option.help) end | ||||
|          end | ||||
|       else | ||||
|          io.write(option) -- just some additional help | ||||
|       end | ||||
|       io.write('\n') | ||||
|    end | ||||
| end | ||||
							
								
								
									
										87
									
								
								DiskFile.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								DiskFile.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,87 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| static const void* torch_DiskFile_id = NULL; | ||||
|  | ||||
| static int torch_DiskFile_new(lua_State *L) | ||||
| { | ||||
|   const char *name = luaL_checkstring(L, 1); | ||||
|   const char *mode = luaL_optstring(L, 2, "r"); | ||||
|   int isQuiet = luaT_optboolean(L, 3, 0); | ||||
|   THFile *self = THDiskFile_new(name, mode, isQuiet); | ||||
|  | ||||
|   luaT_pushudata(L, self, torch_DiskFile_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile_free(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); | ||||
|   THFile_free(self); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile_isLittleEndianCPU(lua_State *L) | ||||
| { | ||||
|   lua_pushboolean(L, THDiskFile_isLittleEndianCPU()); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile_isBigEndianCPU(lua_State *L) | ||||
| { | ||||
|   lua_pushboolean(L, !THDiskFile_isLittleEndianCPU()); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile_nativeEndianEncoding(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); | ||||
|   THDiskFile_nativeEndianEncoding(self); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile_littleEndianEncoding(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); | ||||
|   THDiskFile_littleEndianEncoding(self); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile_bigEndianEncoding(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); | ||||
|   THDiskFile_bigEndianEncoding(self); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_DiskFile___tostring__(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); | ||||
|   lua_pushfstring(L, "torch.DiskFile on <%s> [status: %s -- mode %c%c]",  | ||||
|                   THDiskFile_name(self), | ||||
|                   (THFile_isOpened(self) ? "open" : "closed"), | ||||
|                   (THFile_isReadable(self) ? 'r' : ' '), | ||||
|                   (THFile_isWritable(self) ? 'w' : ' ')); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
| static const struct luaL_Reg torch_DiskFile__ [] = { | ||||
|   {"isLittleEndianCPU", torch_DiskFile_isLittleEndianCPU}, | ||||
|   {"isBigEndianCPU", torch_DiskFile_isBigEndianCPU}, | ||||
|   {"nativeEndianEncoding", torch_DiskFile_nativeEndianEncoding}, | ||||
|   {"littleEndianEncoding", torch_DiskFile_littleEndianEncoding}, | ||||
|   {"bigEndianEncoding", torch_DiskFile_bigEndianEncoding}, | ||||
|   {"__tostring__", torch_DiskFile___tostring__}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_DiskFile_init(lua_State *L) | ||||
| { | ||||
|   torch_DiskFile_id = luaT_newmetatable(L, "torch.DiskFile", "torch.File", | ||||
|                                         torch_DiskFile_new, torch_DiskFile_free, NULL); | ||||
|    | ||||
|   luaL_register(L, NULL, torch_DiskFile__); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
							
								
								
									
										225
									
								
								File.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								File.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,225 @@ | ||||
| #include "THFile.h" | ||||
| #include "luaT.h" | ||||
|  | ||||
| static const void *torch_File_id = NULL; | ||||
| static const void *torch_ByteStorage_id = NULL; | ||||
| static const void *torch_CharStorage_id = NULL; | ||||
| static const void *torch_ShortStorage_id = NULL; | ||||
| static const void *torch_IntStorage_id = NULL; | ||||
| static const void *torch_LongStorage_id = NULL; | ||||
| static const void *torch_FloatStorage_id = NULL; | ||||
| static const void *torch_DoubleStorage_id = NULL; | ||||
|  | ||||
| #define IMPLEMENT_TORCH_FILE_FLAG(NAME)                   \ | ||||
|   static int torch_File_##NAME(lua_State *L)              \ | ||||
|   {                                                       \ | ||||
|     THFile *self = luaT_checkudata(L, 1, torch_File_id);  \ | ||||
|     lua_pushboolean(L, THFile_##NAME(self));              \ | ||||
|     return 1;                                             \ | ||||
|   } | ||||
|  | ||||
| IMPLEMENT_TORCH_FILE_FLAG(isQuiet) | ||||
| IMPLEMENT_TORCH_FILE_FLAG(isReadable) | ||||
| IMPLEMENT_TORCH_FILE_FLAG(isWritable) | ||||
| IMPLEMENT_TORCH_FILE_FLAG(isBinary) | ||||
| IMPLEMENT_TORCH_FILE_FLAG(isAutoSpacing) | ||||
| IMPLEMENT_TORCH_FILE_FLAG(hasError) | ||||
|  | ||||
| #define IMPLEMENT_TORCH_FILE_FUNC(NAME)                   \ | ||||
|   static int torch_File_##NAME(lua_State *L)              \ | ||||
|   {                                                       \ | ||||
|     THFile *self = luaT_checkudata(L, 1, torch_File_id);  \ | ||||
|     THFile_##NAME(self);                                  \ | ||||
|     lua_settop(L, 1);                                     \ | ||||
|     return 1;                                             \ | ||||
|   } | ||||
|  | ||||
| IMPLEMENT_TORCH_FILE_FUNC(binary) | ||||
| IMPLEMENT_TORCH_FILE_FUNC(ascii) | ||||
| IMPLEMENT_TORCH_FILE_FUNC(autoSpacing) | ||||
| IMPLEMENT_TORCH_FILE_FUNC(noAutoSpacing) | ||||
| IMPLEMENT_TORCH_FILE_FUNC(quiet) | ||||
| IMPLEMENT_TORCH_FILE_FUNC(pedantic) | ||||
| IMPLEMENT_TORCH_FILE_FUNC(clearError) | ||||
|  | ||||
| IMPLEMENT_TORCH_FILE_FUNC(synchronize) | ||||
|  | ||||
| static int torch_File_seek(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_File_id); | ||||
|   long position = luaL_checklong(L, 2)-1; | ||||
|   THFile_seek(self, position); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| IMPLEMENT_TORCH_FILE_FUNC(seekEnd) | ||||
|  | ||||
| static int torch_File_position(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_File_id); | ||||
|   lua_pushnumber(L, THFile_position(self)+1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| IMPLEMENT_TORCH_FILE_FUNC(close) | ||||
|  | ||||
| #define IMPLEMENT_TORCH_FILE_RW(TYPEC, TYPE)                            \ | ||||
|   static int torch_File_read##TYPEC(lua_State *L)                       \ | ||||
|   {                                                                     \ | ||||
|     THFile *self = luaT_checkudata(L, 1, torch_File_id);                \ | ||||
|     int narg = lua_gettop(L);                                           \ | ||||
|                                                                         \ | ||||
|     if(narg == 1)                                                       \ | ||||
|     {                                                                   \ | ||||
|       lua_pushnumber(L, THFile_read##TYPEC##Scalar(self));              \ | ||||
|       return 1;                                                         \ | ||||
|     }                                                                   \ | ||||
|     else if(narg == 2)                                                  \ | ||||
|     {                                                                   \ | ||||
|       if(lua_isnumber(L, 2))                                            \ | ||||
|       {                                                                 \ | ||||
|         long size = lua_tonumber(L, 2);                                 \ | ||||
|         long nread;                                                     \ | ||||
|                                                                         \ | ||||
|         TH##TYPEC##Storage *storage = TH##TYPEC##Storage_newWithSize(size); \ | ||||
|         luaT_pushudata(L, storage, torch_##TYPEC##Storage_id);          \ | ||||
|         nread = THFile_read##TYPEC(self, storage);                      \ | ||||
|         if(nread != size)                                               \ | ||||
|           TH##TYPEC##Storage_resize(storage, size);                     \ | ||||
|         return 1;                                                       \ | ||||
|       }                                                                 \ | ||||
|       else if(luaT_toudata(L, 2, torch_##TYPEC##Storage_id))            \ | ||||
|       {                                                                 \ | ||||
|         TH##TYPEC##Storage *storage = luaT_toudata(L, 2, torch_##TYPEC##Storage_id); \ | ||||
|         lua_pushnumber(L, THFile_read##TYPEC(self, storage));           \ | ||||
|         return 1;                                                       \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     luaL_error(L, "nothing, number, or Storage expected");              \ | ||||
|     return 0;                                                           \ | ||||
|   }                                                                     \ | ||||
|                                                                         \ | ||||
|   static int torch_File_write##TYPEC(lua_State *L)                      \ | ||||
|   {                                                                     \ | ||||
|     THFile *self = luaT_checkudata(L, 1, torch_File_id);                \ | ||||
|     int narg = lua_gettop(L);                                           \ | ||||
|                                                                         \ | ||||
|     if(narg == 2)                                                       \ | ||||
|     {                                                                   \ | ||||
|       if(lua_isnumber(L, 2))                                            \ | ||||
|       {                                                                 \ | ||||
|         TYPE value = lua_tonumber(L, 2);                                \ | ||||
|         THFile_write##TYPEC##Scalar(self, (TYPE)value);                 \ | ||||
|         return 0;                                                       \ | ||||
|       }                                                                 \ | ||||
|       else if(luaT_toudata(L, 2, torch_##TYPEC##Storage_id))            \ | ||||
|       {                                                                 \ | ||||
|         TH##TYPEC##Storage *storage = luaT_toudata(L, 2, torch_##TYPEC##Storage_id); \ | ||||
|         lua_pushnumber(L, THFile_write##TYPEC(self, storage));          \ | ||||
|         return 1;                                                       \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     luaL_error(L, "number, or Storage expected");                       \ | ||||
|     return 0;                                                           \ | ||||
|   } | ||||
|  | ||||
|  | ||||
| IMPLEMENT_TORCH_FILE_RW(Byte, unsigned char) | ||||
| IMPLEMENT_TORCH_FILE_RW(Char, char) | ||||
| IMPLEMENT_TORCH_FILE_RW(Short, short) | ||||
| IMPLEMENT_TORCH_FILE_RW(Int, int) | ||||
| IMPLEMENT_TORCH_FILE_RW(Long, long) | ||||
| IMPLEMENT_TORCH_FILE_RW(Float, float) | ||||
| IMPLEMENT_TORCH_FILE_RW(Double, double) | ||||
|  | ||||
| static int torch_File_readString(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_File_id); | ||||
|   const char *format = luaL_checkstring(L, 2); | ||||
|   char *str; | ||||
|   long size; | ||||
|  | ||||
|   size = THFile_readStringRaw(self, format, &str); | ||||
|   lua_pushlstring(L, str, size); | ||||
|   THFree(str); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_File_writeString(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_File_id); | ||||
|   const char *str = NULL; | ||||
|   size_t size; | ||||
|   long nwrite; | ||||
|  | ||||
|   luaL_checktype(L, 2, LUA_TSTRING); | ||||
|   str = lua_tolstring(L, 2, &size); | ||||
|   lua_pushnumber(L, THFile_writeStringRaw(self, str, (long)size)); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_File__ [] = { | ||||
|   {"isQuiet", torch_File_isQuiet}, | ||||
|   {"isReadable", torch_File_isReadable}, | ||||
|   {"isWritable", torch_File_isWritable}, | ||||
|   {"isBinary", torch_File_isBinary}, | ||||
|   {"isAutoSpacing", torch_File_isAutoSpacing}, | ||||
|   {"hasError", torch_File_hasError}, | ||||
|   {"binary", torch_File_binary}, | ||||
|   {"ascii", torch_File_ascii}, | ||||
|   {"autoSpacing", torch_File_autoSpacing}, | ||||
|   {"noAutoSpacing", torch_File_noAutoSpacing}, | ||||
|   {"quiet", torch_File_quiet}, | ||||
|   {"pedantic", torch_File_pedantic}, | ||||
|   {"clearError", torch_File_clearError}, | ||||
|  | ||||
|   /* DEBUG: CHECK DISK FREE & READ/WRITE STRING*/ | ||||
|  | ||||
|   {"readByte", torch_File_readByte}, | ||||
|   {"readChar", torch_File_readChar}, | ||||
|   {"readShort", torch_File_readShort}, | ||||
|   {"readInt", torch_File_readInt}, | ||||
|   {"readLong", torch_File_readLong}, | ||||
|   {"readFloat", torch_File_readFloat}, | ||||
|   {"readDouble", torch_File_readDouble}, | ||||
|   {"readString", torch_File_readString}, | ||||
|  | ||||
|   {"writeByte", torch_File_writeByte}, | ||||
|   {"writeChar", torch_File_writeChar}, | ||||
|   {"writeShort", torch_File_writeShort}, | ||||
|   {"writeInt", torch_File_writeInt}, | ||||
|   {"writeLong", torch_File_writeLong}, | ||||
|   {"writeFloat", torch_File_writeFloat}, | ||||
|   {"writeDouble", torch_File_writeDouble}, | ||||
|   {"writeString", torch_File_writeString}, | ||||
|    | ||||
|   {"synchronize", torch_File_synchronize}, | ||||
|   {"seek", torch_File_seek}, | ||||
|   {"seekEnd", torch_File_seekEnd}, | ||||
|   {"position", torch_File_position}, | ||||
|   {"close", torch_File_close}, | ||||
|  | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_File_init(lua_State *L) | ||||
| { | ||||
|   torch_File_id = luaT_newmetatable(L, "torch.File", NULL, NULL, NULL, NULL); | ||||
|   luaL_register(L, NULL, torch_File__); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
|  | ||||
| void torch_File_init_storage_id(lua_State *L) | ||||
| { | ||||
|   torch_ByteStorage_id = luaT_checktypename2id(L, "torch.ByteStorage"); | ||||
|   torch_CharStorage_id = luaT_checktypename2id(L, "torch.CharStorage"); | ||||
|   torch_ShortStorage_id = luaT_checktypename2id(L, "torch.ShortStorage"); | ||||
|   torch_IntStorage_id = luaT_checktypename2id(L, "torch.IntStorage"); | ||||
|   torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); | ||||
|   torch_FloatStorage_id = luaT_checktypename2id(L, "torch.FloatStorage"); | ||||
|   torch_DoubleStorage_id = luaT_checktypename2id(L, "torch.DoubleStorage"); | ||||
| } | ||||
							
								
								
									
										240
									
								
								File.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								File.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,240 @@ | ||||
| local File = torch.getmetatable('torch.File') | ||||
|  | ||||
| function File:writeBool(value) | ||||
|    if value then | ||||
|       self:writeInt(1) | ||||
|    else | ||||
|       self:writeInt(0) | ||||
|    end | ||||
| end | ||||
|  | ||||
| function File:readBool() | ||||
|    return (self:readInt() == 1) | ||||
| end | ||||
|  | ||||
| local TYPE_NIL      = 0 | ||||
| local TYPE_NUMBER   = 1 | ||||
| local TYPE_STRING   = 2 | ||||
| local TYPE_TABLE    = 3 | ||||
| local TYPE_TORCH    = 4 | ||||
| local TYPE_BOOLEAN  = 5 | ||||
| local TYPE_FUNCTION = 6 | ||||
|  | ||||
| function File:isWritableObject(object) | ||||
|    local typename = type(object) | ||||
|    local typeidx | ||||
|    if type(object) ~= 'boolean' and not object then | ||||
|       typeidx = TYPE_NIL | ||||
|    elseif torch.typename(object) and torch.factory(torch.typename(object)) then | ||||
|       typeidx = TYPE_TORCH | ||||
|    elseif typename == 'table' then | ||||
|       typeidx = TYPE_TABLE | ||||
|    elseif typename == 'number' then | ||||
|       typeidx = TYPE_NUMBER | ||||
|    elseif typename == 'string' then | ||||
|       typeidx = TYPE_STRING | ||||
|    elseif typename == 'boolean' then | ||||
|       typeidx = TYPE_BOOLEAN | ||||
|    elseif typename == 'function' and pcall(string.dump, object) then | ||||
|       typeidx = TYPE_FUNCTION | ||||
|    end | ||||
|    return typeidx | ||||
| end | ||||
|  | ||||
| function File:writeObject(object) | ||||
|    -- we use an environment to keep a record of written objects | ||||
|    if not torch.getenv(self).writeObjects then | ||||
|       torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) | ||||
|    end | ||||
|  | ||||
|    -- if nil object, only write the type and return | ||||
|    if type(object) ~= 'boolean' and not object then | ||||
|       self:writeInt(TYPE_NIL) | ||||
|       return | ||||
|    end | ||||
|  | ||||
|    -- check the type we are dealing with | ||||
|    local typeidx = self:isWritableObject(object) | ||||
|    if not typeidx then | ||||
|       error(string.format('Unwritable object <%s>', type(object))) | ||||
|    end | ||||
|    self:writeInt(typeidx) | ||||
|  | ||||
|    if typeidx == TYPE_NUMBER then | ||||
|       self:writeDouble(object) | ||||
|    elseif typeidx == TYPE_BOOLEAN then | ||||
|       self:writeBool(object) | ||||
|    elseif typeidx == TYPE_STRING then | ||||
|       local stringStorage = torch.CharStorage():string(object) | ||||
|       self:writeInt(#stringStorage) | ||||
|       self:writeChar(stringStorage) | ||||
|    elseif typeidx == TYPE_FUNCTION then | ||||
|       local upvalues = {} | ||||
|       while true do | ||||
|          local name,value = debug.getupvalue(object, #upvalues+1) | ||||
|          if not name then break end | ||||
|          table.insert(upvalues, value) | ||||
|       end | ||||
|       local dumped = string.dump(object) | ||||
|       local stringStorage = torch.CharStorage():string(dumped) | ||||
|       self:writeInt(#stringStorage) | ||||
|       self:writeChar(stringStorage) | ||||
|       self:writeObject(upvalues) | ||||
|    elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then | ||||
|       -- check it exists already (we look at the pointer!) | ||||
|       local objects = torch.getenv(self).writeObjects | ||||
|       local objectsRef = torch.getenv(self).writeObjectsRef | ||||
|       local index = objects[torch.pointer(object)] | ||||
|  | ||||
|       if index then | ||||
|          -- if already exists, write only its index | ||||
|          self:writeInt(index) | ||||
|       else | ||||
|          -- else write the object itself | ||||
|          index = objects.nWriteObject or 0 | ||||
|          index = index + 1 | ||||
|          objects[torch.pointer(object)] = index | ||||
|          objectsRef[object] = index -- we make sure the object is not going to disappear | ||||
|          self:writeInt(index) | ||||
|          objects.nWriteObject = index | ||||
|  | ||||
|          if typeidx == TYPE_TORCH then | ||||
|             local version   = torch.CharStorage():string('V ' .. torch.version(object)) | ||||
|             local className = torch.CharStorage():string(torch.typename(object)) | ||||
|             self:writeInt(#version) | ||||
|             self:writeChar(version) | ||||
|             self:writeInt(#className) | ||||
|             self:writeChar(className) | ||||
|             if object.write then | ||||
|                object:write(self) | ||||
|             elseif type(object) == 'table' then | ||||
|                local var = {} | ||||
|                for k,v in pairs(object) do | ||||
|                   if self:isWritableObject(v) then | ||||
|                      var[k] = v | ||||
|                   else | ||||
|                      print(string.format('$ Warning: cannot write object field <%s>', k)) | ||||
|                   end | ||||
|                end | ||||
|                self:writeObject(var) | ||||
|             else | ||||
|                error(string.format('<%s> is a non-serializable Torch object', torch.typename(object))) | ||||
|             end | ||||
|          else -- it is a table | ||||
|             local size = 0; for k,v in pairs(object) do size = size + 1 end | ||||
|             self:writeInt(size) | ||||
|             for k,v in pairs(object) do | ||||
|                self:writeObject(k) | ||||
|                self:writeObject(v) | ||||
|             end | ||||
|          end | ||||
|       end | ||||
|    else | ||||
|       error('Unwritable object') | ||||
|    end | ||||
| end | ||||
|  | ||||
| function File:readObject() | ||||
|    -- we use an environment to keep a record of read objects | ||||
|    if not torch.getenv(self).writeObjects then | ||||
|       torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) | ||||
|    end | ||||
|  | ||||
|    -- read the typeidx | ||||
|    local typeidx = self:readInt() | ||||
|  | ||||
|    -- is it nil? | ||||
|    if typeidx == TYPE_NIL then | ||||
|       return nil | ||||
|    end | ||||
|  | ||||
|    if typeidx == TYPE_NUMBER then | ||||
|       return self:readDouble() | ||||
|    elseif typeidx == TYPE_BOOLEAN then | ||||
|       return self:readBool() | ||||
|    elseif typeidx == TYPE_STRING then | ||||
|       local size = self:readInt() | ||||
|       return self:readChar(size):string() | ||||
|    elseif typeidx == TYPE_FUNCTION then | ||||
|       local size = self:readInt() | ||||
|       local dumped = self:readChar(size):string() | ||||
|       local func = loadstring(dumped) | ||||
|       local upvalues = self:readObject() | ||||
|       for index,upvalue in ipairs(upvalues) do | ||||
|          debug.setupvalue(func, index, upvalue) | ||||
|       end | ||||
|       return func | ||||
|    elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then | ||||
|       -- read the index | ||||
|       local index = self:readInt() | ||||
|  | ||||
|       -- check it is loaded already | ||||
|       local objects = torch.getenv(self).readObjects | ||||
|       if objects[index] then | ||||
|          return objects[index] | ||||
|       end | ||||
|  | ||||
|       -- otherwise read it | ||||
|       if typeidx == TYPE_TORCH then | ||||
|          local version, className, versionNumber | ||||
|          version = self:readChar(self:readInt()):string() | ||||
|          versionNumber = tonumber(string.match(version, '^V (.*)$')) | ||||
|          if not versionNumber then | ||||
|             className = version | ||||
|             versionNumber = 0 -- file created before existence of versioning system | ||||
|          else | ||||
|             className = self:readChar(self:readInt()):string() | ||||
|          end | ||||
|          if not torch.factory(className) then | ||||
|             error(string.format('unknown Torch class <%s>' .. className)) | ||||
|          end | ||||
|          local object = torch.factory(className)() | ||||
|          objects[index] = object | ||||
|          if object.read then | ||||
|             object:read(self, versionNumber) | ||||
|          elseif type(object) == 'table' then | ||||
|             local var = self:readObject(var) | ||||
|             for k,v in pairs(var) do | ||||
|                object[k] = v | ||||
|             end | ||||
|          else | ||||
|             error(string.format('Cannot load object class <%s>', className)) | ||||
|          end | ||||
|          return object | ||||
|       else -- it is a table | ||||
|          local size = self:readInt() | ||||
|          local object = {} | ||||
|          objects[index] = object | ||||
|          for i = 1,size do | ||||
|             local k = self:readObject() | ||||
|             local v = self:readObject() | ||||
|             object[k] = v | ||||
|          end | ||||
|          return object | ||||
|       end | ||||
|    else | ||||
|       error('unknown object') | ||||
|    end | ||||
| end | ||||
|  | ||||
| -- simple helpers to save/load arbitrary objects/tables | ||||
| function torch.save(filename, object, mode) | ||||
|    mode = mode or 'binary' | ||||
|    local file = torch.DiskFile(filename, 'w') | ||||
|    file[mode](file) | ||||
|    file:writeObject(object) | ||||
|    file:close() | ||||
| end | ||||
|  | ||||
| function torch.load(filename, mode) | ||||
|    mode = mode or 'binary' | ||||
|    local file = torch.DiskFile(filename, 'r') | ||||
|    file[mode](file) | ||||
|    local object = file:readObject() | ||||
|    file:close() | ||||
|    return object | ||||
| end | ||||
|  | ||||
| -- public API (saveobj/loadobj are safe for global import) | ||||
| torch.saveobj = torch.save | ||||
| torch.loadobj = torch.load | ||||
							
								
								
									
										67
									
								
								MemoryFile.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								MemoryFile.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| static const void* torch_MemoryFile_id; | ||||
| static const void* torch_CharStorage_id; | ||||
|  | ||||
| static int torch_MemoryFile_new(lua_State *L) | ||||
| { | ||||
|   const char *mode; | ||||
|   THCharStorage *storage = luaT_toudata(L, 1, torch_CharStorage_id); | ||||
|   THFile *self; | ||||
|  | ||||
|   if(storage) | ||||
|   { | ||||
|     mode = luaL_optstring(L, 2, "rw"); | ||||
|     self = THMemoryFile_newWithStorage(storage, mode); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     mode = luaL_optstring(L, 1, "rw");     | ||||
|     self = THMemoryFile_new(mode); | ||||
|   } | ||||
|  | ||||
|   luaT_pushudata(L, self, torch_MemoryFile_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_MemoryFile_storage(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id); | ||||
|   THCharStorage_retain(THMemoryFile_storage(self)); | ||||
|   luaT_pushudata(L, THMemoryFile_storage(self), torch_CharStorage_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_MemoryFile_free(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id); | ||||
|   THFile_free(self); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_MemoryFile___tostring__(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id); | ||||
|   lua_pushfstring(L, "torch.MemoryFile [status: %s -- mode: %c%c]", | ||||
|                   (THFile_isOpened(self) ? "open" : "closed"), | ||||
|                   (THFile_isReadable(self) ? 'r' : ' '), | ||||
|                   (THFile_isWritable(self) ? 'w' : ' ')); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_MemoryFile__ [] = { | ||||
|   {"storage", torch_MemoryFile_storage}, | ||||
|   {"__tostring__", torch_MemoryFile___tostring__}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_MemoryFile_init(lua_State *L) | ||||
| { | ||||
|   torch_CharStorage_id = luaT_checktypename2id(L, "torch.CharStorage"); | ||||
|  | ||||
|   torch_MemoryFile_id = luaT_newmetatable(L, "torch.MemoryFile", "torch.File", | ||||
|                                           torch_MemoryFile_new, torch_MemoryFile_free, NULL); | ||||
|  | ||||
|   luaL_register(L, NULL, torch_MemoryFile__); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
							
								
								
									
										46
									
								
								PipeFile.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								PipeFile.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,46 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| static const void* torch_PipeFile_id = NULL; | ||||
|  | ||||
| static int torch_PipeFile_new(lua_State *L) | ||||
| { | ||||
|   const char *name = luaL_checkstring(L, 1); | ||||
|   const char *mode = luaL_optstring(L, 2, "r"); | ||||
|   int isQuiet = luaT_optboolean(L, 3, 0); | ||||
|   THFile *self = THPipeFile_new(name, mode, isQuiet); | ||||
|  | ||||
|   luaT_pushudata(L, self, torch_PipeFile_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_PipeFile_free(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_PipeFile_id); | ||||
|   THFile_free(self); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_PipeFile___tostring__(lua_State *L) | ||||
| { | ||||
|   THFile *self = luaT_checkudata(L, 1, torch_PipeFile_id); | ||||
|   lua_pushfstring(L, "torch.PipeFile on <%s> [status: %s -- mode: %c%c]", | ||||
|                   THDiskFile_name(self), | ||||
|                   (THFile_isOpened(self) ? "open" : "closed"), | ||||
|                   (THFile_isReadable(self) ? 'r' : ' '), | ||||
|                   (THFile_isWritable(self) ? 'w' : ' ')); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_PipeFile__ [] = { | ||||
|   {"__tostring__", torch_PipeFile___tostring__}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_PipeFile_init(lua_State *L) | ||||
| { | ||||
|   torch_PipeFile_id = luaT_newmetatable(L, "torch.PipeFile", "torch.DiskFile", | ||||
|                                         torch_PipeFile_new, torch_PipeFile_free, NULL); | ||||
|  | ||||
|   luaL_register(L, NULL, torch_PipeFile__); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
							
								
								
									
										19
									
								
								Storage.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								Storage.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| static const void *torch_File_id = NULL; | ||||
| static const void *torch_ByteStorage_id = NULL; | ||||
| static const void *torch_CharStorage_id = NULL; | ||||
| static const void *torch_ShortStorage_id = NULL; | ||||
| static const void *torch_IntStorage_id = NULL; | ||||
| static const void *torch_LongStorage_id = NULL; | ||||
| static const void *torch_FloatStorage_id = NULL; | ||||
| static const void *torch_DoubleStorage_id = NULL; | ||||
|  | ||||
| #define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME) | ||||
| #define torch_Storage_id TH_CONCAT_3(torch_,Real,Storage_id) | ||||
| #define THFile_readRealRaw TH_CONCAT_3(THFile_read, Real, Raw) | ||||
| #define THFile_writeRealRaw TH_CONCAT_3(THFile_write, Real, Raw) | ||||
| #define STRING_torchStorage TH_CONCAT_STRING_3(torch.,Real,Storage) | ||||
|  | ||||
| #include "generic/Storage.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
							
								
								
									
										29
									
								
								Tensor.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								Tensor.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,29 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| static const void *torch_File_id = NULL; | ||||
|  | ||||
| static const void *torch_ByteStorage_id = NULL; | ||||
| static const void *torch_CharStorage_id = NULL; | ||||
| static const void *torch_ShortStorage_id = NULL; | ||||
| static const void *torch_IntStorage_id = NULL; | ||||
| static const void *torch_LongStorage_id = NULL; | ||||
| static const void *torch_FloatStorage_id = NULL; | ||||
| static const void *torch_DoubleStorage_id = NULL; | ||||
|  | ||||
| static const void *torch_ByteTensor_id = NULL; | ||||
| static const void *torch_CharTensor_id = NULL; | ||||
| static const void *torch_ShortTensor_id = NULL; | ||||
| static const void *torch_IntTensor_id = NULL; | ||||
| static const void *torch_LongTensor_id = NULL; | ||||
| static const void *torch_FloatTensor_id = NULL; | ||||
| static const void *torch_DoubleTensor_id = NULL; | ||||
|  | ||||
| #define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME) | ||||
| #define torch_Storage_id TH_CONCAT_3(torch_,Real,Storage_id) | ||||
| #define STRING_torchStorage TH_CONCAT_STRING_3(torch.,Real,Storage) | ||||
| #define torch_Tensor_(NAME) TH_CONCAT_4(torch_,Real,Tensor_,NAME) | ||||
| #define torch_Tensor_id TH_CONCAT_3(torch_,Real,Tensor_id) | ||||
| #define STRING_torchTensor TH_CONCAT_STRING_3(torch.,Real,Tensor) | ||||
|  | ||||
| #include "generic/Tensor.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
							
								
								
									
										279
									
								
								Tensor.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										279
									
								
								Tensor.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,279 @@ | ||||
| -- additional methods for Storage | ||||
| local Storage = {} | ||||
|  | ||||
| -- additional methods for Tensor | ||||
| local Tensor = {} | ||||
|  | ||||
| -- types | ||||
| local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} | ||||
|  | ||||
| -- tostring() functions for Tensor and Storage | ||||
| local function Storage__printformat(self) | ||||
|    local intMode = true | ||||
|    local type = torch.typename(self) | ||||
| --   if type == 'torch.FloatStorage' or type == 'torch.DoubleStorage' then | ||||
|       for i=1,self:size() do | ||||
|          if self[i] ~= math.ceil(self[i]) then | ||||
|             intMode = false | ||||
|             break | ||||
|          end | ||||
|       end | ||||
| --   end | ||||
|    local tensor = torch.DoubleTensor(torch.DoubleStorage(self:size()):copy(self), 1, self:size()):abs() | ||||
|    local expMin = tensor:minall() | ||||
|    if expMin ~= 0 then | ||||
|       expMin = math.floor(math.log10(expMin)) + 1 | ||||
|    end | ||||
|    local expMax = tensor:maxall() | ||||
|    if expMax ~= 0 then | ||||
|       expMax = math.floor(math.log10(expMax)) + 1 | ||||
|    end | ||||
|  | ||||
|    local format | ||||
|    local scale | ||||
|    local sz | ||||
|    if intMode then | ||||
|       if expMax > 9 then | ||||
|          format = "%11.4e" | ||||
|          sz = 11 | ||||
|       else | ||||
|          format = "%SZd" | ||||
|          sz = expMax + 1 | ||||
|       end | ||||
|    else | ||||
|       if expMax-expMin > 4 then | ||||
|          format = "%SZ.4e" | ||||
|          sz = 11 | ||||
|          if math.abs(expMax) > 99 or math.abs(expMin) > 99 then | ||||
|             sz = sz + 1 | ||||
|          end | ||||
|       else | ||||
|          if expMax > 5 or expMax < 0 then | ||||
|             format = "%SZ.4f" | ||||
|             sz = 7 | ||||
|             scale = math.pow(10, expMax-1) | ||||
|          else | ||||
|             format = "%SZ.4f" | ||||
|             if expMax == 0 then | ||||
|                sz = 7 | ||||
|             else | ||||
|                sz = expMax+6 | ||||
|             end | ||||
|          end | ||||
|       end | ||||
|    end | ||||
|    format = string.gsub(format, 'SZ', sz) | ||||
|    if scale == 1 then | ||||
|       scale = nil | ||||
|    end | ||||
|    return format, scale, sz | ||||
| end | ||||
|  | ||||
| function Storage.__tostring__(self) | ||||
|    local strt = {'\n'} | ||||
|    local format,scale = Storage__printformat(self) | ||||
|    if format:sub(2,4) == 'nan' then format = '%f' end | ||||
|    if scale then | ||||
|       table.insert(strt, string.format('%g', scale) .. ' *\n') | ||||
|       for i = 1,self:size() do | ||||
|          table.insert(strt, string.format(format, self[i]/scale) .. '\n') | ||||
|       end | ||||
|    else | ||||
|       for i = 1,self:size() do | ||||
|          table.insert(strt, string.format(format, self[i]) .. '\n') | ||||
|       end | ||||
|    end | ||||
|    table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. self:size() .. ']\n') | ||||
|    str = table.concat(strt) | ||||
|    return str | ||||
| end | ||||
|  | ||||
| for _,type in ipairs(types) do | ||||
|    local metatable = torch.getmetatable('torch.' .. type .. 'Storage') | ||||
|    for funcname, func in pairs(Storage) do | ||||
|       rawset(metatable, funcname, func) | ||||
|    end | ||||
| end | ||||
|  | ||||
| local function Tensor__printMatrix(self, indent) | ||||
|    local format,scale,sz = Storage__printformat(self:storage()) | ||||
|    if format:sub(2,4) == 'nan' then format = '%f' end | ||||
| --   print('format = ' .. format) | ||||
|    scale = scale or 1 | ||||
|    indent = indent or '' | ||||
|    local strt = {indent} | ||||
|    local nColumnPerLine = math.floor((80-#indent)/(sz+1)) | ||||
| --   print('sz = ' .. sz .. ' and nColumnPerLine = ' .. nColumnPerLine) | ||||
|    local firstColumn = 1 | ||||
|    local lastColumn = -1 | ||||
|    while firstColumn <= self:size(2) do | ||||
|       if firstColumn + nColumnPerLine - 1 <= self:size(2) then | ||||
|          lastColumn = firstColumn + nColumnPerLine - 1 | ||||
|       else | ||||
|          lastColumn = self:size(2) | ||||
|       end | ||||
|       if nColumnPerLine < self:size(2) then | ||||
|          if firstColumn ~= 1 then | ||||
|             table.insert(strt, '\n') | ||||
|          end | ||||
|          table.insert(strt, 'Columns ' .. firstColumn .. ' to ' .. lastColumn .. '\n' .. indent) | ||||
|       end | ||||
|       if scale ~= 1 then | ||||
|          table.insert(strt, string.format('%g', scale) .. ' *\n ' .. indent) | ||||
|       end | ||||
|       for l=1,self:size(1) do | ||||
|          local row = self:select(1, l) | ||||
|          for c=firstColumn,lastColumn do | ||||
|             table.insert(strt, string.format(format, row[c]/scale)) | ||||
|             if c == lastColumn then | ||||
|                table.insert(strt, '\n') | ||||
|                if l~=self:size(1) then | ||||
|                   if scale ~= 1 then | ||||
|                      table.insert(strt, indent .. ' ') | ||||
|                   else | ||||
|                      table.insert(strt, indent) | ||||
|                   end | ||||
|                end | ||||
|             else | ||||
|                table.insert(strt, ' ') | ||||
|             end | ||||
|          end | ||||
|       end | ||||
|       firstColumn = lastColumn + 1 | ||||
|    end | ||||
|    local str = table.concat(strt) | ||||
|    return str | ||||
| end | ||||
|  | ||||
| local function Tensor__printTensor(self) | ||||
|    local counter = torch.LongStorage(self:nDimension()-2) | ||||
|    local strt = {''} | ||||
|    local finished | ||||
|    counter:fill(1) | ||||
|    counter[1] = 0 | ||||
|    while true do | ||||
|       for i=1,self:nDimension()-2 do | ||||
|          counter[i] = counter[i] + 1 | ||||
|          if counter[i] > self:size(i) then | ||||
|             if i == self:nDimension()-2 then | ||||
|                finished = true | ||||
|                break | ||||
|             end | ||||
|             counter[i] = 1 | ||||
|          else | ||||
|             break | ||||
|          end | ||||
|       end | ||||
|       if finished then | ||||
|          break | ||||
|       end | ||||
| --      print(counter) | ||||
|       if #strt > 1 then | ||||
|          table.insert(strt, '\n') | ||||
|       end | ||||
|       table.insert(strt, '(') | ||||
|       local tensor = self | ||||
|       for i=1,self:nDimension()-2 do | ||||
|          tensor = tensor:select(1, counter[i]) | ||||
|          table.insert(strt, counter[i] .. ',') | ||||
|       end | ||||
|       table.insert(strt, '.,.) = \n') | ||||
|       table.insert(strt, Tensor__printMatrix(tensor, ' ')) | ||||
|    end | ||||
|    local str = table.concat(strt) | ||||
|    return str | ||||
| end | ||||
|  | ||||
| function Tensor.__tostring__(self) | ||||
|    local str = '\n' | ||||
|    local strt = {''} | ||||
|    if self:nDimension() == 0 then | ||||
|       table.insert(strt, '[' .. torch.typename(self) .. ' with no dimension]\n') | ||||
|    else | ||||
|       local tensor = torch.DoubleTensor():resize(self:size()):copy(self) | ||||
|       if tensor:nDimension() == 1 then | ||||
|          local format,scale,sz = Storage__printformat(tensor:storage()) | ||||
| 	 if format:sub(2,4) == 'nan' then format = '%f' end | ||||
|          if scale then | ||||
|             table.insert(strt, string.format('%g', scale) .. ' *\n') | ||||
|             for i = 1,tensor:size(1) do | ||||
|                table.insert(strt, string.format(format, tensor[i]/scale) .. '\n') | ||||
|             end | ||||
|          else | ||||
|             for i = 1,tensor:size(1) do | ||||
|                table.insert(strt, string.format(format, tensor[i]) .. '\n') | ||||
|             end | ||||
|          end | ||||
|          table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. ']\n') | ||||
|       elseif tensor:nDimension() == 2 then | ||||
|          table.insert(strt, Tensor__printMatrix(tensor)) | ||||
|          table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. 'x' .. tensor:size(2) .. ']\n') | ||||
|       else | ||||
|          table.insert(strt, Tensor__printTensor(tensor)) | ||||
|          table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ') | ||||
|          for i=1,tensor:nDimension() do | ||||
|             table.insert(strt, tensor:size(i)) | ||||
|             if i ~= tensor:nDimension() then | ||||
|                table.insert(strt, 'x') | ||||
|             end | ||||
|          end | ||||
|          table.insert(strt, ']\n') | ||||
|       end | ||||
|    end | ||||
|    local str = table.concat(strt) | ||||
|    return str | ||||
| end | ||||
|  | ||||
| function Tensor.type(self,type) | ||||
|    local current = torch.typename(self) | ||||
|    if not type then return current end | ||||
|    if type ~= current then | ||||
|       local new = torch.getmetatable(type).new() | ||||
|       if self:nElement() > 0 then | ||||
|          new:resize(self:size()):copy(self) | ||||
|       end | ||||
|       return new | ||||
|    else | ||||
|       return self | ||||
|    end | ||||
| end | ||||
|  | ||||
| function Tensor.typeAs(self,tensor) | ||||
|    return self:type(tensor:type()) | ||||
| end | ||||
|  | ||||
| function Tensor.byte(self,type) | ||||
|    return self:type('torch.ByteTensor') | ||||
| end | ||||
|  | ||||
| function Tensor.char(self,type) | ||||
|    return self:type('torch.CharTensor') | ||||
| end | ||||
|  | ||||
| function Tensor.short(self,type) | ||||
|    return self:type('torch.ShortTensor') | ||||
| end | ||||
|  | ||||
| function Tensor.int(self,type) | ||||
|    return self:type('torch.IntTensor') | ||||
| end | ||||
|  | ||||
| function Tensor.long(self,type) | ||||
|    return self:type('torch.LongTensor') | ||||
| end | ||||
|  | ||||
| function Tensor.float(self,type) | ||||
|    return self:type('torch.FloatTensor') | ||||
| end | ||||
|  | ||||
| function Tensor.double(self,type) | ||||
|    return self:type('torch.DoubleTensor') | ||||
| end | ||||
|  | ||||
|  | ||||
| for _,type in ipairs(types) do | ||||
|    local metatable = torch.getmetatable('torch.' .. type .. 'Tensor') | ||||
|    for funcname, func in pairs(Tensor) do | ||||
|       rawset(metatable, funcname, func) | ||||
|    end | ||||
| end | ||||
							
								
								
									
										127
									
								
								TensorConvWrap.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								TensorConvWrap.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,127 @@ | ||||
| --  | ||||
| -- require 'wrap' | ||||
| --- | ||||
|  | ||||
| interface = wrap.CInterface.new() | ||||
|  | ||||
|  | ||||
| interface.dispatchregistry = {} | ||||
| function interface:wrap(name, ...) | ||||
|    -- usual stuff | ||||
|    --wrap.CInterface.wrap(self, name, ...) | ||||
|  | ||||
|    -- dispatch function | ||||
|    if not interface.dispatchregistry[name] then | ||||
|       interface.dispatchregistry[name] = true | ||||
|       table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) | ||||
|  | ||||
|       interface:print(string.gsub([[ | ||||
| static int torch_NAME(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   const void *id; | ||||
|  | ||||
|   if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ | ||||
|   { | ||||
|     if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ | ||||
|     { | ||||
|       if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ | ||||
|         lua_pop(L, 1); | ||||
|       else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) | ||||
|         luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); | ||||
|     } | ||||
|   } | ||||
|    | ||||
|   lua_pushstring(L, "NAME"); | ||||
|   lua_rawget(L, -2); | ||||
|   if(lua_isfunction(L, -1)) | ||||
|   { | ||||
|     lua_insert(L, 1); | ||||
|     lua_pop(L, 2); /* the two tables we put on the stack above */ | ||||
|     lua_call(L, lua_gettop(L)-1, LUA_MULTRET); | ||||
|   } | ||||
|   else | ||||
|     return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); | ||||
|  | ||||
|   return lua_gettop(L); | ||||
| } | ||||
| ]], 'NAME', name)) | ||||
|   end | ||||
| end | ||||
|  | ||||
| function interface:dispatchregister(name) | ||||
|    local txt = self.txt | ||||
|    table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) | ||||
|    for _,reg in ipairs(self.dispatchregistry) do | ||||
|       table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) | ||||
|    end | ||||
|    table.insert(txt, '{NULL, NULL}') | ||||
|    table.insert(txt, '};') | ||||
|    table.insert(txt, '')    | ||||
|    self.dispatchregistry = {} | ||||
| end | ||||
|  | ||||
| interface:print('/* WARNING: autogenerated file */') | ||||
| interface:print('') | ||||
|  | ||||
| local reals = {ByteTensor='byte', | ||||
|                CharTensor='char', | ||||
|                ShortTensor='short', | ||||
|                IntTensor='int', | ||||
|                LongTensor='long', | ||||
|                FloatTensor='float', | ||||
|                DoubleTensor='double'} | ||||
|  | ||||
| for _,Tensor in ipairs({"FloatTensor", "DoubleTensor", "IntTensor", "LongTensor", "ByteTensor", "CharTensor","ShortTensor"}) do | ||||
|  | ||||
|    local real = reals[Tensor] | ||||
|  | ||||
|    function interface.luaname2wrapname(self, name) | ||||
|       return string.format('torch_%s_%s', Tensor, name) | ||||
|    end | ||||
|  | ||||
|    local function cname(name) | ||||
|       return string.format('TH%s_%s', Tensor, name) | ||||
|    end | ||||
|  | ||||
|    local function lastdim(argn) | ||||
|       return function(arg) | ||||
|                 return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) | ||||
|              end | ||||
|    end | ||||
|     | ||||
|  | ||||
|    for _,name in ipairs({"conv2","xcorr2","conv3","xcorr3"}) do | ||||
|       interface:wrap(name, | ||||
| 		     cname(name), | ||||
| 		     {{name=Tensor, default=true, returned=true}, | ||||
| 		      {name=Tensor, default=true, returned=true}, | ||||
| 		      {name=Tensor}, | ||||
| 		      {name=Tensor}} | ||||
| 		  ) | ||||
|    end | ||||
|     | ||||
|  | ||||
|    --interface:register(string.format("torch_%sLapack__", Tensor)) | ||||
|  | ||||
| --  interface:print(string.gsub([[ | ||||
| -- static void torch_TensorLapack_init(lua_State *L) | ||||
| -- { | ||||
| --   torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); | ||||
| --   torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); | ||||
|  | ||||
| --   luaT_pushmetaclass(L, torch_Tensor_id); | ||||
| --   lua_getfield(L,-1,"torch"); | ||||
| --   luaL_register(L, NULL, torch_TensorLapack__); | ||||
| --   lua_pop(L, 2); | ||||
| -- } | ||||
| -- ]], 'Tensor', Tensor)) | ||||
| end | ||||
|  | ||||
| interface:dispatchregister("torch_TensorConv__") | ||||
|  | ||||
| if arg[1] then | ||||
|    interface:tofile(arg[1]) | ||||
| else | ||||
|    interface:tostdio() | ||||
| end | ||||
							
								
								
									
										132
									
								
								TensorLapackWrap.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								TensorLapackWrap.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,132 @@ | ||||
| --  | ||||
| -- require 'wrap' | ||||
| --- | ||||
|  | ||||
| interface = wrap.CInterface.new() | ||||
|  | ||||
|  | ||||
| interface.dispatchregistry = {} | ||||
| function interface:wrap(name, ...) | ||||
|    -- usual stuff | ||||
|    wrap.CInterface.wrap(self, name, ...) | ||||
|  | ||||
|    -- dispatch function | ||||
|    if not interface.dispatchregistry[name] then | ||||
|       interface.dispatchregistry[name] = true | ||||
|       table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) | ||||
|  | ||||
|       interface:print(string.gsub([[ | ||||
| static int torch_NAME(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   const void *id; | ||||
|  | ||||
|   if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ | ||||
|   { | ||||
|     if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ | ||||
|     { | ||||
|       if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ | ||||
|         lua_pop(L, 1); | ||||
|       else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) | ||||
|         luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); | ||||
|     } | ||||
|   } | ||||
|    | ||||
|   lua_pushstring(L, "NAME"); | ||||
|   lua_rawget(L, -2); | ||||
|   if(lua_isfunction(L, -1)) | ||||
|   { | ||||
|     lua_insert(L, 1); | ||||
|     lua_pop(L, 2); /* the two tables we put on the stack above */ | ||||
|     lua_call(L, lua_gettop(L)-1, LUA_MULTRET); | ||||
|   } | ||||
|   else | ||||
|     return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); | ||||
|  | ||||
|   return lua_gettop(L); | ||||
| } | ||||
| ]], 'NAME', name)) | ||||
|   end | ||||
| end | ||||
|  | ||||
| function interface:dispatchregister(name) | ||||
|    local txt = self.txt | ||||
|    table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) | ||||
|    for _,reg in ipairs(self.dispatchregistry) do | ||||
|       table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) | ||||
|    end | ||||
|    table.insert(txt, '{NULL, NULL}') | ||||
|    table.insert(txt, '};') | ||||
|    table.insert(txt, '')    | ||||
|    self.dispatchregistry = {} | ||||
| end | ||||
|  | ||||
| interface:print('/* WARNING: autogenerated file */') | ||||
| interface:print('') | ||||
|  | ||||
| local reals = {ByteTensor='byte', | ||||
|                CharTensor='char', | ||||
|                ShortTensor='short', | ||||
|                IntTensor='int', | ||||
|                LongTensor='long', | ||||
|                FloatTensor='float', | ||||
|                DoubleTensor='double'} | ||||
|  | ||||
| for _,Tensor in ipairs({"FloatTensor", "DoubleTensor"}) do | ||||
|  | ||||
|    local real = reals[Tensor] | ||||
|  | ||||
|    function interface.luaname2wrapname(self, name) | ||||
|       return string.format('torch_%s_%s', Tensor, name) | ||||
|    end | ||||
|  | ||||
|    local function cname(name) | ||||
|       return string.format('TH%s_%s', Tensor, name) | ||||
|    end | ||||
|  | ||||
|    local function lastdim(argn) | ||||
|       return function(arg) | ||||
|                 return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) | ||||
|              end | ||||
|    end | ||||
|     | ||||
|  | ||||
| --    for _,name in ipairs({"gesv","gels","eig","svd"}) do | ||||
| --       interface:wrap(name, | ||||
| -- 		     cname(name), | ||||
| -- 		     {{name=Tensor, returned=true}, | ||||
| -- 		      {name=Tensor, returned=true}, | ||||
| -- 		      {name=Tensor}, | ||||
| -- 		      {name=Tensor}}, | ||||
| -- 		     cname(name), | ||||
| -- 		     {{name=Tensor, default=true, returned=true, invisible=true}, | ||||
| -- 		      {name=Tensor, default=true, returned=true, invisible=true}, | ||||
| -- 		      {name=Tensor}, | ||||
| -- 		      {name=Tensor}} | ||||
| -- 		  ) | ||||
| --    end | ||||
|     | ||||
|  | ||||
|    --interface:register(string.format("torch_%sLapack__", Tensor)) | ||||
|  | ||||
| --  interface:print(string.gsub([[ | ||||
| -- static void torch_TensorLapack_init(lua_State *L) | ||||
| -- { | ||||
| --   torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); | ||||
| --   torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); | ||||
|  | ||||
| --   luaT_pushmetaclass(L, torch_Tensor_id); | ||||
| --   lua_getfield(L,-1,"torch"); | ||||
| --   luaL_register(L, NULL, torch_TensorLapack__); | ||||
| --   lua_pop(L, 2); | ||||
| -- } | ||||
| -- ]], 'Tensor', Tensor)) | ||||
| end | ||||
|  | ||||
| interface:dispatchregister("torch_TensorLapack__") | ||||
|  | ||||
| if arg[1] then | ||||
|    interface:tofile(arg[1]) | ||||
| else | ||||
|    interface:tostdio() | ||||
| end | ||||
							
								
								
									
										54
									
								
								TensorMath.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								TensorMath.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | ||||
| #include "TH.h" | ||||
| #include "luaT.h" | ||||
| #include "utils.h" | ||||
|  | ||||
| #include "sys/time.h" | ||||
|  | ||||
| #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) | ||||
| #define torch_string_(NAME) TH_CONCAT_STRING_3(torch., Real, NAME) | ||||
|  | ||||
| static const void* torch_ByteTensor_id; | ||||
| static const void* torch_CharTensor_id; | ||||
| static const void* torch_ShortTensor_id; | ||||
| static const void* torch_IntTensor_id; | ||||
| static const void* torch_LongTensor_id; | ||||
| static const void* torch_FloatTensor_id; | ||||
| static const void* torch_DoubleTensor_id; | ||||
|  | ||||
| static const void* torch_LongStorage_id; | ||||
|  | ||||
|  | ||||
| #include "TensorMathWrap.c" | ||||
| //#include "TensorLapackWrap.c" | ||||
| //#include "TensorConvWrap.c" | ||||
|  | ||||
| //#include "generic/TensorLapack.c" | ||||
| //#include "THGenerateFloatTypes.h" | ||||
|  | ||||
| //#include "generic/TensorConv.c" | ||||
| //#include "THGenerateAllTypes.h" | ||||
|  | ||||
| void torch_TensorMath_init(lua_State *L) | ||||
| { | ||||
|   torch_ByteTensorMath_init(L); | ||||
|   torch_CharTensorMath_init(L); | ||||
|   torch_ShortTensorMath_init(L); | ||||
|   torch_IntTensorMath_init(L); | ||||
|   torch_LongTensorMath_init(L); | ||||
|   torch_FloatTensorMath_init(L); | ||||
|   torch_DoubleTensorMath_init(L); | ||||
|   luaL_register(L, NULL, torch_TensorMath__); | ||||
|  | ||||
| /*   torch_FloatLapack_init(L); */ | ||||
| /*   torch_DoubleLapack_init(L); */ | ||||
| /*   luaL_register(L, NULL, torch_TensorLapack__); */ | ||||
|  | ||||
| /*   torch_ByteConv_init(L); */ | ||||
| /*   torch_CharConv_init(L); */ | ||||
| /*   torch_ShortConv_init(L); */ | ||||
| /*   torch_IntConv_init(L); */ | ||||
| /*   torch_LongConv_init(L); */ | ||||
| /*   torch_FloatConv_init(L); */ | ||||
| /*   torch_DoubleConv_init(L); */ | ||||
| /*   luaL_register(L, NULL, torch_TensorConv__); */ | ||||
| } | ||||
							
								
								
									
										110
									
								
								TensorMath.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								TensorMath.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,110 @@ | ||||
| for _,tensortype in ipairs({'ByteTensor', | ||||
|                       'CharTensor', | ||||
|                       'ShortTensor', | ||||
|                       'IntTensor', | ||||
|                       'LongTensor', | ||||
|                       'FloatTensor', | ||||
|                       'DoubleTensor'}) do | ||||
|  | ||||
|    for _,func in ipairs({'add', | ||||
|                          'mul', | ||||
|                          'div', | ||||
|                          'cmul', | ||||
|                          'cdiv', | ||||
|                          'addcmul', | ||||
|                          'addcdiv', | ||||
|                          'log', | ||||
|                          'log1p', | ||||
|                          'exp', | ||||
|                          'cos', | ||||
|                          'acos', | ||||
|                          'cosh', | ||||
|                          'sin', | ||||
|                          'asin', | ||||
|                          'sinh', | ||||
|                          'tan', | ||||
|                          'atan', | ||||
|                          'tanh', | ||||
|                          'pow', | ||||
|                          'sqrt', | ||||
|                          'ceil', | ||||
|                          'floor', | ||||
|                          'abs', | ||||
| 			 'sign' | ||||
|                       }) do | ||||
|  | ||||
|       local torchfunc = torch[tensortype].torch[func] | ||||
|       torch[tensortype][func] = function(self, ...) | ||||
|                              return torchfunc(self, self, ...) | ||||
|                           end       | ||||
|    end | ||||
|  | ||||
|    for _,func in ipairs({'addmv', | ||||
|                          'addmm', | ||||
|                          'addr'}) do | ||||
|        | ||||
|       local torchfunc = torch[tensortype].torch[func] | ||||
|       torch[tensortype][func] = function(self, next1, next2, ...) | ||||
|                                    if type(next1) == 'number' and type(next2) == 'number' then | ||||
|                                       return torchfunc(self, next1, self, next2, ...) | ||||
|                                    elseif type(next1) == 'number' then | ||||
|                                       return torchfunc(self, self, next1, next2, ...)                                       | ||||
|                                    else | ||||
|                                       return torchfunc(self, self, next1, next2, ...) | ||||
|                                    end | ||||
|                           end       | ||||
|    end | ||||
|  | ||||
|    for _,func in ipairs({'zero', | ||||
|                          'fill', | ||||
|                          'dot', | ||||
|                          'minall', | ||||
|                          'maxall', | ||||
|                          'sumall',                          | ||||
|                          'numel', | ||||
|                          'max', | ||||
|                          'min', | ||||
|                          'sum', | ||||
|                          'prod', | ||||
|                          'cumsum', | ||||
|                          'cumprod', | ||||
|                          'trace', | ||||
|                          'cross', | ||||
|                          'zeros', | ||||
|                          'ones', | ||||
|                          'diag', | ||||
|                          'eye', | ||||
|                          'range', | ||||
|                          'randperm', | ||||
|                          'reshape', | ||||
|                          'sort', | ||||
|                          'tril', | ||||
|                          'triu', | ||||
|                          '_histc', | ||||
|                          'cat', | ||||
|                          'mean', | ||||
|                          'std', | ||||
|                          'var', | ||||
|                          'norm', | ||||
|                          'dist', | ||||
|                          'meanall', | ||||
|                          'varall', | ||||
|                          'stdall', | ||||
|                          'linspace', | ||||
|                          'logspace', | ||||
|                          'rand', | ||||
|                          'randn', | ||||
|                          'random', | ||||
|                          'uniform', | ||||
|                          'normal', | ||||
|                          'cauchy', | ||||
|                          'logNormal', | ||||
|                          'exponential', | ||||
|                          'geometric', | ||||
|                          'bernoulli', | ||||
|                          'squeeze' | ||||
|                       }) do | ||||
|  | ||||
|       torch[tensortype][func] = torch[tensortype].torch[func] | ||||
|    end | ||||
| end | ||||
							
								
								
									
										870
									
								
								TensorMathWrap.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										870
									
								
								TensorMathWrap.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,870 @@ | ||||
| --  | ||||
| -- require 'wrap' | ||||
| --- | ||||
|  | ||||
| local interface = wrap.CInterface.new() | ||||
|  | ||||
| interface:print([[ | ||||
| #include "TH.h" | ||||
| #include "luaT.h" | ||||
| #include "utils.h" | ||||
|  | ||||
| static const void* torch_ByteTensor_id; | ||||
| static const void* torch_CharTensor_id; | ||||
| static const void* torch_ShortTensor_id; | ||||
| static const void* torch_IntTensor_id; | ||||
| static const void* torch_LongTensor_id; | ||||
| static const void* torch_FloatTensor_id; | ||||
| static const void* torch_DoubleTensor_id; | ||||
|  | ||||
| static const void* torch_LongStorage_id; | ||||
|                 ]]) | ||||
|  | ||||
| -- special argument specific to torch package | ||||
| interface.argtypes.LongArg = { | ||||
|  | ||||
|    vararg = true, | ||||
|  | ||||
|    helpname = function(arg) | ||||
|                return "(LongStorage | dim1 [dim2...])" | ||||
|             end, | ||||
|  | ||||
|    declare = function(arg) | ||||
|               return string.format("THLongStorage *arg%d = NULL;", arg.i) | ||||
|            end, | ||||
|  | ||||
|    init = function(arg) | ||||
|              if arg.default then | ||||
|                 error('LongArg cannot have a default value') | ||||
|              end | ||||
|           end, | ||||
|     | ||||
|    check = function(arg, idx) | ||||
|             return string.format("torch_islongargs(L, %d)", idx) | ||||
|          end, | ||||
|  | ||||
|    read = function(arg, idx) | ||||
|              return string.format("arg%d = torch_checklongargs(L, %d);", arg.i, idx) | ||||
|           end, | ||||
|     | ||||
|    carg = function(arg, idx) | ||||
|              return string.format('arg%d', arg.i) | ||||
|           end, | ||||
|  | ||||
|    creturn = function(arg, idx) | ||||
|                 return string.format('arg%d', arg.i) | ||||
|              end, | ||||
|     | ||||
|    precall = function(arg) | ||||
|                 local txt = {} | ||||
|                 if arg.returned then | ||||
|                    table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongStorage_id);', arg.i)) | ||||
|                 end | ||||
|                 return table.concat(txt, '\n') | ||||
|              end, | ||||
|  | ||||
|    postcall = function(arg) | ||||
|                  local txt = {} | ||||
|                  if arg.creturned then | ||||
|                     -- this next line is actually debatable | ||||
|                     table.insert(txt, string.format('THLongStorage_retain(arg%d);', arg.i)) | ||||
|                     table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongStorage_id);', arg.i)) | ||||
|                  end | ||||
|                  if not arg.returned and not arg.creturned then | ||||
|                     table.insert(txt, string.format('THLongStorage_free(arg%d);', arg.i)) | ||||
|                  end | ||||
|                  return table.concat(txt, '\n') | ||||
|               end    | ||||
| } | ||||
|  | ||||
| interface.argtypes.charoption = { | ||||
|  | ||||
|    helpname = function(arg) | ||||
|                  if arg.values then | ||||
|                     return "(" .. table.concat(arg.values, '|') .. ")" | ||||
|                  end | ||||
|               end, | ||||
|  | ||||
|    declare = function(arg) | ||||
|                 local txt = {} | ||||
|                 table.insert(txt, string.format("const char *arg%d = NULL;", arg.i)) | ||||
|                 if arg.default then | ||||
|                    table.insert(txt, string.format("char arg%d_default = '%s';", arg.i, arg.default)) | ||||
|                 end | ||||
|                 return table.concat(txt, '\n') | ||||
|            end, | ||||
|  | ||||
|    init = function(arg) | ||||
|              return string.format("arg%d = &arg%d_default;", arg.i, arg.i) | ||||
|           end, | ||||
|     | ||||
|    check = function(arg, idx) | ||||
|               local txt = {} | ||||
|               local txtv = {} | ||||
|               table.insert(txt, string.format('(arg%d = lua_tostring(L, %d)) && (', arg.i, idx)) | ||||
|               for _,value in ipairs(arg.values) do | ||||
|                  table.insert(txtv, string.format("*arg%d == '%s'", arg.i, value)) | ||||
|               end | ||||
|               table.insert(txt, table.concat(txtv, ' || ')) | ||||
|               table.insert(txt, ')')               | ||||
|               return table.concat(txt, '') | ||||
|          end, | ||||
|  | ||||
|    read = function(arg, idx) | ||||
|           end, | ||||
|     | ||||
|    carg = function(arg, idx) | ||||
|              return string.format('arg%d', arg.i) | ||||
|           end, | ||||
|  | ||||
|    creturn = function(arg, idx) | ||||
|              end, | ||||
|     | ||||
|    precall = function(arg) | ||||
|              end, | ||||
|  | ||||
|    postcall = function(arg) | ||||
|               end    | ||||
| } | ||||
|  | ||||
| -- also specific to torch: we generate a 'dispatch' function | ||||
| -- first we create a helper function | ||||
| interface:print([[ | ||||
| static const void* torch_istensorid(lua_State *L, const void *id) | ||||
| { | ||||
|   if(!id) | ||||
|     return NULL; | ||||
|  | ||||
|   luaT_pushmetaclass(L, id); | ||||
|   lua_pushstring(L, "torch"); | ||||
|   lua_rawget(L, -2); | ||||
|   if(lua_istable(L, -1)) | ||||
|     return id; | ||||
|   else | ||||
|   { | ||||
|     lua_pop(L, 2); | ||||
|     return NULL; | ||||
|   } | ||||
|  | ||||
|   return NULL; | ||||
| } | ||||
| ]]) | ||||
|  | ||||
| interface.dispatchregistry = {} | ||||
| function interface:wrap(name, ...) | ||||
|    -- usual stuff | ||||
|    wrap.CInterface.wrap(self, name, ...) | ||||
|  | ||||
|    -- dispatch function | ||||
|    if not interface.dispatchregistry[name] then | ||||
|       interface.dispatchregistry[name] = true | ||||
|       table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) | ||||
|  | ||||
|       interface:print(string.gsub([[ | ||||
| static int torch_NAME(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   const void *id; | ||||
|  | ||||
|   if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ | ||||
|   { | ||||
|     if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ | ||||
|     { | ||||
|       if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ | ||||
|         lua_pop(L, 1); | ||||
|       else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) | ||||
|         luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); | ||||
|     } | ||||
|   } | ||||
|    | ||||
|   lua_pushstring(L, "NAME"); | ||||
|   lua_rawget(L, -2); | ||||
|   if(lua_isfunction(L, -1)) | ||||
|   { | ||||
|     lua_insert(L, 1); | ||||
|     lua_pop(L, 2); /* the two tables we put on the stack above */ | ||||
|     lua_call(L, lua_gettop(L)-1, LUA_MULTRET); | ||||
|   } | ||||
|   else | ||||
|     return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); | ||||
|  | ||||
|   return lua_gettop(L); | ||||
| } | ||||
| ]], 'NAME', name)) | ||||
|   end | ||||
| end | ||||
|  | ||||
| function interface:dispatchregister(name) | ||||
|    local txt = self.txt | ||||
|    table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) | ||||
|    for _,reg in ipairs(self.dispatchregistry) do | ||||
|       table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) | ||||
|    end | ||||
|    table.insert(txt, '{NULL, NULL}') | ||||
|    table.insert(txt, '};') | ||||
|    table.insert(txt, '')    | ||||
|    self.dispatchregistry = {} | ||||
| end | ||||
|  | ||||
| interface:print('/* WARNING: autogenerated file */') | ||||
| interface:print('') | ||||
|  | ||||
| local reals = {ByteTensor='unsigned char', | ||||
|                CharTensor='char', | ||||
|                ShortTensor='short', | ||||
|                IntTensor='int', | ||||
|                LongTensor='long', | ||||
|                FloatTensor='float', | ||||
|                DoubleTensor='double'} | ||||
|  | ||||
| for _,Tensor in ipairs({"ByteTensor", "CharTensor", | ||||
|                         "ShortTensor", "IntTensor", "LongTensor", | ||||
|                         "FloatTensor", "DoubleTensor"}) do | ||||
|  | ||||
|    local real = reals[Tensor] | ||||
|  | ||||
|    function interface.luaname2wrapname(self, name) | ||||
|       return string.format('torch_%s_%s', Tensor, name) | ||||
|    end | ||||
|  | ||||
|    local function cname(name) | ||||
|       return string.format('TH%s_%s', Tensor, name) | ||||
|    end | ||||
|  | ||||
|    local function lastdim(argn) | ||||
|       return function(arg) | ||||
|                 return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) | ||||
|              end | ||||
|    end | ||||
|     | ||||
|    interface:wrap("zero", | ||||
|                   cname("zero"), | ||||
|                   {{name=Tensor, returned=true}}) | ||||
|  | ||||
|    interface:wrap("fill", | ||||
|                   cname("fill"), | ||||
|                   {{name=Tensor, returned=true}, | ||||
|                    {name=real}}) | ||||
|  | ||||
|    interface:wrap("zeros", | ||||
|                   cname("zeros"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name="LongArg"}}) | ||||
|  | ||||
|    interface:wrap("ones", | ||||
|                   cname("ones"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name="LongArg"}}) | ||||
|  | ||||
|    interface:wrap("reshape", | ||||
|                   cname("reshape"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="LongArg"}}) | ||||
|  | ||||
|    interface:wrap("dot", | ||||
|                   cname("dot"), | ||||
|                   {{name=Tensor}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real, creturned=true}}) | ||||
|  | ||||
|    for _,name in ipairs({"minall", "maxall", "sumall"}) do | ||||
|       interface:wrap(name, | ||||
|                      cname(name), | ||||
|                      {{name=Tensor},             | ||||
|                       {name=real, creturned=true}}) | ||||
|    end | ||||
|  | ||||
|    interface:wrap("add", | ||||
|                   cname("add"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real}}, | ||||
|                   cname("cadd"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real, default=1}, | ||||
|                    {name=Tensor}}) | ||||
|  | ||||
|    interface:wrap("mul", | ||||
|                   cname("mul"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real}}) | ||||
|  | ||||
|    interface:wrap("div", | ||||
|                   cname("div"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real}}) | ||||
|  | ||||
|    interface:wrap("cmul", | ||||
|                   cname("cmul"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=Tensor}}) | ||||
|  | ||||
|    interface:wrap("cdiv", | ||||
|                   cname("cdiv"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=Tensor}}) | ||||
|  | ||||
|    interface:wrap("addcmul", | ||||
|                   cname("addcmul"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real, default=1}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=Tensor}}) | ||||
|  | ||||
|    interface:wrap("addcdiv", | ||||
|                   cname("addcdiv"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=real, default=1}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=Tensor}}) | ||||
|  | ||||
|    for _,name in ipairs({"addmv", "addmm", "addr"}) do | ||||
|       interface:wrap(name, | ||||
|                      cname(name), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=real, default=1}, | ||||
|                       {name=Tensor}, | ||||
|                       {name=real, default=1}, | ||||
|                       {name=Tensor}, | ||||
|                       {name=Tensor}}) | ||||
|    end | ||||
|  | ||||
|    interface:wrap("numel", | ||||
|                   cname("numel"), | ||||
|                   {{name=Tensor}, | ||||
|                    {name=real, creturned=true}}) | ||||
|  | ||||
|    for _,name in ipairs({"sum", "prod", "cumsum", "cumprod"}) do | ||||
|       interface:wrap(name, | ||||
|                      cname(name), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name="index", default=lastdim(2)}}) | ||||
|    end | ||||
|  | ||||
|    interface:wrap("min", | ||||
|                   cname("min"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name="IndexTensor", default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="index", default=lastdim(3)}}) | ||||
|  | ||||
|    interface:wrap("max", | ||||
|                   cname("max"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name="IndexTensor", default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="index", default=lastdim(3)}}) | ||||
|  | ||||
|    interface:wrap("trace", | ||||
|                   cname("trace"), | ||||
|                   {{name=Tensor}, | ||||
|                    {name=real, creturned=true}}) | ||||
|  | ||||
|    interface:wrap("cross", | ||||
|                   cname("cross"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="index", default=0}}) | ||||
|  | ||||
|    interface:wrap("diag", | ||||
|                   cname("diag"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="long", default=0}}) | ||||
|  | ||||
|    interface:wrap("eye", | ||||
|                   cname("eye"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name="long"}, | ||||
|                    {name="long", default=0}}) | ||||
|  | ||||
|    interface:wrap("range", | ||||
|                   cname("range"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real}, | ||||
|                    {name=real}, | ||||
|                    {name=real, default=1}}) | ||||
|  | ||||
|    interface:wrap("randperm", | ||||
|                   cname("randperm"), | ||||
|                   {{name=Tensor, default=true, returned=true, userpostcall=function(arg) | ||||
|                                                                               return string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg()) | ||||
|                                                                            end}, | ||||
|                    {name="long"}}) | ||||
|  | ||||
|    interface:wrap("sort", | ||||
|                   cname("sort"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name="IndexTensor", default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="index", default=lastdim(3)}, | ||||
|                    {name="boolean", default=0}}) | ||||
|  | ||||
|  | ||||
|    interface:wrap("tril", | ||||
|                   cname("tril"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="int", default=0}}) | ||||
|  | ||||
|    interface:wrap("triu", | ||||
|                   cname("triu"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="int", default=0}}) | ||||
|  | ||||
|    interface:wrap("cat", | ||||
|                   cname("cat"), | ||||
|                   {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=Tensor}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="index", default=lastdim(2)}}) | ||||
|  | ||||
|    if Tensor == 'ByteTensor' then -- we declare this only once | ||||
|       interface:print( | ||||
|          [[ | ||||
| static int THRandom_random2__(long a, long b) | ||||
| { | ||||
|   THArgCheck(b >= a, 2, "upper bound must be larger than lower bound"); | ||||
|   return((THRandom_random() % (b+1-a)) + a); | ||||
| } | ||||
|           | ||||
| static int THRandom_random1__(long b) | ||||
| { | ||||
|   THArgCheck(b > 0, 1, "upper bound must be strictly positive"); | ||||
|   return(THRandom_random() % b + 1); | ||||
| } | ||||
|          ]]) | ||||
|    end | ||||
|  | ||||
|    interface:print(string.gsub( | ||||
|                       [[ | ||||
| static void THTensor_random2__(THTensor *self, long a, long b) | ||||
| { | ||||
|   THArgCheck(b >= a, 2, "upper bound must be larger than lower bound"); | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = ((THRandom_random() % (b+1-a)) + a);) | ||||
| } | ||||
|  | ||||
| static void THTensor_random1__(THTensor *self, long b) | ||||
| { | ||||
|   THArgCheck(b > 0, 1, "upper bound must be strictly positive"); | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (THRandom_random() % b + 1);) | ||||
| } | ||||
| ]], 'Tensor', Tensor):gsub('real', real)) | ||||
|  | ||||
|    interface:wrap('random', | ||||
|                   'THRandom_random2__', | ||||
|                   {{name='long'}, | ||||
|                    {name='long'}, | ||||
|                    {name='long', creturned=true}}, | ||||
|                   'THRandom_random1__', | ||||
|                   {{name='long'}, | ||||
|                    {name='long', creturned=true}}, | ||||
|                   'THRandom_random', | ||||
|                   {{name='long', creturned=true}}, | ||||
|                   cname("random2__"), | ||||
|                   {{name=Tensor}, | ||||
|                    {name='long'}, | ||||
|                    {name='long'}}, | ||||
|                   cname("random1__"), | ||||
|                   {{name=Tensor}, | ||||
|                    {name='long'}}, | ||||
|                   cname("random"), | ||||
|                   {{name=Tensor}}) | ||||
|  | ||||
|    for _,f in ipairs({{name='geometric'}, | ||||
|                       {name='bernoulli', a=0.5}}) do | ||||
|        | ||||
|       interface:wrap(f.name, | ||||
|                      string.format("THRandom_%s", f.name), | ||||
|                      {{name="double", default=f.a}, | ||||
|                       {name="double", creturned=true}}, | ||||
|                      cname(f.name), | ||||
|                      {{name=Tensor, returned=true}, | ||||
|                       {name=real, default=f.a}}) | ||||
|    end | ||||
|  | ||||
|    interface:wrap("squeeze", | ||||
|                   cname("squeeze"), | ||||
|                   {{name=Tensor, default=true, returned=true, postcall=function(arg) | ||||
|                                                                          local txt = {} | ||||
|                                                                          if arg.returned then | ||||
|                                                                             table.insert(txt, string.format('if(arg%d->nDimension == 1 && arg%d->size[0] == 1)', arg.i, arg.i)) -- number | ||||
|                                                                             table.insert(txt, string.format('lua_pushnumber(L, (lua_Number)(*TH%s_data(arg%d)));', Tensor, arg.i)) | ||||
|                                                                          end | ||||
|                                                                          return table.concat(txt, '\n') | ||||
|                                                                       end}, | ||||
|                    {name=Tensor}}, | ||||
|                   cname("squeeze1d"), | ||||
|                   {{name=Tensor, default=true, returned=true, postcall=function(arg) | ||||
|                                                                           local txt = {} | ||||
|                                                                           if arg.returned then | ||||
|                                                                              table.insert(txt, string.format('if(arg%d->nDimension == 1 && arg%d->size[0] == 1)', arg.i, arg.i)) -- number | ||||
|                                                                             table.insert(txt, string.format('lua_pushnumber(L, (lua_Number)(*TH%s_data(arg%d)));', Tensor, arg.i)) | ||||
|                                                                          end | ||||
|                                                                          return table.concat(txt, '\n') | ||||
|                                                                       end}, | ||||
|                    {name=Tensor}, | ||||
|                    {name="index"}}) | ||||
|  | ||||
|    interface:wrap("sign", | ||||
| 		  cname("sign"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
| 		   {name=Tensor}}) | ||||
|  | ||||
|    interface:wrap("conv2", | ||||
| 		  cname("conv2Dmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=2}, | ||||
|                    {name=Tensor, dim=2}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="C", invisible=true}}, | ||||
| 		  cname("conv2Dcmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="C", invisible=true}}, | ||||
| 		  cname("conv2Dmv"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="C", invisible=true}} | ||||
|                ) | ||||
|  | ||||
|    interface:wrap("xcorr2", | ||||
| 		  cname("conv2Dmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=2}, | ||||
|                    {name=Tensor, dim=2}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name='charoption', values={'V', 'F'}, default='V'}, | ||||
| 		   {name='charoption', default="X", invisible=true}}, | ||||
| 		  cname("conv2Dcmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="X", invisible=true}}, | ||||
| 		  cname("conv2Dmv"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="X", invisible=true}} | ||||
| 		 ) | ||||
|  | ||||
|    interface:wrap("conv3", | ||||
| 		  cname("conv3Dmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
| 		   {name='charoption', default="C", invisible=true}}, | ||||
| 		  cname("conv3Dcmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
| 		   {name='charoption', default="C", invisible=true}}, | ||||
| 		  cname("conv3Dmv"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=Tensor, dim=5}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="C", invisible=true}} | ||||
| 		 ) | ||||
|  | ||||
|    interface:wrap("xcorr3", | ||||
| 		  cname("conv3Dmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=Tensor, dim=3}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
|                    {name='charoption', default="X", invisible=true}}, | ||||
| 		  cname("conv3Dcmul"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
| 		   {name='charoption', default="X", invisible=true}}, | ||||
| 		  cname("conv3Dmv"), | ||||
| 		  {{name=Tensor, default=true, returned=true}, | ||||
|                    {name=real, default=0, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=Tensor, dim=4}, | ||||
|                    {name=Tensor, dim=5}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
|                    {name=real, default=1, invisible=true}, | ||||
| 		   {name='charoption', values={'V', 'F'}, default='V'}, | ||||
| 		   {name='charoption', default="X", invisible=true}} | ||||
| 		 ) | ||||
|  | ||||
|    if Tensor == 'FloatTensor' or Tensor == 'DoubleTensor' then | ||||
|  | ||||
|       interface:wrap("mean", | ||||
|                      cname("mean"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name="index", default=lastdim(2)}}) | ||||
|  | ||||
|       interface:wrap("std", | ||||
|                      cname("std"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name="index", default=lastdim(2)}, | ||||
|                       {name="boolean", default=false}}) | ||||
|  | ||||
|       interface:wrap("var", | ||||
|                      cname("var"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name="index", default=lastdim(2)}, | ||||
|                       {name="boolean", default=false}}) | ||||
|  | ||||
|       interface:wrap("norm", | ||||
|                      cname("norm"), | ||||
|                      {{name=Tensor}, | ||||
|                       {name=real, default=2}, | ||||
|                       {name=real, creturned=true}}) | ||||
|  | ||||
|       interface:wrap("dist", | ||||
|                      cname("dist"), | ||||
|                      {{name=Tensor}, | ||||
|                       {name=Tensor}, | ||||
|                       {name=real, default=2}, | ||||
|                       {name=real, creturned=true}}) | ||||
|  | ||||
|       for _,name in ipairs({"meanall", "varall", "stdall"}) do | ||||
|          interface:wrap(name, | ||||
|                         cname(name), | ||||
|                         {{name=Tensor}, | ||||
|                          {name=real, creturned=true}}) | ||||
|       end | ||||
|  | ||||
|       interface:wrap("linspace", | ||||
|                      cname("linspace"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=real}, | ||||
|                       {name=real}, | ||||
|                       {name="long", default=100}}) | ||||
|  | ||||
|       interface:wrap("logspace", | ||||
|                      cname("logspace"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=real}, | ||||
|                       {name=real}, | ||||
|                       {name="long", default=100}}) | ||||
|  | ||||
|       for _,name in ipairs({"log", "log1p", "exp", | ||||
|                             "cos", "acos", "cosh", | ||||
|                             "sin", "asin", "sinh", | ||||
|                             "tan", "atan", "tanh", | ||||
|                             "sqrt", | ||||
|                             "ceil", "floor", | ||||
|                             "abs"}) do | ||||
|  | ||||
|          interface:wrap(name, | ||||
|                         cname(name), | ||||
|                         {{name=Tensor, default=true, returned=true}, | ||||
|                          {name=Tensor}}, | ||||
|                         name, | ||||
|                         {{name=real}, | ||||
|                          {name=real, creturned=true}}) | ||||
|           | ||||
|       end | ||||
|  | ||||
|       interface:wrap("pow", | ||||
|                      cname("pow"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name=real}}, | ||||
|                      "pow", | ||||
|                      {{name=real}, | ||||
|                       {name=real}, | ||||
|                       {name=real, creturned=true}}) | ||||
|  | ||||
|       interface:wrap("rand", | ||||
|                      cname("rand"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name="LongArg"}}) | ||||
|  | ||||
|       interface:wrap("randn", | ||||
|                      cname("randn"), | ||||
|                      {{name=Tensor, default=true, returned=true}, | ||||
|                       {name="LongArg"}}) | ||||
|  | ||||
|       for _,f in ipairs({{name='uniform', a=0, b=1}, | ||||
|                          {name='normal', a=0, b=1}, | ||||
|                          {name='cauchy', a=0, b=1}, | ||||
|                          {name='logNormal', a=1, b=2}}) do | ||||
|           | ||||
|          interface:wrap(f.name, | ||||
|                         string.format("THRandom_%s", f.name), | ||||
|                         {{name="double", default=f.a}, | ||||
|                          {name="double", default=f.b}, | ||||
|                          {name="double", creturned=true}}, | ||||
|                         cname(f.name), | ||||
|                         {{name=Tensor, returned=true}, | ||||
|                          {name=real, default=f.a}, | ||||
|                          {name=real, default=f.b}}) | ||||
|       end | ||||
|  | ||||
|       for _,f in ipairs({{name='exponential'}}) do | ||||
|           | ||||
|          interface:wrap(f.name, | ||||
|                         string.format("THRandom_%s", f.name), | ||||
|                         {{name="double", default=f.a}, | ||||
|                          {name="double", creturned=true}}, | ||||
|                         cname(f.name), | ||||
|                         {{name=Tensor, returned=true}, | ||||
|                          {name=real, default=f.a}}) | ||||
|       end | ||||
|        | ||||
|       for _,name in ipairs({"gesv","gels"}) do | ||||
|          interface:wrap(name, | ||||
|                         cname(name), | ||||
|                         {{name=Tensor, returned=true}, | ||||
|                          {name=Tensor, returned=true}, | ||||
|                          {name=Tensor}, | ||||
|                          {name=Tensor}}, | ||||
|                         cname(name), | ||||
|                         {{name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                          {name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                          {name=Tensor}, | ||||
|                          {name=Tensor}} | ||||
|                      ) | ||||
|       end | ||||
|  | ||||
|       interface:wrap("eig", | ||||
|                      cname("syev"), | ||||
|                      {{name=Tensor, returned=true}, | ||||
|                       {name=Tensor, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name='charoption', values={'N', 'V'}, default='N'}, | ||||
|                       {name='charoption', values={'U', 'L'}, default='U'}}, | ||||
|                      cname("syev"), | ||||
|                      {{name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                       {name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name='charoption', values={'N', 'V'}, default='N'}, | ||||
|                       {name='charoption', values={'U', 'L'}, default='U'}} | ||||
|                   ) | ||||
|  | ||||
|       interface:wrap("svd", | ||||
|                      cname("gesvd"), | ||||
|                      {{name=Tensor, returned=true}, | ||||
|                       {name=Tensor, returned=true}, | ||||
|                       {name=Tensor, returned=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name='charoption', values={'A', 'S'}, default='S'}}, | ||||
|                      cname("gesvd"), | ||||
|                      {{name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                       {name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                       {name=Tensor, default=true, returned=true, invisible=true}, | ||||
|                       {name=Tensor}, | ||||
|                       {name='charoption', values={'A', 'S'}, default='S'}} | ||||
|                   ) | ||||
|        | ||||
|    end | ||||
|  | ||||
|    interface:register(string.format("torch_%sMath__", Tensor)) | ||||
|  | ||||
|    interface:print(string.gsub([[ | ||||
| static void torch_TensorMath_init(lua_State *L) | ||||
| { | ||||
|   torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); | ||||
|   torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); | ||||
|  | ||||
|   /* register everything into the "torch" field of the tensor metaclass */ | ||||
|   luaT_pushmetaclass(L, torch_Tensor_id); | ||||
|   lua_pushstring(L, "torch"); | ||||
|   lua_newtable(L); | ||||
|   luaL_register(L, NULL, torch_TensorMath__); | ||||
|   lua_rawset(L, -3); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
| ]], 'Tensor', Tensor)) | ||||
| end | ||||
|  | ||||
| interface:dispatchregister("torch_TensorMath__") | ||||
|  | ||||
| interface:print([[ | ||||
| void torch_TensorMath_init(lua_State *L) | ||||
| { | ||||
|   torch_ByteTensorMath_init(L); | ||||
|   torch_CharTensorMath_init(L); | ||||
|   torch_ShortTensorMath_init(L); | ||||
|   torch_IntTensorMath_init(L); | ||||
|   torch_LongTensorMath_init(L); | ||||
|   torch_FloatTensorMath_init(L); | ||||
|   torch_DoubleTensorMath_init(L); | ||||
|   luaL_register(L, NULL, torch_TensorMath__); | ||||
| } | ||||
| ]]) | ||||
|  | ||||
| if arg[1] then | ||||
|    interface:tofile(arg[1]) | ||||
| else | ||||
|    interface:tostdio() | ||||
| end | ||||
							
								
								
									
										8
									
								
								TensorOperator.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								TensorOperator.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| #define torch_TensorOperator_(NAME) TH_CONCAT_4(torch_,Real,TensorOperator_,NAME) | ||||
| #define torch_Tensor_id TH_CONCAT_3(torch_,Real,Tensor_id) | ||||
| #define STRING_torchTensor TH_CONCAT_STRING_3(torch.,Real,Tensor) | ||||
|  | ||||
| #include "generic/TensorOperator.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
							
								
								
									
										124
									
								
								Tester.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								Tester.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,124 @@ | ||||
| local Tester = torch.class('torch.Tester') | ||||
|  | ||||
| function Tester:__init() | ||||
|    self.errors = {} | ||||
|    self.tests = {} | ||||
|    self.testnames = {} | ||||
|    self.curtestname = '' | ||||
| end | ||||
|  | ||||
|  | ||||
| function Tester:assert_sub (condition, message) | ||||
|    if not condition then | ||||
|       local ss = debug.traceback('tester',2) | ||||
|       --print(ss) | ||||
|       ss = ss:match('[^\n]+\n[^\n]+\n([^\n]+\n[^\n]+)\n') | ||||
|       self.errors[#self.errors+1] = self.curtestname .. '\n' .. message .. '\n' .. ss .. '\n' | ||||
|    end | ||||
| end | ||||
| function Tester:assert (condition, message) | ||||
|    self:assert_sub(condition,string.format('%s\n%s  condition=%s',message,' BOOL violation ', tostring(condition))) | ||||
| end | ||||
| function Tester:assertlt (val, condition, message) | ||||
|    self:assert_sub(val<condition,string.format('%s\n%s  val=%s, condition=%s',message,' LT(<) violation ', tostring(val), tostring(condition))) | ||||
| end | ||||
| function Tester:assertgt (val, condition, message) | ||||
|    self:assert_sub(val>condition,string.format('%s\n%s  val=%s, condition=%s',message,' GT(>) violation ', tostring(val), tostring(condition))) | ||||
| end | ||||
| function Tester:assertle (val, condition, message) | ||||
|    self:assert_sub(val<=condition,string.format('%s\n%s  val=%s, condition=%s',message,' LE(<=) violation ', tostring(val), tostring(condition))) | ||||
| end | ||||
| function Tester:assertge (val, condition, message) | ||||
|    self:assert_sub(val>=condition,string.format('%s\n%s  val=%s, condition=%s',message,' GE(>=) violation ', tostring(val), tostring(condition))) | ||||
| end | ||||
| function Tester:asserteq (val, condition, message) | ||||
|    self:assert_sub(val==condition,string.format('%s\n%s  val=%s, condition=%s',message,' EQ(==) violation ', tostring(val), tostring(condition))) | ||||
| end | ||||
| function Tester:assertne (val, condition, message) | ||||
|    self:assert_sub(val~=condition,string.format('%s\n%s  val=%s, condition=%s',message,' NE(~=) violation ', tostring(val), tostring(condition))) | ||||
| end | ||||
| function Tester:assertTensorEq(ta, tb, condition, message) | ||||
|    local diff = ta-tb | ||||
|    local err = diff:abs():maxall() | ||||
|    self:assert_sub(err<condition,string.format('%s\n%s  val=%s, condition=%s',message,' TensorEQ(~=) violation ', tostring(err), tostring(condition))) | ||||
| end | ||||
|  | ||||
| function Tester:pcall(f) | ||||
|    local nerr = #self.errors | ||||
|    local res = f() | ||||
| --    local stat, result = pcall(f) | ||||
| --    if not stat then | ||||
| --       result = result .. debug.traceback() | ||||
| --    end | ||||
| --    return stat, result, stat and (nerr == #self.errors) | ||||
|    return true, res, nerr == #self.errors | ||||
| end | ||||
|  | ||||
| function Tester:report() | ||||
|    print('Completed ' .. #self.tests .. ' tests with ' .. #self.errors .. ' errors') | ||||
|    print() | ||||
|    print(string.rep('-',80)) | ||||
|    for i,v in ipairs(self.errors) do | ||||
|       print(v) | ||||
|       print(string.rep('-',80)) | ||||
|    end | ||||
| end | ||||
|  | ||||
| function Tester:run() | ||||
|    print('Running ' .. #self.tests .. ' tests') | ||||
|    local statstr = string.rep('_',#self.tests) | ||||
|    local pstr = '' | ||||
|    io.write(statstr .. '\r') | ||||
|    for i,v in ipairs(self.tests) do | ||||
|       self.curtestname = self.testnames[i] | ||||
|        | ||||
|       --clear | ||||
|       io.write('\r' .. string.rep(' ', pstr:len())) | ||||
|       io.flush() | ||||
|       --write | ||||
|       pstr = statstr:sub(1,i-1) .. '|' .. statstr:sub(i+1) .. '  ==> ' .. self.curtestname | ||||
|       io.write('\r' .. pstr) | ||||
|       io.flush() | ||||
|        | ||||
|       local stat, message, pass = self:pcall(v) | ||||
|        | ||||
|       if pass then | ||||
| 	 --io.write(string.format('\b_')) | ||||
| 	 statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1) | ||||
|       else | ||||
| 	 statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1) | ||||
| 	 --io.write(string.format('\b*')) | ||||
|       end | ||||
|        | ||||
|       if not stat then | ||||
| 	 print() | ||||
| 	 print('Function call failed: Test No ' .. i .. ' ' .. self.testnames[i]) | ||||
| 	 print(message) | ||||
|       end | ||||
|       collectgarbage() | ||||
|    end | ||||
|    --clear | ||||
|    io.write('\r' .. string.rep(' ', pstr:len())) | ||||
|    io.flush() | ||||
|    -- write finish | ||||
|    pstr = statstr .. '  ==> Done ' | ||||
|    io.write('\r' .. pstr) | ||||
|    io.flush() | ||||
|    print() | ||||
|    print() | ||||
|    self:report() | ||||
| end | ||||
|  | ||||
| function Tester:add(f,name) | ||||
|    name = name or 'unknown' | ||||
|    if type(f) == "table" then | ||||
|       for i,v in pairs(f) do | ||||
| 	 self:add(v,i) | ||||
|       end | ||||
|    elseif type(f) == "function" then | ||||
|       self.tests[#self.tests+1] = f | ||||
|       self.testnames[#self.tests] = name | ||||
|    else | ||||
|       error('Tester:add(f) expects a function or a table of functions') | ||||
|    end | ||||
| end | ||||
							
								
								
									
										157
									
								
								Timer.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								Timer.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,157 @@ | ||||
| #include "general.h" | ||||
|  | ||||
| #ifdef _MSC_VER | ||||
| #include <time.h> | ||||
| #else | ||||
| #include <sys/time.h> | ||||
| #include <sys/resource.h> | ||||
| #endif | ||||
|  | ||||
| #ifdef _MSC_VER | ||||
| static time_t base_time = 0; | ||||
| #endif     | ||||
|  | ||||
| static const void* torch_Timer_id = NULL; | ||||
|  | ||||
| typedef struct _Timer | ||||
| { | ||||
|     int isRunning; | ||||
|  | ||||
|     double totalrealtime; | ||||
|     double totalusertime; | ||||
|     double totalsystime; | ||||
|  | ||||
|     double startrealtime; | ||||
|     double startusertime; | ||||
|     double startsystime; | ||||
|  | ||||
| } Timer; | ||||
|  | ||||
| static double torch_Timer_realtime() | ||||
| { | ||||
|   struct timeval current; | ||||
|   gettimeofday(¤t, NULL); | ||||
|   return (current.tv_sec + current.tv_usec/1000000.0); | ||||
| } | ||||
|  | ||||
| static double torch_Timer_usertime() | ||||
| { | ||||
|   struct rusage current; | ||||
|   getrusage(RUSAGE_SELF, ¤t); | ||||
|   return (current.ru_utime.tv_sec + current.ru_utime.tv_usec/1000000.0); | ||||
| } | ||||
|  | ||||
| static double torch_Timer_systime() | ||||
| { | ||||
|   struct rusage current; | ||||
|   getrusage(RUSAGE_SELF, ¤t); | ||||
|   return (current.ru_stime.tv_sec + current.ru_stime.tv_usec/1000000.0); | ||||
| } | ||||
|  | ||||
| static int torch_Timer_new(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_alloc(L, sizeof(Timer)); | ||||
| #ifdef _MSC_VER | ||||
|   while(!base_time) | ||||
|     time(&base_time); | ||||
| #endif | ||||
|   timer->isRunning = 1; | ||||
|   timer->totalrealtime = 0; | ||||
|   timer->totalusertime = 0; | ||||
|   timer->totalsystime = 0; | ||||
|   timer->startrealtime = torch_Timer_realtime(); | ||||
|   timer->startusertime = torch_Timer_usertime(); | ||||
|   timer->startsystime = torch_Timer_systime(); | ||||
|   luaT_pushudata(L, timer, torch_Timer_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Timer_reset(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); | ||||
|   timer->totalrealtime = 0; | ||||
|   timer->totalusertime = 0; | ||||
|   timer->totalsystime = 0; | ||||
|   timer->startrealtime = torch_Timer_realtime(); | ||||
|   timer->startusertime = torch_Timer_usertime(); | ||||
|   timer->startsystime = torch_Timer_systime(); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Timer_free(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); | ||||
|   luaT_free(L, timer); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_Timer_stop(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); | ||||
|   if(timer->isRunning)   | ||||
|   { | ||||
|     double realtime = torch_Timer_realtime() - timer->startrealtime; | ||||
|     double usertime = torch_Timer_usertime() - timer->startusertime; | ||||
|     double systime = torch_Timer_systime() - timer->startsystime; | ||||
|     timer->totalrealtime += realtime; | ||||
|     timer->totalusertime += usertime; | ||||
|     timer->totalsystime += systime; | ||||
|     timer->isRunning = 0; | ||||
|   } | ||||
|   lua_settop(L, 1); | ||||
|   return 1;   | ||||
| } | ||||
|  | ||||
| static int torch_Timer_resume(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); | ||||
|   if(!timer->isRunning) | ||||
|   { | ||||
|     timer->isRunning = 1; | ||||
|     timer->startrealtime = torch_Timer_realtime(); | ||||
|     timer->startusertime = torch_Timer_usertime(); | ||||
|     timer->startsystime = torch_Timer_systime(); | ||||
|   } | ||||
|   lua_settop(L, 1); | ||||
|   return 1;   | ||||
| } | ||||
|  | ||||
| static int torch_Timer_time(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); | ||||
|   double realtime = (timer->isRunning ? (timer->totalrealtime + torch_Timer_realtime() - timer->startrealtime) : timer->totalrealtime); | ||||
|   double usertime = (timer->isRunning ? (timer->totalusertime + torch_Timer_usertime() - timer->startusertime) : timer->totalusertime); | ||||
|   double systime = (timer->isRunning ? (timer->totalsystime + torch_Timer_systime() - timer->startsystime) : timer->totalsystime); | ||||
|   lua_createtable(L, 0, 3); | ||||
|   lua_pushnumber(L, realtime); | ||||
|   lua_setfield(L, -2, "real"); | ||||
|   lua_pushnumber(L, usertime); | ||||
|   lua_setfield(L, -2, "user"); | ||||
|   lua_pushnumber(L, systime); | ||||
|   lua_setfield(L, -2, "sys"); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Timer___tostring__(lua_State *L) | ||||
| { | ||||
|   Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); | ||||
|   lua_pushfstring(L, "torch.Timer [status: %s]", (timer->isRunning ? "running" : "stopped")); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_Timer__ [] = { | ||||
|   {"reset", torch_Timer_reset}, | ||||
|   {"stop", torch_Timer_stop}, | ||||
|   {"resume", torch_Timer_resume}, | ||||
|   {"time", torch_Timer_time}, | ||||
|   {"__tostring__", torch_Timer___tostring__}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_Timer_init(lua_State *L) | ||||
| { | ||||
|   torch_Timer_id = luaT_newmetatable(L, "torch.Timer", NULL, torch_Timer_new, torch_Timer_free, NULL); | ||||
|   luaL_register(L, NULL, torch_Timer__); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
							
								
								
									
										115
									
								
								dok/cmdline.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								dok/cmdline.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,115 @@ | ||||
| ======  CmdLine ====== | ||||
| {{anchor:torch.CmdLine.dok}} | ||||
|  | ||||
| This class provides a parameter parsing framework which is very | ||||
| usefull when one needs to run several experiments that rely on | ||||
| different parameter settings that are passed in the command line. | ||||
| This class will also override the default print function to direct | ||||
| all the output to a log file as well as screen at the same time. | ||||
|  | ||||
| A sample ''lua'' file is given below that makes use of ''CmdLine'' | ||||
| class. | ||||
|  | ||||
| <file lua> | ||||
|  | ||||
| cmd = torch.CmdLine() | ||||
| cmd:text() | ||||
| cmd:text() | ||||
| cmd:text('Training a simple network') | ||||
| cmd:text() | ||||
| cmd:text('Options') | ||||
| cmd:option('-seed',123,'initial random seed') | ||||
| cmd:option('-booloption',false,'boolean option') | ||||
| cmd:option('-stroption','mystring','string option') | ||||
| cmd:text() | ||||
|  | ||||
| -- parse input params | ||||
| params = cmd:pard(arg) | ||||
|  | ||||
| params.rundir = cmd:string('experiment', params, {dir=true}) | ||||
|  | ||||
| -- create log file | ||||
| cmd:log(params.rundir .. '/log', params) | ||||
|  | ||||
| </file> | ||||
|  | ||||
| When this file is run on the lua commandline as follows | ||||
| <file shell> | ||||
| # lua myscript.lua | ||||
| </file> | ||||
|  | ||||
| It will produce the following output: | ||||
|  | ||||
| <file> | ||||
| [program started on Tue Jan 10 15:33:49 2012] | ||||
| [command line arguments] | ||||
| booloption	false | ||||
| seed	123 | ||||
| rundir	experiment | ||||
| stroption	mystring | ||||
| [----------------------] | ||||
| booloption	false | ||||
| seed	123 | ||||
| rundir	experiment | ||||
| stroption	mystring | ||||
| </file> | ||||
|  | ||||
| The same output will also be written to file | ||||
| ''experiment/log''. Whenever one of the options are passed on the | ||||
| command line and is different than the default value, the ''rundir'' | ||||
| is name is produced to reflect the parameter setting. | ||||
|  | ||||
| <file shell> | ||||
| # lua myscript.lua -seed 456 -stroption mycustomstring | ||||
| </file> | ||||
|  | ||||
| This will produce the following output: | ||||
|  | ||||
| <file> | ||||
| [program started on Tue Jan 10 15:36:55 2012] | ||||
| [command line arguments] | ||||
| booloption	false | ||||
| seed	456 | ||||
| rundir	experiment,seed=456,stroption=mycustomstring | ||||
| stroption	mycustomstring | ||||
| [----------------------] | ||||
| booloption	false | ||||
| seed	456 | ||||
| rundir	experiment,seed=456,stroption=mycustomstring | ||||
| stroption	mycustomstring | ||||
| </file> | ||||
|  | ||||
| and the output will be logged in | ||||
| ''experiment,seed=456,stroption=mycustomstring/log'' | ||||
|  | ||||
| ==== text(string) ==== | ||||
| {{anchor:torch.CmdLine.text}} | ||||
| Logs a custom text message. | ||||
|  | ||||
| ==== option(name, default, help) ==== | ||||
| {{anchor:torch.CmdLine.option}} | ||||
| Stores an option argument. The name should always start with '-'. | ||||
|  | ||||
| ==== [table] parse(arg) ==== | ||||
| {{anchor:torch.CmdLine.parse}} | ||||
| Parses a given table, ''arg'' is by default the argument table that  | ||||
| is created by ''lua'' using the command line arguments passed to the  | ||||
| executable. Returns a table of option values. | ||||
|  | ||||
| ==== [string] string(prefix, params, ignore) ==== | ||||
| {{anchor:torch.CmdLine.string}} | ||||
|  | ||||
| Returns a string representation of the options by concatenating the | ||||
| non-default options. ''ignore'' is a table ''{dir=true}'', which will | ||||
| ensure that option named ''dir'' will be ignored while creating the | ||||
| string representation. | ||||
|  | ||||
| This function is usefull for creating unique experiment directories that | ||||
| depend on the parameter settings. | ||||
|  | ||||
| ==== log(filename, parameter_table) ==== | ||||
| {{anchor:torch.CmdLine.log}} | ||||
|  | ||||
| It set the log filename to ''filename'' and prints the values of | ||||
| parameters in the ''parameter_table''. | ||||
|  | ||||
							
								
								
									
										64
									
								
								dok/diskfile.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								dok/diskfile.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,64 @@ | ||||
| ====== DiskFile ====== | ||||
| {{anchor:torch.DiskFile.dok}} | ||||
|  | ||||
| Parent classes: [[File|File]] | ||||
|  | ||||
| A ''DiskFile'' is a particular ''File'' which is able to perform basic read/write operations | ||||
| on a file stored on disk. It implements all methods described in [[File|File]], and | ||||
| some additional methods relative to //endian// encoding. | ||||
|  | ||||
| By default, a ''DiskFile'' is in [[File#torch.File.binary|ASCII]] mode. If changed to | ||||
| the [[File#torch.File.binary|binary]] mode, the default endian encoding is the native | ||||
| computer one. | ||||
|  | ||||
| The file might be open in read, write, or read-write mode, depending on the parameter | ||||
| ''mode'' (which can take the value ''"r"'', ''"w"'' or ''"rw"'' respectively)  | ||||
| given to the [[#torch.DiskFile|torch.DiskFile(fileName, mode)]]. | ||||
|  | ||||
| =====  torch.DiskFile(fileName, [mode], [quiet]) ===== | ||||
| {{anchor:torch.DiskFile}} | ||||
|  | ||||
| //Constructor// which opens ''fileName'' on disk, using the given ''mode''. Valid ''mode'' are | ||||
| ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). Default is read mode. | ||||
|  | ||||
| If read-write mode, the file //will be created// if it does not exists. If it | ||||
| exists, it will be positionned at the beginning of the file after opening. | ||||
|  | ||||
| If (and only if) ''quiet'' is ''true'', no error will be raised in case of | ||||
| problem opening the file: instead ''nil'' will be returned. | ||||
|  | ||||
| The file is opened in [[File#torch.File.ascii|ASCII]] mode by default. | ||||
|  | ||||
| =====  bigEndianEncoding() ===== | ||||
| {{anchor:torch.DiskFile.bigEndianEncoding}} | ||||
|  | ||||
| In [[file#torch.File.binary|binary]] mode, force encoding in //big endian//.  | ||||
| (//big end first//: decreasing numeric significance with increasing memory | ||||
| addresses) | ||||
|  | ||||
| =====  [boolean] isBigEndianCPU() ===== | ||||
| {{anchor:torch.DiskFile.isBigEndianCPU}} | ||||
|  | ||||
| Returns ''true'' if, and only if, the computer CPU operates in //big endian//. | ||||
| //Big end first//: decreasing numeric significance with increasing | ||||
| memory addresses. | ||||
|  | ||||
| =====  [boolean] isLittleEndianCPU() ===== | ||||
| {{anchor:torch.DiskFile.isLittleEndianCPU}} | ||||
|  | ||||
| Returns ''true'' if, and only if, the computer CPU operates in //little endian//. | ||||
| //Little end first//: increasing numeric significance with increasing | ||||
| memory addresses. | ||||
|  | ||||
| =====  littleEndianEncoding() ===== | ||||
| {{anchor:torch.DiskFile.littleEndianEncoding}} | ||||
|  | ||||
| In [[file#torch.File.binary|binary]] mode, force encoding in //little endian//. | ||||
| (//little end first//: increasing numeric significance with increasing memory | ||||
| addresses) | ||||
|  | ||||
| =====  nativeEndianEncoding() ===== | ||||
| {{anchor:torch.DiskFile.nativeEndianEncoding}} | ||||
|  | ||||
| In [[file#torch.File.binary|binary]] mode, force encoding in //native endian//. | ||||
|  | ||||
							
								
								
									
										333
									
								
								dok/file.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										333
									
								
								dok/file.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,333 @@ | ||||
| ====== File ====== | ||||
| {{anchor:torch.File.dok}} | ||||
|  | ||||
| This is an //abstract// class. It defines most methods implemented by its | ||||
| child classes, like [[DiskFile|DiskFile]], | ||||
| [[MemoryFile|MemoryFile]] and [[PipeFile|PipeFile]]. | ||||
|  | ||||
| Methods defined here are intended for basic read/write functionalities. | ||||
| Read/write methods might write in [[#torch.File.ascii|ASCII]] mode or | ||||
| [[#torch.File.binary|binary]] mode. | ||||
|   | ||||
| In [[#torch.File.ascii|ASCII]] mode, numbers are converted in human readable | ||||
| format (characters). Booleans are converted into ''0'' (false) or ''1'' (true). | ||||
| In [[#torch.File.binary|binary]] mode, numbers and boolean are directly encoded | ||||
| as represented in a register of the computer. While not being human | ||||
| readable and less portable, the binary mode is obviously faster. | ||||
|  | ||||
| In [[#torch.File.ascii|ASCII]] mode, if the default option | ||||
| [[#torch.File.autoSpacing|autoSpacing()]] is chosen, a space will be generated | ||||
| after each written number or boolean. A carriage return will also be added | ||||
| after each call to a write method. With this option, the spaces are | ||||
| supposed to exist while reading. This option can be deactivated with | ||||
| [[#torch.File.noAutoSpacing|noAutoSpacing()]]. | ||||
|  | ||||
| A ''Lua'' error might or might be not generated in case of read/write error | ||||
| or problem in the file. This depends on the choice made between | ||||
| [[#torch.File.quiet|quiet()]] and [[#torch.File.pedantic|pedantic()]] options. It | ||||
| is possible to query if an error occured in the last operation by calling | ||||
| [[#torch.File.hasError|hasError()]]. | ||||
|  | ||||
| =====  Read methods ===== | ||||
| {{anchor:torch.File.read}} | ||||
| {{anchor:torch.File.readBool}} | ||||
| {{anchor:torch.File.readByte}} | ||||
| {{anchor:torch.File.readChar}} | ||||
| {{anchor:torch.File.readShort}} | ||||
| {{anchor:torch.File.readInt}} | ||||
| {{anchor:torch.File.readLong}} | ||||
| {{anchor:torch.File.readFloat}} | ||||
| {{anchor:torch.File.readDouble}} | ||||
|  | ||||
| They are three types of reading methods: | ||||
|   - ''[number] readTYPE()'' | ||||
|   - ''[TYPEStorage] readTYPE(n)'' | ||||
|   - ''[number] readTYPE(TYPEStorage)'' | ||||
|  | ||||
| where ''TYPE'' can be either ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', ''Float'' or ''Double''. | ||||
|  | ||||
| A convenience method also exist for boolean types: ''[boolean] readBool()''. It reads | ||||
| a value on the file with ''readInt()'' and returns ''true'' if and only if this value is ''1''. It is not possible | ||||
| to read storages of booleans. | ||||
|  | ||||
| All these methods depends on the encoding choice: [[#torch.File.ascii|ASCII]] | ||||
| or [[#torch.File.binary|binary]] mode.  In [[#torch.File.ascii|ASCII]] mode, the | ||||
| option [[#torch.File.autoSpacing|autoSpacing()]] and | ||||
| [[#torch.File.noAutoSpacing|noAutoSpacing()]] have also an effect on these | ||||
| methods. | ||||
|  | ||||
| If no parameter is given, one element is returned. This element is | ||||
| converted to a ''Lua'' number when reading. | ||||
|  | ||||
| If ''n'' is given, ''n'' values of the specified type are read | ||||
| and returned in a new [[Storage|Storage]] of that particular type. | ||||
| The storage size corresponds to the number of elements actually read. | ||||
|  | ||||
| If a ''Storage'' is given, the method will attempt to read a number of elements | ||||
| equals to the size of the given storage, and fill up the storage with these elements. | ||||
| The number of elements actually read is returned. | ||||
|  | ||||
| In case of read error, these methods will call the ''Lua'' error function using the default | ||||
| [[#torch.File.pedantic|pedantic]] option, or stay quiet with the [[#torch.File.quiet|quiet]] | ||||
| option. In the latter case, one can check if an error occurred with | ||||
| [[#torch.File.hasError|hasError()]]. | ||||
|  | ||||
| =====  Write methods ===== | ||||
| {{anchor:torch.File.write}} | ||||
| {{anchor:torch.File.writeBool}} | ||||
| {{anchor:torch.File.writeByte}} | ||||
| {{anchor:torch.File.writeChar}} | ||||
| {{anchor:torch.File.writeShort}} | ||||
| {{anchor:torch.File.writeInt}} | ||||
| {{anchor:torch.File.writeLong}} | ||||
| {{anchor:torch.File.writeFloat}} | ||||
| {{anchor:torch.File.writeDouble}} | ||||
|  | ||||
| They are two types of reading methods: | ||||
|   - ''[number] writeTYPE(number)'' | ||||
|   - ''[number] writeTYPE(TYPEStorage)'' | ||||
|  | ||||
| where ''TYPE'' can be either ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', ''Float'' or ''Double''. | ||||
|  | ||||
| A convenience method also exist for boolean types: ''writeBool(value)''. If ''value'' is ''nil'' or | ||||
| not ''true'' a it is equivalent to a ''writeInt(0)'' call, else to ''writeInt(1)''. It is not possible | ||||
| to write storages of booleans. | ||||
|  | ||||
| All these methods depends on the encoding choice: [[#torch.File.ascii|ASCII]] | ||||
| or [[#torch.File.ascii|binary]] mode.  In [[#torch.File.ascii|ASCII]] mode, the | ||||
| option [[#torch.File.autoSpacing|autoSpacing()]] and | ||||
| [[#torch.File.noAutoSpacing|noAutoSpacing()]] have also an effect on these | ||||
| methods. | ||||
|  | ||||
| If one ''Lua'' number is given, this number is converted according to the | ||||
| name of the method when writing (e.g. ''writeInt(3.14)'' will write ''3''). | ||||
|  | ||||
| If a ''Storage'' is given, the method will attempt to write all the elements contained | ||||
| in the storage. | ||||
|  | ||||
| These methods return the number of elements actually written. | ||||
|  | ||||
| In case of read error, these methods will call the ''Lua'' error function using the default | ||||
| [[#torch.File.pedantic|pedantic]] option, or stay quiet with the [[#torch.File.quiet|quiet]] | ||||
| option. In the latter case, one can check if an error occurred with | ||||
| [[#torch.File.hasError|hasError()]]. | ||||
|  | ||||
| =====  Serialization methods ===== | ||||
| {{anchor:torch.File.serialization}} | ||||
|  | ||||
| These methods allow the user to save any serializable objects on disk and | ||||
| reload it later in its original state. In other words, it can perform a | ||||
| //deep// copy of an object into a given ''File''. | ||||
|  | ||||
| Serializable objects are ''Torch'' objects having a ''read()'' and | ||||
| ''write()'' method. ''Lua'' objects such as ''table'', ''number'' or | ||||
| ''string'' or //pure Lua// functions are also serializable. | ||||
|  | ||||
| If the object to save contains several other objects (let say it is a tree | ||||
| of objects), then objects appearing several times in this tree will be | ||||
| //saved only once//. This saves disk space, speedup loading/saving and | ||||
| respect the dependencies between objects. | ||||
|  | ||||
| Interestingly, if the ''File'' is a [[MemoryFile|MemoryFile]], it allows | ||||
| the user to easily make a //clone// of any serializable object: | ||||
| <file lua> | ||||
| file = torch.MemoryFile() -- creates a file in memory | ||||
| file:writeObject(object) -- writes the object into file | ||||
| file:seek(1) -- comes back at the beginning of the file | ||||
| objectClone = file:readObject() -- gets a clone of object | ||||
| </file> | ||||
|  | ||||
| ====  readObject() ==== | ||||
| {{anchor:torch.File.readObject}} | ||||
|  | ||||
| Returns the next [[#torch.File.serialization|serializable]] object saved beforehand | ||||
| in the file with [[#torch.File.writeObject|writeObject()]]. | ||||
|  | ||||
| Note that objects which were [[#torch.File.writeObject|written]] with the same | ||||
| reference have still the same reference after loading. | ||||
|  | ||||
| Example: | ||||
| <file lua> | ||||
| -- creates an array which contains twice the same tensor   | ||||
| array = {} | ||||
| x = torch.Tensor(1) | ||||
| table.insert(array, x) | ||||
| table.insert(array, x) | ||||
|  | ||||
| -- array[1] and array[2] refer to the same address | ||||
| -- x[1] == array[1][1] == array[2][1] == 3.14 | ||||
| array[1][1] = 3.14 | ||||
|  | ||||
| -- write the array on disk | ||||
| file = torch.DiskFile('foo.asc', 'w') | ||||
| file:writeObject(array) | ||||
| file:close() -- make sure the data is written | ||||
|  | ||||
| -- reload the array | ||||
| file = torch.DiskFile('foo.asc', 'r') | ||||
| arrayNew = file:readObject() | ||||
|  | ||||
| -- arrayNew[1] and arrayNew[2] refer to the same address! | ||||
| -- arrayNew[1][1] == arrayNew[2][1] == 3.14 | ||||
| -- so if we do now: | ||||
| arrayNew[1][1] = 2.72 | ||||
| -- arrayNew[1][1] == arrayNew[2][1] == 2.72 ! | ||||
| </file> | ||||
|  | ||||
| ====  writeObject(object) ==== | ||||
| {{anchor:torch.File.writeObject}} | ||||
|  | ||||
| Writes ''object'' into the file. This object can be read later using | ||||
| [[#torch.File.readObject|readObject()]]. Serializable objects are ''Torch'' | ||||
| objects having a ''read()'' and ''write()'' method. ''Lua'' objects such as | ||||
| ''table'', ''number'' or ''string'' or pure Lua functions are also serializable. | ||||
|  | ||||
| If the object has been already written in the file, only a //reference// to | ||||
| this already saved object will be written: this saves space an speed-up | ||||
| writing; it also allows to keep the dependencies between objects intact. | ||||
|  | ||||
| In returns, if one writes an object, modify its member, and write the | ||||
| object again in the same file, the modifications will not be recorded | ||||
| in the file, as only a reference to the original will be written. See | ||||
| [[#torch.File.readObject|readObject()]] for an example. | ||||
|  | ||||
| ====  [string] readString(format) ==== | ||||
| {{anchor:torch.File.readString}} | ||||
|  | ||||
| If ''format'' starts with ''"*l"'' then returns the next line in the ''File''. The end-of-line character is skipped. | ||||
|  | ||||
| If ''format'' starts with ''"*a"'' then returns all the remaining contents of the ''File''. | ||||
|  | ||||
| If no data is available, then an error is raised, except if ''File'' is in [[#torch.File.quiet|quiet()]] mode where | ||||
| it then returns ''nil''. | ||||
|  | ||||
| Because Torch is more precised on number typing, the ''Lua'' format ''"*n"'' is not supported: | ||||
| instead use one of the [[#torch.File.read|number read methods]]. | ||||
|  | ||||
| ====  [number] writeString(str) ==== | ||||
| {{anchor:torch.File.writeString}} | ||||
|  | ||||
| Writes the string ''str'' in the ''File''. If the string cannot be written completely an error is raised, except | ||||
| if ''File'' is in [[#torch.File.quiet|quiet()]] mode where it returns the number of character actually written. | ||||
|  | ||||
| =====  ascii() [default] ===== | ||||
| {{anchor:torch.File.ascii}} | ||||
|  | ||||
| The data read or written will be in ''ASCII'' mode: all numbers are converted | ||||
| to characters (human readable format) and boolean are converted to ''0'' | ||||
| (false) or ''1'' (true). The input-output format in this mode depends on the | ||||
| options [[#torch.File.autoSpacing|autoSpacing()]] and | ||||
| [[#torch.File.noAutoSpacing|noAutoSpacing()]]. | ||||
|  | ||||
| =====  autoSpacing() [default] ===== | ||||
| {{anchor:torch.File.autoSpacing}} | ||||
|  | ||||
| In [[#torch.File.ascii|ASCII]] mode, write additional spaces around the elements | ||||
| written on disk: if writing a [[Storage|Storage]], a space will be | ||||
| generated between each //element// and a //return line// after the last | ||||
| element. If only writing one element, a //return line// will be generated | ||||
| after this element. | ||||
|  | ||||
| Those spaces are supposed to exist while reading in this mode. | ||||
|  | ||||
| This is the default behavior. You can de-activate this option with the | ||||
| [[#torch.File.noAutoSpacing|noAutoSpacing()]] method. | ||||
|  | ||||
| =====  binary() ===== | ||||
| {{anchor:torch.File.binary}} | ||||
|  | ||||
| The data read or written will be in binary mode: the representation in the | ||||
| ''File'' is the same that the one in the computer memory/register (not human | ||||
| readable).  This mode is faster than [[#torch.File.ascii|ASCII]] but less | ||||
| portable. | ||||
|  | ||||
| =====  clearError() ===== | ||||
| {{anchor:torch.File.clearError}} | ||||
|  | ||||
| Clear the error.flag returned by [[#torch.File.hasError|hasError()]]. | ||||
|  | ||||
| =====  close() ===== | ||||
| {{anchor:torch.File.close}} | ||||
|  | ||||
| Close the file. Any subsequent operation will generate a ''Lua'' error. | ||||
|  | ||||
| =====  noAutoSpacing() ===== | ||||
| {{anchor:torch.File.noAutoSpacing}} | ||||
|  | ||||
| In [[#torch.File.ascii|ASCII]] mode, do not put extra spaces between element | ||||
| written on disk. This is the contrary of the option | ||||
| [[#torch.File.autoSpacing|autoSpacing()]]. | ||||
|  | ||||
| =====  synchronize() ===== | ||||
| {{anchor:torch.File.synchronize}} | ||||
|  | ||||
| If the child class bufferize the data while writing, ensure that the data | ||||
| is actually written. | ||||
|  | ||||
|  | ||||
| =====  pedantic() [default] ===== | ||||
| {{anchor:torch.File.pedantic}} | ||||
|  | ||||
| If this mode is chosen (which is the default), a ''Lua'' error will be | ||||
| generated in case of error (which will cause the program to stop). | ||||
|  | ||||
| It is possible to use [[#torch.File.quiet|quiet()]] to avoid ''Lua'' error generation | ||||
| and set a flag instead. | ||||
|  | ||||
| =====  [number] position() ===== | ||||
| {{anchor:torch.File.position}} | ||||
|  | ||||
| Returns the current position (in bytes) in the file. | ||||
| The first position is ''1'' (following Lua standard indexing). | ||||
|  | ||||
| =====  quiet() ===== | ||||
| {{anchor:torch.File.quiet}} | ||||
|  | ||||
| If this mode is chosen instead of [[#torch.File.pedantic|pedantic()]], no ''Lua'' | ||||
| error will be generated in case of read/write error. Instead, a flag will | ||||
| be raised, readable through [[#torch.File.hasError|hasError()]]. This flag can | ||||
| be cleared with [[#torch.File.clearError|clearError()]] | ||||
|  | ||||
| Checking if a file is quiet can be performed using [[#torch.File.isQuiet|isQuiet()]]. | ||||
|  | ||||
| =====  seek(position) ===== | ||||
| {{anchor:torch.File.seek}} | ||||
|  | ||||
| Jump into the file at the given ''position'' (in byte). Might generate/raise | ||||
| an error in case of problem. The first position is ''1'' (following Lua standard indexing). | ||||
|  | ||||
| =====  seekEnd() ===== | ||||
| {{anchor:torch.File.seekEnd}} | ||||
|  | ||||
| Jump at the end of the file. Might generate/raise an error in case of | ||||
| problem. | ||||
|  | ||||
| ===== File state query ===== | ||||
|  | ||||
| These methods allow the user to query the state of the given ''File''. | ||||
|  | ||||
| ====  [boolean] hasError() ==== | ||||
| {{anchor:torch.File.hasError}} | ||||
|  | ||||
| Returns if an error occurred since the last [[#torch.File.clearError|clearError()]] call, or since | ||||
| the opening of the file if ''clearError()'' has never been called. | ||||
|  | ||||
| ====  [boolean] isQuiet() ==== | ||||
| {{anchor:torch.File.isQuiet}} | ||||
|  | ||||
| Returns a boolean which tells if the file is in [[#torch.File.quiet|quiet]] mode or not. | ||||
|  | ||||
| ====  [boolean] isReadable() ==== | ||||
| {{anchor:torch.File.isReadable}} | ||||
|  | ||||
| Tells if one can read the file or not. | ||||
|  | ||||
| ====  [boolean] isWritable() ==== | ||||
| {{anchor:torch.File.isWritable}} | ||||
|  | ||||
| Tells if one can write in the file or not. | ||||
|  | ||||
| ==== [boolean] isAutoSpacing() ==== | ||||
| {{anchor:torch.File.isAutoSpacing}} | ||||
|  | ||||
| Return ''true'' if [[#torch.File.autoSpacing|autoSpacing]] has been chosen. | ||||
							
								
								
									
										39
									
								
								dok/index.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								dok/index.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,39 @@ | ||||
| ====== Torch Package Reference Manual ====== | ||||
| {{anchor:torch.reference.dok}} | ||||
|  | ||||
| The **''torch''** package contains basic classes used everywhere in ''Torch7''. | ||||
|  | ||||
| //Input-output management// is provided with [[File|File]] (abstract class), [[DiskFile|DiskFile]] (file on disk), | ||||
| [[MemoryFile|MemoryFile]] (file in ''RAM'') and [[PipeFile|PipeFile]] (file from a piped command). These | ||||
| classes also handle //serialization//. | ||||
|  | ||||
| [[Storage|Storage]] and [[Tensor|Tensor]] are the basic bricks for //powerful numeric operations//. Tensors support | ||||
| a wide variety of fundamental [[maths|math operations]]. | ||||
|  | ||||
| [[Timer|Timer]] is provided for //measuring time//. | ||||
|  | ||||
| [[Tester|Tester]] is provided as a generic testing framework and it is also used by [[..:nn:index|nn]] package. | ||||
|  | ||||
| [[CmdLine|CmdLine]] is provided as a command line argument parsing utility. | ||||
|  | ||||
| Finally, ''Torch'' provides some [[Utility|utility functions]] for creating and handling ''Torch'' //classes//, | ||||
| as well as support for [[random|random number generation]]. | ||||
|  | ||||
| ===== Torch Packages ===== | ||||
| {{anchor:torch.reference.dok}} | ||||
|  | ||||
|   * File I/O Interface Library | ||||
|     * [[File|File]] is an abstract interface for common file operations. | ||||
|     * [[DiskFile|Disk File]] defines operations on files stored on disk. | ||||
|     * [[MemoryFile|Memory File]] defines operations on stored in RAM. | ||||
|     * [[PipeFile|Pipe File]] defines operations for using piped commands. | ||||
|   * Tensor Library | ||||
|     * [[Storage|Storage]] defines a simple storage interface that controls the underlying storage for any tensor object. | ||||
|     * [[Tensor|Tensor]] defines the //all powerful// tensor object that defines multi-dimensional numerical arrays with type templating. | ||||
|     * [[maths|Mathemetical operations]] are defined for the tensor object types. | ||||
|   * Useful Utilities | ||||
|     * [[Timer|Timer]] provides functionality for //measuring time//. | ||||
|     * [[Tester|Tester]] is a generic tester framework. | ||||
|     * [[CmdLine|CmdLine]] is a command line argument parsing utility. | ||||
|     * [[Random|Random]] defines a random number generator package with various distributions. | ||||
|     * Finally useful [[Utility|utility] functions are provided for easy handling of torch tensor types and class inheritance. | ||||
							
								
								
									
										804
									
								
								dok/maths.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										804
									
								
								dok/maths.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,804 @@ | ||||
| ====== Math Functions ====== | ||||
| {{anchor:torch.maths.dok}} | ||||
|  | ||||
| Torch provides Matlab-like functions for manipulating | ||||
| [[index#Tensor|Tensor]] objects.  Functions fall into several types of | ||||
| categories: | ||||
|   * [[#torch.construction.dok|constructors]] like [[#torch.zeros|zeros]], [[#torch.ones|ones]] | ||||
|   * extractors like  [[#torch.diag|diag]]  and [[#torch.triu|triu]], | ||||
|   * [[#torch.elementwise.dok|element-wise]] operations like [[#torch.abs|abs]] and [[#torch.pow|pow]], | ||||
|   * [[#torch.columnwise.dok|column or row-wise operations]] like [[#torch.sum|sum]] and [[#torch.max|max]], | ||||
|   * [[#torch.matrixwide.dok|matrix-wide operations]] like [[#torch.trace|trace]] and [[#torch.norm|norm]]. | ||||
|   * [[#torch.conv.dok|Convolution and cross-correlation]] operations like [[#torch.conv2|conv2]]. | ||||
|   * [[#torch.linalg.dok|Basic linear algebra operations]] like [[#torch.eig|eigen value/vector calculation]], [[#torch.svd|singular value decomposition (svd)]] and [[#torch.gesv|linear system solution]]. | ||||
|  | ||||
| By default, all operations allocate a new tensor to return the | ||||
| result. However, all functions also support passing the resulting(s) | ||||
| tensor(s) as the first argument(s), in which case the resulting tensor(s) | ||||
| will be resized accordingly and filled with result. | ||||
|  | ||||
| For example, ''torch.conv2'' function can be used in the following manner. | ||||
| <file lua> | ||||
|  | ||||
| x = torch.rand(100,100) | ||||
| k = torch.rand(10,10) | ||||
| res1 = torch.conv2(x,k) | ||||
|  | ||||
| res2 = torch.Tensor() | ||||
| torch.conv2(res2,x,k) | ||||
|  | ||||
| =res2:dist(res1) | ||||
| 0 | ||||
|  | ||||
| </file> | ||||
|  | ||||
| The advantage of second case is, same ''res2'' tensor can be used successively in a loop without any new allocation. | ||||
|  | ||||
| <file lua> | ||||
| -- no new memory allocations... | ||||
| for i=1,100 do | ||||
|     torch.conv2(res2,x,k) | ||||
| end | ||||
| =res2:dist(res1) | ||||
| 0 | ||||
| </file> | ||||
|  | ||||
| ======  Construction or extraction functions ====== | ||||
| {{anchor:torch.construction.dok}} | ||||
|  | ||||
| =====  torch.cat( [res,] x_1, x_2, [dimension] )       ===== | ||||
| {{anchor:torch.cat}} | ||||
|  | ||||
| ''x=torch.cat(x_1,x_2,[dimension])'' returns a tensor ''x'' which is the concatenation of tensors x_1 and x_2 along dimension ''dimension''.  | ||||
|  | ||||
| If ''dimension'' is not specified it is 1. | ||||
|  | ||||
| The other dimensions of x_1 and x_2 have to be equal. | ||||
|  | ||||
| Examples: | ||||
| <file lua> | ||||
| > print(torch.cat(torch.ones(3),torch.zeros(2))) | ||||
|  | ||||
|  1 | ||||
|  1 | ||||
|  1 | ||||
|  0 | ||||
|  0 | ||||
| [torch.Tensor of dimension 5] | ||||
|  | ||||
|  | ||||
| > print(torch.cat(torch.ones(3,2),torch.zeros(2,2))) | ||||
|  | ||||
|  1  1 | ||||
|  1  1 | ||||
|  1  1 | ||||
|  0  0 | ||||
|  0  0 | ||||
| [torch.DoubleTensor of dimension 5x2] | ||||
|  | ||||
|  | ||||
| > print(torch.cat(torch.ones(2,2),torch.zeros(2,2))) | ||||
|  1  1 | ||||
|  1  1 | ||||
|  0  0 | ||||
|  0  0 | ||||
| [torch.DoubleTensor of dimension 4x2] | ||||
|  | ||||
| > print(torch.cat(torch.ones(2,2),torch.zeros(2,2),2)) | ||||
|  1  1  0  0 | ||||
|  1  1  0  0 | ||||
| [torch.DoubleTensor of dimension 2x4] | ||||
|  | ||||
|  | ||||
| > print(torch.cat(torch.cat(torch.ones(2,2),torch.zeros(2,2)),torch.rand(3,2))) | ||||
|  | ||||
|  1.0000  1.0000 | ||||
|  1.0000  1.0000 | ||||
|  0.0000  0.0000 | ||||
|  0.0000  0.0000 | ||||
|  0.3227  0.0493 | ||||
|  0.9161  0.1086 | ||||
|  0.2206  0.7449 | ||||
| [torch.DoubleTensor of dimension 7x2] | ||||
|  | ||||
| </file> | ||||
|  | ||||
|  | ||||
| =====  torch.diag( [res,] x)       ===== | ||||
| {{anchor:torch.diag}} | ||||
|  | ||||
| ''y=torch.diag(x)'' when x is of dimension 1 returns a diagonal matrix with diagonal elements constructed from x. | ||||
|  | ||||
| ''y=torch.diag(x)'' when x is of dimension 2 returns a tensor of dimension 1 | ||||
| with elements constructed from the diagonal of x. | ||||
|  | ||||
| ''y=torch.diag(x,k)'' returns the k-th diagonal of x, | ||||
| wher k=0 is the main diagonal, k>0 is above the main diagonal and k<0  | ||||
| is below the main diagonal. | ||||
|  | ||||
| =====  torch.eye( [res,] n)         ===== | ||||
| {{anchor:torch.eye}} | ||||
|  | ||||
| ''y=torch.eye(n)'' returns the n-by-n identity matrix. | ||||
|  | ||||
| ''y=torch.eye(m,n)'' returns an m-by-n identity matrix with ones on the diagonal and zeros elsewhere. | ||||
|  | ||||
|  | ||||
| =====  torch.linspace( [res,] x1,x2)     ===== | ||||
| {{anchor:torch.linspace}} | ||||
|  | ||||
| ''y=torch.linspace(x1,x2)'' returns a one-dimensional tensor of size 100 equally spaced points between x1 and x2. | ||||
|  | ||||
| ''y=torch.linspace(x1,x2,n)'' returns a one-dimensional tensor of n equally spaced points between x1 and x2. | ||||
|  | ||||
|  | ||||
| =====  torch.logspace( [res,] x1, x2)    ===== | ||||
| {{anchor:torch.logspace}} | ||||
|  | ||||
| ''y=torch.logspace(x1,x2)'' returns a one-dimensional tensor of 50 logarithmically eqally spaced points between x1 and x2. | ||||
|  | ||||
| ''y=torch.logspace(x1,x2,n)'' returns a one-dimensional tensor of n logarithmically equally spaced points between x1 and x2. | ||||
|  | ||||
| =====  torch.ones( [res,] m)  ===== | ||||
| {{anchor:torch.ones}} | ||||
|  | ||||
| ''y=torch.ones(n)'' returns a one-dimensional tensor of size n filled with ones. | ||||
|  | ||||
| ''y=torch.ones(m,n)'' returns a mxn tensor filled with ones. | ||||
|  | ||||
| ''y=torch.ones(m,n,k)'' returns a mxnxk tensor filled with ones. | ||||
|  | ||||
| ''y=torch.ones(d1,...,d_n)'' returns an n-dimensional tensor with sizes d1, ..., d_n filled with ones. | ||||
|  | ||||
| =====  torch.rand( [res,] m [, n, k, ...])        ===== | ||||
| {{anchor:torch.rand}} | ||||
|  | ||||
| ''y=torch.rand(n)'' returns a one-dimensional tensor of size n filled with random numbers from a uniform distribution on the interval (0,1). | ||||
|  | ||||
| ''y=torch.rand(m,n)'' returns a mxn tensor of random numbers from a uniform distribution on the interval (0,1). | ||||
|  | ||||
| =====  torch.randn( [res,] m [, n, k, ...])       ===== | ||||
| {{anchor:torch.randn}} | ||||
|  | ||||
| ''y=torch.randn(n)'' returns a one-dimensional tensor of size n filled with random numbers from a normal distribution with mean zero and variance one. | ||||
|  | ||||
| ''y=torch.randn(m,n)'' returns a mxn tensor of random numbers from a normal distribution with mean zero and variance one. | ||||
|  | ||||
| =====  torch.range([res,] n,m)       ===== | ||||
| {{anchor:torch.range}} | ||||
|  | ||||
| ''y=torch.range(n,m)'' returns a tensor of size m-n+1x1 with integer  | ||||
| values n to m. | ||||
|  | ||||
| <file lua> | ||||
| > print(torch.range(2,5)) | ||||
|  | ||||
|  2 | ||||
|  3 | ||||
|  4 | ||||
|  5 | ||||
| [torch.Tensor of dimension 4] | ||||
| </file> | ||||
|  | ||||
| ''y=torch.range(n,m,incr)'' returns a tensor filled in range n to m with incr increments. | ||||
| <file lua> | ||||
| print(torch.range(2,5,1.2)) | ||||
|  2.0000 | ||||
|  3.2000 | ||||
|  4.4000 | ||||
| [torch.DoubleTensor of dimension 3] | ||||
| </file> | ||||
|  | ||||
| =====  torch.randperm([res,] n)      ===== | ||||
| {{anchor:torch.randperm}} | ||||
|  | ||||
| ''y=torch.randperm(n)'' returns a randomly ordered nx1 tensor of the integers from 1 to n. | ||||
|  | ||||
| =====  torch.reshape([res,] x,m,n)     ===== | ||||
| {{anchor:torch.reshape}} | ||||
|  | ||||
| ''y=torch.reshape(x,m,n)'' returns a new mxn tensor y whose elements | ||||
| are taken rowwise from x, which must have m*n elements. The elements are copied into the new tensor. | ||||
|  | ||||
| =====  torch.tril([res,] x) ===== | ||||
| {{anchor:torch.tril}} | ||||
|  | ||||
| ''y=torch.tril(x)'' returns the lower triangular part of x, the other elements of y are set to 0. | ||||
|  | ||||
| ''torch.tril(x,k)'' returns the elements on and below the k-th diagonal of x as non-zero.   k=0 is the main diagonal, k>0 is above the main diagonal and k<0  | ||||
| is below the main diagonal. | ||||
|  | ||||
| =====  torch.triu([res,] x) ===== | ||||
| {{anchor:torch.triu}} | ||||
|  | ||||
| ''y=torch.triu(x)'' returns the upper triangular part of x, | ||||
| the other elements of y are set to 0. | ||||
|  | ||||
| ''torch.triu(x,k)'' returns the elements on and above the k-th diagonal of x as non-zero.   k=0 is the main diagonal, k>0 is above the main diagonal and k<0  | ||||
| is below the main diagonal. | ||||
|  | ||||
| =====  torch.zeros([res,] x) ===== | ||||
| {{anchor:torch.zeros}} | ||||
|  | ||||
| ''y=torch.zeros(n)'' returns a one-dimensional tensor of size n filled with zeros. | ||||
|  | ||||
| ''y=torch.zeros(m,n)'' returns a mxn tensor filled with zeros. | ||||
|  | ||||
|  | ||||
| ======  Element-wise operations  ====== | ||||
| {{anchor:torch.elementwise.dok}} | ||||
|  | ||||
| =====  torch.abs([res,] x) ===== | ||||
| {{anchor:torch.abs}} | ||||
|  | ||||
| ''y=torch.abs(x)'' returns the absolute values of the elements of x. | ||||
|  | ||||
| =====  torch.acos([res,] x) ===== | ||||
| {{anchor:torch.acos}} | ||||
|  | ||||
| ''y=torch.acos(x)'' returns the arcosine of the elements of x. | ||||
|  | ||||
| =====  torch.asin([res,] x)       ===== | ||||
| {{anchor:torch.asin}} | ||||
|  | ||||
| ''y=torch.asin(x)'' returns the arcsine  of the elements of x. | ||||
|  | ||||
| =====  torch.atan([res,] x)       ===== | ||||
| {{anchor:torch.atan}} | ||||
|  | ||||
| ''y=torch.atan(x)'' returns the arctangent of the elements of x. | ||||
|  | ||||
| =====  torch.ceil([res,] x)       ===== | ||||
| {{anchor:torch.ceil}} | ||||
|  | ||||
| ''y=torch.ceil(x)'' returns the values of the elements of x rounded up to the nearest integers. | ||||
|  | ||||
| =====  torch.cos([res,] x)        ===== | ||||
| {{anchor:torch.cos}} | ||||
|  | ||||
| ''y=torch.cos(x)'' returns the cosine of the elements of x. | ||||
|  | ||||
| =====  torch.cosh([res,] x)       ===== | ||||
| {{anchor:torch.cosh}} | ||||
|  | ||||
| ''y=torch.cosh(x)'' returns the hyberbolic cosine of the elements of x. | ||||
|  | ||||
| =====  torch.exp[res,] (x) ===== | ||||
| {{anchor:torch.exp}} | ||||
|  | ||||
| ''y=torch.exp(x)'' returns, for each element in x,  e (the base of natural logarithms) raised to the power of the element in x. | ||||
|  | ||||
| =====  torch.floor([res,] x) ===== | ||||
| {{anchor:torch.floor}} | ||||
|  | ||||
| ''y=torch.floor(x)'' returns the values of the elements of x rounded down to the nearest integers. | ||||
|  | ||||
| =====  torch.log[res,] (x)         ===== | ||||
| {{anchor:torch.log}} | ||||
|  | ||||
| ''y=torch.log(x)'' returns the natural logarithm of the elements of x. | ||||
|  | ||||
| =====  torch.pow([res,] x)         ===== | ||||
| {{anchor:torch.pow}} | ||||
|  | ||||
| ''y=torch.pow(x,n)'' returns the elements of x to the power of n. | ||||
|  | ||||
| =====  torch.sin([res,] x)         ===== | ||||
| {{anchor:torch.sin}} | ||||
|  | ||||
| ''y=torch.sin(x)'' returns the sine  of the elements of x. | ||||
|  | ||||
| =====  torch.sinh([res,] x)        ===== | ||||
| {{anchor:torch.sinh}} | ||||
|  | ||||
| ''y=torch.sinh(x)'' returns the hyperbolic sine of the elements of x. | ||||
|  | ||||
| =====  torch.sqrt([res,] x) ===== | ||||
| {{anchor:torch.sqrt}} | ||||
|  | ||||
| ''y=torch.sqrt(x)'' returns the square root of the elements of x. | ||||
|  | ||||
| =====  torch.tan([res,] x) ===== | ||||
| {{anchor:torch.tan}} | ||||
|  | ||||
| ''y=torch.abs(x)'' returns the tangent of the elements of x. | ||||
|  | ||||
| =====  torch.tanh([res,] x) ===== | ||||
| {{anchor:torch.tanh}} | ||||
|  | ||||
| ''y=torch.tanh(x)'' returns the hyperbolic tangent of the elements of x. | ||||
|  | ||||
| ======  Column or row-wise operations  (dimension-wise operations) ====== | ||||
| {{anchor:torch.columnwise.dok}} | ||||
|  | ||||
| =====  torch.cross([res,] a,b)      ===== | ||||
| {{anchor:torch.cross}} | ||||
|  | ||||
| ''y=torch.cross(a,b)'' returns the cross product of the tensors a and b. | ||||
| a and b must be 3 element vectors.  | ||||
|  | ||||
| ''y=cross(a,b)'' returns the cross product of a and b along the first dimension of length 3. | ||||
|  | ||||
| ''y=cross(a,b,n)'', where a and b returns the cross | ||||
| product of vectors in dimension n of a and b.  | ||||
| a and b must have the same size,  | ||||
| and both a:size(n) and b:size(n) must be 3. | ||||
|  | ||||
|  | ||||
| =====  torch.cumprod([res,] x)    ===== | ||||
| {{anchor:torch.cumprod}} | ||||
|  | ||||
| ''y=torch.cumprod(x)'' returns the cumulative product of the elements of x, performing the operation over the last dimension. | ||||
|  | ||||
| ''y=torch.cumprod(x,n)'' returns the cumulative product of the elements of x, performing the operation over dimension n. | ||||
|  | ||||
| =====  torch.cumsum([res,] x)     ===== | ||||
| {{anchor:torch.cumsum}} | ||||
|  | ||||
| ''y=torch.cumsum(x)'' returns the cumulative product of the elements of x, performing the operation over the first dimension. | ||||
|  | ||||
| ''y=torch.cumsum(x,n)'' returns the cumulative product of the elements of x, performing the operation over dimension n. | ||||
|  | ||||
| =====  torch.max([resval, resind, ] x) ===== | ||||
| {{anchor:torch.max}} | ||||
|  | ||||
| ''y,i=torch.max(x)'' returns a tensor y of the largest element in  | ||||
| each row of x, and a tensor i of  their corresponding indices in x. | ||||
|  | ||||
| ''y,i=torch.max(x,1)'' performs the max operation for each row and | ||||
| ''y,i=torch.max(x,n)'' performs the max operation over the dimension n. | ||||
|  | ||||
|  | ||||
| =====  torch.mean([res,] x) ===== | ||||
| {{anchor:torch.mean}} | ||||
|  | ||||
| ''y=torch.mean(x)'' returns a tensor y of the mean of the elements in  | ||||
| each row of x. | ||||
|  | ||||
| ''y=torch.mean(x,2)'' performs the mean operation for each row and | ||||
| ''y=torch.mean(x,n)'' performs the mean operation over the dimension n. | ||||
|  | ||||
| =====  torch.min([resval, resind, ] x) ===== | ||||
| {{anchor:torch.min}} | ||||
|  | ||||
| ''y,i=torch.min(x)'' returns a tensor y of the smallest element in  | ||||
| each row of x, and a tensor i of  their corresponding indices in x. | ||||
|  | ||||
| ''y,i=torch.min(x,2)'' performs the min operation for each row and | ||||
| ''y,i=torch.min(x,n)'' performs the min operation over the dimension n. | ||||
|  | ||||
|  | ||||
| =====  torch.prod([res,] x)        ===== | ||||
| {{anchor:torch.prod}} | ||||
|  | ||||
| ''y=torch.prod(x)'' returns a tensor y of the product of the elements in  | ||||
| each row of x.  | ||||
|  | ||||
| ''y=torch.prod(x,2)'' performs the prod operation for each row and | ||||
| ''y=torch.prod(x,n)'' performs the prod operation over the dimension n. | ||||
|  | ||||
| =====  torch.sort([resval, resind, ] x) ===== | ||||
| {{anchor:torch.sort}} | ||||
|  | ||||
| ''y,i=torch.sort(x)'' returns a tensor y of the sorted  | ||||
| rows of x, and a tensor i of the corresponding indices from x. | ||||
|  | ||||
| ''y,i=torch.sort(x,2)'' performs the sort operation for each row and | ||||
| ''y,i=torch.sort(x,n)'' performs the sort operation over the dimension n. | ||||
|  | ||||
| =====  torch.std([res,] x) ===== | ||||
| {{anchor:torch.std}} | ||||
|  | ||||
| ''y=torch.std(x)'' returns a tensor y of the standard deviation of the elements in  | ||||
| each row of x. | ||||
|  | ||||
| ''torch.std(x)'' normalizes by (n-1) where n is the number of elements.  This | ||||
| makes torch.sum(torch.pow(torch.std(x),2))  | ||||
| the best unbiased estimate of the variance if x | ||||
| is a sample from a normal distribution. | ||||
|  | ||||
| ''y=torch.std(x,true)'' performs the std operation normalizing by n instead of n-1. | ||||
|  | ||||
| ''y=torch.std(x,false)'' performs the std operation normalizing by n-1. | ||||
|  | ||||
| ''y=torch.std(x,flag,n)'' performs the std operation over the dimension n. | ||||
|  | ||||
|  | ||||
| =====  torch.sum([res,] x) ===== | ||||
| {{anchor:torch.sum}} | ||||
|  | ||||
| ''y=torch.sum(x)'' returns a tensor y of the sum of the elements in  | ||||
| each row of x. | ||||
|  | ||||
| ''y=torch.sum(x,2)'' performs the sum operation for each row and | ||||
| ''y=torch.sum(x,n)'' performs the sum operation over the dimension n. | ||||
|  | ||||
| =====  torch.var([res,] x) ===== | ||||
| {{anchor:torch.var}} | ||||
|  | ||||
| ''y=torch.var(x)'' returns a tensor y of the standard deviation of the elements in  | ||||
| each row of x. | ||||
|  | ||||
| ''torch.var(x)'' normalizes by (n-1) where n is the number of elements.  This | ||||
| makes torch.sum(torch.var(x))  | ||||
| the best unbiased estimate of the variance if x | ||||
| is a sample from a normal distribution. | ||||
|  | ||||
| ''y=torch.var(x,true)'' performs the var operation normalizing by n instead of n-1. | ||||
|  | ||||
| ''y=torch.var(x,false)'' performs the var operation normalizing by n-1. | ||||
|  | ||||
| ''y=torch.var(x,flag,n)'' performs the var operation over the dimension n. | ||||
|  | ||||
| ======  Matrix-wide operations  (tensor-wide operations) ====== | ||||
| {{anchor:torch.matrixwide.dok}} | ||||
|  | ||||
| =====  torch.norm(x)        ===== | ||||
| {{anchor:torch.norm}} | ||||
|  | ||||
| ''y=torch.norm(x)'' returns the 2-norm of the tensor x.  | ||||
|  | ||||
| ''y=torch.norm(x,p)'' returns the p-norm of the tensor x.  | ||||
|  | ||||
|  | ||||
| =====  torch.dist(x,y)        ===== | ||||
| {{anchor:torch.dist}} | ||||
|  | ||||
| ''y=torch.dist(x,y)'' returns the 2-norm of (x-y).  | ||||
|  | ||||
| ''y=torch.dist(x,y,p)'' returns the p-norm of (x-y).  | ||||
|  | ||||
| =====  torch.numel(x)      ===== | ||||
| {{anchor:torch.numel}} | ||||
|  | ||||
| ''y=torch.numel(x)'' returns the count of the number of elements in the matrix x. | ||||
|  | ||||
| =====  torch.trace(x) ===== | ||||
| {{anchor:torch.trace}} | ||||
|  | ||||
| ''y=torch.trace(x)'' returns the trace (sum of the diagonal elements)  | ||||
| of a matrix x. This is  equal  to the sum of the eigenvalues of x. | ||||
| The returned value ''y'' is a number, not a tensor. | ||||
|  | ||||
| ====== Convolution Operations ====== | ||||
| {{anchor:torch.conv.dok}} | ||||
|  | ||||
| These function implement convolution or cross-correlation of an input | ||||
| image (or set of input images) with a kernel (or set of kernels). The | ||||
| convolution function in Torch can handle different types of | ||||
| input/kernel dimensions and produces corresponding outputs. The | ||||
| general form of operations always remain the same. | ||||
|  | ||||
| ===== torch.conv2([res,] x, k, ['f' or 'v']) ===== | ||||
| {{anchor:torch.conv2}} | ||||
|  | ||||
| This function computes 2 dimensional convolutions between '' x '' and '' k ''. These operations are similar to BLAS operations when number of dimensions of input and kernel are reduced by 2. | ||||
|  | ||||
|   * '' x ''  and '' k '' are 2D : convolution of a single image with a single kernel (2D output). This operation is similar to multiplication of two scalars. | ||||
|   * '' x ''  and '' k '' are 3D : convolution of each input slice with corresponding kernel (3D output). | ||||
|   * '' x (p x m x n) '' 3D, '' k (q x p x ki x kj)'' 4D : convolution of all input slices with the corresponding slice of kernel. Output is 3D '' (q x m x n) ''. This operation is similar to matrix vector product of matrix '' k '' and vector '' x ''. | ||||
|  | ||||
| The last argument controls if the convolution is a full ('f') or valid ('v') convolution. The default is 'valid' convolution. | ||||
|  | ||||
| <file lua> | ||||
| x=torch.rand(100,100) | ||||
| k=torch.rand(10,10) | ||||
| c = torch.conv2(x,k) | ||||
| =c:size() | ||||
|  | ||||
|  91 | ||||
|  91 | ||||
| [torch.LongStorage of size 2] | ||||
|  | ||||
| c = torch.conv2(x,k,'f') | ||||
| =c:size() | ||||
|  | ||||
|  109 | ||||
|  109 | ||||
| [torch.LongStorage of size 2] | ||||
|  | ||||
| </file> | ||||
|  | ||||
| ===== torch.xcorr2([res,] x, k, ['f' or 'v']) ===== | ||||
| {{anchor:torch.xcorr2}} | ||||
|  | ||||
| This function operates with same options and input/output | ||||
| configurations as [[#torch.conv2|torch.conv2]], but performs | ||||
| cross-correlation of the input with the kernel '' k ''. | ||||
|  | ||||
| ===== torch.conv3([res,] x, k, ['f' or 'v']) ===== | ||||
| {{anchor:torch.conv3}} | ||||
|  | ||||
| This function computes 3 dimensional convolutions between '' x '' and '' k ''. These operations are similar to BLAS operations when number of dimensions of input and kernel are reduced by 3. | ||||
|  | ||||
|   * '' x ''  and '' k '' are 3D : convolution of a single image with a single kernel (3D output). This operation is similar to multiplication of two scalars. | ||||
|   * '' x ''  and '' k '' are 4D : convolution of each input slice with corresponding kernel (4D output). | ||||
|   * '' x (p x m x n x o) '' 4D, '' k (q x p x ki x kj x kk)'' 5D : convolution of all input slices with the corresponding slice of kernel. Output is 4D '' (q x m x n x o) ''. This operation is similar to matrix vector product of matrix '' k '' and vector '' x ''. | ||||
|  | ||||
| The last argument controls if the convolution is a full ('f') or valid ('v') convolution. The default is 'valid' convolution. | ||||
|  | ||||
| <file lua> | ||||
| x=torch.rand(100,100,100) | ||||
| k=torch.rand(10,10,10) | ||||
| c = torch.conv3(x,k) | ||||
| =c:size() | ||||
|  | ||||
|  91 | ||||
|  91 | ||||
|  91 | ||||
| [torch.LongStorage of size 3] | ||||
|  | ||||
| c = torch.conv3(x,k,'f') | ||||
| =c:size() | ||||
|  | ||||
|  109 | ||||
|  109 | ||||
|  109 | ||||
| [torch.LongStorage of size 3] | ||||
|  | ||||
| </file> | ||||
|  | ||||
| ===== torch.xcorr3([res,] x, k, ['f' or 'v']) ===== | ||||
| {{anchor:torch.xcorr3}} | ||||
|  | ||||
| This function operates with same options and input/output | ||||
| configurations as [[#torch.conv3|torch.conv3]], but performs | ||||
| cross-correlation of the input with the kernel '' k ''. | ||||
|  | ||||
| ====== Eigenvalues, SVD, Linear System Solution ====== | ||||
| {{anchor:torch.linalg.dok}} | ||||
|  | ||||
| Functions in this section are implemented with an interface to LAPACK | ||||
| libraries. If LAPACK libraries are not found during compilation step, | ||||
| then these functions will not be available. | ||||
|  | ||||
| ===== torch.gesv([resb, resa,] b,a [, true]) ===== | ||||
| {{anchor:torch.gesv}} | ||||
|  | ||||
| Solution of '' AX=B '' and ''A'' has to be square and non-singular. '' | ||||
| A '' is '' m x m '', '' X '' is '' m x k '', '' B '' is '' m x k ''. | ||||
|  | ||||
| If ''resb'' and ''resa'' are given, then they will be used for | ||||
| temporary storage and returning the result. | ||||
|  | ||||
|   * ''resa'' will contain L and U factors for ''LU'' factorization of ''A''. | ||||
|   * ''resb'' will contain the solution. | ||||
|  | ||||
| If ''gesv'' is called with 3 parameters with last parameters ''true'', | ||||
| then ''b'' and ''a'' will destroyed and their output values will be | ||||
| same as ''resa'' and ''resb''. | ||||
|  | ||||
| <file lua> | ||||
| a=torch.Tensor({{6.80, -2.11,  5.66,  5.97,  8.23}, | ||||
|                 {-6.05, -3.30,  5.36, -4.44,  1.08}, | ||||
|                 {-0.45,  2.58, -2.70,  0.27,  9.04}, | ||||
|                 {8.32,  2.71,  4.35,  -7.17,  2.14}, | ||||
|                 {-9.67, -5.14, -7.26,  6.08, -6.87}}):t() | ||||
|  | ||||
| b=torch.Tensor({{4.02,  6.19, -8.22, -7.57, -3.03}, | ||||
|                 {-1.56,  4.00, -8.67,  1.75,  2.86}, | ||||
|                 {9.81, -4.09, -4.57, -8.61,  8.99}}):t() | ||||
|  | ||||
|  =b | ||||
|  4.0200 -1.5600  9.8100 | ||||
|  6.1900  4.0000 -4.0900 | ||||
| -8.2200 -8.6700 -4.5700 | ||||
| -7.5700  1.7500 -8.6100 | ||||
| -3.0300  2.8600  8.9900 | ||||
| [torch.DoubleTensor of dimension 5x3] | ||||
|  | ||||
| =a | ||||
|  6.8000 -6.0500 -0.4500  8.3200 -9.6700 | ||||
| -2.1100 -3.3000  2.5800  2.7100 -5.1400 | ||||
|  5.6600  5.3600 -2.7000  4.3500 -7.2600 | ||||
|  5.9700 -4.4400  0.2700 -7.1700  6.0800 | ||||
|  8.2300  1.0800  9.0400  2.1400 -6.8700 | ||||
| [torch.DoubleTensor of dimension 5x5] | ||||
|  | ||||
|  | ||||
| x=torch.gesv(b,a) | ||||
|  =x | ||||
| -0.8007 -0.3896  0.9555 | ||||
| -0.6952 -0.5544  0.2207 | ||||
|  0.5939  0.8422  1.9006 | ||||
|  1.3217 -0.1038  5.3577 | ||||
|  0.5658  0.1057  4.0406 | ||||
| [torch.DoubleTensor of dimension 5x3] | ||||
|  | ||||
| =b:dist(a*x) | ||||
| 1.1682163181673e-14 | ||||
|  | ||||
| </file> | ||||
|  | ||||
| ===== torch.gels([resb, resa,] b,a) ===== | ||||
| {{anchor:torch.gels}} | ||||
|  | ||||
| Solution of least squares and least norm  problems for a full rank '' A '' that is '' m x n''. | ||||
|   * If '' n %%<=%% m '', then solve '' ||AX-B||_F ''. | ||||
|   * If '' n > m '' , then solve '' min ||X||_F s.t. AX=B ''. | ||||
|  | ||||
| On return, first '' n '' rows of '' X '' matrix contains the solution | ||||
| and the rest contains residual information. Square root of sum squares | ||||
| of elements of each column of '' X '' starting at row '' n + 1 '' is | ||||
| the residual for corresponding column. | ||||
|  | ||||
| <file lua> | ||||
|  | ||||
| a=torch.Tensor({{ 1.44, -9.96, -7.55,  8.34,  7.08, -5.45}, | ||||
|                 {-7.84, -0.28,  3.24,  8.09,  2.52, -5.70}, | ||||
|                 {-4.39, -3.24,  6.27,  5.28,  0.74, -1.19}, | ||||
|                 {4.53,  3.83, -6.64,  2.06, -2.47,  4.70}}):t() | ||||
|  | ||||
| b=torch.Tensor({{8.58,  8.26,  8.48, -5.28,  5.72,  8.93}, | ||||
|                 {9.35, -4.43, -0.70, -0.26, -7.36, -2.52}}):t() | ||||
|  | ||||
| =a | ||||
|  1.4400 -7.8400 -4.3900  4.5300 | ||||
| -9.9600 -0.2800 -3.2400  3.8300 | ||||
| -7.5500  3.2400  6.2700 -6.6400 | ||||
|  8.3400  8.0900  5.2800  2.0600 | ||||
|  7.0800  2.5200  0.7400 -2.4700 | ||||
| -5.4500 -5.7000 -1.1900  4.7000 | ||||
| [torch.DoubleTensor of dimension 6x4] | ||||
|  | ||||
| =b | ||||
|  8.5800  9.3500 | ||||
|  8.2600 -4.4300 | ||||
|  8.4800 -0.7000 | ||||
| -5.2800 -0.2600 | ||||
|  5.7200 -7.3600 | ||||
|  8.9300 -2.5200 | ||||
| [torch.DoubleTensor of dimension 6x2] | ||||
|  | ||||
| x = torch.gels(a,b) | ||||
| =x  | ||||
|  -0.4506   0.2497  | ||||
|  -0.8492  -0.9020 | ||||
|   0.7066   0.6323 | ||||
|   0.1289   0.1351 | ||||
|  13.1193  -7.4922 | ||||
|  -4.8214  -7.1361 | ||||
| [torch.DoubleTensor of dimension 6x2] | ||||
|  | ||||
| =b:dist(a*x:narrow(1,1,4)) | ||||
| 17.390200628863 | ||||
|  | ||||
| =math.sqrt(x:narrow(1,5,2):pow(2):sumall()) | ||||
| 17.390200628863 | ||||
|  | ||||
| </file> | ||||
|  | ||||
| ===== torch.eig([rese, resv,] a, [, 'n' or 'v']) ===== | ||||
| {{anchor:torch.eig}} | ||||
|  | ||||
| Eigen values and eigen vectors of a symmetric real matrix '' A '' of | ||||
| size '' m x m ''. This function calculates all eigenvalues (and | ||||
| vectors) of '' A '' such that '' A = V' diag(e) V ''. Since the input | ||||
| matrix '' A '' is supposed to be symmetric, only upper triangular | ||||
| portion is used. | ||||
|  | ||||
| Last argument defines computation of eigenvectors or eigenvalues | ||||
| only. If '' n '', only eignevalues are computed. If '' v '', both | ||||
| eigenvalues and eigenvectors are computed. | ||||
|  | ||||
| <file lua> | ||||
|  | ||||
| a=torch.Tensor({{ 1.96,  0.00,  0.00,  0.00,  0.00}, | ||||
|                 {-6.49,  3.80,  0.00,  0.00,  0.00}, | ||||
|                 {-0.47, -6.39,  4.17,  0.00,  0.00}, | ||||
| 		{-7.20,  1.50, -1.51,  5.70,  0.00}, | ||||
| 		{-0.65, -6.34,  2.67,  1.80, -7.10}}):t() | ||||
|  | ||||
| =a | ||||
|  1.9600 -6.4900 -0.4700 -7.2000 -0.6500 | ||||
|  0.0000  3.8000 -6.3900  1.5000 -6.3400 | ||||
|  0.0000  0.0000  4.1700 -1.5100  2.6700 | ||||
|  0.0000  0.0000  0.0000  5.7000  1.8000 | ||||
|  0.0000  0.0000  0.0000  0.0000 -7.1000 | ||||
| [torch.DoubleTensor of dimension 5x5] | ||||
|  | ||||
| e = torch.eig(a) | ||||
| =e | ||||
| -11.0656 | ||||
|  -6.2287 | ||||
|   0.8640 | ||||
|   8.8655 | ||||
|  16.0948 | ||||
| [torch.DoubleTensor of dimension 5] | ||||
|  | ||||
| e,v = torch.eig(a,'v') | ||||
| =e | ||||
| -11.0656 | ||||
|  -6.2287 | ||||
|   0.8640 | ||||
|   8.8655 | ||||
|  16.0948 | ||||
| [torch.DoubleTensor of dimension 5] | ||||
|  | ||||
| =v | ||||
| -0.2981 -0.6075  0.4026 -0.3745  0.4896 | ||||
| -0.5078 -0.2880 -0.4066 -0.3572 -0.6053 | ||||
| -0.0816 -0.3843 -0.6600  0.5008  0.3991 | ||||
| -0.0036 -0.4467  0.4553  0.6204 -0.4564 | ||||
| -0.8041  0.4480  0.1725  0.3108  0.1622 | ||||
| [torch.DoubleTensor of dimension 5x5] | ||||
|  | ||||
| =v*torch.diag(e)*v:t() | ||||
|  1.9600 -6.4900 -0.4700 -7.2000 -0.6500 | ||||
| -6.4900  3.8000 -6.3900  1.5000 -6.3400 | ||||
| -0.4700 -6.3900  4.1700 -1.5100  2.6700 | ||||
| -7.2000  1.5000 -1.5100  5.7000  1.8000 | ||||
| -0.6500 -6.3400  2.6700  1.8000 -7.1000 | ||||
| [torch.DoubleTensor of dimension 5x5] | ||||
|  | ||||
| =a:dist(torch.triu(v*torch.diag(e)*v:t())) | ||||
| 1.0219480822443e-14 | ||||
|  | ||||
| </file> | ||||
|  | ||||
| ===== torch.svd([resu, ress, resv] a, [, 's' or 'a']) ===== | ||||
| {{anchor:torch.svd}} | ||||
|  | ||||
| Singular value decomposition of a real matrix '' A '' of size '' n x m | ||||
| '' such that '' A = USV**T ''. The call to ''svd'' returns ''U,S,VT''. | ||||
|  | ||||
| The last argument, if it is string, represents the number of singular | ||||
| values to be computed. 's' stands for 'some' and 'a' stands for 'all'. | ||||
|  | ||||
|  | ||||
| <file lua> | ||||
|  | ||||
| a=torch.Tensor({{8.79,  6.11, -9.15,  9.57, -3.49,  9.84}, | ||||
| 		{9.93,  6.91, -7.93,  1.64,  4.02,  0.15}, | ||||
| 		{9.83,  5.04,  4.86,  8.83,  9.80, -8.99}, | ||||
| 		{5.45, -0.27,  4.85,  0.74, 10.00, -6.02}, | ||||
| 		{3.16,  7.98,  3.01,  5.80,  4.27, -5.31}}):t() | ||||
| =a | ||||
|   8.7900   9.9300   9.8300   5.4500   3.1600 | ||||
|   6.1100   6.9100   5.0400  -0.2700   7.9800 | ||||
|  -9.1500  -7.9300   4.8600   4.8500   3.0100 | ||||
|   9.5700   1.6400   8.8300   0.7400   5.8000 | ||||
|  -3.4900   4.0200   9.8000  10.0000   4.2700 | ||||
|   9.8400   0.1500  -8.9900  -6.0200  -5.3100 | ||||
|  | ||||
| u,s,v = torch.svd(a) | ||||
|  | ||||
| =u | ||||
| -0.5911  0.2632  0.3554  0.3143  0.2299 | ||||
| -0.3976  0.2438 -0.2224 -0.7535 -0.3636 | ||||
| -0.0335 -0.6003 -0.4508  0.2334 -0.3055 | ||||
| -0.4297  0.2362 -0.6859  0.3319  0.1649 | ||||
| -0.4697 -0.3509  0.3874  0.1587 -0.5183 | ||||
|  0.2934  0.5763 -0.0209  0.3791 -0.6526 | ||||
| [torch.DoubleTensor of dimension 6x5] | ||||
|  | ||||
| =s | ||||
|  27.4687 | ||||
|  22.6432 | ||||
|   8.5584 | ||||
|   5.9857 | ||||
|   2.0149 | ||||
| [torch.DoubleTensor of dimension 5] | ||||
|  | ||||
| =v | ||||
| -0.2514 -0.3968 -0.6922 -0.3662 -0.4076 | ||||
|  0.8148  0.3587 -0.2489 -0.3686 -0.0980 | ||||
| -0.2606  0.7008 -0.2208  0.3859 -0.4933 | ||||
|  0.3967 -0.4507  0.2513  0.4342 -0.6227 | ||||
| -0.2180  0.1402  0.5891 -0.6265 -0.4396 | ||||
| [torch.DoubleTensor of dimension 5x5] | ||||
|  | ||||
| =u*torch.diag(s)*v | ||||
|   8.7900   9.9300   9.8300   5.4500   3.1600 | ||||
|   6.1100   6.9100   5.0400  -0.2700   7.9800 | ||||
|  -9.1500  -7.9300   4.8600   4.8500   3.0100 | ||||
|   9.5700   1.6400   8.8300   0.7400   5.8000 | ||||
|  -3.4900   4.0200   9.8000  10.0000   4.2700 | ||||
|   9.8400   0.1500  -8.9900  -6.0200  -5.3100 | ||||
| [torch.DoubleTensor of dimension 6x5] | ||||
|  | ||||
|  =a:dist(u*torch.diag(s)*v) | ||||
| 2.8923773593204e-14 | ||||
|  | ||||
| </file> | ||||
|  | ||||
							
								
								
									
										36
									
								
								dok/memoryfile.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								dok/memoryfile.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | ||||
| ======  MemoryFile ====== | ||||
| {{anchor:torch.MemoryFile.dok}} | ||||
|  | ||||
| Parent classes: [[File|File]] | ||||
|  | ||||
| A ''MemoryFile'' is a particular ''File'' which is able to perform basic | ||||
| read/write operations on a buffer in ''RAM''. It implements all methods | ||||
| described in [[File|File]]. | ||||
|  | ||||
| The data of the this ''File'' is contained into a ''NULL'' terminated | ||||
| [[Storage|CharStorage]]. | ||||
|  | ||||
| =====  torch.MemoryFile([mode]) ===== | ||||
| {{anchor:torch.MemoryFile}} | ||||
|  | ||||
| //Constructor// which returns a new ''MemoryFile'' object using ''mode''. Valid | ||||
| ''mode'' are ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). Default is ''"rw"''. | ||||
|  | ||||
|  | ||||
| =====  torch.MemoryFile(storage, mode) ===== | ||||
| {{anchor:torch.MemoryFile}} | ||||
|  | ||||
| //Constructor// which returns a new ''MemoryFile'' object, using the given | ||||
| [[Storage|storage]] (which must be a ''CharStorage'') and ''mode''. Valid | ||||
| ''mode'' are ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). The last character | ||||
| in this storage //must// be ''NULL'' or an error will be generated. This allow | ||||
| to read existing memory. If used for writing, not that the ''storage'' might | ||||
| be resized by this class if needed.  | ||||
|  | ||||
| =====  [CharStorage] storage() ===== | ||||
| {{anchor:torch.MemoryFile.storage}} | ||||
|  | ||||
| Returns the [[Storage|storage]] which contains all the data of the | ||||
| ''File'' (note: this is //not// a copy, but a //reference// on this storage). The | ||||
| size of the storage is the size of the data in the ''File'', plus one, the | ||||
| last character being ''NULL''. | ||||
							
								
								
									
										21
									
								
								dok/pipefile.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								dok/pipefile.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| ======  PipeFile ====== | ||||
| {{anchor:torch.PipeFile.dok}} | ||||
|  | ||||
| Parent classes: [[DiskFile|DiskFile]] | ||||
|  | ||||
| A ''PipeFile'' is a particular ''File'' which is able to perform basic read/write operations | ||||
| on a command pipe. It implements all methods described in [[DiskFile|DiskFile]] and [[File|File]]. | ||||
|  | ||||
| The file might be open in read or write mode, depending on the parameter | ||||
| ''mode'' (which can take the value ''"r"'' or ''"w"'')  | ||||
| given to the [[#torch.PipeFile|torch.PipeFile(fileName, mode)]]. Read-write mode is not allowed. | ||||
|  | ||||
| =====  torch.PipeFile(command, [mode], [quiet]) ===== | ||||
| {{anchor:torch.PipeFile}} | ||||
|  | ||||
| //Constructor// which execute ''command'' by opening a pipe in read or write | ||||
| ''mode''. Valid ''mode'' are ''"r"'' (read) or ''"w"'' (write). Default is read | ||||
| mode. | ||||
|  | ||||
| If (and only if) ''quiet'' is ''true'', no error will be raised in case of | ||||
| problem opening the file: instead ''nil'' will be returned. | ||||
							
								
								
									
										105
									
								
								dok/random.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								dok/random.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,105 @@ | ||||
| ====== Random Numbers ====== | ||||
| {{anchor:torch.random.dok}} | ||||
|  | ||||
| Torch provides accurate mathematical random generation, based on | ||||
| [[http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html|Mersenne Twister]] | ||||
| random number generator. | ||||
|  | ||||
| ======  Seed Handling ====== | ||||
| {{anchor::torch.seed.dok}} | ||||
|  | ||||
| If no seed is provided to the random generator (using | ||||
| [[#torch.seed|seed()]] or [[#torch.manualSeed|manualSeed()]]), a | ||||
| random seed will be set according to [[#torch.seed|seed()]] the first | ||||
| time a random number is generated. | ||||
|  | ||||
| Initial seed can be obtained using [[#torch.initialSeed|initialSeed()]]. | ||||
|  | ||||
| Setting a particular seed allows the user to (re)-generate a particular serie of | ||||
| random numbers. Example: | ||||
| <file> | ||||
| > torch.manualSeed(123) | ||||
| > = torch.uniform() | ||||
| 0.69646918727085 | ||||
| > return  torch.uniform() | ||||
| 0.71295532141812 | ||||
| > return  torch.uniform() | ||||
| 0.28613933874294 | ||||
| > torch.manualSeed(123) | ||||
| > return  torch.uniform() | ||||
| 0.69646918727085 | ||||
| > return  torch.uniform() | ||||
| 0.71295532141812 | ||||
| > return  torch.uniform() | ||||
| 0.28613933874294 | ||||
| > torch.manualSeed(torch.initialSeed()) | ||||
| > return  torch.uniform() | ||||
| 0.69646918727085 | ||||
| > return  torch.uniform() | ||||
| 0.71295532141812 | ||||
| > return  torch.uniform() | ||||
| 0.28613933874294 | ||||
| </file> | ||||
|  | ||||
| =====  [number] seed() ===== | ||||
| {{anchor:torch.seed}} | ||||
|  | ||||
| Set the seed of the random number generator according to the time of the | ||||
| computer. Granularity is seconds. Returns the seed obtained. | ||||
|  | ||||
| =====  manualSeed(number) ===== | ||||
| {{anchor:torch.manualSeed}} | ||||
|  | ||||
| Set the seed of the random number generator to the given ''number''. | ||||
|  | ||||
| =====  initialSeed() ===== | ||||
| {{anchor:torch.initialSeed}} | ||||
|  | ||||
| Returns the initial seed used to initialize the random generator. | ||||
|  | ||||
| ======  [number] random() ====== | ||||
| {{anchor:torch.random}} | ||||
|  | ||||
| Returns a 32 bit integer random number. | ||||
|  | ||||
| ======  [number] uniform([a],[b]) ====== | ||||
| {{anchor:torch.uniform}} | ||||
|  | ||||
| Returns a random real number according to uniform distribution on [a,b[. By default ''a'' is 0 and ''b'' is 1. | ||||
|  | ||||
| ======  [number] normal([mean],[stdv]) ====== | ||||
| {{anchor:torch.normal}} | ||||
|  | ||||
| Returns a random real number according to a normal distribution with the given ''mean'' and standard deviation ''stdv''. | ||||
| ''stdv'' must be positive. | ||||
|  | ||||
| ======  [number] exponential(lambda) ====== | ||||
| {{anchor:torch.exponential}} | ||||
|  | ||||
| Returns a random real number according to the exponential distribution | ||||
| ''p(x) = lambda * exp(-lambda * x)'' | ||||
|  | ||||
| ======  [number] cauchy(median, sigma) ====== | ||||
| {{anchor:torch.cauchy}} | ||||
|  | ||||
| Returns a random real number according to the Cauchy distribution | ||||
| ''p(x) = sigma/(pi*(sigma^2 + (x-median)^2))'' | ||||
|  | ||||
| ======  [number] logNormal(mean, stdv) ====== | ||||
| {{anchor:torch.logNormal}} | ||||
|  | ||||
| Returns a random real number according to the log-normal distribution, with | ||||
| the given ''mean'' and standard deviation ''stdv''. | ||||
| ''stdv'' must be positive. | ||||
|  | ||||
| ======  [number] geometric(p) ====== | ||||
| {{anchor:torch.geometric}} | ||||
|  | ||||
| Returns a random integer number according to a geometric distribution | ||||
| ''p(i) = (1-p) * p^(i-1)''. ''p'' must satisfy ''0 < p < 1''. | ||||
|  | ||||
| ======  [number] bernouilli([p]) ====== | ||||
| {{anchor:torch.bernoulli}} | ||||
|  | ||||
| Returns ''1'' with probability ''p'' and ''0'' with probability ''1-p''. ''p'' must satisfy ''0 < p < 1''. | ||||
| By default ''p'' is equal to ''0.5''. | ||||
							
								
								
									
										222
									
								
								dok/storage.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								dok/storage.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,222 @@ | ||||
| ======  Storage ====== | ||||
| {{anchor:torch.Storage.dok}} | ||||
| {{anchor:torch.ByteStorage.dok}} | ||||
| {{anchor:torch.CharStorage.dok}} | ||||
| {{anchor:torch.ShortStorage.dok}} | ||||
| {{anchor:torch.IntStorage.dok}} | ||||
| {{anchor:torch.LongStorage.dok}} | ||||
| {{anchor:torch.FloatStorage.dok}} | ||||
| {{anchor:torch.DoubleStorage.dok}} | ||||
|  | ||||
| //Storages// are basically a way for ''Lua'' to access memory of a ''C'' pointer | ||||
| or array. //Storages// can also [[#__torch.StorageMap|map the contents of a file to memory]]. | ||||
| A ''Storage'' is an array of //basic// ''C'' types. For arrays of ''Torch'' objects, | ||||
| use the ''Lua'' tables. | ||||
|  | ||||
| Several ''Storage'' classes for all the basic ''C'' types exist and have the | ||||
| following self-explanatory names: ''ByteStorage'', ''CharStorage'', ''ShortStorage'', | ||||
| ''IntStorage'', ''LongStorage'', ''FloatStorage'', ''DoubleStorage''. | ||||
|  | ||||
| Note that ''ByteStorage'' and ''CharStorage'' represent both arrays of bytes. ''ByteStorage'' represents an array of | ||||
| //unsigned// chars, while ''CharStorage'' represents an array of //signed// chars. | ||||
|  | ||||
| Conversions between two ''Storage'' type might be done using ''copy'': | ||||
| <file lua> | ||||
| x = torch.IntStorage(10):fill(1) | ||||
| y = torch.DoubleStorage(10):copy(x) | ||||
| </file> | ||||
|  | ||||
| [[#torch.Storage|Classical storages]] are [[File#torch.File.serialization|serializable]]. | ||||
| [[#__torch.StorageMap|Storages mapping a file]] are also [[#FileSerialization|serializable]], | ||||
| but //will be saved as a normal storage//. | ||||
|  | ||||
| An alias ''torch.Storage()'' is made over your preferred Storage type, | ||||
| controlled by the | ||||
| [[utility#torch.setdefaulttensortype|torch.setdefaulttensortype]] | ||||
| function. By default, this "points" on ''torch.DoubleStorage''. | ||||
|  | ||||
| =====  torch.TYPEStorage([size]) ===== | ||||
| {{anchor:torch.Storage}} | ||||
|  | ||||
| Returns a new ''Storage'' of type ''TYPE''. Valid ''TYPE'' are ''Byte'', ''Char'', ''Short'', | ||||
| ''Int'', ''Long'', ''Float'', and ''Double''. If ''size'' is given, resize the | ||||
| ''Storage'' accordingly, else create an empty ''Storage''. | ||||
|  | ||||
| Example: | ||||
| <file lua> | ||||
| -- Creates a Storage of 10 double: | ||||
| x = torch.DoubleStorage(10) | ||||
| </file> | ||||
|  | ||||
| The data in the ''Storage'' is //uninitialized//. | ||||
|  | ||||
| ===== torch.TYPEStorage(table) ===== | ||||
| {{anchor:torch.Storage}} | ||||
|  | ||||
| The argument is assumed to be a Lua array of numbers. The constructor returns a new storage of the specified 'TYPE',  | ||||
| of the size of the table, containing all the table elements converted | ||||
|  | ||||
| Example: | ||||
| <file lua> | ||||
| > = torch.IntStorage({1,2,3,4}) | ||||
|  | ||||
|  1 | ||||
|  2 | ||||
|  3 | ||||
|  4 | ||||
| [torch.IntStorage of size 4] | ||||
| </file> | ||||
|  | ||||
| =====  torch.TYPEStorage(filename [, shared]) ===== | ||||
| {{anchor:torch.Storage}} | ||||
| {{anchor:__torch.StorageMap}} | ||||
|  | ||||
| Returns a new kind of ''Storage'' which maps the contents of the given | ||||
| ''filename'' to memory. Valid ''TYPE'' are ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', | ||||
| ''Float'', and ''Double''. If the optional boolean argument ''shared'' is ''true'', | ||||
| the mapped memory is shared amongst all processes on the computer. | ||||
|  | ||||
| When ''shared'' is ''true'', the file must be accessible in read-write mode. Any | ||||
| changes on the storage will be written in the file. The changes might be written | ||||
| only after destruction of the storage. | ||||
|  | ||||
| When ''shared'' is ''false'' (or not provided), the file must be at least | ||||
| readable. Any changes on the storage will not affect the file. Note: | ||||
| changes made on the file after creation of the storage have an unspecified | ||||
| effect on the storage contents. | ||||
|  | ||||
| The [[#torch.Storage.size|size]] of the returned ''Storage'' will be | ||||
| <file lua> | ||||
| (size of file in byte)/(size of TYPE). | ||||
| </file> | ||||
|  | ||||
| Example: | ||||
| <file lua> | ||||
| $ echo "Hello World" > hello.txt | ||||
| $ lua | ||||
| Lua 5.1.3  Copyright (C) 1994-2008 Lua.org, PUC-Rio | ||||
| > require 'torch' | ||||
| > x = torch.CharStorage('hello.txt') | ||||
| > = x | ||||
|   72 | ||||
|  101 | ||||
|  108 | ||||
|  108 | ||||
|  111 | ||||
|   32 | ||||
|   87 | ||||
|  111 | ||||
|  114 | ||||
|  108 | ||||
|  100 | ||||
|   10 | ||||
| [torch.CharStorage of size 12] | ||||
|  | ||||
| > = x:string() | ||||
| Hello World | ||||
|  | ||||
| > = x:fill(42):string() | ||||
| ************ | ||||
| >  | ||||
| $ cat hello.txt  | ||||
| Hello World | ||||
| $ lua | ||||
| Lua 5.1.3  Copyright (C) 1994-2008 Lua.org, PUC-Rio | ||||
| > require 'torch' | ||||
| > x = torch.CharStorage('hello.txt', true) | ||||
| > = x:string() | ||||
| Hello World | ||||
|  | ||||
| > x:fill(42) | ||||
| > | ||||
| $ cat hello.txt  | ||||
| ************ | ||||
| </file> | ||||
|  | ||||
| =====  [number] #self ===== | ||||
| {{anchor:__torch.StorageSharp}} | ||||
|  | ||||
| Returns the number of elements in the storage. Equivalent to [[#torch.Storage.size|size()]]. | ||||
|  | ||||
| =====  [number] self[index] ===== | ||||
| {{anchor:torch.Storage.__index__}} | ||||
|  | ||||
| Returns or set the element at position ''index'' in the storage. Valid range | ||||
| of ''index'' is 1 to [[#torch.Storage.size|size()]]. | ||||
|  | ||||
| Example: | ||||
| <file lua> | ||||
| x = torch.DoubleStorage(10) | ||||
| print(x[5]) | ||||
| </file> | ||||
|  | ||||
| =====  [self] copy(storage) ===== | ||||
| {{anchor:torch.Storage.copy}} | ||||
|  | ||||
| Copy another ''storage''. The types of the two storages might be different: in that case | ||||
| a conversion of types occur (which might result, of course, in loss of precision or rounding). | ||||
| This method returns self, allowing things like: | ||||
| <file lua> | ||||
| x = torch.IntStorage(10):fill(1) | ||||
| y = torch.DoubleStorage(10):copy(x) -- y won't be nil! | ||||
| </file> | ||||
|  | ||||
| =====  [self] fill(value) ===== | ||||
| {{anchor:torch.Storage.fill}} | ||||
|  | ||||
| Fill the ''Storage'' with the given value. This method returns self, allowing things like: | ||||
| <file lua> | ||||
| x = torch.IntStorage(10):fill(0) -- x won't be nil! | ||||
| </file> | ||||
|  | ||||
| =====  [self] resize(size) ===== | ||||
| {{anchor:torch.Storage.resize}} | ||||
|  | ||||
| Resize the storage to the provide ''size''. //The new contents are undertermined//. | ||||
|  | ||||
| This function returns self, allowing things like: | ||||
| <file lua> | ||||
| x = torch.DoubleStorage(10):fill(1) | ||||
| y = torch.DoubleStorage():resize(x:size()):copy(x) -- y won't be nil! | ||||
| </file> | ||||
|  | ||||
| =====  [number] size() ===== | ||||
| {{anchor:torch.Storage.size}} | ||||
|  | ||||
| Returns the number of elements in the storage. Equivalent to [[#__torch.StorageSharp|#]]. | ||||
|  | ||||
| =====  [self] string(str) ===== | ||||
| {{anchor:torch.Storage.string}} | ||||
|  | ||||
| This function is available only on ''ByteStorage'' and ''CharStorage''. | ||||
|  | ||||
| This method resizes the storage to the length of the provided | ||||
| string ''str'', and copy the contents of ''str'' into the storage. The ''NULL'' terminating character is not copied, | ||||
| but ''str'' might contain ''NULL'' characters. The method returns the ''Storage''. | ||||
| <file lua> | ||||
| > x = torch.CharStorage():string("blah blah") | ||||
| > print(x) | ||||
|   98 | ||||
|  108 | ||||
|   97 | ||||
|  104 | ||||
|   32 | ||||
|   98 | ||||
|  108 | ||||
|   97 | ||||
|  104 | ||||
| [torch.CharStorage of size 9] | ||||
| </file> | ||||
|  | ||||
| =====  [string] string() ===== | ||||
| {{anchor:torch.Storage.string}} | ||||
|  | ||||
| This function is available only on ''ByteStorage'' and ''CharStorage''. | ||||
|  | ||||
| The contents of the storage viewed as a string are returned. The string might contain | ||||
| ''NULL'' characters. | ||||
| <file lua> | ||||
| > x = torch.CharStorage():string("blah blah") | ||||
| > print(x:string()) | ||||
| blah blah | ||||
| </file> | ||||
							
								
								
									
										1794
									
								
								dok/tensor.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1794
									
								
								dok/tensor.dok
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										130
									
								
								dok/tester.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								dok/tester.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,130 @@ | ||||
| ======  Tester ====== | ||||
| {{anchor:torch.Tester.dok}} | ||||
|  | ||||
| This class provides a generic unit testing framework. It is already  | ||||
| being used in [[..:nn:index|nn]] package to verify the correctness of classes. | ||||
|  | ||||
| The framework is generally used as follows. | ||||
|  | ||||
| <file lua> | ||||
| mytest = {} | ||||
|  | ||||
| tester = torch.Tester() | ||||
|  | ||||
| function mytest.TestA() | ||||
| 	local a = 10 | ||||
| 	local b = 10 | ||||
| 	tester:asserteq(a,b,'a == b') | ||||
| 	tester:assertne(a,b,'a ~= b') | ||||
| end | ||||
|  | ||||
| function mytest.TestB() | ||||
| 	local a = 10 | ||||
| 	local b = 9 | ||||
| 	tester:assertlt(a,b,'a < b') | ||||
| 	tester:assertgt(a,b,'a > b') | ||||
| end | ||||
|  | ||||
| tester:add(mytest) | ||||
| tester:run() | ||||
|  | ||||
| </file> | ||||
|  | ||||
| Running this code will report 2 errors in 2 test functions. Generally it is  | ||||
| better to put single test cases in each test function unless several very related | ||||
| test cases exit. The error report includes the message and line number of the error. | ||||
|  | ||||
| <file> | ||||
|  | ||||
| Running 2 tests | ||||
| **  ==> Done  | ||||
|  | ||||
| Completed 2 tests with 2 errors | ||||
|  | ||||
| -------------------------------------------------------------------------------- | ||||
| TestB | ||||
| a < b | ||||
|  LT(<) violation   val=10, condition=9 | ||||
| 	...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:23: in function 'assertlt' | ||||
| 	[string "function mytest.TestB()..."]:4: in function 'f' | ||||
|  | ||||
| -------------------------------------------------------------------------------- | ||||
| TestA | ||||
| a ~= b | ||||
|  NE(~=) violation   val=10, condition=10 | ||||
| 	...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:38: in function 'assertne' | ||||
| 	[string "function mytest.TestA()..."]:5: in function 'f' | ||||
|  | ||||
| -------------------------------------------------------------------------------- | ||||
|  | ||||
| </file> | ||||
|  | ||||
|  | ||||
| ==== torch.Tester() ==== | ||||
| {{anchor:torch.Tester}} | ||||
|  | ||||
| Returns a new instance of ''torch.Tester'' class. | ||||
|  | ||||
| ==== add(f, 'name') ==== | ||||
| {{anchor:torch.Tester.add}} | ||||
|  | ||||
| Adds a new test function with name ''name''. The test function is stored in ''f''. | ||||
| The function is supposed to run without any arguments and not return any values. | ||||
|  | ||||
| ==== add(ftable) ==== | ||||
| {{anchor:torch.Tester.add}} | ||||
|  | ||||
| Recursively adds all function entries of the table ''ftable'' as tests. This table  | ||||
| can only have functions or nested tables of functions. | ||||
|  | ||||
| ==== assert(condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assert}} | ||||
|  | ||||
| Saves an error if condition is not true with the optional message. | ||||
|  | ||||
| ==== assertlt(val, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assertlt}} | ||||
|  | ||||
| Saves an error if ''val < condition'' is not true with the optional message. | ||||
|  | ||||
| ==== assertgt(val, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assertgt}} | ||||
|  | ||||
| Saves an error if ''val > condition'' is not true with the optional message. | ||||
|  | ||||
| ==== assertle(val, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assertle}} | ||||
|  | ||||
| Saves an error if ''val <= condition'' is not true with the optional message. | ||||
|  | ||||
| ==== assertge(val, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assertge}} | ||||
|  | ||||
| Saves an error if ''val >= condition'' is not true with the optional message. | ||||
|  | ||||
| ==== asserteq(val, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.asserteq}} | ||||
|  | ||||
| Saves an error if ''val == condition'' is not true with the optional message. | ||||
|  | ||||
| ==== assertne(val, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assertne}} | ||||
|  | ||||
| Saves an error if ''val ~= condition'' is not true with the optional message. | ||||
|  | ||||
| ==== assertTensorEq(ta, tb, condition [, message]) ==== | ||||
| {{anchor:torch.Tester.assertTensorEq}} | ||||
|  | ||||
| Saves an error if ''max(abs(ta-tb)) < condition'' is not true with the optional message. | ||||
|  | ||||
| ==== run() ==== | ||||
| {{anchor:torch.Tester.run}} | ||||
|  | ||||
| Runs all the test functions that are stored using [[#torch.Tester.add|add()]] function.  | ||||
| While running it reports progress and at the end gives a summary of all errors. | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
							
								
								
									
										43
									
								
								dok/timer.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								dok/timer.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| ======  Timer ====== | ||||
| {{anchor:torch.Timer.dok}} | ||||
|  | ||||
| This class is able to measure time (in seconds) elapsed in a particular period. Example: | ||||
| <file lua> | ||||
|   timer = torch.Timer() -- the Timer starts to count now | ||||
|   x = 0 | ||||
|   for i=1,1000000 do | ||||
|     x = x + math.sin(x) | ||||
|   end | ||||
|   print('Time elapsed for 1,000,000 sin: ' .. timer:time().real .. ' seconds') | ||||
| </file> | ||||
|  | ||||
| =====  torch.Timer() ===== | ||||
| {{anchor:torch.Timer}} | ||||
|  | ||||
| Returns a new ''Timer''. The timer starts to count the time now. | ||||
|  | ||||
| =====  [self] reset() ===== | ||||
| {{anchor:torch.Timer.reset}} | ||||
|  | ||||
| Reset the timer accumulated time to ''0''. If the timer was running, the timer | ||||
| restarts to count the time now. If the timer was stopped, it stays stopped. | ||||
|  | ||||
| =====  [self] resume() ===== | ||||
| {{anchor:torch.Timer.resume}} | ||||
|  | ||||
| Resume a stopped timer. The timer restarts to count the time, and addition | ||||
| the accumulated time with the time already counted before being stopped. | ||||
|  | ||||
| =====  [self] stop() ===== | ||||
| {{anchor:torch.Timer.stop}} | ||||
|  | ||||
| Stop the timer. The accumulated time counted until now is stored. | ||||
|  | ||||
| =====   [table] time() ===== | ||||
| {{anchor:torch.Timer.time}} | ||||
|  | ||||
| Returns a table reporting the accumulated time elapsed until now. Following the UNIX shell ''time'' command, | ||||
| there are three fields in the table: | ||||
|   * ''real'': the wall-clock elapsed time. | ||||
|   * ''user'': the elapsed CPU time. Note that the CPU time of a threaded program sums time spent in all threads. | ||||
|   * ''sys'': the time spent in system usage. | ||||
							
								
								
									
										234
									
								
								dok/utility.dok
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								dok/utility.dok
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,234 @@ | ||||
| ======  Torch utility functions ====== | ||||
| {{anchor:torch.utility.dok}} | ||||
|  | ||||
| This functions are used in all Torch package for creating and handling classes. | ||||
| The most interesting function is probably [[#torch.class|torch.class()]] which allows | ||||
| the user to create easily new classes. [[#torch.typename|torch.typename()]] might | ||||
| also be interesting to check what is the class of a given Torch object. | ||||
|  | ||||
| The other functions are more for advanced users. | ||||
|  | ||||
| =====  [metatable] torch.class(name, [parentName]) ===== | ||||
| {{anchor:torch.class}} | ||||
|  | ||||
| Creates a new ''Torch'' class called ''name''. If ''parentName'' is provided, the class will inherit | ||||
| ''parentName'' methods. A class is a table which has a particular metatable. | ||||
|  | ||||
| If ''name'' is of the form ''package.className'' then the class ''className'' will be added to the specified ''package''. | ||||
| In that case, ''package'' has to be a valid (and already loaded) package. If ''name'' does not contain any ''"."'', | ||||
| then the class will be defined in the global environment. | ||||
|  | ||||
| One [or two] (meta)tables are returned. These tables contain all the method | ||||
| provided by the class [and its parent class if it has been provided]. After | ||||
| a call to ''torch.class()'' you have to fill-up properly the metatable. | ||||
|  | ||||
| After the class definition is complete, constructing a new class //name// will be achieved by a call to ''//name//()''. | ||||
| This call will first call the method <file lua>__init()</file> if it exists, passing all arguments of ''//name//()''. | ||||
|  | ||||
| <file lua> | ||||
|  require "torch" | ||||
|  | ||||
|  -- for naming convenience | ||||
|  do | ||||
|    --- creates a class "Foo" | ||||
|    local Foo = torch.class('Foo') | ||||
|   | ||||
|    --- the initializer | ||||
|    function Foo:__init() | ||||
|      self.contents = "this is some text" | ||||
|    end | ||||
|  | ||||
|    --- a method | ||||
|    function Foo:print() | ||||
|      print(self.contents) | ||||
|    end | ||||
|  | ||||
|    --- another one | ||||
|    function Foo:bip() | ||||
|      print('bip') | ||||
|    end | ||||
|  | ||||
|  end | ||||
|  | ||||
|  --- now create an instance of Foo | ||||
|  foo = Foo() | ||||
|  | ||||
|  --- try it out | ||||
|  foo:print() | ||||
|  | ||||
|  --- create a class torch.Bar which | ||||
|  --- inherits from Foo | ||||
|  do | ||||
|    local Bar, parent = torch.class('torch.Bar', 'Foo') | ||||
|  | ||||
|    --- the initializer | ||||
|    function Bar:__init(stuff) | ||||
|      --- call the parent initializer on ourself | ||||
|      parent.__init(self) | ||||
|   | ||||
|      --- do some stuff | ||||
|      self.stuff = stuff | ||||
|    end | ||||
|  | ||||
|    --- a new method | ||||
|    function Bar:boing() | ||||
|      print('boing!') | ||||
|    end | ||||
|  | ||||
|    --- override parent's method | ||||
|    function Bar:print() | ||||
|      print(self.contents) | ||||
|      print(self.stuff) | ||||
|    end | ||||
|  end | ||||
|  | ||||
|  --- create a new instance and use it | ||||
|  bar = torch.Bar("ha ha!") | ||||
|  bar:print() -- overrided method | ||||
|  bar:boing() -- child method | ||||
|  bar:bip()   -- parent's method | ||||
|  | ||||
| </file> | ||||
|  | ||||
| For advanced users, it is worth mentionning that ''torch.class()'' actually | ||||
| calls [[#torch.newmetatable|torch.newmetatable()]].  with a particular | ||||
| constructor. The constructor creates a Lua table and set the right | ||||
| metatable on it, and then calls <file lua>__init()</file> if it exists in the | ||||
| metatable. It also sets a [[#torch.factory|factory]] field <file lua>__factory</file> such that it | ||||
| is possible to create an empty object of this class. | ||||
|  | ||||
| =====  [string] torch.typename(object) ===== | ||||
| {{anchor:torch.typename}} | ||||
|  | ||||
| Checks if ''object'' has a metatable. If it does, and if it corresponds to a | ||||
| ''Torch'' class, then returns a string containing the name of the | ||||
| class. Returns ''nil'' in any other cases. | ||||
|  | ||||
| A Torch class is a class created with [[#torch.class|torch.class()]] or | ||||
| [[#torch.newmetatable|torch.newmetatable()]]. | ||||
|  | ||||
| ===== [userdata] torch.typename2id(string) ===== | ||||
| {{anchor:torch.typename2id}} | ||||
|  | ||||
| Given a Torch class name specified by ''string'', returns a unique | ||||
| corresponding id (defined by a ''lightuserdata'' pointing on the internal | ||||
| structure of the class). This might be useful to do a //fast// check of the | ||||
| class of an object (if used with [[#torch.id|torch.id()]]), avoiding string | ||||
| comparisons. | ||||
|  | ||||
| Returns ''nil'' if ''string'' does not specify a Torch object. | ||||
|  | ||||
| ===== [userdata] torch.id(object) ===== | ||||
| {{anchor:torch.id}} | ||||
|  | ||||
| Returns a unique id corresponding to the //class// of the given Torch object. | ||||
| The id is defined by a ''lightuserdata'' pointing on the internal structure | ||||
| of the class. | ||||
|  | ||||
| Returns ''nil'' if ''object'' is not a Torch object. | ||||
|  | ||||
| This is different from the //object// id returned by [[#torch.pointer|torch.pointer()]]. | ||||
|  | ||||
| =====  [table] torch.newmetatable(name, parentName, constructor) ===== | ||||
| {{anchor:torch.newmetatable}} | ||||
|  | ||||
| Register a new metatable as a Torch type with the given string ''name''. The new metatable is returned. | ||||
|  | ||||
| If the string ''parentName'' is not ''nil'' and is a valid Torch type (previously created | ||||
| by ''torch.newmetatable()'') then set the corresponding metatable as a metatable to the returned new | ||||
| metatable.  | ||||
|  | ||||
| If the given ''constructor'' function is not ''nil'', then assign to the variable ''name'' the given constructor. | ||||
| The given ''name'' might be of the form ''package.className'', in which case the ''className'' will be local to the | ||||
| specified ''package''. In that case, ''package'' must be a valid and already loaded package. | ||||
|  | ||||
| =====  [function] torch.factory(name) ===== | ||||
| {{anchor:torch.factory}} | ||||
|  | ||||
| Returns the factory function of the Torch class ''name''. If the class name is invalid or if the class | ||||
| has no factory, then returns ''nil''. | ||||
|  | ||||
| A Torch class is a class created with [[#torch.class|torch.class()]] or | ||||
| [[#torch.newmetatable|torch.newmetatable()]]. | ||||
|  | ||||
| A factory function is able to return a new (empty) object of its corresponding class. This is helpful for | ||||
| [[File#torch.File.serialization|object serialization]]. | ||||
|  | ||||
| =====  [table] torch.getmetatable(string) ===== | ||||
| {{anchor:torch.getmetatable}} | ||||
|  | ||||
| Given a ''string'', returns a metatable corresponding to the Torch class described | ||||
| by ''string''. Returns ''nil'' if the class does not exist. | ||||
|  | ||||
| A Torch class is a class created with [[#torch.class|torch.class()]] or | ||||
| [[#torch.newmetatable|torch.newmetatable()]]. | ||||
|  | ||||
| Example: | ||||
| <file lua> | ||||
| > for k,v in pairs(torch.getmetatable("torch.CharStorage")) do print(k,v) end | ||||
| __index__       function: 0x1a4ba80 | ||||
| __typename      torch.CharStorage | ||||
| write   function: 0x1a49cc0 | ||||
| __tostring__    function: 0x1a586e0 | ||||
| __newindex__    function: 0x1a4ba40 | ||||
| string  function: 0x1a4d860 | ||||
| __version       1 | ||||
| copy    function: 0x1a49c80 | ||||
| read    function: 0x1a4d840 | ||||
| __len__ function: 0x1a37440 | ||||
| fill    function: 0x1a375c0 | ||||
| resize  function: 0x1a37580 | ||||
| __index table: 0x1a4a080 | ||||
| size    function: 0x1a4ba20 | ||||
| </file> | ||||
|  | ||||
| =====  [boolean] torch.isequal(object1, object2) ===== | ||||
| {{anchor:torch.isequal}} | ||||
|  | ||||
| If the two objects given as arguments are ''Lua'' tables (or Torch objects), then returns ''true'' if and only if the | ||||
| tables (or Torch objects) have the same address in memory. Returns ''false'' in any other cases. | ||||
|  | ||||
| A Torch class is a class created with [[#TorchClass|torch.class()]] or | ||||
| [[#torch.newmetatable|torch.newmetatable()]]. | ||||
|  | ||||
| =====  torch.setenv(function or userdata, table) ===== | ||||
| {{anchor:torch.setenv}} | ||||
|  | ||||
| Assign ''table'' as the Lua environment of the given ''function'' or the given | ||||
| ''userdata''.  To know more about environments, please read the documentation | ||||
| of [[http://www.lua.org/manual/5.1/manual.html#lua_setfenv|lua_setfenv()]] | ||||
| and [[http://www.lua.org/manual/5.1/manual.html#lua_getfenv|lua_getfenv()]]. | ||||
|  | ||||
| =====  [table] torch.getenv(function or userdata) ===== | ||||
| {{anchor:torch.getenv}} | ||||
|  | ||||
| Returns the Lua ''table'' environment of the given ''function'' or the given | ||||
| ''userdata''.  To know more about environments, please read the documentation | ||||
| of [[http://www.lua.org/manual/5.1/manual.html#lua_setfenv|lua_setfenv()]] | ||||
| and [[http://www.lua.org/manual/5.1/manual.html#lua_getfenv|lua_getfenv()]]. | ||||
|  | ||||
| ===== [number] torch.version(object) ===== | ||||
| {{anchor:torch.version}} | ||||
|  | ||||
| Returns the field <file lua>__version</file> of a given object. This might | ||||
| be helpful to handle variations in a class over time. | ||||
|  | ||||
| =====  [number] torch.pointer(object) ===== | ||||
| {{anchor:torch.pointer}} | ||||
|  | ||||
| Returns a unique id (pointer) of the given ''object'', which can be a Torch | ||||
| object, a table, a thread or a function. | ||||
|  | ||||
| This is different from the //class// id returned by [[#torch.id|torch.id()]]. | ||||
|  | ||||
| ===== [object] torch.setmetatable(table, classname) ===== | ||||
| {{anchor:torch.setmetatable}} | ||||
|  | ||||
| Set the metatable of the given ''table'' to the metatable of the Torch object named ''classname''. | ||||
| This function has to be used with a lot of care. | ||||
|  | ||||
| ===== [table] torch.getconstructortable(string) ===== | ||||
| {{anchor:torch.getconstructortable}} | ||||
|  | ||||
| BUGGY | ||||
| Return the constructor table of the Torch class specified by ''string'. | ||||
							
								
								
									
										18
									
								
								general.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								general.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| #ifndef TORCH_GENERAL_INC | ||||
| #define TORCH_GENERAL_INC | ||||
|  | ||||
| #include <stdlib.h> | ||||
| #include <string.h> | ||||
|  | ||||
| #include "luaT.h" | ||||
| #include "TH.h" | ||||
|  | ||||
| #ifdef _MSC_VER | ||||
|  | ||||
| #define snprintf _snprintf | ||||
| #define popen _popen | ||||
| #define pclose _pclose | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										221
									
								
								generic/Storage.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								generic/Storage.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,221 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/Storage.c" | ||||
| #else | ||||
|  | ||||
| static int torch_Storage_(new)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage; | ||||
|   if(lua_type(L, 1) == LUA_TSTRING) | ||||
|   { | ||||
|     const char *fileName = luaL_checkstring(L, 1); | ||||
|     int isShared = luaT_optboolean(L, 2, 0); | ||||
|     storage = THStorage_(newWithMapping)(fileName, isShared);  } | ||||
|   else if(lua_type(L, 1) == LUA_TTABLE) | ||||
|   { | ||||
|     long size = lua_objlen(L, 1); | ||||
|     long i; | ||||
|     storage = THStorage_(newWithSize)(size); | ||||
|     for(i = 1; i <= size; i++) | ||||
|     { | ||||
|       lua_rawgeti(L, 1, i); | ||||
|       if(!lua_isnumber(L, -1)) | ||||
|       { | ||||
|         THStorage_(free)(storage); | ||||
|         luaL_error(L, "element at index %d is not a number", i); | ||||
|       } | ||||
|       THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1)); | ||||
|       lua_pop(L, 1); | ||||
|     } | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     long size = luaL_optlong(L, 1, 0); | ||||
|     storage = THStorage_(newWithSize)(size); | ||||
|   } | ||||
|   luaT_pushudata(L, storage, torch_Storage_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(free)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   THStorage_(free)(storage); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(resize)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   long size = luaL_checklong(L, 2); | ||||
| /*  int keepContent = luaT_optboolean(L, 3, 0); */ | ||||
|   THStorage_(resize)(storage, size);/*, keepContent); */ | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(copy)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   void *src; | ||||
|   if( (src = luaT_toudata(L, 2, torch_Storage_id)) ) | ||||
|     THStorage_(copy)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_ByteStorage_id)) ) | ||||
|     THStorage_(copyByte)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_CharStorage_id)) ) | ||||
|     THStorage_(copyChar)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_ShortStorage_id)) ) | ||||
|     THStorage_(copyShort)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_IntStorage_id)) ) | ||||
|     THStorage_(copyInt)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_LongStorage_id)) ) | ||||
|     THStorage_(copyLong)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_FloatStorage_id)) ) | ||||
|     THStorage_(copyFloat)(storage, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_DoubleStorage_id)) ) | ||||
|     THStorage_(copyDouble)(storage, src); | ||||
|   else | ||||
|     luaL_typerror(L, 2, "torch.*Storage"); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(fill)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   double value = luaL_checknumber(L, 2); | ||||
|   THStorage_(fill)(storage, (real)value); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(__len__)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   lua_pushnumber(L, storage->size); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(__newindex__)(lua_State *L) | ||||
| { | ||||
|   if(lua_isnumber(L, 2)) | ||||
|   { | ||||
|     THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|     long index = luaL_checklong(L, 2) - 1; | ||||
|     double number = luaL_checknumber(L, 3); | ||||
|     THStorage_(set)(storage, index, (real)number); | ||||
|     lua_pushboolean(L, 1); | ||||
|   } | ||||
|   else | ||||
|     lua_pushboolean(L, 0); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(__index__)(lua_State *L) | ||||
| { | ||||
|   if(lua_isnumber(L, 2)) | ||||
|   { | ||||
|     THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|     long index = luaL_checklong(L, 2) - 1; | ||||
|     lua_pushnumber(L, THStorage_(get)(storage, index)); | ||||
|     lua_pushboolean(L, 1); | ||||
|     return 2; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     lua_pushboolean(L, 0); | ||||
|     return 1; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE) | ||||
| static int torch_Storage_(string)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   if(lua_isstring(L, -1)) | ||||
|   { | ||||
|     size_t len = 0; | ||||
|     const char *str = lua_tolstring(L, -1, &len); | ||||
|     THStorage_(resize)(storage, len); | ||||
|     memmove(storage->data, str, len); | ||||
|     lua_settop(L, 1); | ||||
|   } | ||||
|   else | ||||
|     lua_pushlstring(L, (char*)storage->data, storage->size); | ||||
|  | ||||
|   return 1; /* either storage or string */ | ||||
| } | ||||
| #endif | ||||
|  | ||||
| static int torch_Storage_(totable)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   long i; | ||||
|  | ||||
|   lua_newtable(L); | ||||
|   for(i = 0; i < storage->size; i++) | ||||
|   { | ||||
|     lua_pushnumber(L, (lua_Number)storage->data[i]); | ||||
|     lua_rawseti(L, -2, i+1); | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(factory)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = THStorage_(new)(); | ||||
|   luaT_pushudata(L, storage, torch_Storage_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(write)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   THFile *file = luaT_checkudata(L, 2, torch_File_id); | ||||
|   | ||||
|   THFile_writeLongScalar(file, storage->size); | ||||
|   THFile_writeRealRaw(file, storage->data, storage->size); | ||||
|  | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_Storage_(read)(lua_State *L) | ||||
| { | ||||
|   THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); | ||||
|   THFile *file = luaT_checkudata(L, 2, torch_File_id); | ||||
|   long size = THFile_readLongScalar(file); | ||||
|  | ||||
|   THStorage_(resize)(storage, size); | ||||
|   THFile_readRealRaw(file, storage->data, storage->size); | ||||
|  | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_Storage_(_) [] = { | ||||
|   {"size", torch_Storage_(__len__)}, | ||||
|   {"__len__", torch_Storage_(__len__)}, | ||||
|   {"__newindex__", torch_Storage_(__newindex__)}, | ||||
|   {"__index__", torch_Storage_(__index__)}, | ||||
|   {"resize", torch_Storage_(resize)}, | ||||
|   {"fill", torch_Storage_(fill)}, | ||||
|   {"copy", torch_Storage_(copy)}, | ||||
|   {"totable", torch_Storage_(totable)}, | ||||
|   {"write", torch_Storage_(write)}, | ||||
|   {"read", torch_Storage_(read)}, | ||||
| #if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE) | ||||
|   {"string", torch_Storage_(string)}, | ||||
| #endif | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_Storage_(init)(lua_State *L) | ||||
| { | ||||
|   torch_File_id = luaT_checktypename2id(L, "torch.File"); | ||||
|  | ||||
|   torch_Storage_id = luaT_newmetatable(L, STRING_torchStorage, NULL, | ||||
|                                   torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory)); | ||||
|   luaL_register(L, NULL, torch_Storage_(_)); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										939
									
								
								generic/Tensor.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										939
									
								
								generic/Tensor.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,939 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/Tensor.c" | ||||
| #else | ||||
|  | ||||
| static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride, | ||||
|                                                          THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_); | ||||
|  | ||||
| static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_); | ||||
|  | ||||
| static int torch_Tensor_(size)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   if(lua_isnumber(L,2)) | ||||
|   { | ||||
|     int dim = luaL_checkint(L, 2)-1; | ||||
|     luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range"); | ||||
|     lua_pushnumber(L, tensor->size[dim]); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); | ||||
|     memmove(storage->data, tensor->size, sizeof(long)*tensor->nDimension); | ||||
|     luaT_pushudata(L, storage, torch_LongStorage_id); | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(stride)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   if(lua_isnumber(L,2)) | ||||
|   { | ||||
|     int dim = luaL_checkint(L, 2)-1; | ||||
|     luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range"); | ||||
|     lua_pushnumber(L, tensor->stride[dim]); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); | ||||
|     memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension); | ||||
|     luaT_pushudata(L, storage, torch_LongStorage_id); | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(nDimension)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   lua_pushnumber(L, tensor->nDimension); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(storage)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|  | ||||
|   if(tensor->storage) | ||||
|   { | ||||
|     THStorage_(retain)(tensor->storage); | ||||
|     luaT_pushudata(L, tensor->storage, torch_Storage_id); | ||||
|   } | ||||
|   else | ||||
|     lua_pushnil(L); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(storageOffset)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   lua_pushnumber(L, tensor->storageOffset+1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(new)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor; | ||||
|   long storageOffset; | ||||
|   THLongStorage *size, *stride; | ||||
|  | ||||
|   if(lua_type(L, 1) == LUA_TTABLE) | ||||
|   { | ||||
|     long i, j; | ||||
|     THLongStorage *counter; | ||||
|     long si = 0; | ||||
|     int dimension = 0; | ||||
|     int is_finished = 0; | ||||
|  | ||||
|     lua_settop(L, 1); | ||||
|     size = THLongStorage_new(); | ||||
|  | ||||
|     while( (lua_type(L, -1) == LUA_TTABLE) && (lua_objlen(L, -1) > 0) ) | ||||
|     { | ||||
|       THLongStorage_resize(size, dimension+1); | ||||
|       size->data[dimension] = lua_objlen(L, -1); | ||||
|       dimension++; | ||||
|       lua_rawgeti(L, -1, 1); | ||||
|     } | ||||
|     lua_pop(L, 1); | ||||
|              | ||||
|     counter = THLongStorage_newWithSize(size->size); | ||||
|     THLongStorage_fill(counter, 0); | ||||
|  | ||||
|     tensor = THTensor_(newWithSize)(size, NULL); | ||||
|      | ||||
|     if(size->size == 0) | ||||
|       is_finished = 1; | ||||
|      | ||||
|     while(!is_finished) | ||||
|     { | ||||
|       if(!lua_istable(L, -1)) | ||||
|       { | ||||
|         THLongStorage_free(size); | ||||
|         THLongStorage_free(counter); | ||||
|         THTensor_(free)(tensor); | ||||
|         luaL_error(L, "invalid tensor definition"); | ||||
|       } | ||||
|  | ||||
|       if(lua_objlen(L, -1) != size->data[size->size-1]) | ||||
|       { | ||||
|         THLongStorage_free(size); | ||||
|         THLongStorage_free(counter); | ||||
|         THTensor_(free)(tensor); | ||||
|         luaL_error(L, "invalid tensor sizes"); | ||||
|       } | ||||
|  | ||||
|       for(i = 0; i < size->data[size->size-1]; i++) | ||||
|       { | ||||
|         lua_rawgeti(L, -1, i+1); | ||||
|         if(!lua_isnumber(L, -1)) | ||||
|         { | ||||
|           THLongStorage_free(size); | ||||
|           THLongStorage_free(counter); | ||||
|           THTensor_(free)(tensor); | ||||
|           luaL_error(L, "invalid element (not a number)"); | ||||
|         } | ||||
|         THStorage_(set)(THTensor_(storage)(tensor), si++, (real)lua_tonumber(L, -1)); | ||||
|         lua_pop(L, 1); | ||||
|       } | ||||
|      | ||||
|       if(size->size == 1) | ||||
|         break; | ||||
|  | ||||
|       for(i = size->size-2; i >= 0; i--) | ||||
|       { | ||||
|         if(++counter->data[i] == size->data[i]) | ||||
|         { | ||||
|           if(i == 0) | ||||
|           { | ||||
|             is_finished = 1; | ||||
|             break; | ||||
|           } | ||||
|           else | ||||
|           { | ||||
|             counter->data[i] = 0; | ||||
|             lua_pop(L, 1); | ||||
|           } | ||||
|         } | ||||
|         else | ||||
|         { | ||||
|           lua_pop(L, 1); | ||||
|           for(j = i; j < size->size-1; j++) | ||||
|           { | ||||
|             if(!lua_istable(L, -1)) | ||||
|             { | ||||
|               THLongStorage_free(size); | ||||
|               THLongStorage_free(counter); | ||||
|               THTensor_(free)(tensor); | ||||
|               luaL_error(L, "invalid tensor definition"); | ||||
|             } | ||||
|             if(lua_objlen(L, -1) != size->data[j]) | ||||
|             { | ||||
|               THLongStorage_free(size); | ||||
|               THLongStorage_free(counter); | ||||
|               THTensor_(free)(tensor); | ||||
|               luaL_error(L, "invalid tensor sizes"); | ||||
|             } | ||||
|             lua_rawgeti(L, -1, counter->data[j]+1); | ||||
|           } | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     THLongStorage_free(size); | ||||
|     THLongStorage_free(counter); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     THStorage *storage; | ||||
|  | ||||
|     torch_Tensor_(c_readTensorStorageSizeStride)(L, 1, 1, 1, 1, 1, | ||||
|                                                  &storage, &storageOffset, &size, &stride); | ||||
|      | ||||
|     tensor = THTensor_(newWithStorage)(storage, storageOffset, size, stride); | ||||
|  | ||||
|     THLongStorage_free(size); | ||||
|     THLongStorage_free(stride); | ||||
|   } | ||||
|  | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(set)(lua_State *L) | ||||
| { | ||||
|   THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THStorage *storage; | ||||
|   long storageOffset; | ||||
|   THLongStorage *size, *stride; | ||||
|  | ||||
|   torch_Tensor_(c_readTensorStorageSizeStride)(L, 2, 1, 1, 1, 1, | ||||
|                                                &storage, &storageOffset, &size, &stride); | ||||
|  | ||||
|   THTensor_(setStorage)(self, storage, storageOffset, size, stride); | ||||
|  | ||||
|   THLongStorage_free(size); | ||||
|   THLongStorage_free(stride); | ||||
|  | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(clone)(lua_State *L) | ||||
| { | ||||
|   THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   self = THTensor_(newClone)(self); | ||||
|   luaT_pushudata(L, self, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(contiguous)(lua_State *L) | ||||
| { | ||||
|   THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   self = THTensor_(newContiguous)(self); | ||||
|   luaT_pushudata(L, self, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| /* Resize */ | ||||
| static int torch_Tensor_(resizeAs)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id); | ||||
|   THTensor_(resizeAs)(tensor, src); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(resize)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THLongStorage *size, *stride; | ||||
|  | ||||
|   torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride); | ||||
|  | ||||
|   THTensor_(resize)(tensor, size, stride); | ||||
|  | ||||
|   THLongStorage_free(size); | ||||
|   THLongStorage_free(stride); | ||||
|  | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(narrow)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   int dimension = luaL_checkint(L, 2)-1; | ||||
|   long firstIndex = luaL_checklong(L, 3)-1; | ||||
|   long size = luaL_checklong(L, 4); | ||||
|  | ||||
| /*  THArgCheck( (dimension >= 0) && (dimension < tensor->nDimension), 2, "out of range"); | ||||
|   THArgCheck( (firstIndex >= 0) && (firstIndex < tensor->size[dimension]), 3, "out of range"); | ||||
|   THArgCheck( (size > 0) && (firstIndex+size <= tensor->size[dimension]), 4, "out of range"); | ||||
| */ | ||||
|   tensor = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size); | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(sub)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1; | ||||
|  | ||||
|   d0s = luaL_checklong(L, 2)-1; | ||||
|   d0e = luaL_checklong(L, 3)-1; | ||||
|   if(d0s < 0) | ||||
|     d0s += tensor->size[0]+1; | ||||
|   if(d0e < 0) | ||||
|     d0e += tensor->size[0]+1; | ||||
|   luaL_argcheck(L, tensor->nDimension > 0, 2, "invalid dimension"); | ||||
|   luaL_argcheck(L, d0s >= 0 && d0s < tensor->size[0], 2, "out of range"); | ||||
|   luaL_argcheck(L, d0e >= 0 && d0e < tensor->size[0], 3, "out of range"); | ||||
|   luaL_argcheck(L, d0e >= d0s, 3, "end smaller than beginning"); | ||||
|  | ||||
|   if(!lua_isnone(L, 4)) | ||||
|   { | ||||
|     d1s = luaL_checklong(L, 4)-1; | ||||
|     d1e = luaL_checklong(L, 5)-1; | ||||
|     if(d1s < 0) | ||||
|       d1s += tensor->size[1]+1; | ||||
|     if(d1e < 0) | ||||
|       d1e += tensor->size[1]+1; | ||||
|     luaL_argcheck(L, tensor->nDimension > 1, 4, "invalid dimension"); | ||||
|     luaL_argcheck(L, d1s >= 0 && d1s < tensor->size[1], 4, "out of range"); | ||||
|     luaL_argcheck(L, d1e >= 0 && d1e < tensor->size[1], 5, "out of range");     | ||||
|     luaL_argcheck(L, d1e >= d1s, 5, "end smaller than beginning"); | ||||
|  | ||||
|     if(!lua_isnone(L, 6)) | ||||
|     { | ||||
|       d2s = luaL_checklong(L, 6)-1; | ||||
|       d2e = luaL_checklong(L, 7)-1; | ||||
|       if(d2s < 0) | ||||
|         d2s += tensor->size[2]+1; | ||||
|       if(d2e < 0) | ||||
|         d2e += tensor->size[2]+1; | ||||
|       luaL_argcheck(L, tensor->nDimension > 2, 6, "invalid dimension"); | ||||
|       luaL_argcheck(L, d2s >= 0 && d2s < tensor->size[2], 6, "out of range"); | ||||
|       luaL_argcheck(L, d2e >= 0 && d2e < tensor->size[2], 7, "out of range");     | ||||
|       luaL_argcheck(L, d2e >= d2s, 7, "end smaller than beginning"); | ||||
|  | ||||
|       if(!lua_isnone(L, 8)) | ||||
|       { | ||||
|         d3s = luaL_checklong(L, 8)-1; | ||||
|         d3e = luaL_checklong(L, 9)-1; | ||||
|         if(d3s < 0) | ||||
|           d3s += tensor->size[3]+1; | ||||
|         if(d3e < 0) | ||||
|           d3e += tensor->size[3]+1; | ||||
|         luaL_argcheck(L, tensor->nDimension > 3, 8, "invalid dimension"); | ||||
|         luaL_argcheck(L, d3s >= 0 && d3s < tensor->size[3], 8, "out of range"); | ||||
|         luaL_argcheck(L, d3e >= 0 && d3e < tensor->size[3], 9, "out of range");     | ||||
|         luaL_argcheck(L, d3e >= d3s, 9, "end smaller than beginning"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   tensor = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(narrow)(tensor, NULL, 0, d0s, d0e-d0s+1); | ||||
|   if(d1s >= 0) | ||||
|     THTensor_(narrow)(tensor, NULL, 1, d1s, d1e-d1s+1); | ||||
|   if(d2s >= 0) | ||||
|     THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1); | ||||
|   if(d3s >= 0) | ||||
|     THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1); | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(select)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   int dimension = luaL_checkint(L, 2)-1; | ||||
|   long sliceIndex = luaL_checklong(L, 3)-1; | ||||
|  | ||||
| /*   THArgCheck(src->nDimension > 1, 1, "cannot select on a vector"); | ||||
|   THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range"); | ||||
|   THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 3, "out of range"); | ||||
| */ | ||||
|  | ||||
|   if(tensor->nDimension > 1) | ||||
|   { | ||||
|     tensor = THTensor_(newWithTensor)(tensor); | ||||
|     THTensor_(select)(tensor, NULL, dimension, sliceIndex); | ||||
|     luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     THArgCheck(tensor->nDimension == 1, 1, "empty Tensor"); | ||||
|     lua_pushnumber(L, THTensor_(get1d)(tensor, sliceIndex)); | ||||
|   } | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
|  | ||||
| static int torch_Tensor_(transpose)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   int dimension1 = luaL_checkint(L, 2)-1; | ||||
|   int dimension2 = luaL_checkint(L, 3)-1; | ||||
|  | ||||
| /* | ||||
|   THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 2, "out of range"); | ||||
|   THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 3, "out of range"); | ||||
| */ | ||||
|  | ||||
|   tensor = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(transpose)(tensor, NULL, dimension1, dimension2); | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(t)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|  | ||||
|   luaL_argcheck(L, tensor->nDimension == 2, 1, "Tensor must have 2 dimensions"); | ||||
|  | ||||
|   tensor = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(transpose)(tensor, NULL, 0, 1); | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(unfold)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   int dimension = luaL_checkint(L, 2)-1; | ||||
|   long size = luaL_checklong(L, 3); | ||||
|   long step = luaL_checklong(L, 4); | ||||
|  | ||||
| /* | ||||
|   THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor"); | ||||
|   THArgCheck(dimension < src->nDimension, 2, "out of range"); | ||||
|   THArgCheck(size <= src->size[dimension], 3, "out of range"); | ||||
| */ | ||||
|  | ||||
|   tensor = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(unfold)(tensor, NULL, dimension, size, step); | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| /* is contiguous? [a bit like in TnXIterator] */ | ||||
| static int torch_Tensor_(isContiguous)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   lua_pushboolean(L, THTensor_(isContiguous)(tensor)); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(nElement)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   lua_pushnumber(L, THTensor_(nElement)(tensor)); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(copy)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   void *src; | ||||
|   if( (src = luaT_toudata(L, 2, torch_Tensor_id)) ) | ||||
|     THTensor_(copy)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_ByteTensor_id)) ) | ||||
|     THTensor_(copyByte)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_CharTensor_id)) ) | ||||
|     THTensor_(copyChar)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_ShortTensor_id)) ) | ||||
|     THTensor_(copyShort)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_IntTensor_id)) ) | ||||
|     THTensor_(copyInt)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_LongTensor_id)) ) | ||||
|     THTensor_(copyLong)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_FloatTensor_id)) ) | ||||
|     THTensor_(copyFloat)(tensor, src); | ||||
|   else if( (src = luaT_toudata(L, 2, torch_DoubleTensor_id)) ) | ||||
|     THTensor_(copyDouble)(tensor, src); | ||||
|   else | ||||
|     luaL_typerror(L, 2, "torch.*Tensor"); | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(__newindex__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THLongStorage *idx = NULL; | ||||
|  | ||||
|   if(lua_isnumber(L, 2)) | ||||
|   { | ||||
|     long index = luaL_checklong(L,2)-1; | ||||
|     void *src; | ||||
|     if (lua_isnumber(L,3)) { | ||||
|       real value = (real)luaL_checknumber(L,3); | ||||
|       luaL_argcheck(L, tensor->nDimension == 1, 1, "must be a one dimensional tensor"); | ||||
|       luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range"); | ||||
|       THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copy)(tensor, src); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copyByte)(tensor, src); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copyChar)(tensor, src); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copyShort)(tensor, src); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copyInt)(tensor, src); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copyLong)(tensor, src); | ||||
|     } else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(narrow)(tensor, NULL, 0, index, 1); | ||||
|       THTensor_(copyFloat)(tensor, src); | ||||
|     } else { | ||||
|       luaL_typerror(L, 3, "torch.*Tensor"); | ||||
|     } | ||||
|     lua_pushboolean(L, 1); | ||||
|   } | ||||
|   else if((idx = luaT_toudata(L, 2, torch_LongStorage_id))) | ||||
|   { | ||||
|     long index = THTensor_(storageOffset)(tensor); | ||||
|     real value = (real)luaL_checknumber(L,3); | ||||
|     int dim; | ||||
|  | ||||
|     luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size"); | ||||
|  | ||||
|     for(dim = 0; dim < idx->size; dim++) | ||||
|     { | ||||
|       long z = idx->data[dim]-1; | ||||
|       luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); | ||||
|       index += z*tensor->stride[dim]; | ||||
|     } | ||||
|  | ||||
|     THStorage_(set)(tensor->storage, index, value); | ||||
|     lua_pushboolean(L, 1); | ||||
|   } | ||||
|   else if(lua_istable(L, 2)) | ||||
|   { | ||||
|     long index = THTensor_(storageOffset)(tensor); | ||||
|     real value = (real)luaL_checknumber(L,3); | ||||
|     int dim; | ||||
|  | ||||
|     luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size"); | ||||
|  | ||||
|     for(dim = 0; dim < tensor->nDimension; dim++) | ||||
|     { | ||||
|       long z; | ||||
|  | ||||
|       lua_rawgeti(L, 2, dim+1); | ||||
|       if(!lua_isnumber(L, -1)) | ||||
|         luaL_error(L, "number expected for each dimension"); | ||||
|  | ||||
|       z = lua_tonumber(L, -1)-1; | ||||
|       lua_pop(L, 1); | ||||
|  | ||||
|       luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); | ||||
|       index += z*tensor->stride[dim]; | ||||
|     } | ||||
|     THStorage_(set)(tensor->storage, index, value); | ||||
|     lua_pushboolean(L, 1); | ||||
|   } | ||||
|   else | ||||
|     lua_pushboolean(L, 0); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(__index__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THLongStorage *idx = NULL; | ||||
|  | ||||
|   if(lua_isnumber(L, 2)) | ||||
|   { | ||||
|     long index = luaL_checklong(L,2)-1; | ||||
|      | ||||
|     luaL_argcheck(L, tensor->nDimension > 0, 1, "empty tensor"); | ||||
|     luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range"); | ||||
|  | ||||
|     if(tensor->nDimension == 1) | ||||
|     { | ||||
|       lua_pushnumber(L, THStorage_(get)(tensor->storage, tensor->storageOffset+index*tensor->stride[0])); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       tensor = THTensor_(newWithTensor)(tensor); | ||||
|       THTensor_(select)(tensor, NULL, 0, index); | ||||
|       luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|     } | ||||
|     lua_pushboolean(L, 1); | ||||
|     return 2; | ||||
|   } | ||||
|   else if((idx = luaT_toudata(L, 2, torch_LongStorage_id))) | ||||
|   { | ||||
|     long index = THTensor_(storageOffset)(tensor); | ||||
|     int dim; | ||||
|  | ||||
|     luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size"); | ||||
|      | ||||
|     for(dim = 0; dim < idx->size; dim++) | ||||
|     { | ||||
|       long z = idx->data[dim]-1; | ||||
|       luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); | ||||
|       index += z*tensor->stride[dim]; | ||||
|     } | ||||
|     lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index)); | ||||
|     lua_pushboolean(L, 1); | ||||
|     return 2; | ||||
|   } | ||||
|   else if(lua_istable(L, 2)) | ||||
|   { | ||||
|     long index = THTensor_(storageOffset)(tensor); | ||||
|     int dim; | ||||
|  | ||||
|     luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size"); | ||||
|      | ||||
|     for(dim = 0; dim < tensor->nDimension; dim++) | ||||
|     { | ||||
|       long z; | ||||
|  | ||||
|       lua_rawgeti(L, 2, dim+1); | ||||
|       if(!lua_isnumber(L, -1)) | ||||
|         luaL_error(L, "number expected for each dimension"); | ||||
|  | ||||
|       z = lua_tonumber(L, -1)-1; | ||||
|       lua_pop(L, 1); | ||||
|  | ||||
|       luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); | ||||
|       index += z*tensor->stride[dim]; | ||||
|     } | ||||
|     lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index)); | ||||
|     lua_pushboolean(L, 1); | ||||
|     return 2; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     lua_pushboolean(L, 0); | ||||
|     return 1; | ||||
|   } | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(free)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THTensor_(free)(tensor); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| /* helpful functions */ | ||||
| static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_) | ||||
| { | ||||
|   THLongStorage *size = NULL; | ||||
|   THLongStorage *stride = NULL; | ||||
|    | ||||
|   if( (size = luaT_toudata(L, index, torch_LongStorage_id)) ) | ||||
|   { | ||||
|     if(!lua_isnoneornil(L, index+1)) | ||||
|     { | ||||
|       if( (stride = luaT_toudata(L, index+1, torch_LongStorage_id)) ) | ||||
|         luaL_argcheck(L, stride->size == size->size, index+1, "provided stride and size are inconsistent"); | ||||
|       else | ||||
|         luaL_argcheck(L, 0, index+1, "torch.LongStorage expected"); | ||||
|     } | ||||
|     THLongStorage_retain(size); | ||||
|     if(stride) | ||||
|       THLongStorage_retain(stride); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     int i; | ||||
|  | ||||
|     size = THLongStorage_newWithSize(8); | ||||
|     stride = THLongStorage_newWithSize(8); | ||||
|     THLongStorage_fill(size, -1); | ||||
|     THLongStorage_fill(stride, -1); | ||||
|  | ||||
|     if(allowStride) | ||||
|     { | ||||
|       for(i = 0; i < 8; i++) | ||||
|       { | ||||
|         if(lua_isnone(L, index+2*i)) | ||||
|           break; | ||||
|         size->data[i] = luaL_checklong(L, index+2*i); | ||||
|          | ||||
|         if(lua_isnone(L, index+2*i+1)) | ||||
|           break; | ||||
|         stride->data[i] = luaL_checklong(L, index+2*i+1); | ||||
|       } | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       for(i = 0; i < 8; i++) | ||||
|       { | ||||
|         if(lua_isnone(L, index+i)) | ||||
|           break; | ||||
|         size->data[i] = luaL_checklong(L, index+i); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   *size_ = size; | ||||
|   *stride_ = stride; | ||||
| } | ||||
|  | ||||
| static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride, | ||||
|                                                          THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_) | ||||
| { | ||||
|   static char errMsg[64]; | ||||
|   THTensor *src = NULL; | ||||
|   THStorage *storage = NULL; | ||||
|  | ||||
|   int arg1Type = lua_type(L, index); | ||||
|  | ||||
|   if( allowNone && (arg1Type == LUA_TNONE) ) | ||||
|   { | ||||
|     *storage_ = NULL; | ||||
|     *storageOffset_ = 0; | ||||
|     *size_ = NULL; | ||||
|     *stride_ = NULL; | ||||
|     return; | ||||
|   } | ||||
|   else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor_id)) ) | ||||
|   { | ||||
|     *storage_ = src->storage; | ||||
|     *storageOffset_ = src->storageOffset; | ||||
|     *size_ = THTensor_(newSizeOf)(src); | ||||
|     *stride_ = THTensor_(newStrideOf)(src); | ||||
|     return; | ||||
|   } | ||||
|   else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage_id)) ) | ||||
|   { | ||||
|     *storage_ = storage; | ||||
|     if(lua_isnone(L, index+1)) | ||||
|     { | ||||
|       *storageOffset_ = 0; | ||||
|       *size_ = THLongStorage_newWithSize1(storage->size); | ||||
|       *stride_ = THLongStorage_newWithSize1(1); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       *storageOffset_ = luaL_checklong(L, index+1)-1; | ||||
|       torch_Tensor_(c_readSizeStride)(L, index+2, allowStride, size_, stride_); | ||||
|     } | ||||
|     return; | ||||
|   } | ||||
|   else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, torch_LongStorage_id)) ) | ||||
|   { | ||||
|     *storage_ = NULL; | ||||
|     *storageOffset_ = 0; | ||||
|     torch_Tensor_(c_readSizeStride)(L, index, 0, size_, stride_); | ||||
|  | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   *storage_ = NULL; | ||||
|   *storageOffset_ = 0; | ||||
|  | ||||
|   sprintf(errMsg, "expecting number%s%s", (allowTensor ? " or Tensor" : ""), (allowStorage ? " or Storage" : "")); | ||||
|   luaL_argcheck(L, 0, index, errMsg); | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(apply)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   luaL_checktype(L, 2, LUA_TFUNCTION); | ||||
|   lua_settop(L, 2); | ||||
|  | ||||
|   TH_TENSOR_APPLY(real, tensor, | ||||
|                   lua_pushvalue(L, 2); | ||||
|                   lua_pushnumber(L, *tensor_data); | ||||
|                   lua_call(L, 1, 1); | ||||
|                   if(lua_isnumber(L, 3)) | ||||
|                   { | ||||
|                     *tensor_data = (real)lua_tonumber(L, 3); | ||||
|                     lua_pop(L, 1); | ||||
|                   } | ||||
|                   else if(lua_isnil(L, 3)) | ||||
|                     lua_pop(L, 1); | ||||
|                   else | ||||
|                     luaL_error(L, "given function should return a number or nil");); | ||||
|  | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(map)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id); | ||||
|   luaL_checktype(L, 3, LUA_TFUNCTION); | ||||
|   lua_settop(L, 3); | ||||
|  | ||||
|   TH_TENSOR_APPLY2(real, tensor, real, src, | ||||
|                   lua_pushvalue(L, 3); | ||||
|                   lua_pushnumber(L, *tensor_data); | ||||
|                   lua_pushnumber(L, *src_data); | ||||
|                   lua_call(L, 2, 1); | ||||
|                   if(lua_isnumber(L, 4)) | ||||
|                   { | ||||
|                     *tensor_data = (real)lua_tonumber(L, 4); | ||||
|                     lua_pop(L, 1); | ||||
|                   } | ||||
|                   else if(lua_isnil(L, 4)) | ||||
|                     lua_pop(L, 1); | ||||
|                   else | ||||
|                     luaL_error(L, "given function should return a number or nil");); | ||||
|  | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(map2)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor_id); | ||||
|   THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor_id); | ||||
|   luaL_checktype(L, 4, LUA_TFUNCTION); | ||||
|   lua_settop(L, 4); | ||||
|  | ||||
|   TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2, | ||||
|                   lua_pushvalue(L, 4); | ||||
|                   lua_pushnumber(L, *tensor_data); | ||||
|                   lua_pushnumber(L, *src1_data); | ||||
|                   lua_pushnumber(L, *src2_data); | ||||
|                   lua_call(L, 3, 1); | ||||
|                   if(lua_isnumber(L, 5)) | ||||
|                   { | ||||
|                     *tensor_data = (real)lua_tonumber(L, 5); | ||||
|                     lua_pop(L, 1); | ||||
|                   } | ||||
|                   else if(lua_isnil(L, 5)) | ||||
|                     lua_pop(L, 1); | ||||
|                   else | ||||
|                     luaL_error(L, "given function should return a number or nothing");); | ||||
|  | ||||
|   lua_settop(L, 1); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(factory)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = THTensor_(new)(); | ||||
|   luaT_pushudata(L, tensor, torch_Tensor_id); | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(write)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THFile *file = luaT_checkudata(L, 2, torch_File_id); | ||||
|  | ||||
|   THFile_writeIntScalar(file, tensor->nDimension); | ||||
|   THFile_writeLongRaw(file, tensor->size, tensor->nDimension); | ||||
|   THFile_writeLongRaw(file, tensor->stride, tensor->nDimension); | ||||
|   THFile_writeLongScalar(file, tensor->storageOffset+1); /* to respect Lua convention */ | ||||
|  | ||||
|   lua_getfield(L, 2, "writeObject"); /* the method */ | ||||
|   lua_pushvalue(L, 2); /* the file */ | ||||
|   /* the storage */ | ||||
|   if(tensor->storage) | ||||
|   { | ||||
|     THStorage_(retain)(tensor->storage); | ||||
|     luaT_pushudata(L, tensor->storage, torch_Storage_id); | ||||
|   } | ||||
|   else | ||||
|     lua_pushnil(L); | ||||
|  | ||||
|   lua_call(L, 2, 0); /* call the method */ | ||||
|  | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int torch_Tensor_(read)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THFile *file = luaT_checkudata(L, 2, torch_File_id); | ||||
|  | ||||
|   tensor->nDimension = THFile_readIntScalar(file); | ||||
|   tensor->size = THAlloc(sizeof(long)*tensor->nDimension); | ||||
|   tensor->stride = THAlloc(sizeof(long)*tensor->nDimension); | ||||
|   THFile_readLongRaw(file, tensor->size, tensor->nDimension); | ||||
|   THFile_readLongRaw(file, tensor->stride, tensor->nDimension); | ||||
|   tensor->storageOffset = THFile_readLongScalar(file); | ||||
|   tensor->storageOffset--;  /* to respect Lua convention */ | ||||
|  | ||||
|   lua_getfield(L, 2, "readObject"); /* the method */ | ||||
|   lua_pushvalue(L, 2); /* the file */ | ||||
|   lua_call(L, 1, 1); /* call the method */ | ||||
|  | ||||
|   tensor->storage = luaT_toudata(L, -1, torch_Storage_id); | ||||
|   if(tensor->storage) | ||||
|     THStorage_(retain)(tensor->storage); | ||||
|  | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_Tensor_(_) [] = { | ||||
|   {"contiguous", torch_Tensor_(contiguous)}, | ||||
|   {"size", torch_Tensor_(size)}, | ||||
|   {"__len__", torch_Tensor_(size)}, | ||||
|   {"stride", torch_Tensor_(stride)}, | ||||
|   {"dim", torch_Tensor_(nDimension)}, | ||||
|   {"nDimension", torch_Tensor_(nDimension)}, | ||||
|   {"set", torch_Tensor_(set)}, | ||||
|   {"storage", torch_Tensor_(storage)}, | ||||
|   {"storageOffset", torch_Tensor_(storageOffset)}, | ||||
|   {"clone", torch_Tensor_(clone)}, | ||||
|   {"contiguous", torch_Tensor_(contiguous)}, | ||||
|   {"resizeAs", torch_Tensor_(resizeAs)}, | ||||
|   {"resize", torch_Tensor_(resize)}, | ||||
|   {"narrow", torch_Tensor_(narrow)}, | ||||
|   {"sub", torch_Tensor_(sub)}, | ||||
|   {"select", torch_Tensor_(select)}, | ||||
|   {"transpose", torch_Tensor_(transpose)}, | ||||
|   {"t", torch_Tensor_(t)}, | ||||
|   {"unfold", torch_Tensor_(unfold)}, | ||||
|   {"isContiguous", torch_Tensor_(isContiguous)}, | ||||
|   {"nElement", torch_Tensor_(nElement)}, | ||||
|   {"copy", torch_Tensor_(copy)}, | ||||
|   {"apply", torch_Tensor_(apply)}, | ||||
|   {"map", torch_Tensor_(map)}, | ||||
|   {"map2", torch_Tensor_(map2)}, | ||||
|   {"read", torch_Tensor_(read)}, | ||||
|   {"write", torch_Tensor_(write)}, | ||||
|   {"__index__", torch_Tensor_(__index__)}, | ||||
|   {"__newindex__", torch_Tensor_(__newindex__)}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_Tensor_(init)(lua_State *L) | ||||
| { | ||||
|   torch_File_id = luaT_checktypename2id(L, "torch.File"); | ||||
|   torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); | ||||
|   torch_Storage_id = luaT_checktypename2id(L, STRING_torchStorage); | ||||
|  | ||||
|   torch_Tensor_id = luaT_newmetatable(L, STRING_torchTensor, NULL, | ||||
|                                  torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); | ||||
|   luaL_register(L, NULL, torch_Tensor_(_)); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										175
									
								
								generic/TensorConv.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								generic/TensorConv.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,175 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/TensorConv.c" | ||||
| #else | ||||
|  | ||||
| static int torch_(convxcorr2)(lua_State *L, const char* ktype) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   THTensor *r_ = NULL; | ||||
|   THTensor *im = NULL; | ||||
|   THTensor *ker = NULL; | ||||
|   char type[2]; | ||||
|   int rgiven = 0; | ||||
|  | ||||
|   type[0] = 'v'; | ||||
|   type[1] = ktype[0]; | ||||
|  | ||||
|   if (narg == 2 | ||||
|       && (ker = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
|       && (im  = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|   } | ||||
|   else if (narg == 3 | ||||
| 	   && (lua_type(L,3) == LUA_TSTRING) | ||||
| 	   && (ker = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (im = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     type[0] = *(luaL_checkstring(L,3)); | ||||
|     luaL_argcheck(L, (type[0] == 'v' || type[0] == 'V' || type[0] == 'f' || type[0] == 'F'), | ||||
| 		  3, "[Tensor, ] Tensor, Tensor [, x or c]"); | ||||
|     if (type[0] == 'V') type[0] = 'v'; | ||||
|     if (type[0] == 'F') type[0] = 'f'; | ||||
|   } | ||||
|   else if (narg == 4 | ||||
| 	   && (type[0] = *(luaL_checkstring(L,4))) | ||||
| 	   && (ker = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (im = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (r_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     rgiven = 1; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L,"[Tensor, ] Tensor, Tensor [, x or c]"); | ||||
|   } | ||||
|    | ||||
|   if (!r_) r_ = THTensor_(new)(); | ||||
|  | ||||
|   if (im->nDimension == 2 && ker->nDimension == 2) | ||||
|   { | ||||
|     THTensor_(conv2Dmul)(r_,0.0,1.0,im,ker,1,1,type); | ||||
|   } | ||||
|   else if (im->nDimension == 3 && ker->nDimension == 3) | ||||
|   { | ||||
|     THTensor_(conv2Dcmul)(r_,0.0,1.0,im,ker,1,1,type); | ||||
|   } | ||||
|   else if (im->nDimension == 3 && ker->nDimension == 4) | ||||
|   { | ||||
|     THTensor_(conv2Dmv)(r_,0.0,1.0,im,ker,1,1,type); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L," (2D,2D) or (3D,3D) or (3D,4D) "); | ||||
|   } | ||||
|  | ||||
|   pushreturn(rgiven, r_, torch_(Tensor_id)); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_(convxcorr3)(lua_State *L, char* ktype) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   THTensor *r_ = NULL; | ||||
|   THTensor *im = NULL; | ||||
|   THTensor *ker = NULL; | ||||
|   char type[2]; | ||||
|   int rgiven = 0; | ||||
|    | ||||
|   type[0] = 'v'; | ||||
|   type[1] = ktype[0]; | ||||
|  | ||||
|   if (narg == 2 | ||||
|       && (ker = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
|       && (im  = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|   } | ||||
|   else if (narg == 3 | ||||
| 	   && (lua_type(L,3) == LUA_TSTRING) | ||||
| 	   && (ker = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (im = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     type[0] = *(luaL_checkstring(L,3)); | ||||
|     luaL_argcheck(L, (type[0] == 'v' || type[0] == 'V' || type[0] == 'f' || type[0] == 'F'), | ||||
| 		  3, "[Tensor, ] Tensor, Tensor [, x or c]"); | ||||
|     if (type[0] == 'V') type[0] = 'v'; | ||||
|     if (type[0] == 'F') type[0] = 'f'; | ||||
|   } | ||||
|   else if (narg == 4 | ||||
| 	   && (type[0] = *(luaL_checkstring(L,4))) | ||||
| 	   && (ker = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (im = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (r_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     rgiven = 1; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L,"[Tensor, ] Tensor, Tensor [, x or c]"); | ||||
|   } | ||||
|    | ||||
|   if (!r_) r_ = THTensor_(new)(); | ||||
|  | ||||
|   if (im->nDimension == 3 && ker->nDimension == 3) | ||||
|   { | ||||
|     THTensor_(conv3Dmul)(r_,0.0,1.0,im,ker,1,1,1,type); | ||||
|   } | ||||
|   else if (im->nDimension == 4 && ker->nDimension == 4) | ||||
|   { | ||||
|     THTensor_(conv3Dcmul)(r_,0.0,1.0,im,ker,1,1,1,type); | ||||
|   } | ||||
|   else if (im->nDimension == 4 && ker->nDimension == 5) | ||||
|   { | ||||
|     THTensor_(conv3Dmv)(r_,0.0,1.0,im,ker,1,1,1,type); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L," (3D,3D) or (4D,4D) or (4D,5D) "); | ||||
|   } | ||||
|  | ||||
|   pushreturn(rgiven, r_, torch_(Tensor_id)); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_(conv2)(lua_State *L) | ||||
| { | ||||
|   return torch_(convxcorr2)(L,"convolution"); | ||||
| } | ||||
| static int torch_(xcorr2)(lua_State *L) | ||||
| { | ||||
|   return torch_(convxcorr2)(L,"xcorrelation"); | ||||
| } | ||||
|  | ||||
|  | ||||
| static int torch_(conv3)(lua_State *L) | ||||
| { | ||||
|   return torch_(convxcorr3)(L,"convolution"); | ||||
| } | ||||
| static int torch_(xcorr3)(lua_State *L) | ||||
| { | ||||
|   return torch_(convxcorr3)(L,"xcorrelation"); | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_(Conv__) [] = { | ||||
|   {"conv2", torch_(conv2)}, | ||||
|   {"xcorr2", torch_(xcorr2)}, | ||||
|   {"conv3", torch_(conv3)}, | ||||
|   {"xcorr3", torch_(xcorr3)}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_(Conv_init)(lua_State *L) | ||||
| { | ||||
|   torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor)); | ||||
|  | ||||
|   /* register everything into the "torch" field of the tensor metaclass */ | ||||
|   luaT_pushmetaclass(L, torch_(Tensor_id)); | ||||
|   lua_pushstring(L, "torch"); | ||||
|   lua_rawget(L, -2); | ||||
|   luaL_register(L, NULL, torch_(Conv__)); | ||||
|   lua_pop(L, 2); | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
							
								
								
									
										274
									
								
								generic/TensorLapack.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								generic/TensorLapack.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,274 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/TensorLapack.c" | ||||
| #else | ||||
|  | ||||
| #define pushreturn(i,t,tid) \ | ||||
|   if (!i)					\ | ||||
|     luaT_pushudata(L, t, tid);			\ | ||||
|   else						\ | ||||
|     lua_pushvalue(L,i)			 | ||||
|  | ||||
| static int torch_(gesv)(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   THTensor *ra_ = NULL; | ||||
|   THTensor *rb_ = NULL; | ||||
|   THTensor *a_ = NULL; | ||||
|   THTensor *b_ = NULL; | ||||
|   int ragiven = 0; | ||||
|   int rbgiven = 0; | ||||
|  | ||||
|   if (narg == 2 | ||||
|       && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
|       && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|   } | ||||
|   else if (narg == 3 | ||||
| 	   && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     if(lua_toboolean(L,3)) | ||||
|     { | ||||
|       ra_ = a_; | ||||
|       rb_ = b_; | ||||
|       a_ = NULL; | ||||
|       b_ = NULL; | ||||
|       ragiven = 2; | ||||
|       rbgiven = 1; | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");       | ||||
|     } | ||||
|   } | ||||
|   else if (narg == 4 | ||||
| 	   && (a_ = luaT_toudata(L,4,torch_(Tensor_id))) | ||||
| 	   && (b_ = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (ra_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (rb_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     ragiven = 2; | ||||
|     rbgiven = 1; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");       | ||||
|   } | ||||
|  | ||||
|   if (!ra_) ra_ = THTensor_(new)(); | ||||
|   if (!rb_) rb_ = THTensor_(new)(); | ||||
|    | ||||
|   THTensor_(gesv)(rb_,ra_,b_,a_); | ||||
|  | ||||
|   pushreturn(rbgiven,rb_,torch_(Tensor_id)); | ||||
|   pushreturn(ragiven,ra_,torch_(Tensor_id)); | ||||
|  | ||||
|   return 2; | ||||
| } | ||||
|  | ||||
| static int torch_(gels)(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   THTensor *ra_ = NULL; | ||||
|   THTensor *rb_ = NULL; | ||||
|   THTensor *a_ = NULL; | ||||
|   THTensor *b_ = NULL; | ||||
|   int ragiven = 0; | ||||
|   int rbgiven = 0; | ||||
|  | ||||
|   if (narg == 2 | ||||
|       && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
|       && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|   } | ||||
|   else if (narg == 3 | ||||
| 	   && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     if (lua_toboolean(L,3)) | ||||
|     { | ||||
|       ra_ = a_; | ||||
|       rb_ = b_; | ||||
|       a_ = NULL; | ||||
|       b_ = NULL; | ||||
|       ragiven = 2; | ||||
|       rbgiven = 1; | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");       | ||||
|     } | ||||
|   } | ||||
|   else if (narg == 4 | ||||
| 	   && (a_ = luaT_toudata(L,4,torch_(Tensor_id))) | ||||
| 	   && (b_ = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (ra_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (rb_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     ragiven = 2; | ||||
|     rbgiven = 1; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]"); | ||||
|   } | ||||
|  | ||||
|   if (!ra_) ra_ = THTensor_(new)(); | ||||
|   if (!rb_) rb_ = THTensor_(new)(); | ||||
|  | ||||
|   THTensor_(gels)(rb_,ra_,b_,a_); | ||||
|  | ||||
|   pushreturn(rbgiven,rb_,torch_(Tensor_id)); | ||||
|   pushreturn(ragiven,ra_,torch_(Tensor_id)); | ||||
|  | ||||
|   return 2; | ||||
| } | ||||
|  | ||||
| static int torch_(eig)(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   THTensor *re_ = NULL; | ||||
|   THTensor *rv_ = NULL; | ||||
|   THTensor *a_ = NULL; | ||||
|   char type = 'N'; | ||||
|   char uplo = 'U'; | ||||
|   int regiven = 0; | ||||
|   int rvgiven = 0; | ||||
|  | ||||
|   if (narg == 1 | ||||
|       && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|   } | ||||
|   else if (narg == 2 | ||||
| 	   && (lua_type(L,2) == LUA_TSTRING) | ||||
| 	   && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     type = *(luaL_checkstring(L,2)); | ||||
|     luaL_argcheck(L, (type == 'v' || type == 'V' || type == 'n' || type == 'N'), | ||||
| 		  2, "[Tensor, ] [Tensor, ] Tensor [, N or V]"); | ||||
|     if (type == 'v') type = 'V'; | ||||
|     if (type == 'n') type = 'N'; | ||||
|   } | ||||
|   else if (narg == 2 | ||||
| 	   && (a_  = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (re_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     regiven = 1; | ||||
|   } | ||||
|   else if (narg == 3 | ||||
| 	   && (a_  = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (rv_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (re_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     regiven = 1; | ||||
|     rvgiven = 2; | ||||
|   } | ||||
|   else if (narg == 4 | ||||
| 	   && (type = *(luaL_checkstring(L,4))) | ||||
| 	   && (a_  = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (rv_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (re_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     regiven = 1; | ||||
|     rvgiven = 2; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L,"[Tensor, ] [Tensor, ] Tensor [, N or V]"); | ||||
|   } | ||||
|   if (!re_) re_ = THTensor_(new)(); | ||||
|   if (!rv_) rv_ = THTensor_(new)(); | ||||
|  | ||||
|   THTensor_(syev)(re_,rv_,a_,&type,&uplo); | ||||
|  | ||||
|   pushreturn(regiven, re_, torch_(Tensor_id)); | ||||
|   pushreturn(rvgiven, rv_, torch_(Tensor_id)); | ||||
|  | ||||
|   return 2; | ||||
| } | ||||
|  | ||||
| static int torch_(svd)(lua_State *L) | ||||
| { | ||||
|   int narg = lua_gettop(L); | ||||
|   THTensor *ru_ = NULL; | ||||
|   THTensor *rs_ = NULL; | ||||
|   THTensor *rv_ = NULL; | ||||
|   THTensor *a_ = NULL; | ||||
|   char type = 'S'; | ||||
|   int rugiven = 0; | ||||
|   int rsgiven = 0; | ||||
|   int rvgiven = 0; | ||||
|  | ||||
|   if (narg == 1 | ||||
|       && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|   } | ||||
|   else if (narg ==2  | ||||
| 	   && (type = *(luaL_checkstring(L,2))) | ||||
| 	   && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     luaL_argcheck(L, (type == 's' || type == 'S' || type == 'a' || type == 'A'), | ||||
| 		  2, "[Tensor, ] [Tensor, ] [Tensor, ] Tensor [, A or S]"); | ||||
|     if (type == 's') type = 'S'; | ||||
|     if (type == 'a') type = 'A'; | ||||
|   } | ||||
|   else if (narg == 4 | ||||
| 	   && (a_  = luaT_toudata(L,4,torch_(Tensor_id))) | ||||
| 	   && (rv_ = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (rs_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (ru_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     rugiven = 1; | ||||
|     rsgiven = 2; | ||||
|     rvgiven = 3; | ||||
|   } | ||||
|   else if (narg == 5 | ||||
| 	   && (type = *(luaL_checkstring(L,5))) | ||||
| 	   && (a_  = luaT_toudata(L,4,torch_(Tensor_id))) | ||||
| 	   && (rv_ = luaT_toudata(L,3,torch_(Tensor_id))) | ||||
| 	   && (rs_ = luaT_toudata(L,2,torch_(Tensor_id))) | ||||
| 	   && (ru_ = luaT_toudata(L,1,torch_(Tensor_id)))) | ||||
|   { | ||||
|     rugiven = 1; | ||||
|     rsgiven = 2; | ||||
|     rvgiven = 3; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     luaL_error(L,"[Tensor, Tensor, Tensor], Tensor, [, 'A' or 'S' ]"); | ||||
|   } | ||||
|  | ||||
|   if (!ru_) ru_ = THTensor_(new)(); | ||||
|   if (!rs_) rs_ = THTensor_(new)(); | ||||
|   if (!rv_) rv_ = THTensor_(new)(); | ||||
|  | ||||
|   THTensor_(gesvd)(ru_,rs_,rv_,a_,&type); | ||||
|  | ||||
|   pushreturn(rugiven,ru_,torch_(Tensor_id)); | ||||
|   pushreturn(rsgiven,rs_,torch_(Tensor_id)); | ||||
|   pushreturn(rvgiven,rv_,torch_(Tensor_id)); | ||||
|  | ||||
|   return 3; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_(lapack__) [] = { | ||||
|   {"gesv", torch_(gesv)}, | ||||
|   {"gels", torch_(gels)}, | ||||
|   {"eig", torch_(eig)}, | ||||
|   {"svd", torch_(svd)}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_(Lapack_init)(lua_State *L) | ||||
| { | ||||
|   torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor)); | ||||
|  | ||||
|   /* register everything into the "torch" field of the tensor metaclass */ | ||||
|   luaT_pushmetaclass(L, torch_(Tensor_id)); | ||||
|   lua_pushstring(L, "torch"); | ||||
|   lua_rawget(L, -2); | ||||
|   luaL_register(L, NULL, torch_(lapack__)); | ||||
|   lua_pop(L, 2); | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										177
									
								
								generic/TensorOperator.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								generic/TensorOperator.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,177 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/TensorOperator.c" | ||||
| #else | ||||
|  | ||||
| static const void* torch_Tensor_id; | ||||
|  | ||||
| static int torch_TensorOperator_(__add__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); | ||||
|   THTensor *r; | ||||
|  | ||||
|   if(!tensor1 && !tensor2) | ||||
|     luaL_error(L, "expecting two Tensors or one Tensor and one number"); | ||||
|   else | ||||
|   { | ||||
|     r = THTensor_(new)(); | ||||
|     luaT_pushudata(L, r, torch_Tensor_id); | ||||
|      | ||||
|     if(!tensor1 && tensor2) | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor2); | ||||
|       THTensor_(copy)(r, tensor2); | ||||
|       THTensor_(add)(r, r, luaL_checknumber(L, 1)); | ||||
|     } | ||||
|     else if(tensor1 && !tensor2) | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor1); | ||||
|       THTensor_(copy)(r, tensor1); | ||||
|       THTensor_(add)(r, r, luaL_checknumber(L, 2)); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor1); | ||||
|       THTensor_(copy)(r, tensor1); | ||||
|       THTensor_(cadd)(r, r, 1, tensor2); | ||||
|     } | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_TensorOperator_(__sub__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); | ||||
|   THTensor *r; | ||||
|  | ||||
|   if(!tensor1 && !tensor2) | ||||
|     luaL_error(L, "expecting two Tensors or one Tensor and one number"); | ||||
|   else | ||||
|   { | ||||
|     r = THTensor_(new)(); | ||||
|     luaT_pushudata(L, r, torch_Tensor_id); | ||||
|      | ||||
|     if(!tensor1 && tensor2) | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor2); | ||||
|       THTensor_(fill)(r, luaL_checknumber(L, 1)); | ||||
|       THTensor_(cadd)(r, r, -1, tensor2); | ||||
|     } | ||||
|     else if(tensor1 && !tensor2) | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor1); | ||||
|       THTensor_(copy)(r, tensor1); | ||||
|       THTensor_(add)(r, r, -luaL_checknumber(L, 2)); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor1); | ||||
|       THTensor_(copy)(r, tensor1); | ||||
|       THTensor_(cadd)(r, r, -1, tensor2); | ||||
|     } | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_TensorOperator_(__unm__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *r; | ||||
|  | ||||
|   r = THTensor_(new)(); | ||||
|   luaT_pushudata(L, r, torch_Tensor_id); | ||||
|   THTensor_(resizeAs)(r, tensor); | ||||
|   THTensor_(copy)(r, tensor); | ||||
|   THTensor_(mul)(r, r, -1); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_TensorOperator_(__mul__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); | ||||
|   THTensor *r; | ||||
|  | ||||
|   if(!tensor1 && !tensor2) | ||||
|     luaL_error(L, "expecting two Tensors or one Tensor and one number"); | ||||
|   else | ||||
|   { | ||||
|     r = THTensor_(new)(); | ||||
|     luaT_pushudata(L, r, torch_Tensor_id); | ||||
|      | ||||
|     if(!tensor1 && tensor2) | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor2); | ||||
|       THTensor_(copy)(r, tensor2); | ||||
|       THTensor_(mul)(r, r, luaL_checknumber(L, 1)); | ||||
|     } | ||||
|     else if(tensor1 && !tensor2) | ||||
|     { | ||||
|       THTensor_(resizeAs)(r, tensor1); | ||||
|       THTensor_(copy)(r, tensor1); | ||||
|       THTensor_(mul)(r, r, luaL_checknumber(L, 2)); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       int dimt = tensor1->nDimension; | ||||
|       int dims = tensor2->nDimension; | ||||
|        | ||||
|       if(dimt == 1 && dims == 1) | ||||
|         lua_pushnumber(L, THTensor_(dot)(tensor1, tensor2)); /* ok, we wasted r, but who cares */ | ||||
|       else if(dimt == 2 && dims == 1) | ||||
|       { | ||||
|         THTensor_(resize1d)(r, tensor1->size[0]); | ||||
|         THTensor_(zero)(r); | ||||
|         THTensor_(addmv)(r, 1, r, 1, tensor1, tensor2); | ||||
|       } | ||||
|       else if(dimt == 2 && dims == 2) | ||||
|       { | ||||
|         THTensor_(resize2d)(r, tensor1->size[0], tensor2->size[1]); | ||||
|         THTensor_(zero)(r); | ||||
|         THTensor_(addmm)(r, 1, r, 1, tensor1, tensor2); | ||||
|       } | ||||
|       else | ||||
|         luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension);  | ||||
|     } | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static int torch_TensorOperator_(__div__)(lua_State *L) | ||||
| { | ||||
|   THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); | ||||
|   THTensor *r; | ||||
|  | ||||
|   luaL_argcheck(L, lua_isnumber(L,2), 2, "number expected"); | ||||
|  | ||||
|   r = THTensor_(new)(); | ||||
|   luaT_pushudata(L, r, torch_Tensor_id); | ||||
|  | ||||
|   THTensor_(resizeAs)(r, tensor); | ||||
|   THTensor_(copy)(r, tensor); | ||||
|   THTensor_(mul)(r, r, 1/lua_tonumber(L, 2)); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg torch_TensorOperator_(_) [] = { | ||||
|   {"__add__", torch_TensorOperator_(__add__)}, | ||||
|   {"__sub__", torch_TensorOperator_(__sub__)}, | ||||
|   {"__unm__", torch_TensorOperator_(__unm__)}, | ||||
|   {"__mul__", torch_TensorOperator_(__mul__)}, | ||||
|   {"__div__", torch_TensorOperator_(__div__)}, | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void torch_TensorOperator_(init)(lua_State *L) | ||||
| { | ||||
|   torch_Tensor_id = luaT_checktypename2id(L, STRING_torchTensor); | ||||
|  | ||||
|   luaT_pushmetaclass(L, torch_Tensor_id); | ||||
|   luaL_register(L, NULL, torch_TensorOperator_(_)); | ||||
|   lua_pop(L, 1); | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										44
									
								
								generic/hist.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								generic/hist.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,44 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/lab.c" | ||||
| #else | ||||
|  | ||||
| #include "interfaces.c" | ||||
|  | ||||
| static int lab_(histc)(lua_State *L) | ||||
| { | ||||
|   THTensor *r = luaT_checkudata(L, 1, torch_(Tensor_id)); | ||||
|   THTensor *h = luaT_checkudata(L, 2, torch_(Tensor_id)); | ||||
|   int nbins = luaL_checknumber(L, 3); | ||||
|   real *h_data = THTensor_(data)(h); | ||||
|  | ||||
|   TH_TENSOR_APPLY(real, r,                                      \ | ||||
|                   if ((*r_data <= nbins) && (*r_data >= 1)) {   \ | ||||
|                     *(h_data + (int)(*r_data) - 1) += 1;        \ | ||||
|                   }) | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static const struct luaL_Reg lab_(stuff__) [] = { | ||||
|   {"_histc", lab_(histc)}, | ||||
| #endif | ||||
|   {NULL, NULL} | ||||
| }; | ||||
|  | ||||
| void lab_(init)(lua_State *L) | ||||
| { | ||||
|   torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor)); | ||||
|   torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); | ||||
|  | ||||
|   /* register everything into the "lab" field of the tensor metaclass */ | ||||
|   luaT_pushmetaclass(L, torch_(Tensor_id)); | ||||
|   lua_pushstring(L, "lab"); | ||||
|   lua_newtable(L); | ||||
|   luaL_register(L, NULL, lab_(stuff__)); | ||||
|   lua_rawset(L, -3); | ||||
|   lua_pop(L, 1); | ||||
|  | ||||
| /*  luaT_registeratid(L, lab_(stuff__), torch_(Tensor_id)); */ | ||||
| /*  luaL_register(L, NULL, lab_(stuff__)); */   | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										123
									
								
								hist.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								hist.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,123 @@ | ||||
| --  | ||||
| -- rudimentary histogram diplay on the command line. | ||||
| -- | ||||
| -- Author: Marco Scoffier | ||||
| -- Date  :  | ||||
| -- Mod   : Oct 21, 2011 | ||||
| --  + made 80 columns default | ||||
| --  + save index of max bin in h.max not pointer to bin | ||||
| -- | ||||
| function torch.histc__tostring(h, barHeight) | ||||
|    barHeight = barHeight or 10 | ||||
|    local lastm = h[h.max].nb | ||||
|    local incr  = lastm/(barHeight+1) | ||||
|    local m     = lastm - incr | ||||
|    local tl    = torch.Tensor(#h):fill(0) | ||||
|    local toph  = '|' | ||||
|    local topm  = ':' | ||||
|    local topl  = '.' | ||||
|    local bar   = '|' | ||||
|    local blank = ' ' | ||||
|    local yaxis = '--------:'  | ||||
|    local str = 'nsamples:' | ||||
|    str = str ..  | ||||
|       string.format('  min:(bin:%d/#%d/cntr:%2.2f)  max:(bin:%d/#%d/cntr:%2.2f)\n', | ||||
|                     h.min,h[h.min].nb,h[h.min].val,  | ||||
|                     h.max,h[h.max].nb,h[h.max].val) | ||||
|     | ||||
|    str = str .. yaxis | ||||
|    for j = 1,#h do  | ||||
|       str = str .. '-' | ||||
|    end | ||||
|    str = str .. '\n' | ||||
|  | ||||
|    for i = 1,barHeight do | ||||
|       -- y axis | ||||
|       if i%1==0 then | ||||
|          str = str .. string.format('%1.2e:',m) | ||||
|       end | ||||
|       for j = 1,#h do | ||||
|          if tl[j] == 1 then | ||||
|             str = str .. bar | ||||
|          elseif h[j].nb < m then | ||||
|             str = str .. blank | ||||
|          else | ||||
|             -- in the bracket | ||||
|             tl[j] = 1 | ||||
|             -- find 1/3rds | ||||
|             local p = (lastm - h[j].nb) / incr | ||||
|             if p > 0.66 then | ||||
|                str = str .. toph | ||||
|             elseif p > 0.33 then | ||||
|                str = str .. topm | ||||
|             else | ||||
|                str = str .. topl | ||||
|             end | ||||
|          end | ||||
|       end | ||||
|       str = str .. '\n' | ||||
|       lastm = m  | ||||
|       m     = m - incr | ||||
|    end | ||||
|    -- x axis | ||||
|    str = str .. yaxis  | ||||
|    for j = 1,#h do | ||||
|       if ((j - 2) % 6 == 0)then | ||||
|          str = str .. '^' | ||||
|       else | ||||
|          str = str .. '-' | ||||
|       end | ||||
|    end | ||||
|    str = str .. '\ncenters ' | ||||
|    for j = 1,#h do | ||||
|       if ((j - 2) % 6 == 0)then | ||||
|          if h[j].val < 0 then | ||||
|             str = str .. '-' | ||||
|          else | ||||
|             str = str .. '+' | ||||
|          end | ||||
|          str = str .. string.format('%1.2f ',math.abs(h[j].val)) | ||||
|       end | ||||
|    end | ||||
|    return str | ||||
| end | ||||
|  | ||||
| -- a simple function that computes the histogram of a tensor | ||||
| function torch.histc(...) | ||||
|    -- get args | ||||
|    local args = {...} | ||||
|    local tensor = args[1] or error('usage: torch.histc (tensor [, nBins] [, min] [, max]') | ||||
|    local bins = args[2] or 80 - 8 | ||||
|    local min = args[3] or tensor:min() | ||||
|    local max = args[4] or tensor:max() | ||||
|    local raw = args[5] or false | ||||
|  | ||||
|    -- compute histogram | ||||
|    local hist = torch.zeros(bins) | ||||
|    local ten = torch.Tensor(tensor:nElement()):copy(tensor) | ||||
|    ten:add(-min):div(max-min):mul(bins - 1e-6):floor():add(1) | ||||
|    ten.torch._histc(ten, hist, bins) | ||||
|  | ||||
|    -- return raw histogram (no extra info) | ||||
|    if raw then return hist end | ||||
|  | ||||
|    -- cleanup hist | ||||
|    local cleanhist = {} | ||||
|    cleanhist.raw = hist | ||||
|    local _,mx = torch.max(cleanhist.raw) | ||||
|    local _,mn = torch.min(cleanhist.raw) | ||||
|    cleanhist.bins = bins | ||||
|    cleanhist.binwidth = (max-min)/bins | ||||
|    for i = 1,bins do | ||||
|       cleanhist[i] = {} | ||||
|       cleanhist[i].val = min + (i-0.5)*cleanhist.binwidth | ||||
|       cleanhist[i].nb = hist[i] | ||||
|    end | ||||
|    cleanhist.max = mx[1] | ||||
|    cleanhist.min = mn[1] | ||||
|  | ||||
|    -- print function | ||||
|    setmetatable(cleanhist, {__tostring=torch.histc__tostring}) | ||||
|    return cleanhist | ||||
| end | ||||
|  | ||||
							
								
								
									
										99
									
								
								init.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								init.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,99 @@ | ||||
| #include "general.h" | ||||
| #include "utils.h" | ||||
|  | ||||
| extern void torch_utils_init(lua_State *L); | ||||
| extern void torch_random_init(lua_State *L); | ||||
| extern void torch_File_init(lua_State *L); | ||||
| extern void torch_File_init_storage_id(lua_State *L); | ||||
| extern void torch_DiskFile_init(lua_State *L); | ||||
| extern void torch_MemoryFile_init(lua_State *L); | ||||
| extern void torch_PipeFile_init(lua_State *L); | ||||
| extern void torch_Timer_init(lua_State *L); | ||||
|  | ||||
| extern void torch_ByteStorage_init(lua_State *L); | ||||
| extern void torch_CharStorage_init(lua_State *L); | ||||
| extern void torch_ShortStorage_init(lua_State *L); | ||||
| extern void torch_IntStorage_init(lua_State *L); | ||||
| extern void torch_LongStorage_init(lua_State *L); | ||||
| extern void torch_FloatStorage_init(lua_State *L); | ||||
| extern void torch_DoubleStorage_init(lua_State *L); | ||||
|  | ||||
| extern void torch_ByteTensor_init(lua_State *L); | ||||
| extern void torch_CharTensor_init(lua_State *L); | ||||
| extern void torch_ShortTensor_init(lua_State *L); | ||||
| extern void torch_IntTensor_init(lua_State *L); | ||||
| extern void torch_LongTensor_init(lua_State *L); | ||||
| extern void torch_FloatTensor_init(lua_State *L); | ||||
| extern void torch_DoubleTensor_init(lua_State *L); | ||||
|  | ||||
| extern void torch_ByteTensorOperator_init(lua_State *L); | ||||
| extern void torch_CharTensorOperator_init(lua_State *L); | ||||
| extern void torch_ShortTensorOperator_init(lua_State *L); | ||||
| extern void torch_IntTensorOperator_init(lua_State *L); | ||||
| extern void torch_LongTensorOperator_init(lua_State *L); | ||||
| extern void torch_FloatTensorOperator_init(lua_State *L); | ||||
| extern void torch_DoubleTensorOperator_init(lua_State *L); | ||||
|  | ||||
| extern void torch_TensorMath_init(lua_State *L); | ||||
|  | ||||
| static lua_State *globalL; | ||||
| static void luaTorchErrorHandlerFunction(const char *msg) | ||||
| { | ||||
|   luaL_error(globalL, msg); | ||||
| } | ||||
|  | ||||
| static void luaTorchArgCheckHandlerFunction(int condition, int argNumber, const char *msg) | ||||
| { | ||||
|   luaL_argcheck(globalL, condition, argNumber, msg); | ||||
| } | ||||
|  | ||||
| DLL_EXPORT int luaopen_libtorch(lua_State *L) | ||||
| { | ||||
|   globalL = L; | ||||
|   THSetErrorHandler(luaTorchErrorHandlerFunction); | ||||
|   THSetArgCheckHandler(luaTorchArgCheckHandlerFunction); | ||||
|  | ||||
|   lua_newtable(L); | ||||
|   lua_pushvalue(L, -1); | ||||
|   lua_setfield(L, LUA_GLOBALSINDEX, "torch"); | ||||
|  | ||||
|   torch_File_init(L); | ||||
|  | ||||
|   torch_ByteStorage_init(L); | ||||
|   torch_CharStorage_init(L); | ||||
|   torch_ShortStorage_init(L); | ||||
|   torch_IntStorage_init(L); | ||||
|   torch_LongStorage_init(L); | ||||
|   torch_FloatStorage_init(L); | ||||
|   torch_DoubleStorage_init(L); | ||||
|  | ||||
|   torch_ByteTensor_init(L); | ||||
|   torch_CharTensor_init(L); | ||||
|   torch_ShortTensor_init(L); | ||||
|   torch_IntTensor_init(L); | ||||
|   torch_LongTensor_init(L); | ||||
|   torch_FloatTensor_init(L); | ||||
|   torch_DoubleTensor_init(L); | ||||
|  | ||||
|   torch_File_init_storage_id(L); | ||||
|  | ||||
|   torch_ByteTensorOperator_init(L); | ||||
|   torch_CharTensorOperator_init(L); | ||||
|   torch_ShortTensorOperator_init(L); | ||||
|   torch_IntTensorOperator_init(L); | ||||
|   torch_LongTensorOperator_init(L); | ||||
|   torch_FloatTensorOperator_init(L); | ||||
|   torch_DoubleTensorOperator_init(L); | ||||
|  | ||||
|   torch_Timer_init(L); | ||||
|   torch_DiskFile_init(L); | ||||
|   torch_PipeFile_init(L); | ||||
|   torch_MemoryFile_init(L); | ||||
|  | ||||
|   torch_TensorMath_init(L); | ||||
|  | ||||
|   torch_utils_init(L); | ||||
|   torch_random_init(L); | ||||
|  | ||||
|   return 1; | ||||
| } | ||||
							
								
								
									
										78
									
								
								init.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								init.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,78 @@ | ||||
|  | ||||
| -- We are using paths.require to appease mkl | ||||
| require "paths" | ||||
| paths.require "libtorch" | ||||
| require "libtorch" | ||||
|  | ||||
| --- package stuff | ||||
| function torch.packageLuaPath(name) | ||||
|    if not name then | ||||
|       local ret = string.match(torch.packageLuaPath('torch'), '(.*)/') | ||||
|        if not ret then --windows? | ||||
|            ret = string.match(torch.packageLuaPath('torch'), '(.*)\\') | ||||
|        end | ||||
|        return ret  | ||||
|    end | ||||
|    for path in string.gmatch(package.path, "(.-);") do | ||||
|       path = string.gsub(path, "%?", name) | ||||
|       local f = io.open(path) | ||||
|       if f then | ||||
|          f:close() | ||||
|          local ret = string.match(path, "(.*)/") | ||||
|          if not ret then --windows? | ||||
|              ret = string.match(path, "(.*)\\") | ||||
|          end | ||||
|          return ret | ||||
|       end | ||||
|    end | ||||
| end | ||||
|  | ||||
| function torch.include(package, file) | ||||
|    dofile(torch.packageLuaPath(package) .. '/' .. file)  | ||||
| end | ||||
|  | ||||
| function torch.class(tname, parenttname) | ||||
|  | ||||
|    local function constructor(...) | ||||
|       local self = {} | ||||
|       torch.setmetatable(self, tname) | ||||
|       if self.__init then | ||||
|          self:__init(...) | ||||
|       end | ||||
|       return self | ||||
|    end | ||||
|     | ||||
|    local function factory() | ||||
|       local self = {} | ||||
|       torch.setmetatable(self, tname) | ||||
|       return self | ||||
|    end | ||||
|  | ||||
|    local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory) | ||||
|    local mpt | ||||
|    if parenttname then | ||||
|       mpt = torch.getmetatable(parenttname) | ||||
|    end | ||||
|    return mt, mpt | ||||
| end | ||||
|  | ||||
| function torch.setdefaulttensortype(typename) | ||||
|    assert(type(typename) == 'string', 'string expected') | ||||
|    if torch.getconstructortable(typename) then | ||||
|       torch.Tensor = torch.getconstructortable(typename) | ||||
|       torch.Storage = torch.getconstructortable(torch.typename(torch.Tensor(1):storage())) | ||||
|       torch.__setdefaulttensortype(typename) | ||||
|    else | ||||
|       error(string.format("<%s> is not a string describing a torch object", typename)) | ||||
|    end | ||||
| end | ||||
|  | ||||
| torch.setdefaulttensortype('torch.DoubleTensor') | ||||
|  | ||||
| torch.include('torch', 'Tensor.lua') | ||||
| torch.include('torch', 'File.lua') | ||||
| torch.include('torch', 'CmdLine.lua') | ||||
| torch.include('torch', 'Tester.lua') | ||||
| torch.include('torch', 'TensorMath.lua') | ||||
| torch.include('torch', 'test.lua') | ||||
| return torch | ||||
							
								
								
									
										2
									
								
								lib/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								lib/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| ADD_SUBDIRECTORY(TH) | ||||
| ADD_SUBDIRECTORY(luaT) | ||||
							
								
								
									
										117
									
								
								lib/TH/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								lib/TH/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,117 @@ | ||||
| # -*- cmake -*- | ||||
|  | ||||
| SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) | ||||
|  | ||||
| SET(hdr  | ||||
|   THGeneral.h THStorage.h THTensor.h THTensorApply.h | ||||
|   THBlas.h THLapack.h THLogAdd.h THRandom.h THVector.h) | ||||
| SET(src  | ||||
|   THGeneral.c THStorage.c THTensor.c THBlas.c THLapack.c | ||||
|   THLogAdd.c THRandom.c | ||||
|   THFile.c THDiskFile.c THMemoryFile.c) | ||||
|  | ||||
| SET(src ${src} ${hdr}) | ||||
|  | ||||
| IF(UNIX) | ||||
|   INCLUDE(CheckFunctionExists) | ||||
|   SET(CMAKE_EXTRA_INCLUDE_FILES "sys/mman.h") | ||||
|   CHECK_FUNCTION_EXISTS(mmap HAVE_MMAP) | ||||
|   IF(HAVE_MMAP) | ||||
|     ADD_DEFINITIONS(-DHAVE_MMAP=1) | ||||
|   ENDIF(HAVE_MMAP) | ||||
| ENDIF(UNIX) | ||||
|  | ||||
| ADD_LIBRARY(TH SHARED ${src}) | ||||
|  | ||||
| FIND_PACKAGE(BLAS) | ||||
| FIND_PACKAGE(LAPACK) | ||||
|  | ||||
| IF (LAPACK_FOUND) | ||||
|     SET(CMAKE_C_FLAGS "-D__LAPACK__ ${CMAKE_C_FLAGS}") | ||||
| ENDIF(LAPACK_FOUND) | ||||
|  | ||||
| FIND_PACKAGE(SSE) | ||||
|  | ||||
| IF (SSE2_FOUND) | ||||
|   SET(CMAKE_C_FLAGS "-msse2 -D__SSE2__ ${CMAKE_C_FLAGS}") | ||||
| ENDIF (SSE2_FOUND) | ||||
| IF (SSE3_FOUND) | ||||
|   SET(CMAKE_C_FLAGS "-msse3 -D__SSE3__ ${CMAKE_C_FLAGS}") | ||||
| ENDIF (SSE3_FOUND) | ||||
| IF (SSSE3_FOUND) | ||||
|   SET(CMAKE_C_FLAGS "-mssse3 -D__SSSE3__ ${CMAKE_C_FLAGS}") | ||||
| ENDIF (SSSE3_FOUND) | ||||
| IF (SSE4.1_FOUND) | ||||
|   SET(CMAKE_C_FLAGS "-msse4.1 -D__SSE4_1__ ${CMAKE_C_FLAGS}") | ||||
| ENDIF (SSE4.1_FOUND) | ||||
|  | ||||
| IF(BLAS_FOUND) | ||||
|   ADD_DEFINITIONS(-DUSE_LAPACK) | ||||
| #  INCLUDE_DIRECTORIES(${CBLAS_INCLUDE_DIR}) | ||||
|   TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES}) | ||||
| ENDIF(BLAS_FOUND) | ||||
|  | ||||
| #CONFIGURE_FILE("THCBlas.h.in" "${CMAKE_CURRENT_BINARY_DIR}/THCBlas.h") | ||||
| #INCLUDE_DIRECTORIES("${CMAKE_CURRENT_BINARY_DIR}") | ||||
| #INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/THCBlas.h"  | ||||
| #  DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH") | ||||
|  | ||||
| INSTALL(TARGETS TH | ||||
|           RUNTIME DESTINATION "${Torch_INSTALL_BIN_SUBDIR}" | ||||
|           LIBRARY DESTINATION "${Torch_INSTALL_LIB_SUBDIR}" | ||||
|           ARCHIVE DESTINATION "${Torch_INSTALL_LIB_SUBDIR}") | ||||
|  | ||||
| INSTALL(FILES | ||||
|   TH.h | ||||
|   THBlas.h | ||||
|   THDiskFile.h | ||||
|   THFile.h | ||||
|   THFilePrivate.h | ||||
|   THGeneral.h | ||||
|   THGenerateAllTypes.h | ||||
|   THGenerateFloatTypes.h | ||||
|   THGenerateIntTypes.h | ||||
|   THLapack.h | ||||
|   THLogAdd.h | ||||
|   THMemoryFile.h | ||||
|   THRandom.h | ||||
|   THStorage.h | ||||
|   THTensor.h | ||||
|   THTensorApply.h | ||||
|   THTensorDimApply.h | ||||
|   THTensorMacros.h | ||||
|   THVector.h | ||||
|   DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH") | ||||
|  | ||||
| INSTALL(FILES | ||||
|   generic/THBlas.c | ||||
|   generic/THBlas.h | ||||
|   generic/THLapack.c | ||||
|   generic/THLapack.h | ||||
|   generic/THStorage.c | ||||
|   generic/THStorage.h | ||||
|   generic/THStorageCopy.c | ||||
|   generic/THStorageCopy.h | ||||
|   generic/THTensor.c | ||||
|   generic/THTensor.h | ||||
|   generic/THTensorConv.c | ||||
|   generic/THTensorConv.h | ||||
|   generic/THTensorCopy.c | ||||
|   generic/THTensorCopy.h | ||||
|   generic/THTensorLapack.c | ||||
|   generic/THTensorLapack.h | ||||
|   generic/THTensorMath.c | ||||
|   generic/THTensorMath.h | ||||
|   generic/THTensorRandom.c | ||||
|   generic/THTensorRandom.h | ||||
|   generic/THVector.c | ||||
|   DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH/generic") | ||||
|  | ||||
| # Create THConfig.cmake | ||||
| GET_TARGET_PROPERTY(TH_OUTPUT_NAME TH LOCATION) | ||||
| GET_FILENAME_COMPONENT(TH_OUTPUT_NAME ${TH_OUTPUT_NAME} NAME) | ||||
| SET(TH_LIBRARIES "${Torch_INSTALL_LIB}/${TH_OUTPUT_NAME}") | ||||
| SET(TH_INCLUDE_DIR "${Torch_INSTALL_INCLUDE}/TH") | ||||
| CONFIGURE_FILE(THConfig.cmake.in "${Torch_BINARY_DIR}/cmake-external/THConfig.cmake") | ||||
| INSTALL(FILES "${Torch_BINARY_DIR}/cmake-external/THConfig.cmake"  | ||||
|   DESTINATION "${Torch_INSTALL_CMAKE_SUBDIR}") | ||||
							
								
								
									
										23
									
								
								lib/TH/TH.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								lib/TH/TH.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | ||||
| #ifndef TH_INC | ||||
| #define TH_INC | ||||
|  | ||||
| #include "THBlas.h" | ||||
|  | ||||
| #ifdef __LAPACK__ | ||||
| #include "THLapack.h" | ||||
| #endif | ||||
|  | ||||
| #include "THVector.h" | ||||
| #include "THGeneral.h" | ||||
| #include "THLogAdd.h" | ||||
| #include "THRandom.h" | ||||
| #include "THStorage.h" | ||||
| #include "THTensor.h" | ||||
| #include "THTensorApply.h" | ||||
| #include "THTensorDimApply.h" | ||||
|  | ||||
| #include "THFile.h" | ||||
| #include "THDiskFile.h" | ||||
| #include "THMemoryFile.h" | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										5
									
								
								lib/TH/THBlas.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								lib/TH/THBlas.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | ||||
| #include "THBlas.h" | ||||
|  | ||||
| /* #include "THCBlas.h" */ | ||||
| #include "generic/THBlas.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
							
								
								
									
										11
									
								
								lib/TH/THBlas.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								lib/TH/THBlas.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| #ifndef TH_BLAS_INC | ||||
| #define TH_BLAS_INC | ||||
|  | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| #define THBlas_(NAME) TH_CONCAT_4(TH,Real,Blas_,NAME) | ||||
|  | ||||
| #include "generic/THBlas.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										8
									
								
								lib/TH/THCBlas.h.in
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								lib/TH/THCBlas.h.in
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| /* -*- C -*- */ | ||||
|  | ||||
| #cmakedefine USE_CBLAS @USE_CBLAS@ | ||||
|  | ||||
| #if USE_CBLAS | ||||
| # include "@CBLAS_INCLUDE_FILE@" | ||||
| #endif | ||||
|  | ||||
							
								
								
									
										9
									
								
								lib/TH/THConfig.cmake.in
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								lib/TH/THConfig.cmake.in
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | ||||
| # Find the TH includes and library | ||||
| # | ||||
| # TH_INCLUDE_DIR -- where to find the includes | ||||
| # TH_LIBRARIES -- list of libraries to link against | ||||
| # TH_FOUND -- set to 1 if found | ||||
|  | ||||
| SET(TH_FOUND 1) | ||||
| SET(TH_INCLUDE_DIR "@TH_INCLUDE_DIR@") | ||||
| SET(TH_LIBRARIES "@TH_LIBRARIES@") | ||||
							
								
								
									
										592
									
								
								lib/TH/THDiskFile.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										592
									
								
								lib/TH/THDiskFile.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,592 @@ | ||||
| #include "THGeneral.h" | ||||
| #include "THDiskFile.h" | ||||
| #include "THFilePrivate.h" | ||||
|  | ||||
| typedef struct THDiskFile__ | ||||
| { | ||||
|     THFile file; | ||||
|  | ||||
|     FILE *handle; | ||||
|     char *name; | ||||
|     int isNativeEncoding; | ||||
|  | ||||
| } THDiskFile; | ||||
|  | ||||
| static int THDiskFile_isOpened(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)self; | ||||
|   return (dfself->handle != NULL); | ||||
| } | ||||
|  | ||||
| const char *THDiskFile_name(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)self; | ||||
|   return dfself->name; | ||||
| } | ||||
|  | ||||
|  | ||||
| #define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM) \ | ||||
|   static long THDiskFile_read##TYPEC(THFile *self, TYPE *data, long n)  \ | ||||
|   {                                                                     \ | ||||
|     THDiskFile *dfself = (THDiskFile*)(self);                           \ | ||||
|     long nread = 0L;                                                    \ | ||||
|                                                                         \ | ||||
|     THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \ | ||||
|     THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); \ | ||||
|                                                                         \ | ||||
|     if(dfself->file.isBinary)                                           \ | ||||
|     {                                                                   \ | ||||
|       nread = fread(data, sizeof(TYPE), n, dfself->handle);        \ | ||||
|       if(!dfself->isNativeEncoding && (sizeof(TYPE) > 1) && (nread > 0)) \ | ||||
|         THDiskFile_reverseMemory(data, data, sizeof(TYPE), nread);      \ | ||||
|     }                                                                   \ | ||||
|     else                                                                \ | ||||
|     {                                                                   \ | ||||
|       long i;                                                           \ | ||||
|       for(i = 0; i < n; i++)                                            \ | ||||
|       {                                                                 \ | ||||
|         ASCII_READ_ELEM; /* increment here result and break if wrong */ \ | ||||
|       }                                                                 \ | ||||
|       if(dfself->file.isAutoSpacing && (n > 0))                         \ | ||||
|       {                                                                 \ | ||||
|         int c = fgetc(dfself->handle);                                  \ | ||||
|         if( (c != '\n') && (c != EOF) )                                 \ | ||||
|           ungetc(c, dfself->handle);                                    \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     if(nread != n)                                                      \ | ||||
|     {                                                                   \ | ||||
|       dfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \ | ||||
|       if(!dfself->file.isQuiet)                                         \ | ||||
|         THError("read error: read %d blocks instead of %d", nread, n);  \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     return nread;                                                       \ | ||||
|   }                                                                     \ | ||||
|                                                                         \ | ||||
|   static long THDiskFile_write##TYPEC(THFile *self, TYPE *data, long n) \ | ||||
|   {                                                                     \ | ||||
|     THDiskFile *dfself = (THDiskFile*)(self);                           \ | ||||
|     long nwrite = 0L;                                                   \ | ||||
|                                                                         \ | ||||
|     THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \ | ||||
|     THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); \ | ||||
|                                                                         \ | ||||
|     if(dfself->file.isBinary)                                           \ | ||||
|     {                                                                   \ | ||||
|       if(dfself->isNativeEncoding)                                      \ | ||||
|       {                                                                 \ | ||||
|         nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle);         \ | ||||
|       }                                                                 \ | ||||
|       else                                                              \ | ||||
|       {                                                                 \ | ||||
|         if(sizeof(TYPE) > 1)                                            \ | ||||
|         {                                                               \ | ||||
|           char *buffer = THAlloc(sizeof(TYPE)*n);                       \ | ||||
|           THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n);      \ | ||||
|           nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle);     \ | ||||
|           THFree(buffer);                                               \ | ||||
|         }                                                               \ | ||||
|         else                                                            \ | ||||
|           nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle);       \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|     else                                                                \ | ||||
|     {                                                                   \ | ||||
|       long i;                                                           \ | ||||
|       for(i = 0; i < n; i++)                                            \ | ||||
|       {                                                                 \ | ||||
|         ASCII_WRITE_ELEM;                                               \ | ||||
|         if( dfself->file.isAutoSpacing && (i < n-1) )                   \ | ||||
|           fprintf(dfself->handle, " ");                                 \ | ||||
|       }                                                                 \ | ||||
|       if(dfself->file.isAutoSpacing && (n > 0))                         \ | ||||
|         fprintf(dfself->handle, "\n");                                  \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     if(nwrite != n)                                                     \ | ||||
|     {                                                                   \ | ||||
|       dfself->file.hasError = 1;                                        \ | ||||
|       if(!dfself->file.isQuiet)                                         \ | ||||
|         THError("write error: wrote %d blocks instead of %d", nwrite, n); \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     return nwrite;                                                      \ | ||||
| } | ||||
|  | ||||
| static int THDiskFile_mode(const char *mode, int *isReadable, int *isWritable) | ||||
| { | ||||
|   *isReadable = 0; | ||||
|   *isWritable = 0; | ||||
|   if(strlen(mode) == 1) | ||||
|   { | ||||
|     if(*mode == 'r') | ||||
|     { | ||||
|       *isReadable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|     else if(*mode == 'w') | ||||
|     { | ||||
|       *isWritable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|   } | ||||
|   else if(strlen(mode) == 2) | ||||
|   { | ||||
|     if(mode[0] == 'r' && mode[1] == 'w') | ||||
|     { | ||||
|       *isReadable = 1; | ||||
|       *isWritable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|   } | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static void THDiskFile_synchronize(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   fflush(dfself->handle); | ||||
| } | ||||
|  | ||||
| static void THDiskFile_seek(THFile *self, long position) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|  | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   THArgCheck(position >= 0, 2, "position must be positive"); | ||||
|  | ||||
|   if(fseek(dfself->handle, position, SEEK_SET) < 0) | ||||
|   { | ||||
|     dfself->file.hasError = 1; | ||||
|     if(!dfself->file.isQuiet) | ||||
|       THError("unable to seek at position %d", position); | ||||
|   } | ||||
| } | ||||
|  | ||||
| static void THDiskFile_seekEnd(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|  | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|  | ||||
|   if(fseek(dfself->handle, 0L, SEEK_END) < 0) | ||||
|   { | ||||
|     dfself->file.hasError = 1; | ||||
|     if(!dfself->file.isQuiet) | ||||
|       THError("unable to seek at end of file"); | ||||
|   } | ||||
| } | ||||
|  | ||||
| static long THDiskFile_position(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   return ftell(dfself->handle); | ||||
| } | ||||
|  | ||||
| static void THDiskFile_close(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   fclose(dfself->handle); | ||||
|   dfself->handle = NULL; | ||||
| } | ||||
|  | ||||
| /* Little and Big Endian */ | ||||
|  | ||||
| static void THDiskFile_reverseMemory(void *dst, const void *src, long blockSize, long numBlocks) | ||||
| { | ||||
|   if(blockSize != 1) | ||||
|   { | ||||
|     long halfBlockSize = blockSize/2; | ||||
|     char *charSrc = (char*)src; | ||||
|     char *charDst = (char*)dst; | ||||
|     long b, i; | ||||
|     for(b = 0; b < numBlocks; b++) | ||||
|     { | ||||
|       for(i = 0; i < halfBlockSize; i++) | ||||
|       { | ||||
|         char z = charSrc[i]; | ||||
|         charDst[i] = charSrc[blockSize-1-i]; | ||||
|         charDst[blockSize-1-i] = z; | ||||
|       } | ||||
|       charSrc += blockSize; | ||||
|       charDst += blockSize; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| int THDiskFile_isLittleEndianCPU(void) | ||||
| { | ||||
|   int x = 7; | ||||
|   char *ptr = (char *)&x; | ||||
|  | ||||
|   if(ptr[0] == 0) | ||||
|     return 0; | ||||
|   else | ||||
|     return 1; | ||||
| } | ||||
|  | ||||
| int THDiskFile_isBigEndianCPU(void) | ||||
| { | ||||
|   return(!THDiskFile_isLittleEndianCPU()); | ||||
| } | ||||
|  | ||||
| void THDiskFile_nativeEndianEncoding(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   dfself->isNativeEncoding = 1; | ||||
| } | ||||
|  | ||||
| void THDiskFile_littleEndianEncoding(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   dfself->isNativeEncoding = THDiskFile_isLittleEndianCPU(); | ||||
| } | ||||
|  | ||||
| void THDiskFile_bigEndianEncoding(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   dfself->isNativeEncoding = !THDiskFile_isLittleEndianCPU(); | ||||
| } | ||||
|  | ||||
| /* End of Little and Big Endian Stuff */ | ||||
|  | ||||
| static void THDiskFile_free(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   if(dfself->handle) | ||||
|     fclose(dfself->handle); | ||||
|   THFree(dfself->name); | ||||
|   THFree(dfself); | ||||
| } | ||||
|  | ||||
| /* READ_WRITE_METHODS(int, Bool, */ | ||||
| /*                    int value = 0; int ret = fscanf(file->handle, "%d", &value); array[i] = (value ? 1 : 0); if(ret <= 0) break; else result++, */ | ||||
| /*                    int value = (array[i] ? 1 : 0); nElemWritten = fprintf(file->handle, "%d", value), */ | ||||
| /*                    true) */ | ||||
|  | ||||
| /* Note that we do a trick */ | ||||
| READ_WRITE_METHODS(unsigned char, Byte, | ||||
|                    nread = fread(data, 1, n, dfself->handle); break, | ||||
|                    nwrite = fwrite(data, 1, n, dfself->handle); break) | ||||
|  | ||||
| READ_WRITE_METHODS(char, Char, | ||||
|                    nread = fread(data, 1, n, dfself->handle); break, | ||||
|                    nwrite = fwrite(data, 1, n, dfself->handle); break) | ||||
|  | ||||
| READ_WRITE_METHODS(short, Short, | ||||
|                    int ret = fscanf(dfself->handle, "%hd", &data[i]); if(ret <= 0) break; else nread++, | ||||
|                    int ret = fprintf(dfself->handle, "%hd", data[i]); if(ret <= 0) break; else nwrite++) | ||||
|  | ||||
| READ_WRITE_METHODS(int, Int, | ||||
|                    int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++, | ||||
|                    int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++) | ||||
|  | ||||
| READ_WRITE_METHODS(long, Long, | ||||
|                    int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++, | ||||
|                    int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++) | ||||
|  | ||||
| READ_WRITE_METHODS(float, Float, | ||||
|                    int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++, | ||||
|                    int ret = fprintf(dfself->handle, "%g", data[i]); if(ret <= 0) break; else nwrite++) | ||||
|  | ||||
| READ_WRITE_METHODS(double, Double, | ||||
|                    int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++, | ||||
|                    int ret = fprintf(dfself->handle, "%lg", data[i]); if(ret <= 0) break; else nwrite++) | ||||
|  | ||||
| static long THDiskFile_readString(THFile *self, const char *format, char **str_) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); | ||||
|   THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'"); | ||||
|  | ||||
| /* note: the string won't survive long, as it is copied into lua */ | ||||
| /* so 1024 is not that big... */ | ||||
| #define TBRS_BSZ 1024L | ||||
|  | ||||
|   if(format[1] == 'a') | ||||
|   { | ||||
|     char *p = THAlloc(TBRS_BSZ); | ||||
|     long total = TBRS_BSZ; | ||||
|     long pos = 0L; | ||||
|      | ||||
|     for (;;) | ||||
|     { | ||||
|       if(total-pos == 0) /* we need more space! */ | ||||
|       { | ||||
|         total += TBRS_BSZ; | ||||
|         p = THRealloc(p, total); | ||||
|       } | ||||
|       pos += fread(p+pos, 1, total-pos, dfself->handle); | ||||
|       if (pos < total) /* eof? */ | ||||
|       { | ||||
|         if(pos == 0L) | ||||
|         { | ||||
|           THFree(p); | ||||
|           dfself->file.hasError = 1; | ||||
|           if(!dfself->file.isQuiet) | ||||
|             THError("read error: read 0 blocks instead of 1"); | ||||
|  | ||||
|           *str_ = NULL; | ||||
|           return 0; | ||||
|         } | ||||
|         *str_ = p; | ||||
|         return pos; | ||||
|       } | ||||
|     }     | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     char *p = THAlloc(TBRS_BSZ); | ||||
|     long total = TBRS_BSZ; | ||||
|     long pos = 0L; | ||||
|     long size; | ||||
|  | ||||
|     for (;;) | ||||
|     { | ||||
|       if(total-pos <= 1) /* we can only write '\0' in there! */ | ||||
|       { | ||||
|         total += TBRS_BSZ; | ||||
|         p = THRealloc(p, total); | ||||
|       } | ||||
|       if (fgets(p+pos, total-pos, dfself->handle) == NULL) /* eof? */ | ||||
|       { | ||||
|         if(pos == 0L) | ||||
|         { | ||||
|           THFree(p); | ||||
|           dfself->file.hasError = 1; | ||||
|           if(!dfself->file.isQuiet) | ||||
|             THError("read error: read 0 blocks instead of 1"); | ||||
|  | ||||
|           *str_ = NULL; | ||||
|           return 0; | ||||
|         } | ||||
|         *str_ = p; | ||||
|         return pos; | ||||
|       } | ||||
|       size = strlen(p+pos); | ||||
|       if (size == 0L || (p+pos)[size-1] != '\n') | ||||
|       { | ||||
|         pos += size; | ||||
|       } | ||||
|       else | ||||
|       { | ||||
|         pos += size-1L; /* do not include `eol' */ | ||||
|         *str_ = p; | ||||
|         return pos; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   *str_ = NULL; | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
|  | ||||
| static long THDiskFile_writeString(THFile *self, const char *str, long size) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   long nwrite; | ||||
|  | ||||
|   THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); | ||||
|   THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); | ||||
|  | ||||
|   nwrite = fwrite(str, 1, size, dfself->handle); | ||||
|   if(nwrite != size) | ||||
|   { | ||||
|     dfself->file.hasError = 1; | ||||
|     if(!dfself->file.isQuiet) | ||||
|       THError("write error: wrote %ld blocks instead of %ld", nwrite, size); | ||||
|   } | ||||
|  | ||||
|   return nwrite; | ||||
| } | ||||
|  | ||||
| THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) | ||||
| { | ||||
|   static struct THFileVTable vtable = { | ||||
|     THDiskFile_isOpened, | ||||
|  | ||||
|     THDiskFile_readByte, | ||||
|     THDiskFile_readChar, | ||||
|     THDiskFile_readShort, | ||||
|     THDiskFile_readInt, | ||||
|     THDiskFile_readLong, | ||||
|     THDiskFile_readFloat, | ||||
|     THDiskFile_readDouble, | ||||
|     THDiskFile_readString, | ||||
|  | ||||
|     THDiskFile_writeByte, | ||||
|     THDiskFile_writeChar, | ||||
|     THDiskFile_writeShort, | ||||
|     THDiskFile_writeInt, | ||||
|     THDiskFile_writeLong, | ||||
|     THDiskFile_writeFloat, | ||||
|     THDiskFile_writeDouble, | ||||
|     THDiskFile_writeString, | ||||
|  | ||||
|     THDiskFile_synchronize, | ||||
|     THDiskFile_seek, | ||||
|     THDiskFile_seekEnd, | ||||
|     THDiskFile_position, | ||||
|     THDiskFile_close, | ||||
|     THDiskFile_free | ||||
|   }; | ||||
|  | ||||
|   int isReadable; | ||||
|   int isWritable; | ||||
|   FILE *handle; | ||||
|   THDiskFile *self; | ||||
|  | ||||
|   THArgCheck(THDiskFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); | ||||
|  | ||||
|   if( isReadable && isWritable ) | ||||
|   { | ||||
|     handle = fopen(name, "r+b"); | ||||
|     if(!handle) | ||||
|     { | ||||
|       handle = fopen(name, "wb"); | ||||
|       if(handle) | ||||
|       { | ||||
|         fclose(handle); | ||||
|         handle = fopen(name, "r+b"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   else | ||||
|     handle = fopen(name, (isReadable ? "rb" : "wb")); | ||||
|  | ||||
|   if(!handle) | ||||
|   { | ||||
|     if(isQuiet) | ||||
|       return 0; | ||||
|     else | ||||
|       THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); | ||||
|   } | ||||
|  | ||||
|   self = THAlloc(sizeof(THDiskFile)); | ||||
|  | ||||
|   self->handle = handle; | ||||
|   self->name = THAlloc(strlen(name)+1); | ||||
|   strcpy(self->name, name); | ||||
|   self->isNativeEncoding = 1; | ||||
|  | ||||
|   self->file.vtable = &vtable; | ||||
|   self->file.isQuiet = isQuiet; | ||||
|   self->file.isReadable = isReadable; | ||||
|   self->file.isWritable = isWritable; | ||||
|   self->file.isBinary = 0; | ||||
|   self->file.isAutoSpacing = 1; | ||||
|   self->file.hasError = 0; | ||||
|  | ||||
|   return (THFile*)self; | ||||
| } | ||||
|  | ||||
| /* PipeFile */ | ||||
|  | ||||
| static int THPipeFile_mode(const char *mode, int *isReadable, int *isWritable) | ||||
| { | ||||
|   *isReadable = 0; | ||||
|   *isWritable = 0; | ||||
|   if(strlen(mode) == 1) | ||||
|   { | ||||
|     if(*mode == 'r') | ||||
|     { | ||||
|       *isReadable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|     else if(*mode == 'w') | ||||
|     { | ||||
|       *isWritable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|   } | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static void THPipeFile_free(THFile *self) | ||||
| { | ||||
|   THDiskFile *dfself = (THDiskFile*)(self); | ||||
|   if(dfself->handle) | ||||
|     pclose(dfself->handle); | ||||
|   THFree(dfself->name); | ||||
|   THFree(dfself); | ||||
| } | ||||
|  | ||||
| THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) | ||||
| { | ||||
|   static struct THFileVTable vtable = { | ||||
|     THDiskFile_isOpened, | ||||
|  | ||||
|     THDiskFile_readByte, | ||||
|     THDiskFile_readChar, | ||||
|     THDiskFile_readShort, | ||||
|     THDiskFile_readInt, | ||||
|     THDiskFile_readLong, | ||||
|     THDiskFile_readFloat, | ||||
|     THDiskFile_readDouble, | ||||
|     THDiskFile_readString, | ||||
|  | ||||
|     THDiskFile_writeByte, | ||||
|     THDiskFile_writeChar, | ||||
|     THDiskFile_writeShort, | ||||
|     THDiskFile_writeInt, | ||||
|     THDiskFile_writeLong, | ||||
|     THDiskFile_writeFloat, | ||||
|     THDiskFile_writeDouble, | ||||
|     THDiskFile_writeString, | ||||
|  | ||||
|     THDiskFile_synchronize, | ||||
|     THDiskFile_seek, | ||||
|     THDiskFile_seekEnd, | ||||
|     THDiskFile_position, | ||||
|     THDiskFile_close, | ||||
|     THPipeFile_free | ||||
|   }; | ||||
|  | ||||
|   int isReadable; | ||||
|   int isWritable; | ||||
|   FILE *handle; | ||||
|   THDiskFile *self; | ||||
|  | ||||
|   THArgCheck(THPipeFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w'"); | ||||
|  | ||||
| #ifdef _WIN32 | ||||
|   handle = popen(name, (isReadable ? "rb" : "wb")); | ||||
| #else | ||||
|   handle = popen(name, (isReadable ? "r" : "w")); | ||||
| #endif | ||||
|  | ||||
|   if(!handle) | ||||
|   { | ||||
|     if(isQuiet) | ||||
|       return 0; | ||||
|     else | ||||
|       THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); | ||||
|   } | ||||
|  | ||||
|   self = THAlloc(sizeof(THDiskFile)); | ||||
|  | ||||
|   self->handle = handle; | ||||
|   self->name = THAlloc(strlen(name)+1); | ||||
|   strcpy(self->name, name); | ||||
|   self->isNativeEncoding = 1; | ||||
|  | ||||
|   self->file.vtable = &vtable; | ||||
|   self->file.isQuiet = isQuiet; | ||||
|   self->file.isReadable = isReadable; | ||||
|   self->file.isWritable = isWritable; | ||||
|   self->file.isBinary = 0; | ||||
|   self->file.isAutoSpacing = 1; | ||||
|   self->file.hasError = 0; | ||||
|  | ||||
|   return (THFile*)self; | ||||
| } | ||||
							
								
								
									
										17
									
								
								lib/TH/THDiskFile.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								lib/TH/THDiskFile.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | ||||
| #ifndef TH_DISK_FILE_INC | ||||
| #define TH_DISK_FILE_INC | ||||
|  | ||||
| #include "THFile.h" | ||||
|  | ||||
| THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet); | ||||
| THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet); | ||||
|  | ||||
| const char *THDiskFile_name(THFile *self); | ||||
|  | ||||
| int THDiskFile_isLittleEndianCPU(void); | ||||
| int THDiskFile_isBigEndianCPU(void); | ||||
| void THDiskFile_nativeEndianEncoding(THFile *self); | ||||
| void THDiskFile_littleEndianEncoding(THFile *self); | ||||
| void THDiskFile_bigEndianEncoding(THFile *self); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										154
									
								
								lib/TH/THFile.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								lib/TH/THFile.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,154 @@ | ||||
| #include "THFile.h" | ||||
| #include "THFilePrivate.h" | ||||
|  | ||||
| #define IMPLEMENT_THFILE_RW(TYPEC, TYPE)                          \ | ||||
|   long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n)  \ | ||||
|   {                                                               \ | ||||
|     return (*self->vtable->read##TYPEC)(self, data, n);           \ | ||||
|   }                                                               \ | ||||
|                                                                   \ | ||||
|   long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \ | ||||
|   {                                                               \ | ||||
|     return (*self->vtable->write##TYPEC)(self, data, n);          \ | ||||
|   } | ||||
|    | ||||
| IMPLEMENT_THFILE_RW(Byte, unsigned char) | ||||
| IMPLEMENT_THFILE_RW(Char, char) | ||||
| IMPLEMENT_THFILE_RW(Short, short) | ||||
| IMPLEMENT_THFILE_RW(Int, int) | ||||
| IMPLEMENT_THFILE_RW(Long, long) | ||||
| IMPLEMENT_THFILE_RW(Float, float) | ||||
| IMPLEMENT_THFILE_RW(Double, double) | ||||
|  | ||||
| long THFile_readStringRaw(THFile *self, const char *format, char **str_) | ||||
| { | ||||
|   return self->vtable->readString(self, format, str_); | ||||
| } | ||||
|  | ||||
| long THFile_writeStringRaw(THFile *self, const char *str, long size) | ||||
| { | ||||
|   return self->vtable->writeString(self, str, size); | ||||
| } | ||||
|  | ||||
| void THFile_synchronize(THFile *self) | ||||
| { | ||||
|   self->vtable->synchronize(self); | ||||
| } | ||||
|  | ||||
| void THFile_seek(THFile *self, long position) | ||||
| { | ||||
|   self->vtable->seek(self, position); | ||||
| } | ||||
|  | ||||
| void THFile_seekEnd(THFile *self) | ||||
| { | ||||
|   self->vtable->seekEnd(self); | ||||
| } | ||||
|  | ||||
| long THFile_position(THFile *self) | ||||
| { | ||||
|   return self->vtable->position(self); | ||||
| } | ||||
|  | ||||
| void THFile_close(THFile *self) | ||||
| { | ||||
|   self->vtable->close(self); | ||||
| } | ||||
|  | ||||
| void THFile_free(THFile *self) | ||||
| { | ||||
|   self->vtable->free(self); | ||||
| } | ||||
|  | ||||
| int THFile_isOpened(THFile *self) | ||||
| { | ||||
|   return self->vtable->isOpened(self); | ||||
| } | ||||
|  | ||||
| #define IMPLEMENT_THFILE_FLAGS(FLAG) \ | ||||
|   int THFile_##FLAG(THFile *self)    \ | ||||
|   {                                  \ | ||||
|     return self->FLAG;               \ | ||||
|   } | ||||
|  | ||||
| IMPLEMENT_THFILE_FLAGS(isQuiet) | ||||
| IMPLEMENT_THFILE_FLAGS(isReadable) | ||||
| IMPLEMENT_THFILE_FLAGS(isWritable) | ||||
| IMPLEMENT_THFILE_FLAGS(isBinary) | ||||
| IMPLEMENT_THFILE_FLAGS(isAutoSpacing) | ||||
| IMPLEMENT_THFILE_FLAGS(hasError) | ||||
|  | ||||
| void THFile_binary(THFile *self) | ||||
| { | ||||
|   self->isBinary = 1; | ||||
| } | ||||
|  | ||||
| void THFile_ascii(THFile *self) | ||||
| { | ||||
|   self->isBinary = 0; | ||||
| } | ||||
|  | ||||
| void THFile_autoSpacing(THFile *self) | ||||
| { | ||||
|   self->isAutoSpacing = 1; | ||||
| } | ||||
|  | ||||
| void THFile_noAutoSpacing(THFile *self) | ||||
| { | ||||
|   self->isAutoSpacing = 0; | ||||
| } | ||||
|  | ||||
| void THFile_quiet(THFile *self) | ||||
| { | ||||
|   self->isQuiet = 1; | ||||
| } | ||||
|  | ||||
| void THFile_pedantic(THFile *self) | ||||
| { | ||||
|   self->isQuiet = 0; | ||||
| } | ||||
|  | ||||
| void THFile_clearError(THFile *self) | ||||
| { | ||||
|   self->hasError = 0; | ||||
| } | ||||
|  | ||||
| #define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE)                  \ | ||||
|   TYPE THFile_read##TYPEC##Scalar(THFile *self)               \ | ||||
|   {                                                           \ | ||||
|     TYPE scalar;                                              \ | ||||
|     THFile_read##TYPEC##Raw(self, &scalar, 1);                \ | ||||
|     return scalar;                                            \ | ||||
|   }                                                           \ | ||||
|                                                               \ | ||||
|   void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \ | ||||
|   {                                                           \ | ||||
|     THFile_write##TYPEC##Raw(self, &scalar, 1);               \ | ||||
|   } | ||||
|  | ||||
| IMPLEMENT_THFILE_SCALAR(Byte, unsigned char) | ||||
| IMPLEMENT_THFILE_SCALAR(Char, char) | ||||
| IMPLEMENT_THFILE_SCALAR(Short, short) | ||||
| IMPLEMENT_THFILE_SCALAR(Int, int) | ||||
| IMPLEMENT_THFILE_SCALAR(Long, long) | ||||
| IMPLEMENT_THFILE_SCALAR(Float, float) | ||||
| IMPLEMENT_THFILE_SCALAR(Double, double) | ||||
|  | ||||
| #define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE)                           \ | ||||
|   long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage)    \ | ||||
|   {                                                                     \ | ||||
|     return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \ | ||||
|   }                                                                     \ | ||||
|                                                                         \ | ||||
|   long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage)   \ | ||||
|   {                                                                     \ | ||||
|     return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \ | ||||
|   } | ||||
|  | ||||
| IMPLEMENT_THFILE_STORAGE(Byte, unsigned char) | ||||
| IMPLEMENT_THFILE_STORAGE(Char, char) | ||||
| IMPLEMENT_THFILE_STORAGE(Short, short) | ||||
| IMPLEMENT_THFILE_STORAGE(Int, int) | ||||
| IMPLEMENT_THFILE_STORAGE(Long, long) | ||||
| IMPLEMENT_THFILE_STORAGE(Float, float) | ||||
| IMPLEMENT_THFILE_STORAGE(Double, double) | ||||
							
								
								
									
										84
									
								
								lib/TH/THFile.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								lib/TH/THFile.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,84 @@ | ||||
| #ifndef TH_FILE_INC | ||||
| #define TH_FILE_INC | ||||
|  | ||||
| #include "THStorage.h" | ||||
|  | ||||
| typedef struct THFile__ THFile; | ||||
|  | ||||
| int THFile_isOpened(THFile *self); | ||||
| int THFile_isQuiet(THFile *self); | ||||
| int THFile_isReadable(THFile *self); | ||||
| int THFile_isWritable(THFile *self); | ||||
| int THFile_isBinary(THFile *self); | ||||
| int THFile_isAutoSpacing(THFile *self); | ||||
| int THFile_hasError(THFile *self); | ||||
|  | ||||
| void THFile_binary(THFile *self); | ||||
| void THFile_ascii(THFile *self); | ||||
| void THFile_autoSpacing(THFile *self); | ||||
| void THFile_noAutoSpacing(THFile *self); | ||||
| void THFile_quiet(THFile *self); | ||||
| void THFile_pedantic(THFile *self); | ||||
| void THFile_clearError(THFile *self); | ||||
|  | ||||
| /* scalar */ | ||||
| unsigned char THFile_readByteScalar(THFile *self); | ||||
| char THFile_readCharScalar(THFile *self); | ||||
| short THFile_readShortScalar(THFile *self); | ||||
| int THFile_readIntScalar(THFile *self); | ||||
| long THFile_readLongScalar(THFile *self); | ||||
| float THFile_readFloatScalar(THFile *self); | ||||
| double THFile_readDoubleScalar(THFile *self); | ||||
|  | ||||
| void THFile_writeByteScalar(THFile *self, unsigned char scalar); | ||||
| void THFile_writeCharScalar(THFile *self, char scalar); | ||||
| void THFile_writeShortScalar(THFile *self, short scalar); | ||||
| void THFile_writeIntScalar(THFile *self, int scalar); | ||||
| void THFile_writeLongScalar(THFile *self, long scalar); | ||||
| void THFile_writeFloatScalar(THFile *self, float scalar); | ||||
| void THFile_writeDoubleScalar(THFile *self, double scalar); | ||||
|  | ||||
| /* storage */ | ||||
| long THFile_readByte(THFile *self, THByteStorage *storage); | ||||
| long THFile_readChar(THFile *self, THCharStorage *storage); | ||||
| long THFile_readShort(THFile *self, THShortStorage *storage); | ||||
| long THFile_readInt(THFile *self, THIntStorage *storage); | ||||
| long THFile_readLong(THFile *self, THLongStorage *storage); | ||||
| long THFile_readFloat(THFile *self, THFloatStorage *storage); | ||||
| long THFile_readDouble(THFile *self, THDoubleStorage *storage); | ||||
|  | ||||
| long THFile_writeByte(THFile *self, THByteStorage *storage); | ||||
| long THFile_writeChar(THFile *self, THCharStorage *storage); | ||||
| long THFile_writeShort(THFile *self, THShortStorage *storage); | ||||
| long THFile_writeInt(THFile *self, THIntStorage *storage); | ||||
| long THFile_writeLong(THFile *self, THLongStorage *storage); | ||||
| long THFile_writeFloat(THFile *self, THFloatStorage *storage); | ||||
| long THFile_writeDouble(THFile *self, THDoubleStorage *storage); | ||||
|  | ||||
| /* raw */ | ||||
| long THFile_readByteRaw(THFile *self, unsigned char *data, long n); | ||||
| long THFile_readCharRaw(THFile *self, char *data, long n); | ||||
| long THFile_readShortRaw(THFile *self, short *data, long n); | ||||
| long THFile_readIntRaw(THFile *self, int *data, long n); | ||||
| long THFile_readLongRaw(THFile *self, long *data, long n); | ||||
| long THFile_readFloatRaw(THFile *self, float *data, long n); | ||||
| long THFile_readDoubleRaw(THFile *self, double *data, long n); | ||||
| long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */ | ||||
|  | ||||
| long THFile_writeByteRaw(THFile *self, unsigned char *data, long n); | ||||
| long THFile_writeCharRaw(THFile *self, char *data, long n); | ||||
| long THFile_writeShortRaw(THFile *self, short *data, long n); | ||||
| long THFile_writeIntRaw(THFile *self, int *data, long n); | ||||
| long THFile_writeLongRaw(THFile *self, long *data, long n); | ||||
| long THFile_writeFloatRaw(THFile *self, float *data, long n); | ||||
| long THFile_writeDoubleRaw(THFile *self, double *data, long n); | ||||
| long THFile_writeStringRaw(THFile *self, const char *str, long size); | ||||
|  | ||||
| void THFile_synchronize(THFile *self); | ||||
| void THFile_seek(THFile *self, long position); | ||||
| void THFile_seekEnd(THFile *self); | ||||
| long THFile_position(THFile *self); | ||||
| void THFile_close(THFile *self); | ||||
| void THFile_free(THFile *self); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										43
									
								
								lib/TH/THFilePrivate.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								lib/TH/THFilePrivate.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| struct THFile__ | ||||
| { | ||||
|     struct THFileVTable *vtable; | ||||
|  | ||||
|     int isQuiet; | ||||
|     int isReadable; | ||||
|     int isWritable; | ||||
|     int isBinary; | ||||
|     int isAutoSpacing; | ||||
|     int hasError; | ||||
| }; | ||||
|  | ||||
| /* virtual table definition */ | ||||
|  | ||||
| struct THFileVTable | ||||
| { | ||||
|     int (*isOpened)(THFile *self); | ||||
|  | ||||
|     long (*readByte)(THFile *self, unsigned char *data, long n); | ||||
|     long (*readChar)(THFile *self, char *data, long n); | ||||
|     long (*readShort)(THFile *self, short *data, long n); | ||||
|     long (*readInt)(THFile *self, int *data, long n); | ||||
|     long (*readLong)(THFile *self, long *data, long n); | ||||
|     long (*readFloat)(THFile *self, float *data, long n); | ||||
|     long (*readDouble)(THFile *self, double *data, long n); | ||||
|     long (*readString)(THFile *self, const char *format, char **str_); | ||||
|  | ||||
|     long (*writeByte)(THFile *self, unsigned char *data, long n); | ||||
|     long (*writeChar)(THFile *self, char *data, long n); | ||||
|     long (*writeShort)(THFile *self, short *data, long n); | ||||
|     long (*writeInt)(THFile *self, int *data, long n); | ||||
|     long (*writeLong)(THFile *self, long *data, long n); | ||||
|     long (*writeFloat)(THFile *self, float *data, long n); | ||||
|     long (*writeDouble)(THFile *self, double *data, long n); | ||||
|     long (*writeString)(THFile *self, const char *str, long size); | ||||
|  | ||||
|     void (*synchronize)(THFile *self); | ||||
|     void (*seek)(THFile *self, long position); | ||||
|     void (*seekEnd)(THFile *self); | ||||
|     long (*position)(THFile *self); | ||||
|     void (*close)(THFile *self); | ||||
|     void (*free)(THFile *self); | ||||
| }; | ||||
							
								
								
									
										110
									
								
								lib/TH/THGeneral.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								lib/TH/THGeneral.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,110 @@ | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| /* Torch Error Handling */ | ||||
| static void defaultTorchErrorHandlerFunction(const char *msg) | ||||
| { | ||||
|   printf("$ Error: %s\n", msg); | ||||
|   exit(-1); | ||||
| } | ||||
|  | ||||
| static void (*torchErrorHandlerFunction)(const char *msg) = defaultTorchErrorHandlerFunction; | ||||
|  | ||||
| void THError(const char *fmt, ...) | ||||
| { | ||||
|   static char msg[1024]; | ||||
|   va_list args; | ||||
|  | ||||
|   /* vasprintf not standard */ | ||||
|   /* vsnprintf: how to handle if does not exists? */ | ||||
|   va_start(args, fmt); | ||||
|   vsnprintf(msg, 1024, fmt, args); | ||||
|   va_end(args); | ||||
|  | ||||
|   (*torchErrorHandlerFunction)(msg); | ||||
| } | ||||
|  | ||||
| void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg) ) | ||||
| { | ||||
|   if(torchErrorHandlerFunction_) | ||||
|     torchErrorHandlerFunction = torchErrorHandlerFunction_; | ||||
|   else | ||||
|     torchErrorHandlerFunction = defaultTorchErrorHandlerFunction; | ||||
| } | ||||
|  | ||||
| /* Torch Arg Checking Handling */ | ||||
| static void defaultTorchArgCheckHandlerFunction(int condition, int argNumber, const char *msg) | ||||
| { | ||||
|   if(!condition) | ||||
|   { | ||||
|     if(msg) | ||||
|       printf("$ Invalid argument %d: %s\n", argNumber, msg); | ||||
|     else | ||||
|       printf("$ Invalid argument %d\n", argNumber); | ||||
|     exit(-1); | ||||
|   } | ||||
| } | ||||
| static void (*torchArgCheckHandlerFunction)(int condition, int argNumber, const char *msg) = defaultTorchArgCheckHandlerFunction; | ||||
|  | ||||
| void THArgCheck(int condition, int argNumber, const char *msg) | ||||
| { | ||||
|   (*torchArgCheckHandlerFunction)(condition, argNumber, msg); | ||||
| } | ||||
|  | ||||
| void THSetArgCheckHandler( void (*torchArgCheckHandlerFunction_)(int condition, int argNumber, const char *msg) ) | ||||
| { | ||||
|   if(torchArgCheckHandlerFunction_) | ||||
|     torchArgCheckHandlerFunction = torchArgCheckHandlerFunction_; | ||||
|   else | ||||
|     torchArgCheckHandlerFunction = defaultTorchArgCheckHandlerFunction; | ||||
| } | ||||
|  | ||||
| void* THAlloc(long size) | ||||
| { | ||||
|   void *ptr; | ||||
|  | ||||
|   if(size < 0) | ||||
|     THError("$ Torch: invalid memory size -- maybe an overflow?"); | ||||
|  | ||||
|   if(size == 0) | ||||
|     return NULL; | ||||
|  | ||||
|   ptr = malloc(size); | ||||
|   if(!ptr) | ||||
|     THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); | ||||
|  | ||||
|   return ptr; | ||||
| } | ||||
|  | ||||
| void* THRealloc(void *ptr, long size) | ||||
| { | ||||
|   if(!ptr) | ||||
|     return(THAlloc(size)); | ||||
|    | ||||
|   if(size == 0) | ||||
|   { | ||||
|     THFree(ptr); | ||||
|     return NULL; | ||||
|   } | ||||
|  | ||||
|   if(size < 0) | ||||
|     THError("$ Torch: invalid memory size -- maybe an overflow?"); | ||||
|  | ||||
|   ptr = realloc(ptr, size); | ||||
|   if(!ptr) | ||||
|     THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); | ||||
|   return ptr; | ||||
| } | ||||
|  | ||||
| void THFree(void *ptr) | ||||
| { | ||||
|   free(ptr); | ||||
| } | ||||
|  | ||||
| #ifdef _MSC_VER | ||||
| double log1p(const double x) | ||||
| { | ||||
|   volatile double y; | ||||
|   y = 1 + x; | ||||
|   return log(y) - ((y-1)-x)/y ;  /* cancels errors with IEEE arithmetic */ | ||||
| } | ||||
| #endif | ||||
							
								
								
									
										72
									
								
								lib/TH/THGeneral.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								lib/TH/THGeneral.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | ||||
| #ifndef TH_GENERAL_INC | ||||
| #define TH_GENERAL_INC | ||||
|  | ||||
| #include <stdlib.h> | ||||
| #include <stdio.h> | ||||
| #include <stdarg.h> | ||||
| #include <math.h> | ||||
| #include <limits.h> | ||||
| #include <float.h> | ||||
| #include <time.h> | ||||
| #include <string.h> | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| # define TH_EXTERNC extern "C" | ||||
| #else | ||||
| # define TH_EXTERNC extern | ||||
| #endif | ||||
|  | ||||
| #ifdef WIN32 | ||||
| # ifdef TH_EXPORTS | ||||
| #  define TH_API TH_EXTERNC __declspec(dllexport) | ||||
| # else | ||||
| #  define TH_API TH_EXTERNC __declspec(dllimport) | ||||
| # endif | ||||
| #else | ||||
| # define TH_API TH_EXTERNC | ||||
| #endif | ||||
|  | ||||
| #define THInf DBL_MAX | ||||
|  | ||||
| #if !defined(inline) | ||||
| # define inline | ||||
| #endif | ||||
|  | ||||
| #ifndef M_PI | ||||
| # define M_PI 3.14159265358979323846 | ||||
| #endif | ||||
|  | ||||
| #ifdef _MSC_VER | ||||
| TH_API double log1p(const double x); | ||||
| #endif | ||||
|  | ||||
| TH_API void THError(const char *fmt, ...); | ||||
| TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg) ); | ||||
| TH_API void THArgCheck(int condition, int argNumber, const char *msg); | ||||
| TH_API void THSetArgCheckHandler( void (*torchArgCheckHandlerFunction)(int condition, int argNumber, const char *msg) ); | ||||
| TH_API void* THAlloc(long size); | ||||
| TH_API void* THRealloc(void *ptr, long size); | ||||
| TH_API void THFree(void *ptr); | ||||
|  | ||||
| #define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y) | ||||
| #define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y | ||||
|  | ||||
| #define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z) | ||||
| #define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z | ||||
|  | ||||
| #define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w) | ||||
| #define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w | ||||
|  | ||||
| #define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y) | ||||
| #define TH_CONCAT_2_EXPAND(x,y) x ## y | ||||
|  | ||||
| #define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z) | ||||
| #define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z | ||||
|  | ||||
| #define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w | ||||
| #define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w) | ||||
|  | ||||
| #define THMin(X, Y)  ((X) < (Y) ? (X) : (Y)) | ||||
| #define THMax(X, Y)  ((X) > (Y) ? (X) : (Y)) | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										83
									
								
								lib/TH/THGenerateAllTypes.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								lib/TH/THGenerateAllTypes.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,83 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" | ||||
| #endif | ||||
|  | ||||
| #define real unsigned char | ||||
| #define accreal long | ||||
| #define Real Byte | ||||
| #define TH_REAL_IS_BYTE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| /*#line 1 "THByteStorage.h"*/ | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_BYTE | ||||
|  | ||||
| #define real char | ||||
| #define accreal long | ||||
| #define Real Char | ||||
| #define TH_REAL_IS_CHAR | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_CHAR | ||||
|  | ||||
| #define real short | ||||
| #define accreal long | ||||
| #define Real Short | ||||
| #define TH_REAL_IS_SHORT | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_SHORT | ||||
|  | ||||
| #define real int | ||||
| #define accreal long | ||||
| #define Real Int | ||||
| #define TH_REAL_IS_INT | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_INT | ||||
|  | ||||
| #define real long | ||||
| #define accreal long | ||||
| #define Real Long | ||||
| #define TH_REAL_IS_LONG | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_LONG | ||||
|  | ||||
| #define real float | ||||
| #define accreal double | ||||
| #define Real Float | ||||
| #define TH_REAL_IS_FLOAT | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_FLOAT | ||||
|  | ||||
| #define real double | ||||
| #define accreal double | ||||
| #define Real Double | ||||
| #define TH_REAL_IS_DOUBLE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_DOUBLE | ||||
|  | ||||
| #undef TH_GENERIC_FILE | ||||
							
								
								
									
										27
									
								
								lib/TH/THGenerateFloatTypes.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								lib/TH/THGenerateFloatTypes.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" | ||||
| #endif | ||||
|  | ||||
| #define real float | ||||
| #define accreal double | ||||
| #define Real Float | ||||
| #define TH_REAL_IS_FLOAT | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef accreal | ||||
| #undef real | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_FLOAT | ||||
|  | ||||
| #define real double | ||||
| #define accreal double | ||||
| #define Real Double | ||||
| #define TH_REAL_IS_DOUBLE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef accreal | ||||
| #undef real | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_DOUBLE | ||||
|  | ||||
| #undef TH_GENERIC_FILE | ||||
							
								
								
									
										60
									
								
								lib/TH/THGenerateIntTypes.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								lib/TH/THGenerateIntTypes.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #error "You must define TH_GENERIC_FILE before including THGenerateIntTypes.h" | ||||
| #endif | ||||
|  | ||||
| #define real unsigned char | ||||
| #define accreal long | ||||
| #define Real Byte | ||||
| #define TH_REAL_IS_BYTE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_BYTE | ||||
|  | ||||
| #define real char | ||||
| #define accreal long | ||||
| #define Real Char | ||||
| #define TH_REAL_IS_CHAR | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_CHAR | ||||
|  | ||||
| #define real short | ||||
| #define accreal long | ||||
| #define Real Short | ||||
| #define TH_REAL_IS_SHORT | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_SHORT | ||||
|  | ||||
| #define real int | ||||
| #define accreal long | ||||
| #define Real Int | ||||
| #define TH_REAL_IS_INT | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_INT | ||||
|  | ||||
| #define real long | ||||
| #define accreal long | ||||
| #define Real Long | ||||
| #define TH_REAL_IS_LONG | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| #undef real | ||||
| #undef accreal | ||||
| #undef Real | ||||
| #undef TH_REAL_IS_LONG | ||||
|  | ||||
| #undef TH_GENERIC_FILE | ||||
							
								
								
									
										5
									
								
								lib/TH/THLapack.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								lib/TH/THLapack.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | ||||
| #include "THLapack.h" | ||||
|  | ||||
| /* #include "THCBlas.h" */ | ||||
| #include "generic/THLapack.c" | ||||
| #include "THGenerateFloatTypes.h" | ||||
							
								
								
									
										11
									
								
								lib/TH/THLapack.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								lib/TH/THLapack.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| #ifndef TH_LAPACK_INC | ||||
| #define TH_LAPACK_INC | ||||
|  | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| #define THLapack_(NAME) TH_CONCAT_4(TH,Real,Lapack_,NAME) | ||||
|  | ||||
| #include "generic/THLapack.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										86
									
								
								lib/TH/THLogAdd.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								lib/TH/THLogAdd.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,86 @@ | ||||
| #include "THLogAdd.h" | ||||
|  | ||||
| #ifdef USE_DOUBLE | ||||
| #define MINUS_LOG_THRESHOLD -39.14 | ||||
| #else | ||||
| #define MINUS_LOG_THRESHOLD -18.42 | ||||
| #endif | ||||
|  | ||||
| const double THLog2Pi=1.83787706640934548355; | ||||
| const double THLogZero=-THInf; | ||||
| const double THLogOne=0; | ||||
|  | ||||
| double THLogAdd(double log_a, double log_b) | ||||
| { | ||||
|   double minusdif; | ||||
|  | ||||
|   if (log_a < log_b) | ||||
|   { | ||||
|     double tmp = log_a; | ||||
|     log_a = log_b; | ||||
|     log_b = tmp; | ||||
|   } | ||||
|  | ||||
|   minusdif = log_b - log_a; | ||||
| #ifdef DEBUG | ||||
|   if (isnan(minusdif)) | ||||
|     THError("THLogAdd: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a); | ||||
| #endif | ||||
|   if (minusdif < MINUS_LOG_THRESHOLD) | ||||
|     return log_a; | ||||
|   else | ||||
|     return log_a + log1p(exp(minusdif)); | ||||
| } | ||||
|  | ||||
| double THLogSub(double log_a, double log_b) | ||||
| { | ||||
|   double minusdif; | ||||
|  | ||||
|   if (log_a < log_b) | ||||
|     THError("LogSub: log_a (%f) should be greater than log_b (%f)", log_a, log_b); | ||||
|  | ||||
|   minusdif = log_b - log_a; | ||||
| #ifdef DEBUG | ||||
|   if (isnan(minusdif)) | ||||
|     THError("LogSub: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a); | ||||
| #endif | ||||
|   if (log_a == log_b) | ||||
|     return THLogZero; | ||||
|   else if (minusdif < MINUS_LOG_THRESHOLD) | ||||
|     return log_a; | ||||
|   else | ||||
|     return log_a + log1p(-exp(minusdif)); | ||||
| } | ||||
|  | ||||
| /* Credits to Leon Bottou */ | ||||
| double THExpMinusApprox(double x) | ||||
| { | ||||
| #define EXACT_EXPONENTIAL 0 | ||||
| #if EXACT_EXPONENTIAL | ||||
|   return exp(-x); | ||||
| #else | ||||
|   /* fast approximation of exp(-x) for x positive */ | ||||
| # define A0   (1.0) | ||||
| # define A1   (0.125) | ||||
| # define A2   (0.0078125) | ||||
| # define A3   (0.00032552083) | ||||
| # define A4   (1.0172526e-5) | ||||
|   if (x < 13.0) | ||||
|   { | ||||
| /*    assert(x>=0); */ | ||||
|     double y; | ||||
|     y = A0+x*(A1+x*(A2+x*(A3+x*A4))); | ||||
|     y *= y; | ||||
|     y *= y; | ||||
|     y *= y; | ||||
|     y = 1/y; | ||||
|     return y; | ||||
|   } | ||||
|   return 0; | ||||
| # undef A0 | ||||
| # undef A1 | ||||
| # undef A2 | ||||
| # undef A3 | ||||
| # undef A4 | ||||
| #endif | ||||
| } | ||||
							
								
								
									
										14
									
								
								lib/TH/THLogAdd.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								lib/TH/THLogAdd.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| #ifndef TH_LOG_ADD_INC | ||||
| #define TH_LOG_ADD_INC | ||||
|  | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| TH_API const double THLog2Pi; | ||||
| TH_API const double THLogZero; | ||||
| TH_API const double THLogOne; | ||||
|  | ||||
| TH_API double THLogAdd(double log_a, double log_b); | ||||
| TH_API double THLogSub(double log_a, double log_b); | ||||
| TH_API double THExpMinusApprox(const double x); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										492
									
								
								lib/TH/THMemoryFile.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										492
									
								
								lib/TH/THMemoryFile.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,492 @@ | ||||
| #include "THMemoryFile.h" | ||||
| #include "THFilePrivate.h" | ||||
|  | ||||
| typedef struct THMemoryFile__ | ||||
| { | ||||
|     THFile file; | ||||
|     THCharStorage *storage; | ||||
|     long size; | ||||
|     long position; | ||||
|  | ||||
| } THMemoryFile; | ||||
|  | ||||
| static int THMemoryFile_isOpened(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|   return (mfself->storage != NULL); | ||||
| } | ||||
|  | ||||
| static char *THMemoryFile_strnextspace(char *str_, char *c_) | ||||
| { | ||||
|   char c; | ||||
|  | ||||
|   while( (c = *str_) ) | ||||
|   { | ||||
|     if( (c != ' ') && (c != '\n') && (c != ':') && (c != ';') ) | ||||
|       break; | ||||
|     str_++; | ||||
|   } | ||||
|  | ||||
|   while( (c = *str_) ) | ||||
|   { | ||||
|     if( (c == ' ') || (c == '\n') || (c == ':') || (c == ';') ) | ||||
|     { | ||||
|       *c_ = c; | ||||
|       *str_ = '\0'; | ||||
|       return(str_); | ||||
|     } | ||||
|     str_++; | ||||
|   } | ||||
|   return NULL; | ||||
| } | ||||
|  | ||||
| static void THMemoryFile_grow(THMemoryFile *self, long size) | ||||
| { | ||||
|   long missingSpace; | ||||
|  | ||||
|   if(size <= self->size) | ||||
|     return; | ||||
|   else | ||||
|   { | ||||
|     if(size < self->storage->size) /* note the "<" and not "<=" */ | ||||
|     { | ||||
|       self->size = size; | ||||
|       self->storage->data[self->size] = '\0'; | ||||
|       return; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   missingSpace = size-self->storage->size+1; /* +1 for the '\0' */ | ||||
|   THCharStorage_resize(self->storage, (self->storage->size/2 > missingSpace ? | ||||
|                                        self->storage->size + (self->storage->size/2) | ||||
|                                        : self->storage->size + missingSpace)); | ||||
| } | ||||
|  | ||||
| static int THMemoryFile_mode(const char *mode, int *isReadable, int *isWritable) | ||||
| { | ||||
|   *isReadable = 0; | ||||
|   *isWritable = 0; | ||||
|   if(strlen(mode) == 1) | ||||
|   { | ||||
|     if(*mode == 'r') | ||||
|     { | ||||
|       *isReadable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|     else if(*mode == 'w') | ||||
|     { | ||||
|       *isWritable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|   } | ||||
|   else if(strlen(mode) == 2) | ||||
|   { | ||||
|     if(mode[0] == 'r' && mode[1] == 'w') | ||||
|     { | ||||
|       *isReadable = 1; | ||||
|       *isWritable = 1; | ||||
|       return 1; | ||||
|     } | ||||
|   } | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| /********************************************************/ | ||||
|  | ||||
| #define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM, INSIDE_SPACING) \ | ||||
|   static long THMemoryFile_read##TYPEC(THFile *self, TYPE *data, long n) \ | ||||
|   {                                                                     \ | ||||
|     THMemoryFile *mfself = (THMemoryFile*)self;                         \ | ||||
|     long nread = 0L;                                                    \ | ||||
|                                                                         \ | ||||
|     THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");     \ | ||||
|     THArgCheck(mfself->file.isReadable, 1, "attempt to read in a write-only file"); \ | ||||
|                                                                         \ | ||||
|     if(mfself->file.isBinary)                                           \ | ||||
|     {                                                                   \ | ||||
|       long nByte = sizeof(TYPE)*n;                                      \ | ||||
|       long nByteRemaining = (mfself->position + nByte <= mfself->size ? nByte : mfself->size-mfself->position); \ | ||||
|       nread = nByteRemaining/sizeof(TYPE);                              \ | ||||
|       memmove(data, mfself->storage->data+mfself->position, nread*sizeof(TYPE)); \ | ||||
|       mfself->position += nread*sizeof(TYPE);                           \ | ||||
|     }                                                                   \ | ||||
|     else                                                                \ | ||||
|     {                                                                   \ | ||||
|       long i;                                                           \ | ||||
|       for(i = 0; i < n; i++)                                            \ | ||||
|       {                                                                 \ | ||||
|         long nByteRead = 0;                                             \ | ||||
|         char spaceChar = 0;                                             \ | ||||
|         char *spacePtr = THMemoryFile_strnextspace(mfself->storage->data+mfself->position, &spaceChar); \ | ||||
|         ASCII_READ_ELEM;                                                \ | ||||
|         if(ret == EOF)                                                  \ | ||||
|         {                                                               \ | ||||
|           while(mfself->storage->data[mfself->position])                \ | ||||
|             mfself->position++;                                         \ | ||||
|         }                                                               \ | ||||
|         else                                                            \ | ||||
|           mfself->position += nByteRead;                                \ | ||||
|         if(spacePtr)                                                    \ | ||||
|           *spacePtr = spaceChar;                                        \ | ||||
|       }                                                                 \ | ||||
|       if(mfself->file.isAutoSpacing && (n > 0))                         \ | ||||
|       {                                                                 \ | ||||
|         if( (mfself->position < mfself->size) && (mfself->storage->data[mfself->position] == '\n') ) \ | ||||
|           mfself->position++;                                           \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     if(nread != n)                                                      \ | ||||
|     {                                                                   \ | ||||
|       mfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \ | ||||
|       if(!mfself->file.isQuiet)                                         \ | ||||
|         THError("read error: read %d blocks instead of %d", nread, n);  \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     return nread;                                                       \ | ||||
|   }                                                                     \ | ||||
|                                                                         \ | ||||
|   static long THMemoryFile_write##TYPEC(THFile *self, TYPE *data, long n) \ | ||||
|   {                                                                     \ | ||||
|     THMemoryFile *mfself = (THMemoryFile*)self;                         \ | ||||
|     long nread = 0L;                                                    \ | ||||
|                                                                         \ | ||||
|     THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");     \ | ||||
|     THArgCheck(mfself->file.isWritable, 1, "attempt to write in a read-only file"); \ | ||||
|                                                                         \ | ||||
|     if(mfself->file.isBinary)                                           \ | ||||
|     {                                                                   \ | ||||
|       long nByte = sizeof(TYPE)*n;                                      \ | ||||
|       THMemoryFile_grow(mfself, mfself->position+nByte);                \ | ||||
|       memmove(mfself->storage->data+mfself->position, data, nByte);     \ | ||||
|       mfself->position += nByte;                                        \ | ||||
|       if(mfself->position > mfself->size)                               \ | ||||
|       {                                                                 \ | ||||
|         mfself->size = mfself->position;                                \ | ||||
|         mfself->storage->data[mfself->size] = '\0';                     \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|     else                                                                \ | ||||
|     {                                                                   \ | ||||
|       long i;                                                           \ | ||||
|       for(i = 0; i < n; i++)                                            \ | ||||
|       {                                                                 \ | ||||
|         long nByteWritten;                                              \ | ||||
|         while (1)                                                       \ | ||||
|         {                                                               \ | ||||
|           ASCII_WRITE_ELEM;                                             \ | ||||
|           if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size-mfself->position) ) \ | ||||
|           {                                                             \ | ||||
|             mfself->position += nByteWritten;                           \ | ||||
|             break;                                                      \ | ||||
|           }                                                             \ | ||||
|           THMemoryFile_grow(mfself, mfself->storage->size + (mfself->storage->size/2) + 2); \ | ||||
|         }                                                               \ | ||||
|         if(mfself->file.isAutoSpacing)                                  \ | ||||
|         {                                                               \ | ||||
|           if(i < n-1)                                                   \ | ||||
|           {                                                             \ | ||||
|             THMemoryFile_grow(mfself, mfself->position+1);              \ | ||||
|             sprintf(mfself->storage->data+mfself->position, " ");       \ | ||||
|             mfself->position++;                                         \ | ||||
|           }                                                             \ | ||||
|           if(i == n-1)                                                  \ | ||||
|           {                                                             \ | ||||
|             THMemoryFile_grow(mfself, mfself->position+1);              \ | ||||
|             sprintf(mfself->storage->data+mfself->position, "\n");      \ | ||||
|             mfself->position++;                                         \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|       }                                                                 \ | ||||
|       if(mfself->position > mfself->size)                               \ | ||||
|       {                                                                 \ | ||||
|         mfself->size = mfself->position;                                \ | ||||
|         mfself->storage->data[mfself->size] = '\0';                     \ | ||||
|       }                                                                 \ | ||||
|     }                                                                   \ | ||||
|                                                                         \ | ||||
|     return n;                                                           \ | ||||
|   } | ||||
|  | ||||
|  | ||||
| THCharStorage *THMemoryFile_storage(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|  | ||||
|   THCharStorage_resize(mfself->storage, mfself->size+1); | ||||
|  | ||||
|   return mfself->storage; | ||||
| } | ||||
|  | ||||
| static void THMemoryFile_synchronize(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
| } | ||||
|  | ||||
| static void THMemoryFile_seek(THFile *self, long position) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|  | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|   THArgCheck(position >= 0, 2, "position must be positive"); | ||||
|  | ||||
|   if(position <= mfself->size) | ||||
|     mfself->position = position; | ||||
|   else | ||||
|   { | ||||
|     mfself->file.hasError = 1; | ||||
|     if(!mfself->file.isQuiet) | ||||
|       THError("unable to seek at position %d", position); | ||||
|   } | ||||
| } | ||||
|  | ||||
| static void THMemoryFile_seekEnd(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|  | ||||
|   mfself->position = mfself->size; | ||||
| } | ||||
|  | ||||
| static long THMemoryFile_position(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|   return mfself->position; | ||||
| } | ||||
|  | ||||
| static void THMemoryFile_close(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|   THCharStorage_free(mfself->storage); | ||||
|   mfself->storage = NULL; | ||||
| } | ||||
|  | ||||
| static void THMemoryFile_free(THFile *self) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|  | ||||
|   if(mfself->storage) | ||||
|     THCharStorage_free(mfself->storage); | ||||
|  | ||||
|   THFree(mfself); | ||||
| } | ||||
|  | ||||
| /* READ_WRITE_METHODS(bool, Bool, */ | ||||
| /*                    int value = 0; int ret = sscanf(mfself->storage->data+mfself->position, "%d%n", &value, &nByteRead); data[i] = (value ? 1 : 0), */ | ||||
| /*                    int value = (data[i] ? 1 : 0); nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", value), */ | ||||
| /*                    1) */ | ||||
|  | ||||
| READ_WRITE_METHODS(unsigned char, Byte, | ||||
|                    long ret = (mfself->position + n <= mfself->size ? n : mfself->size-mfself->position);  \ | ||||
|                    if(spacePtr) *spacePtr = spaceChar; \ | ||||
|                    nByteRead = ret; \ | ||||
|                    nread = ret; \ | ||||
|                    i = n-1; \ | ||||
|                    memmove(data, mfself->storage->data+mfself->position, nByteRead), | ||||
|                    nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \ | ||||
|                    i = n-1; \ | ||||
|                    if(nByteWritten > -1) | ||||
|                      memmove(mfself->storage->data+mfself->position, data, nByteWritten), | ||||
|                    0) | ||||
|  | ||||
| /* DEBUG: we should check if %n is count or not as a element (so ret might need to be ret-- on some systems) */ | ||||
| /* Note that we do a trick for char */ | ||||
| READ_WRITE_METHODS(char, Char, | ||||
|                    long ret = (mfself->position + n <= mfself->size ? n : mfself->size-mfself->position);  \ | ||||
|                    if(spacePtr) *spacePtr = spaceChar; \ | ||||
|                    nByteRead = ret; \ | ||||
|                    nread = ret; \ | ||||
|                    i = n-1; \ | ||||
|                    memmove(data, mfself->storage->data+mfself->position, nByteRead), | ||||
|                    nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \ | ||||
|                    i = n-1; \ | ||||
|                    if(nByteWritten > -1) | ||||
|                      memmove(mfself->storage->data+mfself->position, data, nByteWritten), | ||||
|                    0) | ||||
|  | ||||
| READ_WRITE_METHODS(short, Short, | ||||
|                    int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%hd%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, | ||||
|                    nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%hd", data[i]), | ||||
|                    1) | ||||
|  | ||||
| READ_WRITE_METHODS(int, Int, | ||||
|                    int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%d%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, | ||||
|                    nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", data[i]), | ||||
|                    1) | ||||
|  | ||||
| READ_WRITE_METHODS(long, Long, | ||||
|                    int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%ld%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, | ||||
|                    nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%ld", data[i]), | ||||
|                    1) | ||||
|  | ||||
| READ_WRITE_METHODS(float, Float, | ||||
|                    int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, | ||||
|                    nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%g", data[i]), | ||||
|                    1) | ||||
|  | ||||
| READ_WRITE_METHODS(double, Double, | ||||
|                    int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, | ||||
|                    nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%lg", data[i]), | ||||
|                    1) | ||||
|  | ||||
| static char* THMemoryFile_cloneString(const char *str, long size) | ||||
| { | ||||
|   char *cstr = THAlloc(size); | ||||
|   memcpy(cstr, str, size); | ||||
|   return cstr; | ||||
| } | ||||
|  | ||||
| static long THMemoryFile_readString(THFile *self, const char *format, char **str_) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|  | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|   THArgCheck(mfself->file.isReadable, 1, "attempt to read in a write-only file"); | ||||
|   THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'"); | ||||
|  | ||||
|   if(mfself->position == mfself->size) /* eof ? */ | ||||
|   { | ||||
|     mfself->file.hasError = 1; | ||||
|     if(!mfself->file.isQuiet) | ||||
|       THError("read error: read 0 blocks instead of 1"); | ||||
|  | ||||
|     *str_ = NULL; | ||||
|     return 0; | ||||
|   } | ||||
|    | ||||
|   if(format[1] == 'a') | ||||
|   { | ||||
|     long str_size = mfself->size-mfself->position; | ||||
|  | ||||
|     *str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, str_size); | ||||
|     mfself->position = mfself->size; | ||||
|  | ||||
|     return str_size; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     char *p = mfself->storage->data+mfself->position; | ||||
|     long posEol = -1; | ||||
|     long i; | ||||
|     for(i = 0L; i < mfself->size-mfself->position; i++) | ||||
|     { | ||||
|       if(p[i] == '\n') | ||||
|       { | ||||
|         posEol = i; | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if(posEol >= 0) | ||||
|     { | ||||
|       *str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, posEol); | ||||
|       mfself->position += posEol+1; | ||||
|       return posEol; | ||||
|     } | ||||
|     else /* well, we read all! */ | ||||
|     { | ||||
|       long str_size = mfself->size-mfself->position; | ||||
|  | ||||
|       *str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, str_size); | ||||
|       mfself->position = mfself->size; | ||||
|  | ||||
|       return str_size; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   *str_ = NULL; | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static long THMemoryFile_writeString(THFile *self, const char *str, long size) | ||||
| { | ||||
|   THMemoryFile *mfself = (THMemoryFile*)self; | ||||
|  | ||||
|   THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); | ||||
|   THArgCheck(mfself->file.isWritable, 1, "attempt to write in a read-only file"); | ||||
|  | ||||
|   THMemoryFile_grow(mfself, mfself->position+size); | ||||
|   memmove(mfself->storage->data+mfself->position, str, size); | ||||
|   mfself->position += size; | ||||
|   if(mfself->position > mfself->size) | ||||
|   { | ||||
|     mfself->size = mfself->position; | ||||
|     mfself->storage->data[mfself->size] = '\0'; | ||||
|   } | ||||
|  | ||||
|   return size; | ||||
| } | ||||
|  | ||||
| THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) | ||||
| { | ||||
|   static struct THFileVTable vtable = { | ||||
|     THMemoryFile_isOpened, | ||||
|  | ||||
|     THMemoryFile_readByte, | ||||
|     THMemoryFile_readChar, | ||||
|     THMemoryFile_readShort, | ||||
|     THMemoryFile_readInt, | ||||
|     THMemoryFile_readLong, | ||||
|     THMemoryFile_readFloat, | ||||
|     THMemoryFile_readDouble, | ||||
|     THMemoryFile_readString, | ||||
|  | ||||
|     THMemoryFile_writeByte, | ||||
|     THMemoryFile_writeChar, | ||||
|     THMemoryFile_writeShort, | ||||
|     THMemoryFile_writeInt, | ||||
|     THMemoryFile_writeLong, | ||||
|     THMemoryFile_writeFloat, | ||||
|     THMemoryFile_writeDouble, | ||||
|     THMemoryFile_writeString, | ||||
|  | ||||
|     THMemoryFile_synchronize, | ||||
|     THMemoryFile_seek, | ||||
|     THMemoryFile_seekEnd, | ||||
|     THMemoryFile_position, | ||||
|     THMemoryFile_close, | ||||
|     THMemoryFile_free | ||||
|   }; | ||||
|  | ||||
|   THMemoryFile *mfself; | ||||
|   int isReadable; | ||||
|   int isWritable; | ||||
|  | ||||
|   if(storage) | ||||
|   { | ||||
|     THArgCheck(storage->data[storage->size-1] == '\0', 1, "provided CharStorage must be terminated by 0"); | ||||
|     THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); | ||||
|     THCharStorage_retain(storage); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); | ||||
|     storage = THCharStorage_newWithSize(1); | ||||
|     storage->data[0] = '\0'; | ||||
|   } | ||||
|  | ||||
|   mfself = THAlloc(sizeof(THMemoryFile)); | ||||
|  | ||||
|   mfself->storage = storage; | ||||
|   mfself->size = (storage ? storage->size-1 : 0); | ||||
|   mfself->position = 0; | ||||
|  | ||||
|   mfself->file.vtable = &vtable; | ||||
|   mfself->file.isQuiet = 0; | ||||
|   mfself->file.isReadable = isReadable; | ||||
|   mfself->file.isWritable = isWritable; | ||||
|   mfself->file.isBinary = 0; | ||||
|   mfself->file.isAutoSpacing = 1; | ||||
|   mfself->file.hasError = 0; | ||||
|  | ||||
|   return (THFile*)mfself; | ||||
| } | ||||
|  | ||||
| THFile *THMemoryFile_new(const char *mode) | ||||
| { | ||||
|   return THMemoryFile_newWithStorage(NULL, mode); | ||||
| } | ||||
							
								
								
									
										12
									
								
								lib/TH/THMemoryFile.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								lib/TH/THMemoryFile.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | ||||
| #ifndef TH_MEMORY_FILE_INC | ||||
| #define TH_MEMORY_FILE_INC | ||||
|  | ||||
| #include "THFile.h" | ||||
| #include "THStorage.h" | ||||
|  | ||||
| THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode); | ||||
| THFile *THMemoryFile_new(const char *mode); | ||||
|  | ||||
| THCharStorage *THMemoryFile_storage(THFile *self); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										238
									
								
								lib/TH/THRandom.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								lib/TH/THRandom.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,238 @@ | ||||
| #include "THGeneral.h" | ||||
| #include "THRandom.h" | ||||
|  | ||||
| /* The initial seed. */ | ||||
| static unsigned long the_initial_seed; | ||||
|  | ||||
| /* Code for the Mersenne Twister random generator.... */ | ||||
| #define n 624 | ||||
| #define m 397 | ||||
| static int left = 1; | ||||
| static int initf = 0; | ||||
| static unsigned long *next; | ||||
| static unsigned long state[n]; /* the array for the state vector  */ | ||||
| /********************************/ | ||||
|  | ||||
| /* For normal distribution */ | ||||
| static double normal_x; | ||||
| static double normal_y; | ||||
| static double normal_rho; | ||||
| static int normal_is_valid = 0; | ||||
|  | ||||
| unsigned long THRandom_seed() | ||||
| { | ||||
|   unsigned long s = (unsigned long)time(0); | ||||
|   THRandom_manualSeed(s); | ||||
|   return s; | ||||
| } | ||||
|  | ||||
| /* The next 4 methods are taken from http:www.math.keio.ac.jpmatumotoemt.html | ||||
|    Here is the copyright: | ||||
|    Some minor modifications have been made to adapt to "my" C... */ | ||||
|  | ||||
| /* | ||||
|    A C-program for MT19937, with initialization improved 2002/2/10. | ||||
|    Coded by Takuji Nishimura and Makoto Matsumoto. | ||||
|    This is a faster version by taking Shawn Cokus's optimization, | ||||
|    Matthe Bellew's simplification, Isaku Wada's double version. | ||||
|  | ||||
|    Before using, initialize the state by using init_genrand(seed) | ||||
|    or init_by_array(init_key, key_length). | ||||
|  | ||||
|    Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, | ||||
|    All rights reserved. | ||||
|  | ||||
|    Redistribution and use in source and binary forms, with or without | ||||
|    modification, are permitted provided that the following conditions | ||||
|    are met: | ||||
|  | ||||
|      1. Redistributions of source code must retain the above copyright | ||||
|         notice, this list of conditions and the following disclaimer. | ||||
|  | ||||
|      2. Redistributions in binary form must reproduce the above copyright | ||||
|         notice, this list of conditions and the following disclaimer in the | ||||
|         documentation and/or other materials provided with the distribution. | ||||
|  | ||||
|      3. The names of its contributors may not be used to endorse or promote | ||||
|         products derived from this software without specific prior written | ||||
|         permission. | ||||
|  | ||||
|    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||||
|    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||||
|    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||||
|    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||||
|    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||||
|    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||||
|    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||||
|    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF | ||||
|    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING | ||||
|    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||||
|    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||||
|  | ||||
|  | ||||
|    Any feedback is very welcome. | ||||
|    http://www.math.keio.ac.jp/matumoto/emt.html | ||||
|    email: matumoto@math.keio.ac.jp | ||||
| */ | ||||
|  | ||||
| /* Macros for the Mersenne Twister random generator... */ | ||||
| /* Period parameters */   | ||||
| /* #define n 624 */ | ||||
| /* #define m 397 */ | ||||
| #define MATRIX_A 0x9908b0dfUL   /* constant vector a */ | ||||
| #define UMASK 0x80000000UL /* most significant w-r bits */ | ||||
| #define LMASK 0x7fffffffUL /* least significant r bits */ | ||||
| #define MIXBITS(u,v) ( ((u) & UMASK) | ((v) & LMASK) ) | ||||
| #define TWIST(u,v) ((MIXBITS(u,v) >> 1) ^ ((v)&1UL ? MATRIX_A : 0UL)) | ||||
| /*********************************************************** That's it. */ | ||||
|  | ||||
| void THRandom_manualSeed(unsigned long the_seed_) | ||||
| { | ||||
|   int j; | ||||
|   the_initial_seed = the_seed_; | ||||
|   state[0]= the_initial_seed & 0xffffffffUL; | ||||
|   for(j = 1; j < n; j++) | ||||
|   { | ||||
|     state[j] = (1812433253UL * (state[j-1] ^ (state[j-1] >> 30)) + j);  | ||||
|     /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ | ||||
|     /* In the previous versions, mSBs of the seed affect   */ | ||||
|     /* only mSBs of the array state[].                        */ | ||||
|     /* 2002/01/09 modified by makoto matsumoto             */ | ||||
|     state[j] &= 0xffffffffUL;  /* for >32 bit machines */ | ||||
|   } | ||||
|   left = 1; | ||||
|   initf = 1; | ||||
| } | ||||
|  | ||||
| unsigned long THRandom_initialSeed() | ||||
| { | ||||
|   if(initf == 0) | ||||
|   { | ||||
|     THRandom_seed(); | ||||
|   } | ||||
|  | ||||
|   return the_initial_seed; | ||||
| } | ||||
|  | ||||
| void THRandom_nextState() | ||||
| { | ||||
|   unsigned long *p=state; | ||||
|   int j; | ||||
|  | ||||
|   /* if init_genrand() has not been called, */ | ||||
|   /* a default initial seed is used         */ | ||||
|   if(initf == 0) | ||||
|     THRandom_seed(); | ||||
|  | ||||
|   left = n; | ||||
|   next = state; | ||||
|      | ||||
|   for(j = n-m+1; --j; p++)  | ||||
|     *p = p[m] ^ TWIST(p[0], p[1]); | ||||
|  | ||||
|   for(j = m; --j; p++)  | ||||
|     *p = p[m-n] ^ TWIST(p[0], p[1]); | ||||
|  | ||||
|   *p = p[m-n] ^ TWIST(p[0], state[0]); | ||||
| } | ||||
|  | ||||
| unsigned long THRandom_random() | ||||
| { | ||||
|   unsigned long y; | ||||
|  | ||||
|   if (--left == 0) | ||||
|     THRandom_nextState(); | ||||
|   y = *next++; | ||||
|    | ||||
|   /* Tempering */ | ||||
|   y ^= (y >> 11); | ||||
|   y ^= (y << 7) & 0x9d2c5680UL; | ||||
|   y ^= (y << 15) & 0xefc60000UL; | ||||
|   y ^= (y >> 18); | ||||
|  | ||||
|   return y; | ||||
| } | ||||
|  | ||||
| /* generates a random number on [0,1)-double-interval */ | ||||
| static double __uniform__() | ||||
| { | ||||
|   unsigned long y; | ||||
|  | ||||
|   if(--left == 0) | ||||
|     THRandom_nextState(); | ||||
|   y = *next++; | ||||
|  | ||||
|   /* Tempering */ | ||||
|   y ^= (y >> 11); | ||||
|   y ^= (y << 7) & 0x9d2c5680UL; | ||||
|   y ^= (y << 15) & 0xefc60000UL; | ||||
|   y ^= (y >> 18); | ||||
|    | ||||
|   return (double)y * (1.0/4294967296.0);  | ||||
|   /* divided by 2^32 */ | ||||
| } | ||||
|  | ||||
| /********************************************************* | ||||
|  | ||||
|  Thanks *a lot* Takuji Nishimura and Makoto Matsumoto! | ||||
|  | ||||
|  Now my own code... | ||||
|  | ||||
| *********************************************************/ | ||||
|  | ||||
| double THRandom_uniform(double a, double b) | ||||
| { | ||||
|   return(__uniform__() * (b - a) + a); | ||||
| } | ||||
|  | ||||
| double THRandom_normal(double mean, double stdv) | ||||
| { | ||||
|   THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); | ||||
|  | ||||
|   if(!normal_is_valid) | ||||
|   { | ||||
|     normal_x = __uniform__(); | ||||
|     normal_y = __uniform__(); | ||||
|     normal_rho = sqrt(-2. * log(1.0-normal_y)); | ||||
|     normal_is_valid = 1; | ||||
|   } | ||||
|   else | ||||
|     normal_is_valid = 0; | ||||
|    | ||||
|   if(normal_is_valid) | ||||
|     return normal_rho*cos(2.*M_PI*normal_x)*stdv+mean; | ||||
|   else | ||||
|     return normal_rho*sin(2.*M_PI*normal_x)*stdv+mean; | ||||
| } | ||||
|  | ||||
| double THRandom_exponential(double lambda) | ||||
| { | ||||
|   return(-1. / lambda * log(1-__uniform__())); | ||||
| } | ||||
|  | ||||
| double THRandom_cauchy(double median, double sigma) | ||||
| { | ||||
|   return(median + sigma * tan(M_PI*(__uniform__()-0.5))); | ||||
| } | ||||
|  | ||||
| /* Faut etre malade pour utiliser ca. | ||||
|    M'enfin. */ | ||||
| double THRandom_logNormal(double mean, double stdv) | ||||
| { | ||||
|   double zm = mean*mean; | ||||
|   double zs = stdv*stdv; | ||||
|   THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); | ||||
|   return(exp(THRandom_normal(log(zm/sqrt(zs + zm)), sqrt(log(zs/zm+1)) ))); | ||||
| } | ||||
|  | ||||
| int THRandom_geometric(double p) | ||||
| { | ||||
|   THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1"); | ||||
|   return((int)(log(1-__uniform__()) / log(p)) + 1); | ||||
| } | ||||
|  | ||||
| int THRandom_bernoulli(double p) | ||||
| { | ||||
|   THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1"); | ||||
|   return(__uniform__() <= p); | ||||
| } | ||||
							
								
								
									
										52
									
								
								lib/TH/THRandom.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								lib/TH/THRandom.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,52 @@ | ||||
| #ifndef TH_RANDOM_INC | ||||
| #define TH_RANDOM_INC | ||||
|  | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| /* Initializes the random number generator with the current time (granularity: seconds) and returns the seed. */ | ||||
| TH_API unsigned long THRandom_seed(); | ||||
|  | ||||
| /* Initializes the random number generator with the given long "the_seed_". */ | ||||
| TH_API void THRandom_manualSeed(unsigned long the_seed_); | ||||
|  | ||||
| /* Returns the starting seed used. */ | ||||
| TH_API unsigned long THRandom_initialSeed(); | ||||
|  | ||||
| /* Generates a uniform 32 bits integer. */ | ||||
| TH_API unsigned long THRandom_random(); | ||||
|  | ||||
| /* Generates a uniform random number on [0,1[. */ | ||||
| TH_API double THRandom_uniform(double a, double b); | ||||
|  | ||||
| /** Generates a random number from a normal distribution. | ||||
|     (With mean #mean# and standard deviation #stdv >= 0#). | ||||
| */ | ||||
| TH_API double THRandom_normal(double mean, double stdv); | ||||
|  | ||||
| /** Generates a random number from an exponential distribution. | ||||
|     The density is $p(x) = lambda * exp(-lambda * x)$, where | ||||
|     lambda is a positive number. | ||||
| */ | ||||
| TH_API double THRandom_exponential(double lambda); | ||||
|  | ||||
| /** Returns a random number from a Cauchy distribution. | ||||
|     The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$ | ||||
| */ | ||||
| TH_API double THRandom_cauchy(double median, double sigma); | ||||
|  | ||||
| /** Generates a random number from a log-normal distribution. | ||||
|     (#mean > 0# is the mean of the log-normal distribution | ||||
|     and #stdv# is its standard deviation). | ||||
| */ | ||||
| TH_API double THRandom_logNormal(double mean, double stdv); | ||||
|  | ||||
| /** Generates a random number from a geometric distribution. | ||||
|     It returns an integer #i#, where $p(i) = (1-p) * p^(i-1)$. | ||||
|     p must satisfy $0 < p < 1$. | ||||
| */ | ||||
| TH_API int THRandom_geometric(double p); | ||||
|  | ||||
| /* Returns true with probability $p$ and false with probability $1-p$ (p > 0). */ | ||||
| TH_API int THRandom_bernoulli(double p); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										7
									
								
								lib/TH/THStorage.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								lib/TH/THStorage.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | ||||
| #include "THStorage.h" | ||||
|  | ||||
| #include "generic/THStorage.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THStorageCopy.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
							
								
								
									
										33
									
								
								lib/TH/THStorage.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								lib/TH/THStorage.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | ||||
| #ifndef TH_STORAGE_INC | ||||
| #define TH_STORAGE_INC | ||||
|  | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| /* stuff for mapped files */ | ||||
| #ifdef _WIN32 | ||||
| #include <windows.h> | ||||
| #endif | ||||
|  | ||||
| #if HAVE_MMAP | ||||
| #include <sys/types.h> | ||||
| #include <sys/mman.h> | ||||
| #include <sys/stat.h> | ||||
| #include <fcntl.h> | ||||
| #include <unistd.h> | ||||
| #endif | ||||
| /* end of stuff for mapped files */ | ||||
|  | ||||
| #define THStorage        TH_CONCAT_3(TH,Real,Storage) | ||||
| #define THStorage_(NAME) TH_CONCAT_4(TH,Real,Storage_,NAME) | ||||
|  | ||||
| /* fast access methods */ | ||||
| #define TH_STORAGE_GET(storage, idx) ((storage)->data[(idx)]) | ||||
| #define TH_STORAGE_SET(storage, idx, value) ((storage)->data[(idx)] = (value)) | ||||
|  | ||||
| #include "generic/THStorage.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THStorageCopy.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										24
									
								
								lib/TH/THTensor.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								lib/TH/THTensor.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | ||||
| #include "THTensor.h" | ||||
| #include "THVector.h" | ||||
| #include "THBlas.h" | ||||
| #include "THLapack.h" | ||||
| #include "THRandom.h" | ||||
| #include "THTensorDimApply.h" | ||||
|  | ||||
| #include "generic/THTensor.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THTensorCopy.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THTensorRandom.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THTensorMath.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THTensorConv.c" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THTensorLapack.c" | ||||
| #include "THGenerateFloatTypes.h" | ||||
							
								
								
									
										35
									
								
								lib/TH/THTensor.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								lib/TH/THTensor.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,35 @@ | ||||
| #ifndef TH_TENSOR_INC | ||||
| #define TH_TENSOR_INC | ||||
|  | ||||
| #include "THStorage.h" | ||||
| #include "THTensorApply.h" | ||||
|  | ||||
| #define THTensor          TH_CONCAT_3(TH,Real,Tensor) | ||||
| #define THTensor_(NAME)   TH_CONCAT_4(TH,Real,Tensor_,NAME) | ||||
|  | ||||
| /* basics */ | ||||
| #include "generic/THTensor.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "generic/THTensorCopy.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| #include "THTensorMacros.h" | ||||
|  | ||||
| /* random numbers */ | ||||
| #include "generic/THTensorRandom.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| /* maths */ | ||||
| #include "generic/THTensorMath.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| /* convolutions */ | ||||
| #include "generic/THTensorConv.h" | ||||
| #include "THGenerateAllTypes.h" | ||||
|  | ||||
| /* lapack support */ | ||||
| #include "generic/THTensorLapack.h" | ||||
| #include "THGenerateFloatTypes.h" | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										428
									
								
								lib/TH/THTensorApply.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										428
									
								
								lib/TH/THTensorApply.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,428 @@ | ||||
| #ifndef TH_TENSOR_APPLY_INC | ||||
| #define TH_TENSOR_APPLY_INC | ||||
|  | ||||
| #define TH_TENSOR_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \ | ||||
| { \ | ||||
|   TYPE1 *TENSOR1##_data = NULL; \ | ||||
|   long *TENSOR1##_counter = NULL; \ | ||||
|   long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \ | ||||
|   TYPE2 *TENSOR2##_data = NULL; \ | ||||
|   long *TENSOR2##_counter = NULL; \ | ||||
|   long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \ | ||||
|   TYPE2 *TENSOR3##_data = NULL; \ | ||||
|   long *TENSOR3##_counter = NULL; \ | ||||
|   long TENSOR3##_stride = 0, TENSOR3##_size = 0, TENSOR3##_dim = 0, TENSOR3##_i, TENSOR3##_n; \ | ||||
|   int TH_TENSOR_APPLY_hasFinished = 0; \ | ||||
| \ | ||||
|   TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \ | ||||
|   for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \ | ||||
|     TENSOR1##_n *= TENSOR1->size[TENSOR1##_i]; \ | ||||
| \ | ||||
|   TENSOR2##_n = (TENSOR2->nDimension ? 1 : 0); \ | ||||
|   for(TENSOR2##_i = 0; TENSOR2##_i < TENSOR2->nDimension; TENSOR2##_i++) \ | ||||
|     TENSOR2##_n *= TENSOR2->size[TENSOR2##_i]; \ | ||||
| \ | ||||
|   TENSOR3##_n = (TENSOR3->nDimension ? 1 : 0); \ | ||||
|   for(TENSOR3##_i = 0; TENSOR3##_i < TENSOR3->nDimension; TENSOR3##_i++) \ | ||||
|     TENSOR3##_n *= TENSOR3->size[TENSOR3##_i]; \ | ||||
| \ | ||||
|   if(TENSOR1##_n != TENSOR2##_n || TENSOR1##_n != TENSOR3##_n) /* should we do the check in the function instead? i think so */ \ | ||||
|     THError("inconsistent tensor size"); \ | ||||
| \ | ||||
|   if(TENSOR1->nDimension == 0) \ | ||||
|     TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|   else \ | ||||
|   { \ | ||||
|     TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \ | ||||
|     for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR1->size[TENSOR1##_dim] != 1) \ | ||||
|         break; \ | ||||
|     } \ | ||||
|     TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \ | ||||
|     TENSOR1##_size = 1; \ | ||||
|     for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR1->size[TENSOR1##_dim] != 1) \ | ||||
|       { \ | ||||
|         if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \ | ||||
|           TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|     } \ | ||||
|     TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \ | ||||
|     for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \ | ||||
|       TENSOR1##_counter[TENSOR1##_i] = 0; \ | ||||
| \ | ||||
|     TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \ | ||||
|     for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR2->size[TENSOR2##_dim] != 1) \ | ||||
|         break; \ | ||||
|     } \ | ||||
|     TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \ | ||||
|     TENSOR2##_size = 1; \ | ||||
|     for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR2->size[TENSOR2##_dim] != 1) \ | ||||
|       { \ | ||||
|         if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \ | ||||
|           TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|     } \ | ||||
|     TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \ | ||||
|     for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \ | ||||
|       TENSOR2##_counter[TENSOR2##_i] = 0; \ | ||||
| \ | ||||
|     TENSOR3##_data = TENSOR3->storage->data+TENSOR3->storageOffset; \ | ||||
|     for(TENSOR3##_dim = TENSOR3->nDimension-1; TENSOR3##_dim >= 0; TENSOR3##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR3->size[TENSOR3##_dim] != 1) \ | ||||
|         break; \ | ||||
|     } \ | ||||
|     TENSOR3##_stride = (TENSOR3##_dim == -1 ? 0 : TENSOR3->stride[TENSOR3##_dim]); \ | ||||
|     TENSOR3##_size = 1; \ | ||||
|     for(TENSOR3##_dim = TENSOR3->nDimension-1; TENSOR3##_dim >= 0; TENSOR3##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR3->size[TENSOR3##_dim] != 1) \ | ||||
|       { \ | ||||
|         if(TENSOR3->stride[TENSOR3##_dim] == TENSOR3##_size) \ | ||||
|           TENSOR3##_size *= TENSOR3->size[TENSOR3##_dim]; \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|     } \ | ||||
|     TENSOR3##_counter = (long*)THAlloc(sizeof(long)*(TENSOR3##_dim+1)); \ | ||||
|     for(TENSOR3##_i = 0; TENSOR3##_i <= TENSOR3##_dim; TENSOR3##_i++) \ | ||||
|       TENSOR3##_counter[TENSOR3##_i] = 0; \ | ||||
|   } \ | ||||
| \ | ||||
|   TENSOR1##_i = 0; \ | ||||
|   TENSOR2##_i = 0; \ | ||||
|   TENSOR3##_i = 0; \ | ||||
|   while(!TH_TENSOR_APPLY_hasFinished) \ | ||||
|   { \ | ||||
|     for(; TENSOR1##_i < TENSOR1##_size && TENSOR2##_i < TENSOR2##_size && TENSOR3##_i < TENSOR3##_size; TENSOR1##_i++, TENSOR2##_i++, TENSOR3##_i++, TENSOR1##_data += TENSOR1##_stride, TENSOR2##_data += TENSOR2##_stride, TENSOR3##_data += TENSOR3##_stride) /* 0 et pas TENSOR##_dim! */ \ | ||||
|     { \ | ||||
|       CODE \ | ||||
|     } \ | ||||
| \ | ||||
|     if(TENSOR1##_i == TENSOR1##_size) \ | ||||
|     { \ | ||||
|       if(TENSOR1##_dim == -1) \ | ||||
|          break; \ | ||||
| \ | ||||
|       TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \ | ||||
|       for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \ | ||||
|       { \ | ||||
|         TENSOR1##_counter[TENSOR1##_i]++; \ | ||||
|         TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \ | ||||
| \ | ||||
|         if(TENSOR1##_counter[TENSOR1##_i]  == TENSOR1->size[TENSOR1##_i]) \ | ||||
|         { \ | ||||
|           if(TENSOR1##_i == 0) \ | ||||
|           { \ | ||||
|             TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|             break; \ | ||||
|           } \ | ||||
|             else \ | ||||
|           { \ | ||||
|             TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \ | ||||
|             TENSOR1##_counter[TENSOR1##_i] = 0; \ | ||||
|           } \ | ||||
|         } \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|       TENSOR1##_i = 0; \ | ||||
|     } \ | ||||
| \ | ||||
|     if(TENSOR2##_i == TENSOR2##_size) \ | ||||
|     { \ | ||||
|       if(TENSOR2##_dim == -1) \ | ||||
|          break; \ | ||||
| \ | ||||
|       TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \ | ||||
|       for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \ | ||||
|       { \ | ||||
|         TENSOR2##_counter[TENSOR2##_i]++; \ | ||||
|         TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \ | ||||
| \ | ||||
|         if(TENSOR2##_counter[TENSOR2##_i]  == TENSOR2->size[TENSOR2##_i]) \ | ||||
|         { \ | ||||
|           if(TENSOR2##_i == 0) \ | ||||
|           { \ | ||||
|             TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|             break; \ | ||||
|           } \ | ||||
|             else \ | ||||
|           { \ | ||||
|             TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \ | ||||
|             TENSOR2##_counter[TENSOR2##_i] = 0; \ | ||||
|           } \ | ||||
|         } \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|       TENSOR2##_i = 0; \ | ||||
|     } \ | ||||
| \ | ||||
|     if(TENSOR3##_i == TENSOR3##_size) \ | ||||
|     { \ | ||||
|       if(TENSOR3##_dim == -1) \ | ||||
|          break; \ | ||||
| \ | ||||
|       TENSOR3##_data -= TENSOR3##_size*TENSOR3##_stride; \ | ||||
|       for(TENSOR3##_i = TENSOR3##_dim; TENSOR3##_i >= 0; TENSOR3##_i--) \ | ||||
|       { \ | ||||
|         TENSOR3##_counter[TENSOR3##_i]++; \ | ||||
|         TENSOR3##_data += TENSOR3->stride[TENSOR3##_i]; \ | ||||
| \ | ||||
|         if(TENSOR3##_counter[TENSOR3##_i]  == TENSOR3->size[TENSOR3##_i]) \ | ||||
|         { \ | ||||
|           if(TENSOR3##_i == 0) \ | ||||
|           { \ | ||||
|             TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|             break; \ | ||||
|           } \ | ||||
|             else \ | ||||
|           { \ | ||||
|             TENSOR3##_data -= TENSOR3##_counter[TENSOR3##_i]*TENSOR3->stride[TENSOR3##_i]; \ | ||||
|             TENSOR3##_counter[TENSOR3##_i] = 0; \ | ||||
|           } \ | ||||
|         } \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|       TENSOR3##_i = 0; \ | ||||
|     } \ | ||||
|   } \ | ||||
|   THFree(TENSOR1##_counter); \ | ||||
|   THFree(TENSOR2##_counter); \ | ||||
|   THFree(TENSOR3##_counter); \ | ||||
| } | ||||
|  | ||||
| #define TH_TENSOR_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \ | ||||
| { \ | ||||
|   TYPE1 *TENSOR1##_data = NULL; \ | ||||
|   long *TENSOR1##_counter = NULL; \ | ||||
|   long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \ | ||||
|   TYPE2 *TENSOR2##_data = NULL; \ | ||||
|   long *TENSOR2##_counter = NULL; \ | ||||
|   long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \ | ||||
|   int TH_TENSOR_APPLY_hasFinished = 0; \ | ||||
| \ | ||||
|   TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \ | ||||
|   for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \ | ||||
|     TENSOR1##_n *= TENSOR1->size[TENSOR1##_i]; \ | ||||
| \ | ||||
|   TENSOR2##_n = (TENSOR2->nDimension ? 1 : 0); \ | ||||
|   for(TENSOR2##_i = 0; TENSOR2##_i < TENSOR2->nDimension; TENSOR2##_i++) \ | ||||
|     TENSOR2##_n *= TENSOR2->size[TENSOR2##_i]; \ | ||||
| \ | ||||
|   if(TENSOR1##_n != TENSOR2##_n) /* should we do the check in the function instead? i think so */ \ | ||||
|     THError("inconsistent tensor size"); \ | ||||
| \ | ||||
|   if(TENSOR1->nDimension == 0) \ | ||||
|     TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|   else \ | ||||
|   { \ | ||||
|     TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \ | ||||
|     for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR1->size[TENSOR1##_dim] != 1) \ | ||||
|         break; \ | ||||
|     } \ | ||||
|     TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \ | ||||
|     TENSOR1##_size = 1; \ | ||||
|     for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR1->size[TENSOR1##_dim] != 1) \ | ||||
|       { \ | ||||
|         if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \ | ||||
|           TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|     } \ | ||||
|     TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \ | ||||
|     for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \ | ||||
|       TENSOR1##_counter[TENSOR1##_i] = 0; \ | ||||
| \ | ||||
|     TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \ | ||||
|     for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR2->size[TENSOR2##_dim] != 1) \ | ||||
|         break; \ | ||||
|     } \ | ||||
|     TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \ | ||||
|     TENSOR2##_size = 1; \ | ||||
|     for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR2->size[TENSOR2##_dim] != 1) \ | ||||
|       { \ | ||||
|         if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \ | ||||
|           TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|     } \ | ||||
|     TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \ | ||||
|     for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \ | ||||
|       TENSOR2##_counter[TENSOR2##_i] = 0; \ | ||||
|   } \ | ||||
| \ | ||||
|   TENSOR1##_i = 0; \ | ||||
|   TENSOR2##_i = 0; \ | ||||
|   while(!TH_TENSOR_APPLY_hasFinished) \ | ||||
|   { \ | ||||
|     for(; TENSOR1##_i < TENSOR1##_size && TENSOR2##_i < TENSOR2##_size; TENSOR1##_i++, TENSOR2##_i++, TENSOR1##_data += TENSOR1##_stride, TENSOR2##_data += TENSOR2##_stride) /* 0 et pas TENSOR##_dim! */ \ | ||||
|     { \ | ||||
|       CODE \ | ||||
|     } \ | ||||
| \ | ||||
|     if(TENSOR1##_i == TENSOR1##_size) \ | ||||
|     { \ | ||||
|       if(TENSOR1##_dim == -1) \ | ||||
|          break; \ | ||||
| \ | ||||
|       TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \ | ||||
|       for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \ | ||||
|       { \ | ||||
|         TENSOR1##_counter[TENSOR1##_i]++; \ | ||||
|         TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \ | ||||
| \ | ||||
|         if(TENSOR1##_counter[TENSOR1##_i]  == TENSOR1->size[TENSOR1##_i]) \ | ||||
|         { \ | ||||
|           if(TENSOR1##_i == 0) \ | ||||
|           { \ | ||||
|             TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|             break; \ | ||||
|           } \ | ||||
|             else \ | ||||
|           { \ | ||||
|             TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \ | ||||
|             TENSOR1##_counter[TENSOR1##_i] = 0; \ | ||||
|           } \ | ||||
|         } \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|       TENSOR1##_i = 0; \ | ||||
|     } \ | ||||
| \ | ||||
|     if(TENSOR2##_i == TENSOR2##_size) \ | ||||
|     { \ | ||||
|       if(TENSOR2##_dim == -1) \ | ||||
|          break; \ | ||||
| \ | ||||
|       TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \ | ||||
|       for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \ | ||||
|       { \ | ||||
|         TENSOR2##_counter[TENSOR2##_i]++; \ | ||||
|         TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \ | ||||
| \ | ||||
|         if(TENSOR2##_counter[TENSOR2##_i]  == TENSOR2->size[TENSOR2##_i]) \ | ||||
|         { \ | ||||
|           if(TENSOR2##_i == 0) \ | ||||
|           { \ | ||||
|             TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|             break; \ | ||||
|           } \ | ||||
|             else \ | ||||
|           { \ | ||||
|             TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \ | ||||
|             TENSOR2##_counter[TENSOR2##_i] = 0; \ | ||||
|           } \ | ||||
|         } \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|       TENSOR2##_i = 0; \ | ||||
|     } \ | ||||
|   } \ | ||||
|   THFree(TENSOR1##_counter); \ | ||||
|   THFree(TENSOR2##_counter); \ | ||||
| } | ||||
|  | ||||
| #define TH_TENSOR_APPLY(TYPE, TENSOR, CODE) \ | ||||
| { \ | ||||
|   TYPE *TENSOR##_data = NULL; \ | ||||
|   long *TENSOR##_counter = NULL; \ | ||||
|   long TENSOR##_stride = 0, TENSOR##_size = 0, TENSOR##_dim = 0, TENSOR##_i; \ | ||||
|   int TH_TENSOR_APPLY_hasFinished = 0; \ | ||||
| \ | ||||
|   if(TENSOR->nDimension == 0) \ | ||||
|     TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|   else \ | ||||
|   { \ | ||||
|     TENSOR##_data = TENSOR->storage->data+TENSOR->storageOffset; \ | ||||
| \ | ||||
|     /* what is the first stride (ignore first dims=1)? */ \ | ||||
|     /* it will be used for the whole largest contiguous section */ \ | ||||
|     for(TENSOR##_dim = TENSOR->nDimension-1; TENSOR##_dim >= 0; TENSOR##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR->size[TENSOR##_dim] != 1) \ | ||||
|         break; \ | ||||
|     } \ | ||||
|     TENSOR##_stride = (TENSOR##_dim == -1 ? 0 : TENSOR->stride[TENSOR##_dim]); \ | ||||
| \ | ||||
|     /* what is the largest contiguous section? */ \ | ||||
|     TENSOR##_size = 1; \ | ||||
|     for(TENSOR##_dim = TENSOR->nDimension-1; TENSOR##_dim >= 0; TENSOR##_dim--) \ | ||||
|     { \ | ||||
|       if(TENSOR->size[TENSOR##_dim] != 1) \ | ||||
|       { \ | ||||
|         if(TENSOR->stride[TENSOR##_dim] == TENSOR##_size) \ | ||||
|           TENSOR##_size *= TENSOR->size[TENSOR##_dim]; \ | ||||
|         else \ | ||||
|           break; \ | ||||
|       } \ | ||||
|     } \ | ||||
| \ | ||||
|     /* counter over found dimensions */ \ | ||||
|     TENSOR##_counter = (long*)THAlloc(sizeof(long)*(TENSOR##_dim+1)); \ | ||||
|     for(TENSOR##_i = 0; TENSOR##_i <= TENSOR##_dim; TENSOR##_i++) \ | ||||
|       TENSOR##_counter[TENSOR##_i] = 0; \ | ||||
|   } \ | ||||
| \ | ||||
|   while(!TH_TENSOR_APPLY_hasFinished) \ | ||||
|   { \ | ||||
|     for(TENSOR##_i = 0; TENSOR##_i < TENSOR##_size; TENSOR##_i++, TENSOR##_data += TENSOR##_stride) /* 0 et pas TENSOR##_dim! */ \ | ||||
|     { \ | ||||
|       CODE \ | ||||
|     } \ | ||||
| \ | ||||
|     if(TENSOR##_dim == -1) \ | ||||
|        break; \ | ||||
|  \ | ||||
|     TENSOR##_data -= TENSOR##_i*TENSOR##_stride; \ | ||||
|     for(TENSOR##_i = TENSOR##_dim; TENSOR##_i >= 0; TENSOR##_i--) \ | ||||
|     { \ | ||||
|       TENSOR##_counter[TENSOR##_i]++; \ | ||||
|       TENSOR##_data += TENSOR->stride[TENSOR##_i]; \ | ||||
| \ | ||||
|       if(TENSOR##_counter[TENSOR##_i]  == TENSOR->size[TENSOR##_i]) \ | ||||
|       { \ | ||||
|         if(TENSOR##_i == 0) \ | ||||
|         { \ | ||||
|           TH_TENSOR_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         else \ | ||||
|         { \ | ||||
|           TENSOR##_data -= TENSOR##_counter[TENSOR##_i]*TENSOR->stride[TENSOR##_i]; \ | ||||
|           TENSOR##_counter[TENSOR##_i] = 0; \ | ||||
|         } \ | ||||
|       } \ | ||||
|       else \ | ||||
|         break; \ | ||||
|     } \ | ||||
|   } \ | ||||
|   THFree(TENSOR##_counter); \ | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										232
									
								
								lib/TH/THTensorDimApply.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										232
									
								
								lib/TH/THTensorDimApply.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,232 @@ | ||||
| #ifndef TH_TENSOR_DIM_APPLY_INC | ||||
| #define TH_TENSOR_DIM_APPLY_INC | ||||
|  | ||||
| #define TH_TENSOR_DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIMENSION, CODE) \ | ||||
| { \ | ||||
|   TYPE1 *TENSOR1##_data = NULL; \ | ||||
|   long TENSOR1##_stride = 0, TENSOR1##_size = 0; \ | ||||
|   TYPE2 *TENSOR2##_data = NULL; \ | ||||
|   long TENSOR2##_stride = 0, TENSOR2##_size = 0; \ | ||||
|   TYPE3 *TENSOR3##_data = NULL; \ | ||||
|   long TENSOR3##_stride = 0, TENSOR3##_size = 0; \ | ||||
|   long *TH_TENSOR_DIM_APPLY_counter = NULL; \ | ||||
|   int TH_TENSOR_DIM_APPLY_hasFinished = 0; \ | ||||
|   int TH_TENSOR_DIM_APPLY_i; \ | ||||
| \ | ||||
|   if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \ | ||||
|     THError("invalid dimension"); \ | ||||
|   if( TENSOR1->nDimension != TENSOR2->nDimension ) \ | ||||
|     THError("inconsistent tensor sizes"); \ | ||||
|   if( TENSOR1->nDimension != TENSOR3->nDimension ) \ | ||||
|     THError("inconsistent tensor sizes"); \ | ||||
|   for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|   { \ | ||||
|     if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ | ||||
|       continue; \ | ||||
|     if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \ | ||||
|       THError("inconsistent tensor sizes"); \ | ||||
|     if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR3->size[TH_TENSOR_DIM_APPLY_i]) \ | ||||
|       THError("inconsistent tensor sizes"); \ | ||||
|   } \ | ||||
| \ | ||||
|   TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \ | ||||
|   for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|     TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ | ||||
| \ | ||||
|   TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \ | ||||
|   TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \ | ||||
|   TENSOR1##_size = TENSOR1->size[DIMENSION]; \ | ||||
| \ | ||||
|   TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \ | ||||
|   TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \ | ||||
|   TENSOR2##_size = TENSOR2->size[DIMENSION]; \ | ||||
| \ | ||||
|   TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \ | ||||
|   TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \ | ||||
|   TENSOR3##_size = TENSOR3->size[DIMENSION]; \ | ||||
| \ | ||||
|   while(!TH_TENSOR_DIM_APPLY_hasFinished) \ | ||||
|   { \ | ||||
|     CODE \ | ||||
| \ | ||||
|     if(TENSOR1->nDimension == 1) \ | ||||
|        break; \ | ||||
|  \ | ||||
|     for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|     { \ | ||||
|       if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ | ||||
|       { \ | ||||
|         if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ | ||||
|         { \ | ||||
|           TH_TENSOR_DIM_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         continue; \ | ||||
|       } \ | ||||
| \ | ||||
|       TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \ | ||||
|       TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|       TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|       TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
| \ | ||||
|       if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) \ | ||||
|       { \ | ||||
|         if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ | ||||
|         { \ | ||||
|           TH_TENSOR_DIM_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         else \ | ||||
|         { \ | ||||
|           TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|           TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|           TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|           TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ | ||||
|         } \ | ||||
|       } \ | ||||
|       else \ | ||||
|         break; \ | ||||
|     } \ | ||||
|   } \ | ||||
|   THFree(TH_TENSOR_DIM_APPLY_counter); \ | ||||
| } | ||||
|  | ||||
| #define TH_TENSOR_DIM_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, DIMENSION, CODE) \ | ||||
| { \ | ||||
|   TYPE1 *TENSOR1##_data = NULL; \ | ||||
|   long TENSOR1##_stride = 0, TENSOR1##_size = 0; \ | ||||
|   TYPE2 *TENSOR2##_data = NULL; \ | ||||
|   long TENSOR2##_stride = 0, TENSOR2##_size = 0; \ | ||||
|   long *TH_TENSOR_DIM_APPLY_counter = NULL; \ | ||||
|   int TH_TENSOR_DIM_APPLY_hasFinished = 0; \ | ||||
|   int TH_TENSOR_DIM_APPLY_i; \ | ||||
| \ | ||||
|   if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \ | ||||
|     THError("invalid dimension"); \ | ||||
|   if( TENSOR1->nDimension != TENSOR2->nDimension ) \ | ||||
|     THError("inconsistent tensor sizes"); \ | ||||
|   for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|   { \ | ||||
|     if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ | ||||
|       continue; \ | ||||
|     if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \ | ||||
|       THError("inconsistent tensor sizes"); \ | ||||
|   } \ | ||||
| \ | ||||
|   TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \ | ||||
|   for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|     TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ | ||||
| \ | ||||
|   TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \ | ||||
|   TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \ | ||||
|   TENSOR1##_size = TENSOR1->size[DIMENSION]; \ | ||||
| \ | ||||
|   TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \ | ||||
|   TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \ | ||||
|   TENSOR2##_size = TENSOR2->size[DIMENSION]; \ | ||||
| \ | ||||
|   while(!TH_TENSOR_DIM_APPLY_hasFinished) \ | ||||
|   { \ | ||||
|     CODE \ | ||||
| \ | ||||
|     if(TENSOR1->nDimension == 1) \ | ||||
|        break; \ | ||||
|  \ | ||||
|     for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|     { \ | ||||
|       if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ | ||||
|       { \ | ||||
|         if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ | ||||
|         { \ | ||||
|           TH_TENSOR_DIM_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         continue; \ | ||||
|       } \ | ||||
| \ | ||||
|       TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \ | ||||
|       TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|       TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
| \ | ||||
|       if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) \ | ||||
|       { \ | ||||
|         if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ | ||||
|         { \ | ||||
|           TH_TENSOR_DIM_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         else \ | ||||
|         { \ | ||||
|           TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|           TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|           TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ | ||||
|         } \ | ||||
|       } \ | ||||
|       else \ | ||||
|         break; \ | ||||
|     } \ | ||||
|   } \ | ||||
|   THFree(TH_TENSOR_DIM_APPLY_counter); \ | ||||
| } | ||||
|  | ||||
| #define TH_TENSOR_DIM_APPLY(TYPE, TENSOR, DIMENSION, CODE) \ | ||||
| { \ | ||||
|   TYPE *TENSOR##_data = NULL; \ | ||||
|   long TENSOR##_stride = 0, TENSOR##_size = 0; \ | ||||
|   long *TH_TENSOR_DIM_APPLY_counter = NULL; \ | ||||
|   int TH_TENSOR_DIM_APPLY_hasFinished = 0; \ | ||||
|   int TH_TENSOR_DIM_APPLY_i; \ | ||||
| \ | ||||
|   if( (DIMENSION < 0) || (DIMENSION >= TENSOR->nDimension) ) \ | ||||
|     THError("invalid dimension"); \ | ||||
| \ | ||||
|   TENSOR##_data = (TENSOR)->storage->data+(TENSOR)->storageOffset; \ | ||||
|   TENSOR##_stride = (TENSOR)->stride[DIMENSION]; \ | ||||
|   TENSOR##_size = TENSOR->size[DIMENSION]; \ | ||||
|   TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR->nDimension)); \ | ||||
|   for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|     TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ | ||||
| \ | ||||
|   while(!TH_TENSOR_DIM_APPLY_hasFinished) \ | ||||
|   { \ | ||||
|     CODE \ | ||||
| \ | ||||
|     if(TENSOR->nDimension == 1) \ | ||||
|        break; \ | ||||
|  \ | ||||
|     for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->nDimension; TH_TENSOR_DIM_APPLY_i++) \ | ||||
|     { \ | ||||
|       if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ | ||||
|       { \ | ||||
|         if(TH_TENSOR_DIM_APPLY_i == TENSOR->nDimension-1) \ | ||||
|         { \ | ||||
|           TH_TENSOR_DIM_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         continue; \ | ||||
|       } \ | ||||
| \ | ||||
|       TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \ | ||||
|       TENSOR##_data += TENSOR->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
| \ | ||||
|       if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR->size[TH_TENSOR_DIM_APPLY_i]) \ | ||||
|       { \ | ||||
|         if(TH_TENSOR_DIM_APPLY_i == TENSOR->nDimension-1) \ | ||||
|         { \ | ||||
|           TH_TENSOR_DIM_APPLY_hasFinished = 1; \ | ||||
|           break; \ | ||||
|         } \ | ||||
|         else \ | ||||
|         { \ | ||||
|           TENSOR##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR->stride[TH_TENSOR_DIM_APPLY_i]; \ | ||||
|           TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ | ||||
|         } \ | ||||
|       } \ | ||||
|       else \ | ||||
|         break; \ | ||||
|     } \ | ||||
|   } \ | ||||
|   THFree(TH_TENSOR_DIM_APPLY_counter); \ | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										30
									
								
								lib/TH/THTensorMacros.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								lib/TH/THTensorMacros.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | ||||
| #ifndef TH_TENSOR_MACROS_INC | ||||
| #define TH_TENSOR_MACROS_INC | ||||
|  | ||||
| /* fast method to access to tensor data */ | ||||
|  | ||||
| #define THTensor_fastGet1d(self, x0)                                    \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]]) | ||||
|  | ||||
| #define THTensor_fastGet2d(self, x0, x1)                                \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]]) | ||||
|  | ||||
| #define THTensor_fastGet3d(self, x0, x1, x2)                            \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]]) | ||||
|  | ||||
| #define THTensor_fastGet4d(self, x0, x1, x2, x3)                        \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]]) | ||||
|  | ||||
| #define THTensor_fastSet1d(self, x0, value)                             \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]] = value) | ||||
|  | ||||
| #define THTensor_fastSet2d(self, x0, x1, value)                         \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]] = value) | ||||
|  | ||||
| #define THTensor_fastSet3d(self, x0, x1, x2, value)                     \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]] = value) | ||||
|  | ||||
| #define THTensor_fastSet4d(self, x0, x1, x2, x3, value)                 \ | ||||
|   (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]] = value) | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										240
									
								
								lib/TH/THVector.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								lib/TH/THVector.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,240 @@ | ||||
| #ifndef TH_VECTOR_INC | ||||
| #define TH_VECTOR_INC | ||||
|  | ||||
| #include "THGeneral.h" | ||||
|  | ||||
| #define THVector_(NAME) TH_CONCAT_4(TH,Real,Vector_,NAME) | ||||
|  | ||||
| #if defined __SSE2__ || defined __SSE3__ || defined __SSSE3__ \ | ||||
|   || defined __SSE4_1__ || defined __SSE4_2__ | ||||
|  | ||||
| #ifdef __SSE2__ | ||||
| #include <emmintrin.h> | ||||
| #endif | ||||
|   | ||||
| #ifdef __SSE3__ | ||||
| #include <pmmintrin.h> | ||||
| #endif | ||||
|   | ||||
| #ifdef __SSSE3__ | ||||
| #include <tmmintrin.h> | ||||
| #endif | ||||
|   | ||||
| #if defined (__SSE4_2__) || defined (__SSE4_1__) | ||||
| #include <smmintrin.h> | ||||
| #endif | ||||
|  | ||||
| #define THDoubleVector_fill(x, c, n) {          \ | ||||
|     long i;                                     \ | ||||
|     __m128d XMM0 = _mm_set1_pd(c);              \ | ||||
|     for (i=0; i<=((n)-8); i+=8) {               \ | ||||
|       _mm_storeu_pd((x)+i  , XMM0);             \ | ||||
|       _mm_storeu_pd((x)+i+2, XMM0);             \ | ||||
|       _mm_storeu_pd((x)+i+4, XMM0);             \ | ||||
|       _mm_storeu_pd((x)+i+6, XMM0);             \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%8);                   \ | ||||
|     for (i=0; i<((n)%8); i++) {                 \ | ||||
|       x[off+i] = c;                             \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
|  | ||||
| #define THDoubleVector_add(y, x, c, n) {        \ | ||||
|     long i = 0;                                 \ | ||||
|     __m128d XMM7 = _mm_set1_pd(c);              \ | ||||
|     __m128d XMM0,XMM2;                          \ | ||||
|     for (; i<=((n)-2); i+=2) {                  \ | ||||
|       XMM0 = _mm_loadu_pd((x)+i);               \ | ||||
|       XMM2 = _mm_loadu_pd((y)+i);               \ | ||||
|       XMM0 = _mm_mul_pd(XMM0, XMM7);            \ | ||||
|       XMM2 = _mm_add_pd(XMM2, XMM0);            \ | ||||
|       _mm_storeu_pd((y)+i  , XMM2);             \ | ||||
|     }                                           \ | ||||
|     for (; i<(n); i++) {                        \ | ||||
|       y[i] += c * x[i];                         \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THDoubleVector_diff(z, x, y, n) {       \ | ||||
|     long i;                                     \ | ||||
|     for (i=0; i<=((n)-8); i+=8) {               \ | ||||
|       __m128d XMM0 = _mm_loadu_pd((x)+i  );     \ | ||||
|       __m128d XMM1 = _mm_loadu_pd((x)+i+2);     \ | ||||
|       __m128d XMM2 = _mm_loadu_pd((x)+i+4);     \ | ||||
|       __m128d XMM3 = _mm_loadu_pd((x)+i+6);     \ | ||||
|       __m128d XMM4 = _mm_loadu_pd((y)+i  );     \ | ||||
|       __m128d XMM5 = _mm_loadu_pd((y)+i+2);     \ | ||||
|       __m128d XMM6 = _mm_loadu_pd((y)+i+4);     \ | ||||
|       __m128d XMM7 = _mm_loadu_pd((y)+i+6);     \ | ||||
|       XMM0 = _mm_sub_pd(XMM0, XMM4);            \ | ||||
|       XMM1 = _mm_sub_pd(XMM1, XMM5);            \ | ||||
|       XMM2 = _mm_sub_pd(XMM2, XMM6);            \ | ||||
|       XMM3 = _mm_sub_pd(XMM3, XMM7);            \ | ||||
|       _mm_storeu_pd((z)+i  , XMM0);             \ | ||||
|       _mm_storeu_pd((z)+i+2, XMM1);             \ | ||||
|       _mm_storeu_pd((z)+i+4, XMM2);             \ | ||||
|       _mm_storeu_pd((z)+i+6, XMM3);             \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%8);                   \ | ||||
|     for (i=0; i<((n)%8); i++) {                 \ | ||||
|       z[off+i] = x[off+i] - y[off+i];           \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THDoubleVector_scale(y, c, n) {         \ | ||||
|     long i;                                     \ | ||||
|     __m128d XMM7 = _mm_set1_pd(c);              \ | ||||
|     for (i=0; i<=((n)-4); i+=4) {               \ | ||||
|       __m128d XMM0 = _mm_loadu_pd((y)+i  );     \ | ||||
|       __m128d XMM1 = _mm_loadu_pd((y)+i+2);     \ | ||||
|       XMM0 = _mm_mul_pd(XMM0, XMM7);            \ | ||||
|       XMM1 = _mm_mul_pd(XMM1, XMM7);            \ | ||||
|       _mm_storeu_pd((y)+i  , XMM0);             \ | ||||
|       _mm_storeu_pd((y)+i+2, XMM1);             \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%4);                   \ | ||||
|     for (i=0; i<((n)%4); i++) {                 \ | ||||
|       y[off+i] *= c;                            \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THDoubleVector_mul(y, x, n) {           \ | ||||
|     long i;                                     \ | ||||
|     for (i=0; i<=((n)-8); i+=8) {               \ | ||||
|       __m128d XMM0 = _mm_loadu_pd((x)+i  );     \ | ||||
|       __m128d XMM1 = _mm_loadu_pd((x)+i+2);     \ | ||||
|       __m128d XMM2 = _mm_loadu_pd((x)+i+4);     \ | ||||
|       __m128d XMM3 = _mm_loadu_pd((x)+i+6);     \ | ||||
|       __m128d XMM4 = _mm_loadu_pd((y)+i  );     \ | ||||
|       __m128d XMM5 = _mm_loadu_pd((y)+i+2);     \ | ||||
|       __m128d XMM6 = _mm_loadu_pd((y)+i+4);     \ | ||||
|       __m128d XMM7 = _mm_loadu_pd((y)+i+6);     \ | ||||
|       XMM4 = _mm_mul_pd(XMM4, XMM0);            \ | ||||
|       XMM5 = _mm_mul_pd(XMM5, XMM1);            \ | ||||
|       XMM6 = _mm_mul_pd(XMM6, XMM2);            \ | ||||
|       XMM7 = _mm_mul_pd(XMM7, XMM3);            \ | ||||
|       _mm_storeu_pd((y)+i  , XMM4);             \ | ||||
|       _mm_storeu_pd((y)+i+2, XMM5);             \ | ||||
|       _mm_storeu_pd((y)+i+4, XMM6);             \ | ||||
|       _mm_storeu_pd((y)+i+6, XMM7);             \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%8);                   \ | ||||
|     for (i=0; i<((n)%8); i++) {                 \ | ||||
|       y[off+i] *= x[off+i];                     \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THFloatVector_fill(x, c, n) {           \ | ||||
|     long i;                                     \ | ||||
|     __m128 XMM0 = _mm_set_ps1(c);               \ | ||||
|     for (i=0; i<=((n)-16); i+=16) {             \ | ||||
|       _mm_storeu_ps((x)+i  ,  XMM0);            \ | ||||
|       _mm_storeu_ps((x)+i+4,  XMM0);            \ | ||||
|       _mm_storeu_ps((x)+i+8,  XMM0);            \ | ||||
|       _mm_storeu_ps((x)+i+12, XMM0);            \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%16);                  \ | ||||
|     for (i=0; i<((n)%16); i++) {                \ | ||||
|       x[off+i] = c;                             \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THFloatVector_add(y, x, c, n) {         \ | ||||
|     long i = 0;                                 \ | ||||
|     __m128 XMM7 = _mm_set_ps1(c);               \ | ||||
|     __m128 XMM0,XMM2;                           \ | ||||
|     for (; i<=((n)-4); i+=4) {                  \ | ||||
|       XMM0 = _mm_loadu_ps((x)+i);               \ | ||||
|       XMM2 = _mm_loadu_ps((y)+i);               \ | ||||
|       XMM0 = _mm_mul_ps(XMM0, XMM7);            \ | ||||
|       XMM2 = _mm_add_ps(XMM2, XMM0);            \ | ||||
|       _mm_storeu_ps((y)+i  , XMM2);             \ | ||||
|     }                                           \ | ||||
|     for (; i<(n); i++) {                        \ | ||||
|       y[i] += c * x[i];                         \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THFloatVector_diff(z, x, y, n) {        \ | ||||
|     long i;                                     \ | ||||
|     for (i=0; i<=((n)-16); i+=16) {             \ | ||||
|       __m128 XMM0 = _mm_loadu_ps((x)+i   );     \ | ||||
|       __m128 XMM1 = _mm_loadu_ps((x)+i+ 4);     \ | ||||
|       __m128 XMM2 = _mm_loadu_ps((x)+i+ 8);     \ | ||||
|       __m128 XMM3 = _mm_loadu_ps((x)+i+12);     \ | ||||
|       __m128 XMM4 = _mm_loadu_ps((y)+i   );     \ | ||||
|       __m128 XMM5 = _mm_loadu_ps((y)+i+ 4);     \ | ||||
|       __m128 XMM6 = _mm_loadu_ps((y)+i+ 8);     \ | ||||
|       __m128 XMM7 = _mm_loadu_ps((y)+i+12);     \ | ||||
|       XMM0 = _mm_sub_ps(XMM0, XMM4);            \ | ||||
|       XMM1 = _mm_sub_ps(XMM1, XMM5);            \ | ||||
|       XMM2 = _mm_sub_ps(XMM2, XMM6);            \ | ||||
|       XMM3 = _mm_sub_ps(XMM3, XMM7);            \ | ||||
|       _mm_storeu_ps((z)+i   , XMM0);            \ | ||||
|       _mm_storeu_ps((z)+i+ 4, XMM1);            \ | ||||
|       _mm_storeu_ps((z)+i+ 8, XMM2);            \ | ||||
|       _mm_storeu_ps((z)+i+12, XMM3);            \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%16);                  \ | ||||
|     for (i=0; i<((n)%16); i++) {                \ | ||||
|       z[off+i] = x[off+i] - y[off+i];           \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THFloatVector_scale(y, c, n) {          \ | ||||
|     long i;                                     \ | ||||
|     __m128 XMM7 = _mm_set_ps1(c);               \ | ||||
|     for (i=0; i<=((n)-8); i+=8) {               \ | ||||
|       __m128 XMM0 = _mm_loadu_ps((y)+i  );      \ | ||||
|       __m128 XMM1 = _mm_loadu_ps((y)+i+4);      \ | ||||
|       XMM0 = _mm_mul_ps(XMM0, XMM7);            \ | ||||
|       XMM1 = _mm_mul_ps(XMM1, XMM7);            \ | ||||
|       _mm_storeu_ps((y)+i  , XMM0);             \ | ||||
|       _mm_storeu_ps((y)+i+4, XMM1);             \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%8);                   \ | ||||
|     for (i=0; i<((n)%8); i++) {                 \ | ||||
|       y[off+i] *= c;                            \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #define THFloatVector_mul(y, x, n) {            \ | ||||
|     long i;                                     \ | ||||
|     for (i=0; i<=((n)-16); i+=16) {             \ | ||||
|       __m128 XMM0 = _mm_loadu_ps((x)+i   );     \ | ||||
|       __m128 XMM1 = _mm_loadu_ps((x)+i+ 4);     \ | ||||
|       __m128 XMM2 = _mm_loadu_ps((x)+i+ 8);     \ | ||||
|       __m128 XMM3 = _mm_loadu_ps((x)+i+12);     \ | ||||
|       __m128 XMM4 = _mm_loadu_ps((y)+i   );     \ | ||||
|       __m128 XMM5 = _mm_loadu_ps((y)+i+ 4);     \ | ||||
|       __m128 XMM6 = _mm_loadu_ps((y)+i+ 8);     \ | ||||
|       __m128 XMM7 = _mm_loadu_ps((y)+i+12);     \ | ||||
|       XMM4 = _mm_mul_ps(XMM4, XMM0);            \ | ||||
|       XMM5 = _mm_mul_ps(XMM5, XMM1);            \ | ||||
|       XMM6 = _mm_mul_ps(XMM6, XMM2);            \ | ||||
|       XMM7 = _mm_mul_ps(XMM7, XMM3);            \ | ||||
|       _mm_storeu_ps((y)+i   , XMM4);            \ | ||||
|       _mm_storeu_ps((y)+i+ 4, XMM5);            \ | ||||
|       _mm_storeu_ps((y)+i+ 8, XMM6);            \ | ||||
|       _mm_storeu_ps((y)+i+12, XMM7);            \ | ||||
|     }                                           \ | ||||
|     long off = (n) - ((n)%16);                  \ | ||||
|     for (i=0; i<((n)%16); i++) {                \ | ||||
|       y[off+i] *= x[off+i];                     \ | ||||
|     }                                           \ | ||||
|   } | ||||
|  | ||||
| #else | ||||
|  | ||||
| /* If SSE2 not defined, then generate plain C operators */ | ||||
| #include "generic/THVector.c" | ||||
| #include "THGenerateFloatTypes.h" | ||||
|  | ||||
| #endif | ||||
|  | ||||
| /* For non-float types, generate plain C operators */ | ||||
| #include "generic/THVector.c" | ||||
| #include "THGenerateIntTypes.h" | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										212
									
								
								lib/TH/cmake/FindBLAS.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								lib/TH/cmake/FindBLAS.cmake
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,212 @@ | ||||
| # - Find BLAS library | ||||
| # This module finds an installed fortran library that implements the BLAS  | ||||
| # linear-algebra interface (see http://www.netlib.org/blas/).   | ||||
| # The list of libraries searched for is taken | ||||
| # from the autoconf macro file, acx_blas.m4 (distributed at | ||||
| # http://ac-archive.sourceforge.net/ac-archive/acx_blas.html). | ||||
| # | ||||
| # This module sets the following variables: | ||||
| #  BLAS_FOUND - set to true if a library implementing the BLAS interface is found. | ||||
| #  BLAS_INFO - name of the detected BLAS library. | ||||
| #  BLAS_F2C - set to true if following the f2c return convention | ||||
| #  BLAS_LIBRARIES - list of libraries to link against to use BLAS | ||||
| #  BLAS_INCLUDE_DIR - include directory | ||||
|  | ||||
| SET(BLAS_LIBRARIES) | ||||
| SET(BLAS_INCLUDE_DIR) | ||||
| SET(BLAS_INFO) | ||||
| SET(BLAS_F2C) | ||||
|  | ||||
| # CBLAS in Intel mkl | ||||
| FIND_PACKAGE(MKL) | ||||
| IF (MKL_FOUND AND NOT BLAS_LIBRARIES) | ||||
|   SET(BLAS_INFO imkl) | ||||
|   SET(BLAS_LIBRARIES ${MKL_LIBRARIES}) | ||||
|   SET(BLAS_INCLUDE_DIR ${MKL_INCLUDE_DIR}) | ||||
|   SET(BLAS_VERSION ${MKL_VERSION}) | ||||
| ENDIF (MKL_FOUND AND NOT BLAS_LIBRARIES) | ||||
|  | ||||
| # Old FindBlas | ||||
| INCLUDE(CheckCSourceRuns) | ||||
| INCLUDE(CheckFortranFunctionExists) | ||||
| SET(_verbose TRUE) | ||||
|  | ||||
| MACRO(Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list) | ||||
|   # This macro checks for the existence of the combination of fortran libraries | ||||
|   # given by _list.  If the combination is found, this macro checks (using the  | ||||
|   # Check_Fortran_Function_Exists macro) whether can link against that library | ||||
|   # combination using the name of a routine given by _name using the linker | ||||
|   # flags given by _flags.  If the combination of libraries is found and passes | ||||
|   # the link test, LIBRARIES is set to the list of complete library paths that | ||||
|   # have been found.  Otherwise, LIBRARIES is set to NOTFOUND. | ||||
|   # N.B. _prefix is the prefix applied to the names of all cached variables that | ||||
|   # are generated internally and marked advanced by this macro. | ||||
|    | ||||
|   set(__list) | ||||
|   foreach(_elem ${_list}) | ||||
|     if(__list) | ||||
|       set(__list "${__list} - ${_elem}") | ||||
|     else(__list) | ||||
|       set(__list "${_elem}") | ||||
|     endif(__list) | ||||
|   endforeach(_elem) | ||||
|   if(_verbose) | ||||
|     message(STATUS "Checking for [${__list}]") | ||||
|   endif(_verbose) | ||||
|  | ||||
|   set(_libraries_work TRUE) | ||||
|   set(${LIBRARIES}) | ||||
|   set(_combined_name) | ||||
|   foreach(_library ${_list}) | ||||
|     set(_combined_name ${_combined_name}_${_library}) | ||||
|     if(_libraries_work) | ||||
|       if ( WIN32 ) | ||||
|         find_library(${_prefix}_${_library}_LIBRARY | ||||
|           NAMES ${_library} | ||||
|           PATHS ENV LIB  | ||||
|           PATHS ENV PATH ) | ||||
|       endif ( WIN32 ) | ||||
|       if ( APPLE )  | ||||
|         find_library(${_prefix}_${_library}_LIBRARY | ||||
|           NAMES ${_library} | ||||
|           PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64  | ||||
|           ENV DYLD_LIBRARY_PATH ) | ||||
|       else ( APPLE ) | ||||
|         find_library(${_prefix}_${_library}_LIBRARY | ||||
|           NAMES ${_library} | ||||
|           PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64  | ||||
|           ENV LD_LIBRARY_PATH ) | ||||
|       endif( APPLE ) | ||||
|       mark_as_advanced(${_prefix}_${_library}_LIBRARY) | ||||
|       set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) | ||||
|       set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) | ||||
|     endif(_libraries_work) | ||||
|   endforeach(_library ${_list}) | ||||
|   if(_libraries_work) | ||||
|     # Test this combination of libraries. | ||||
|     set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}}) | ||||
|     if (CMAKE_Fortran_COMPILER_WORKS) | ||||
|       check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) | ||||
|     else (CMAKE_Fortran_COMPILER_WORKS) | ||||
|       check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) | ||||
|     endif (CMAKE_Fortran_COMPILER_WORKS) | ||||
|     set(CMAKE_REQUIRED_LIBRARIES) | ||||
|     mark_as_advanced(${_prefix}${_combined_name}_WORKS) | ||||
|     set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) | ||||
|   endif(_libraries_work) | ||||
|   if(NOT _libraries_work) | ||||
|     set(${LIBRARIES} NOTFOUND) | ||||
|   endif(NOT _libraries_work) | ||||
| endmacro(Check_Fortran_Libraries) | ||||
|  | ||||
|  | ||||
| # Apple BLAS library? | ||||
| if(NOT BLAS_LIBRARIES) | ||||
|   check_fortran_libraries( | ||||
|   BLAS_LIBRARIES | ||||
|   BLAS | ||||
|   sgemm | ||||
|   "" | ||||
|   "Accelerate") | ||||
|   if (BLAS_LIBRARIES) | ||||
|     set(BLAS_INFO "accelerate") | ||||
|   endif (BLAS_LIBRARIES) | ||||
| endif(NOT BLAS_LIBRARIES) | ||||
| if ( NOT BLAS_LIBRARIES ) | ||||
|   check_fortran_libraries( | ||||
|     BLAS_LIBRARIES | ||||
|     BLAS | ||||
|     sgemm | ||||
|     "" | ||||
|     "vecLib") | ||||
|   if (BLAS_LIBRARIES) | ||||
|     set(BLAS_INFO "veclib") | ||||
|   endif (BLAS_LIBRARIES) | ||||
| endif ( NOT BLAS_LIBRARIES ) | ||||
|  | ||||
| # BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) | ||||
| if(NOT BLAS_LIBRARIES) | ||||
|   check_fortran_libraries( | ||||
|   BLAS_LIBRARIES | ||||
|   BLAS | ||||
|   sgemm | ||||
|   "" | ||||
|   "cblas;f77blas;atlas") | ||||
|   if (BLAS_LIBRARIES) | ||||
|     set(BLAS_INFO "atlas") | ||||
|   endif (BLAS_LIBRARIES) | ||||
| endif(NOT BLAS_LIBRARIES) | ||||
|  | ||||
| # Generic BLAS library? | ||||
| if(NOT BLAS_LIBRARIES) | ||||
|   check_fortran_libraries( | ||||
|   BLAS_LIBRARIES | ||||
|   BLAS | ||||
|   sgemm | ||||
|   "" | ||||
|   "blas") | ||||
|   if (BLAS_LIBRARIES) | ||||
|     set(BLAS_INFO "generic") | ||||
|   endif (BLAS_LIBRARIES) | ||||
| endif(NOT BLAS_LIBRARIES) | ||||
|  | ||||
| # Determine if blas was compiled with the f2c conventions | ||||
| IF (BLAS_LIBRARIES) | ||||
|   SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) | ||||
|   CHECK_C_SOURCE_RUNS(" | ||||
| #include <stdio.h> | ||||
| float x[4] = { 1, 2, 3, 4 }; | ||||
| float y[4] = { .1, .01, .001, .0001 }; | ||||
| int four = 4; | ||||
| int one = 1; | ||||
| extern double sdot_(); | ||||
| int main() { | ||||
|   int i; | ||||
|   double r = sdot_(&four, x, &one, y, &one); | ||||
|   exit((float)r != (float).1234); | ||||
| }" BLAS_F2C_DOUBLE_WORKS ) | ||||
|   CHECK_C_SOURCE_RUNS(" | ||||
| #include <stdio.h> | ||||
| float x[4] = { 1, 2, 3, 4 }; | ||||
| float y[4] = { .1, .01, .001, .0001 }; | ||||
| int four = 4; | ||||
| int one = 1; | ||||
| extern float sdot_(); | ||||
| int main() { | ||||
|   int i; | ||||
|   double r = sdot_(&four, x, &one, y, &one); | ||||
|   exit((float)r != (float).1234); | ||||
| }" BLAS_F2C_FLOAT_WORKS ) | ||||
|   IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) | ||||
|     IF (_verbose) | ||||
|       MESSAGE(STATUS "This BLAS uses the F2C return conventions") | ||||
|     ENDIF(_verbose) | ||||
|     SET(BLAS_F2C TRUE) | ||||
|   ELSE (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) | ||||
|     SET(BLAS_F2C FALSE) | ||||
|   ENDIF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) | ||||
| ENDIF(BLAS_LIBRARIES) | ||||
|  | ||||
| # epilogue | ||||
|  | ||||
| if(BLAS_LIBRARIES) | ||||
|   set(BLAS_FOUND TRUE) | ||||
| else(BLAS_LIBRARIES) | ||||
|   set(BLAS_FOUND FALSE) | ||||
| endif(BLAS_LIBRARIES) | ||||
|  | ||||
| IF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) | ||||
|   message(FATAL_ERROR "Cannot find a library with BLAS API. Please specify library location.") | ||||
| ENDIF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) | ||||
| IF(NOT BLAS_FIND_QUIETLY) | ||||
|   IF(BLAS_FOUND) | ||||
|     MESSAGE(STATUS "Found a library with BLAS API (${BLAS_INFO}).") | ||||
|   ELSE(BLAS_FOUND) | ||||
|     MESSAGE(STATUS "Cannot find a library with BLAS API. Not using BLAS.") | ||||
|   ENDIF(BLAS_FOUND) | ||||
| ENDIF(NOT BLAS_FIND_QUIETLY) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
							
								
								
									
										166
									
								
								lib/TH/cmake/FindLAPACK.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								lib/TH/cmake/FindLAPACK.cmake
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,166 @@ | ||||
| # - Find LAPACK library | ||||
| # This module finds an installed fortran library that implements the LAPACK | ||||
| # linear-algebra interface (see http://www.netlib.org/lapack/). | ||||
| # | ||||
| # The approach follows that taken for the autoconf macro file, acx_lapack.m4 | ||||
| # (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). | ||||
| # | ||||
| # This module sets the following variables: | ||||
| #  LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found | ||||
| #  LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK | ||||
|  | ||||
| SET(LAPACK_LIBRARIES) | ||||
| SET(LAPACK_INFO) | ||||
|  | ||||
| IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) | ||||
|   FIND_PACKAGE(BLAS) | ||||
| ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) | ||||
|   FIND_PACKAGE(BLAS REQUIRED) | ||||
| ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) | ||||
|  | ||||
| # LAPACK in Intel mkl | ||||
| IF (MKL_FOUND AND NOT LAPACK_LIBRARIES) | ||||
|   SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES}) | ||||
|   SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR}) | ||||
|   SET(LAPACK_INFO "mkl") | ||||
| ENDIF (MKL_FOUND AND NOT LAPACK_LIBRARIES) | ||||
|  | ||||
| # Old search lapack script | ||||
| include(CheckFortranFunctionExists) | ||||
|  | ||||
| macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas) | ||||
|   # This macro checks for the existence of the combination of fortran libraries | ||||
|   # given by _list.  If the combination is found, this macro checks (using the  | ||||
|   # Check_Fortran_Function_Exists macro) whether can link against that library | ||||
|   # combination using the name of a routine given by _name using the linker | ||||
|   # flags given by _flags.  If the combination of libraries is found and passes | ||||
|   # the link test, LIBRARIES is set to the list of complete library paths that | ||||
|   # have been found.  Otherwise, LIBRARIES is set to FALSE. | ||||
|   # N.B. _prefix is the prefix applied to the names of all cached variables that | ||||
|   # are generated internally and marked advanced by this macro. | ||||
|   set(_libraries_work TRUE) | ||||
|   set(${LIBRARIES}) | ||||
|   set(_combined_name) | ||||
|   foreach(_library ${_list}) | ||||
|     set(_combined_name ${_combined_name}_${_library}) | ||||
|     if(_libraries_work) | ||||
|       if (WIN32) | ||||
|         find_library(${_prefix}_${_library}_LIBRARY | ||||
|           NAMES ${_library} PATHS ENV LIB PATHS ENV PATH) | ||||
|       else (WIN32) | ||||
|         if(APPLE) | ||||
|           find_library(${_prefix}_${_library}_LIBRARY | ||||
|             NAMES ${_library} | ||||
|             PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64  | ||||
|             ENV DYLD_LIBRARY_PATH) | ||||
|         else(APPLE) | ||||
|           find_library(${_prefix}_${_library}_LIBRARY | ||||
|             NAMES ${_library} | ||||
|             PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64  | ||||
|             ENV LD_LIBRARY_PATH) | ||||
|         endif(APPLE) | ||||
|       endif(WIN32) | ||||
|       mark_as_advanced(${_prefix}_${_library}_LIBRARY) | ||||
|       set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) | ||||
|       set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) | ||||
|     endif(_libraries_work) | ||||
|   endforeach(_library ${_list}) | ||||
|   if(_libraries_work) | ||||
|     # Test this combination of libraries. | ||||
|     set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas}) | ||||
|     if (CMAKE_Fortran_COMPILER_WORKS) | ||||
|       check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) | ||||
|     else (CMAKE_Fortran_COMPILER_WORKS) | ||||
|       check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) | ||||
|     endif (CMAKE_Fortran_COMPILER_WORKS) | ||||
|     set(CMAKE_REQUIRED_LIBRARIES) | ||||
|     mark_as_advanced(${_prefix}${_combined_name}_WORKS) | ||||
|     set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) | ||||
|   endif(_libraries_work) | ||||
|   if(NOT _libraries_work) | ||||
|     set(${LIBRARIES} FALSE) | ||||
|   endif(NOT _libraries_work) | ||||
| endmacro(Check_Lapack_Libraries) | ||||
|  | ||||
|  | ||||
| if(BLAS_FOUND) | ||||
|    | ||||
|   #acml lapack | ||||
|   if(NOT LAPACK_LIBRARIES) | ||||
|     check_lapack_libraries( | ||||
|       LAPACK_LIBRARIES | ||||
|       LAPACK | ||||
|       cheev | ||||
|       "" | ||||
|       "acml" | ||||
|       "${BLAS_LIBRARIES}" | ||||
|       ) | ||||
|     if(LAPACK_LIBRARIES) | ||||
|       SET(LAPACK_INFO "acml") | ||||
|     endif(LAPACK_LIBRARIES) | ||||
|   endif(NOT LAPACK_LIBRARIES) | ||||
|  | ||||
|   # Apple LAPACK library? | ||||
|   if(NOT LAPACK_LIBRARIES) | ||||
|     check_lapack_libraries( | ||||
|       LAPACK_LIBRARIES | ||||
|       LAPACK | ||||
|       cheev | ||||
|       "" | ||||
|       "Accelerate" | ||||
|       "${BLAS_LIBRARIES}" | ||||
|       ) | ||||
|     if(LAPACK_LIBRARIES) | ||||
|       SET(LAPACK_INFO "Accelerate") | ||||
|     endif(LAPACK_LIBRARIES) | ||||
|   endif(NOT LAPACK_LIBRARIES) | ||||
|  | ||||
|   if ( NOT LAPACK_LIBRARIES ) | ||||
|     check_lapack_libraries( | ||||
|       LAPACK_LIBRARIES | ||||
|       LAPACK | ||||
|       cheev | ||||
|       "" | ||||
|       "vecLib" | ||||
|       "${BLAS_LIBRARIES}" | ||||
|       ) | ||||
|     if(LAPACK_LIBRARIES) | ||||
|       SET(LAPACK_INFO "veclib") | ||||
|     endif(LAPACK_LIBRARIES) | ||||
|   endif ( NOT LAPACK_LIBRARIES ) | ||||
|  | ||||
|   # Generic LAPACK library? | ||||
|   if ( NOT LAPACK_LIBRARIES ) | ||||
|     check_lapack_libraries( | ||||
|       LAPACK_LIBRARIES | ||||
|       LAPACK | ||||
|       cheev | ||||
|       "" | ||||
|       "lapack" | ||||
|       "${BLAS_LIBRARIES}" | ||||
|       ) | ||||
|     if(LAPACK_LIBRARIES) | ||||
|       SET(LAPACK_INFO "generic") | ||||
|     endif(LAPACK_LIBRARIES) | ||||
|   endif ( NOT LAPACK_LIBRARIES ) | ||||
|  | ||||
| else(BLAS_FOUND) | ||||
|   message(STATUS "LAPACK requires BLAS") | ||||
| endif(BLAS_FOUND) | ||||
|  | ||||
| if(LAPACK_LIBRARIES) | ||||
|   set(LAPACK_FOUND TRUE) | ||||
| else(LAPACK_LIBRARIES) | ||||
|   set(LAPACK_FOUND FALSE) | ||||
| endif(LAPACK_LIBRARIES) | ||||
|  | ||||
| IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) | ||||
|   message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.") | ||||
| ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) | ||||
| IF(NOT LAPACK_FIND_QUIETLY) | ||||
|   IF(LAPACK_FOUND) | ||||
|     MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})") | ||||
|   ELSE(LAPACK_FOUND) | ||||
|     MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.") | ||||
|   ENDIF(LAPACK_FOUND) | ||||
| ENDIF(NOT LAPACK_FIND_QUIETLY) | ||||
							
								
								
									
										274
									
								
								lib/TH/cmake/FindMKL.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								lib/TH/cmake/FindMKL.cmake
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,274 @@ | ||||
| # - Find INTEL MKL library | ||||
| # | ||||
| # This module finds the Intel Mkl libraries. | ||||
| # | ||||
| # This module sets the following variables: | ||||
| #  MKL_FOUND - set to true if a library implementing the CBLAS interface is found | ||||
| #  MKL_VERSION - best guess | ||||
| #  MKL_INCLUDE_DIR - path to include dir. | ||||
| #  MKL_LIBRARIES - list of libraries for base mkl | ||||
| #  MKL_LAPACK_LIBRARIES - list of libraries to add for lapack | ||||
| #  MKL_SCALAPACK_LIBRARIES - list of libraries to add for scalapack | ||||
| #  MKL_SOLVER_LIBRARIES - list of libraries to add for the solvers | ||||
| #  MKL_CDFT_LIBRARIES - list of libraries to add for the solvers | ||||
|  | ||||
|  | ||||
| # Do nothing if MKL_FOUND was set before! | ||||
| IF (NOT MKL_FOUND) | ||||
|  | ||||
| SET(MKL_VERSION) | ||||
| SET(MKL_INCLUDE_DIR) | ||||
| SET(MKL_LIBRARIES) | ||||
| SET(MKL_LAPACK_LIBRARIES) | ||||
| SET(MKL_SCALAPACK_LIBRARIES) | ||||
| SET(MKL_SOLVER_LIBRARIES) | ||||
| SET(MKL_CDFT_LIBRARIES) | ||||
|  | ||||
| # Includes | ||||
| INCLUDE(CheckTypeSize) | ||||
| INCLUDE(CheckFunctionExists) | ||||
|  | ||||
| # Prints diagnostic | ||||
| # SET(_verbose TRUE) | ||||
|  | ||||
| # Intel Compiler Suite | ||||
| SET(INTEL_COMPILER_DIR CACHE STRING | ||||
|   "Root directory of the Intel Compiler Suite (contains ipp, mkl, etc.)") | ||||
| SET(INTEL_MKL_DIR CACHE STRING | ||||
|   "Root directory of the Intel MKL (standalone)") | ||||
| SET(INTEL_MKL_SEQUENTIAL OFF CACHE BOOL | ||||
|   "Force using the sequential (non threaded) libraries") | ||||
|  | ||||
| # Checks | ||||
| CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP) | ||||
| IF ("${SIZE_OF_VOIDP}" EQUAL 8) | ||||
|   SET(mklvers "em64t") | ||||
|   SET(iccvers "intel64") | ||||
|   SET(mkl64s "_lp64") | ||||
| ELSE ("${SIZE_OF_VOIDP}" EQUAL 8) | ||||
|   SET(mklvers "32") | ||||
|   SET(iccvers "ia32") | ||||
|   SET(mkl64s) | ||||
| ENDIF ("${SIZE_OF_VOIDP}" EQUAL 8) | ||||
| IF (CMAKE_COMPILER_IS_GNUCC) | ||||
|   SET(mklthreads "mkl_gnu_thread" "mkl_intel_thread") | ||||
|   SET(mklifaces  "gf" "intel") | ||||
| ELSE (CMAKE_COMPILER_IS_GNUCC) | ||||
|   SET(mklthreads "mkl_intel_thread") | ||||
|   SET(mklifaces  "intel") | ||||
| ENDIF (CMAKE_COMPILER_IS_GNUCC) | ||||
| SET(mklrtls "iomp5" "guide") | ||||
|  | ||||
| # Kernel libraries dynamically loaded | ||||
| SET(mklkerlibs "mc" "mc3" "nc" "p4n" "p4m" "p4m3" "p4p" "def") | ||||
| SET(mklseq) | ||||
|  | ||||
|  | ||||
|  | ||||
| # Paths | ||||
| SET(saved_CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}) | ||||
| SET(saved_CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH}) | ||||
| IF (INTEL_COMPILER_DIR) | ||||
|   # TODO: diagnostic if dir does not exist | ||||
|   SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} | ||||
|     "${INTEL_COMPILER_DIR}/lib/${iccvers}") | ||||
|   IF (NOT INTEL_MKL_DIR) | ||||
|     SET(INTEL_MKL_DIR "${INTEL_COMPILER_DIR}/mkl") | ||||
|   ENDIF (NOT INTEL_MKL_DIR) | ||||
| ENDIF (INTEL_COMPILER_DIR) | ||||
| IF (INTEL_MKL_DIR) | ||||
|   # TODO: diagnostic if dir does not exist | ||||
|   SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} | ||||
|     "${INTEL_MKL_DIR}/include") | ||||
|   SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} | ||||
|     "${INTEL_MKL_DIR}/lib/${mklvers}") | ||||
| ENDIF (INTEL_MKL_DIR) | ||||
|  | ||||
| # Try linking multiple libs | ||||
| MACRO(CHECK_ALL_LIBRARIES LIBRARIES _name _list _flags) | ||||
|   # This macro checks for the existence of the combination of libraries given by _list. | ||||
|   # If the combination is found, this macro whether we can link against that library | ||||
|   # combination using the name of a routine given by _name using the linker | ||||
|   # flags given by _flags.  If the combination of libraries is found and passes | ||||
|   # the link test, LIBRARIES is set to the list of complete library paths that | ||||
|   # have been found.  Otherwise, LIBRARIES is set to FALSE. | ||||
|   # N.B. _prefix is the prefix applied to the names of all cached variables that | ||||
|   # are generated internally and marked advanced by this macro. | ||||
|   SET(_prefix "${LIBRARIES}") | ||||
|   IF (_verbose) | ||||
|     SET(__list) | ||||
|     FOREACH(_elem ${_list}) | ||||
|       IF(__list) | ||||
|         SET(__list "${__list} - ${_elem}") | ||||
|       ELSE(__list) | ||||
|         SET(__list "${_elem}") | ||||
|       ENDIF(__list) | ||||
|     ENDFOREACH(_elem) | ||||
|   ENDIF(_verbose) | ||||
|   # start checking | ||||
|   SET(_libraries_work TRUE) | ||||
|   SET(${LIBRARIES}) | ||||
|   SET(_combined_name) | ||||
|   SET(_paths) | ||||
|   FOREACH(_library ${_list}) | ||||
|     SET(_combined_name ${_combined_name}_${_library}) | ||||
|     IF(_libraries_work)       | ||||
|       FIND_LIBRARY(${_prefix}_${_library}_LIBRARY NAMES ${_library}) | ||||
|       MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY) | ||||
|       SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) | ||||
|       SET(_libraries_work ${${_prefix}_${_library}_LIBRARY}) | ||||
|     ENDIF(_libraries_work) | ||||
|   ENDFOREACH(_library ${_list}) | ||||
|   # Test this combination of libraries. | ||||
|   IF(_libraries_work) | ||||
|     SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}}) | ||||
|     CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS) | ||||
|     SET(CMAKE_REQUIRED_LIBRARIES) | ||||
|     MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS) | ||||
|     SET(_libraries_work ${${_prefix}${_combined_name}_WORKS}) | ||||
|   ENDIF(_libraries_work) | ||||
|   # Fin | ||||
|   IF(_libraries_work) | ||||
|     IF (_verbose) | ||||
|       MESSAGE(STATUS "FindMKL: ${__list} : ok") | ||||
|     ENDIF (_verbose) | ||||
|   ELSE (_libraries_work) | ||||
|     SET(${LIBRARIES}) | ||||
|     MARK_AS_ADVANCED(${LIBRARIES}) | ||||
|     IF (_verbose) | ||||
|       MESSAGE(STATUS "FindMKL: ${__list} : no") | ||||
|     ENDIF (_verbose) | ||||
|   ENDIF(_libraries_work) | ||||
| ENDMACRO(CHECK_ALL_LIBRARIES) | ||||
|  | ||||
|  | ||||
| # Check for version 10/11 | ||||
| IF (NOT MKL_LIBRARIES) | ||||
|   SET(MKL_VERSION 1011) | ||||
| ENDIF (NOT MKL_LIBRARIES) | ||||
| FOREACH(mklrtl ${mklrtls}) | ||||
|   FOREACH(mkliface ${mklifaces}) | ||||
|     FOREACH(mkl64 ${mkl64s} "") | ||||
|       FOREACH(mklthread ${mklthreads}) | ||||
|         IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL) | ||||
|           CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm | ||||
|             "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;m" "") | ||||
|         ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)           | ||||
|       ENDFOREACH(mklthread) | ||||
|     ENDFOREACH(mkl64) | ||||
|   ENDFOREACH(mkliface) | ||||
| ENDFOREACH(mklrtl) | ||||
| FOREACH(mklrtl ${mklrtls}) | ||||
|   FOREACH(mkliface ${mklifaces}) | ||||
|     FOREACH(mkl64 ${mkl64s} "") | ||||
|       IF (NOT MKL_LIBRARIES) | ||||
|         CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm | ||||
|           "mkl_${mkliface}${mkl64};mkl_sequential;mkl_core;m" "") | ||||
|         IF (MKL_LIBRARIES) | ||||
|           SET(mklseq "_sequential") | ||||
|         ENDIF (MKL_LIBRARIES) | ||||
|       ENDIF (NOT MKL_LIBRARIES) | ||||
|     ENDFOREACH(mkl64) | ||||
|   ENDFOREACH(mkliface) | ||||
| ENDFOREACH(mklrtl) | ||||
| FOREACH(mklrtl ${mklrtls}) | ||||
|   FOREACH(mkliface ${mklifaces}) | ||||
|     FOREACH(mkl64 ${mkl64s} "") | ||||
|       FOREACH(mklthread ${mklthreads}) | ||||
|         IF (NOT MKL_LIBRARIES) | ||||
|           CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm | ||||
|             "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;m" "") | ||||
|         ENDIF (NOT MKL_LIBRARIES)           | ||||
|       ENDFOREACH(mklthread) | ||||
|     ENDFOREACH(mkl64) | ||||
|   ENDFOREACH(mkliface) | ||||
| ENDFOREACH(mklrtl) | ||||
|  | ||||
| # Check for older versions | ||||
| IF (NOT MKL_LIBRARIES) | ||||
|   SET(MKL_VERSION 900) | ||||
|   CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm | ||||
|     "mkl;guide;pthread;m" "") | ||||
| ENDIF (NOT MKL_LIBRARIES)           | ||||
|  | ||||
| # Include files | ||||
| IF (MKL_LIBRARIES) | ||||
|   FIND_PATH(MKL_INCLUDE_DIR "mkl_cblas.h") | ||||
|   MARK_AS_ADVANCED(MKL_INCLUDE_DIR) | ||||
| ENDIF (MKL_LIBRARIES) | ||||
|  | ||||
| # Other libraries | ||||
| IF (MKL_LIBRARIES) | ||||
|   FOREACH(mkl64 ${mkl64s} "_core" "") | ||||
|     FOREACH(mkls ${mklseq} "") | ||||
|       IF (NOT MKL_LAPACK_LIBRARIES) | ||||
|         FIND_LIBRARY(MKL_LAPACK_LIBRARIES NAMES "mkl_lapack${mkl64}${mkls}") | ||||
|         MARK_AS_ADVANCED(MKL_LAPACK_LIBRARIES) | ||||
|       ENDIF (NOT MKL_LAPACK_LIBRARIES) | ||||
|       IF (NOT MKL_SCALAPACK_LIBRARIES) | ||||
|         FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}")  | ||||
|         MARK_AS_ADVANCED(MKL_SCALAPACK_LIBRARIES) | ||||
|       ENDIF (NOT MKL_SCALAPACK_LIBRARIES) | ||||
|       IF (NOT MKL_SOLVER_LIBRARIES) | ||||
|         FIND_LIBRARY(MKL_SOLVER_LIBRARIES NAMES "mkl_solver${mkl64}${mkls}") | ||||
|         MARK_AS_ADVANCED(MKL_SOLVER_LIBRARIES) | ||||
|       ENDIF (NOT MKL_SOLVER_LIBRARIES) | ||||
|       IF (NOT MKL_CDFT_LIBRARIES) | ||||
|         FIND_LIBRARY(MKL_CDFT_LIBRARIES NAMES "mkl_cdft${mkl64}${mkls}") | ||||
|         MARK_AS_ADVANCED(MKL_CDFT_LIBRARIES) | ||||
|       ENDIF (NOT MKL_CDFT_LIBRARIES) | ||||
|     ENDFOREACH(mkls) | ||||
|   ENDFOREACH(mkl64) | ||||
| ENDIF (MKL_LIBRARIES) | ||||
|  | ||||
| # LibIRC: intel compiler always links this;  | ||||
| # gcc does not; but mkl kernels sometimes need it. | ||||
| IF (MKL_LIBRARIES) | ||||
|   IF (CMAKE_COMPILER_IS_GNUCC) | ||||
|     FIND_LIBRARY(MKL_KERNEL_libirc "irc") | ||||
|   ELSEIF (CMAKE_C_COMPILER_ID AND NOT CMAKE_C_COMPILER_ID STREQUAL "Intel") | ||||
|     FIND_LIBRARY(MKL_KERNEL_libirc "irc") | ||||
|   ENDIF (CMAKE_COMPILER_IS_GNUCC) | ||||
|   MARK_AS_ADVANCED(MKL_KERNEL_libirc) | ||||
|   IF (MKL_KERNEL_libirc) | ||||
|     SET(MKL_LIBRARIES ${MKL_LIBRARIES} ${MKL_KERNEL_libirc}) | ||||
|   ENDIF (MKL_KERNEL_libirc) | ||||
| ENDIF (MKL_LIBRARIES) | ||||
|  | ||||
| # Final | ||||
| SET(CMAKE_LIBRARY_PATH ${saved_CMAKE_LIBRARY_PATH}) | ||||
| SET(CMAKE_INCLUDE_PATH ${saved_CMAKE_INCLUDE_PATH}) | ||||
| IF (MKL_LIBRARIES) | ||||
|   SET(MKL_FOUND TRUE) | ||||
| ELSE (MKL_LIBRARIES) | ||||
|   SET(MKL_FOUND FALSE) | ||||
|   SET(MKL_VERSION) | ||||
| ENDIF (MKL_LIBRARIES) | ||||
|  | ||||
| # Results | ||||
| IF (_verbose) | ||||
|   MESSAGE(STATUS "*** MKL_FOUND = ${MKL_FOUND}") | ||||
|   MESSAGE(STATUS "*** MKL_INCLUDE_DIR = ${MKL_INCLUDE_DIR}") | ||||
|   MESSAGE(STATUS "*** MKL_LIBRARIES = ${MKL_LIBRARIES}") | ||||
|   MESSAGE(STATUS "*** MKL_LAPACK_LIBRARIES = ${MKL_LAPACK_LIBRARIES}") | ||||
|   MESSAGE(STATUS "*** MKL_SCALAPACK_LIBRARIES = ${MKL_SCALAPACK_LIBRARIES}") | ||||
|   MESSAGE(STATUS "*** MKL_SOLVER_LIBRARIES = ${MKL_SOLVER_LIBRARIES}") | ||||
|   MESSAGE(STATUS "*** MKL_CDFT_LIBRARIES = ${MKL_CDFT_LIBRARIES}") | ||||
| ENDIF(_verbose) | ||||
|  | ||||
| # Standard termination | ||||
| IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) | ||||
|   MESSAGE(FATAL_ERROR "MKL library not found. Please specify library  location") | ||||
| ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) | ||||
| IF(NOT MKL_FIND_QUIETLY) | ||||
|   IF(MKL_FOUND) | ||||
|     MESSAGE(STATUS "MKL library found") | ||||
|   ELSE(MKL_FOUND) | ||||
|     MESSAGE(STATUS "MKL library not found") | ||||
|   ENDIF(MKL_FOUND) | ||||
| ENDIF(NOT MKL_FIND_QUIETLY) | ||||
|  | ||||
| # Do nothing if MKL_FOUND was set before! | ||||
| ENDIF (NOT MKL_FOUND) | ||||
|  | ||||
|  | ||||
							
								
								
									
										104
									
								
								lib/TH/cmake/FindSSE.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								lib/TH/cmake/FindSSE.cmake
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,104 @@ | ||||
| # Check if SSE instructions are available on the machine where  | ||||
| # the project is compiled. | ||||
|  | ||||
| IF(CMAKE_SYSTEM_NAME MATCHES "Linux") | ||||
|    EXEC_PROGRAM(cat ARGS "/proc/cpuinfo" OUTPUT_VARIABLE CPUINFO) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*(sse2).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "sse2" "${SSE_THERE}" SSE2_TRUE) | ||||
|    IF (SSE2_TRUE) | ||||
|       set(SSE2_FOUND true CACHE BOOL "SSE2 available on host") | ||||
|    ELSE (SSE2_TRUE) | ||||
|       set(SSE2_FOUND false CACHE BOOL "SSE2 available on host") | ||||
|    ENDIF (SSE2_TRUE) | ||||
|  | ||||
|    # /proc/cpuinfo apparently omits sse3 :( | ||||
|    STRING(REGEX REPLACE "^.*[^s](sse3).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "sse3" "${SSE_THERE}" SSE3_TRUE) | ||||
|    IF (NOT SSE3_TRUE) | ||||
|       STRING(REGEX REPLACE "^.*(T2300).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|       STRING(COMPARE EQUAL "T2300" "${SSE_THERE}" SSE3_TRUE) | ||||
|    ENDIF (NOT SSE3_TRUE) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*(ssse3).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "ssse3" "${SSE_THERE}" SSSE3_TRUE) | ||||
|    IF (SSE3_TRUE OR SSSE3_TRUE) | ||||
|       set(SSE3_FOUND true CACHE BOOL "SSE3 available on host") | ||||
|    ELSE (SSE3_TRUE OR SSSE3_TRUE) | ||||
|       set(SSE3_FOUND false CACHE BOOL "SSE3 available on host") | ||||
|    ENDIF (SSE3_TRUE OR SSSE3_TRUE) | ||||
|    IF (SSSE3_TRUE) | ||||
|       set(SSSE3_FOUND true CACHE BOOL "SSSE3 available on host") | ||||
|    ELSE (SSSE3_TRUE) | ||||
|       set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host") | ||||
|    ENDIF (SSSE3_TRUE) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*(sse4_1).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "sse4_1" "${SSE_THERE}" SSE41_TRUE) | ||||
|    IF (SSE41_TRUE) | ||||
|       set(SSE4_1_FOUND true CACHE BOOL "SSE4.1 available on host") | ||||
|    ELSE (SSE41_TRUE) | ||||
|       set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") | ||||
|    ENDIF (SSE41_TRUE) | ||||
| ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") | ||||
|    EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE | ||||
|       CPUINFO) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*[^S](SSE2).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "SSE2" "${SSE_THERE}" SSE2_TRUE) | ||||
|    IF (SSE2_TRUE) | ||||
|       set(SSE2_FOUND true CACHE BOOL "SSE2 available on host") | ||||
|    ELSE (SSE2_TRUE) | ||||
|       set(SSE2_FOUND false CACHE BOOL "SSE2 available on host") | ||||
|    ENDIF (SSE2_TRUE) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*[^S](SSE3).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "SSE3" "${SSE_THERE}" SSE3_TRUE) | ||||
|    IF (SSE3_TRUE) | ||||
|       set(SSE3_FOUND true CACHE BOOL "SSE3 available on host") | ||||
|    ELSE (SSE3_TRUE) | ||||
|       set(SSE3_FOUND false CACHE BOOL "SSE3 available on host") | ||||
|    ENDIF (SSE3_TRUE) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*(SSSE3).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "SSSE3" "${SSE_THERE}" SSSE3_TRUE) | ||||
|    IF (SSSE3_TRUE) | ||||
|       set(SSSE3_FOUND true CACHE BOOL "SSSE3 available on host") | ||||
|    ELSE (SSSE3_TRUE) | ||||
|       set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host") | ||||
|    ENDIF (SSSE3_TRUE) | ||||
|  | ||||
|    STRING(REGEX REPLACE "^.*(SSE4.1).*$" "\\1" SSE_THERE ${CPUINFO}) | ||||
|    STRING(COMPARE EQUAL "SSE4.1" "${SSE_THERE}" SSE41_TRUE) | ||||
|    IF (SSE41_TRUE) | ||||
|       set(SSE4_1_FOUND true CACHE BOOL "SSE4.1 available on host") | ||||
|    ELSE (SSE41_TRUE) | ||||
|       set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") | ||||
|    ENDIF (SSE41_TRUE) | ||||
| ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows") | ||||
|    # TODO | ||||
|    set(SSE2_FOUND   true  CACHE BOOL "SSE2 available on host") | ||||
|    set(SSE3_FOUND   false CACHE BOOL "SSE3 available on host") | ||||
|    set(SSSE3_FOUND  false CACHE BOOL "SSSE3 available on host") | ||||
|    set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") | ||||
| ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux") | ||||
|    set(SSE2_FOUND   true  CACHE BOOL "SSE2 available on host") | ||||
|    set(SSE3_FOUND   false CACHE BOOL "SSE3 available on host") | ||||
|    set(SSSE3_FOUND  false CACHE BOOL "SSSE3 available on host") | ||||
|    set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") | ||||
| ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux") | ||||
|  | ||||
| if(NOT SSE2_FOUND) | ||||
|       MESSAGE(STATUS "Could not find hardware support for SSE2 on this machine.") | ||||
| endif(NOT SSE2_FOUND) | ||||
| if(NOT SSE3_FOUND) | ||||
|       MESSAGE(STATUS "Could not find hardware support for SSE3 on this machine.") | ||||
| endif(NOT SSE3_FOUND) | ||||
| if(NOT SSSE3_FOUND) | ||||
|       MESSAGE(STATUS "Could not find hardware support for SSSE3 on this machine.") | ||||
| endif(NOT SSSE3_FOUND) | ||||
| if(NOT SSE4_1_FOUND) | ||||
|       MESSAGE(STATUS "Could not find hardware support for SSE4.1 on this machine.") | ||||
| endif(NOT SSE4_1_FOUND) | ||||
|  | ||||
| mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND) | ||||
							
								
								
									
										382
									
								
								lib/TH/generic/THBlas.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								lib/TH/generic/THBlas.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,382 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THBlas.c" | ||||
| #else | ||||
|  | ||||
| void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) | ||||
| { | ||||
|   if(n == 1) | ||||
|   { | ||||
|     incx = 1; | ||||
|     incy = 1; | ||||
|   } | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) | ||||
|   { | ||||
|     int i_n = (int)n; | ||||
|     int i_incx = (int)incx; | ||||
|     int i_incy = (int)incy; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void dswap_(int *n, double *x, int *incx, double *y, int *incy); | ||||
|     dswap_(&i_n, x, &i_incx, y, &i_incy); | ||||
| #else | ||||
|     extern void sswap_(int *n, float *x, int *incx, float *y, int *incy); | ||||
|     sswap_(&i_n, x, &i_incx, y, &i_incy); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i; | ||||
|     for(i = 0; i < n; i++) | ||||
|     { | ||||
|       real z = x[i*incx]; | ||||
|       x[i*incx] = y[i*incy]; | ||||
|       y[i*incy] = z; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THBlas_(scal)(long n, real a, real *x, long incx) | ||||
| { | ||||
|   if(n == 1) | ||||
|     incx = 1; | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (n <= INT_MAX) && (incx <= INT_MAX) ) | ||||
|   { | ||||
|     int i_n = (int)n; | ||||
|     int i_incx = (int)incx; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void dscal_(int *n, double *a, double *x, int *incx); | ||||
|     dscal_(&i_n, &a, x, &i_incx); | ||||
| #else | ||||
|     extern void sscal_(int *n, float *a, float *x, int *incx); | ||||
|     sscal_(&i_n, &a, x, &i_incx); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i; | ||||
|     for(i = 0; i < n; i++) | ||||
|       x[i*incx] *= a; | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THBlas_(copy)(long n, real *x, long incx, real *y, long incy) | ||||
| { | ||||
|   if(n == 1) | ||||
|   { | ||||
|     incx = 1; | ||||
|     incy = 1; | ||||
|   } | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) | ||||
|   { | ||||
|     int i_n = (int)n; | ||||
|     int i_incx = (int)incx; | ||||
|     int i_incy = (int)incy; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void dcopy_(int *n, double *x, int *incx, double *y, int *incy); | ||||
|     dcopy_(&i_n, x, &i_incx, y, &i_incy); | ||||
| #else | ||||
|     extern void scopy_(int *n, float *x, int *incx, float *y, int *incy); | ||||
|     scopy_(&i_n, x, &i_incx, y, &i_incy); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i; | ||||
|     for(i = 0; i < n; i++) | ||||
|       y[i*incy] = x[i*incx]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy) | ||||
| { | ||||
|   if(n == 1) | ||||
|   { | ||||
|     incx = 1; | ||||
|     incy = 1; | ||||
|   } | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) | ||||
|   { | ||||
|     int i_n = (int)n; | ||||
|     int i_incx = (int)incx; | ||||
|     int i_incy = (int)incy; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy); | ||||
|     daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); | ||||
| #else | ||||
|     extern void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy); | ||||
|     saxpy_(&i_n, &a, x, &i_incx, y, &i_incy); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i; | ||||
|     for(i = 0; i < n; i++) | ||||
|       y[i*incy] += a*x[i*incx]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| real THBlas_(dot)(long n, real *x, long incx, real *y, long incy) | ||||
| { | ||||
|   if(n == 1) | ||||
|   { | ||||
|     incx = 1; | ||||
|     incy = 1; | ||||
|   } | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) | ||||
|   { | ||||
|     int i_n = (int)n; | ||||
|     int i_incx = (int)incx; | ||||
|     int i_incy = (int)incy; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern double ddot_(int *n, double *x, int *incx, double *y, int *incy); | ||||
|     return ddot_(&i_n, x, &i_incx, y, &i_incy); | ||||
| #else | ||||
|     extern float sdot_(int *n, float *x, int *incx, float *y, int *incy); | ||||
|     return sdot_(&i_n, x, &i_incx, y, &i_incy); | ||||
| #endif | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i; | ||||
|     real sum = 0; | ||||
|     for(i = 0; i < n; i++) | ||||
|     sum += x[i*incx]*y[i*incy]; | ||||
|     return sum; | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy) | ||||
| { | ||||
|   if(n == 1) | ||||
|     lda = m; | ||||
|    | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (m <= INT_MAX) && (n <= INT_MAX) &&  | ||||
|       (lda > 0) && (lda <= INT_MAX) && | ||||
|       (incx > 0) && (incx <= INT_MAX) && | ||||
|       (incy > 0) && (incy <= INT_MAX) ) | ||||
|   { | ||||
|     int i_m = (int)m; | ||||
|     int i_n = (int)n; | ||||
|     int i_lda = (int)lda; | ||||
|     int i_incx = (int)incx; | ||||
|     int i_incy = (int)incy; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy); | ||||
|     dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); | ||||
| #else | ||||
|     extern void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy); | ||||
|     sgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i, j; | ||||
|  | ||||
|     if( (trans == 'T') || (trans == 't') ) | ||||
|     { | ||||
|       for(i = 0; i < n; i++) | ||||
|       { | ||||
|         real sum = 0; | ||||
|         real *row_ = a+lda*i; | ||||
|         for(j = 0; j < m; j++) | ||||
|           sum += x[j*incx]*row_[j]; | ||||
|         y[i*incy] = beta*y[i*incy] + alpha*sum; | ||||
|       } | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       if(beta != 1) | ||||
|         THBlas_(scal)(m, beta, y, incy); | ||||
|        | ||||
|       for(j = 0; j < n; j++) | ||||
|       { | ||||
|         real *column_ = a+lda*j; | ||||
|         real z = alpha*x[j*incx]; | ||||
|         for(i = 0; i < m; i++) | ||||
|           y[i*incy] += z*column_[i]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda) | ||||
| { | ||||
|   if(n == 1) | ||||
|     lda = m; | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX)  && (incx <= INT_MAX) && (incy <= INT_MAX) ) | ||||
|   { | ||||
|     int i_m = (int)m; | ||||
|     int i_n = (int)n; | ||||
|     int i_lda = (int)lda; | ||||
|     int i_incx = (int)incx; | ||||
|     int i_incy = (int)incy; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void dger_(int *m, int *n, double *alpha, double *x, int *incx, real *y, int *incy, double *a, int *lda); | ||||
|     dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda); | ||||
| #else | ||||
|     extern void sger_(int *m, int *n, float *alpha, float *x, int *incx, real *y, int *incy, float *a, int *lda); | ||||
|     sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i, j; | ||||
|     for(j = 0; j < n; j++) | ||||
|     { | ||||
|       real *column_ = a+j*lda; | ||||
|       real z = alpha*y[j*incy]; | ||||
|       for(i = 0; i < m; i++) | ||||
|         column_[i] += z*x[i*incx] ; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc) | ||||
| { | ||||
|   int transa_ = ((transa == 't') || (transa == 'T')); | ||||
|   int transb_ = ((transb == 't') || (transb == 'T')); | ||||
|  | ||||
|   if(n == 1) | ||||
|     ldc = m; | ||||
|  | ||||
|   if(transa_) | ||||
|   { | ||||
|     if(m == 1) | ||||
|       lda = k; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     if(k == 1) | ||||
|       lda = m; | ||||
|   } | ||||
|  | ||||
|   if(transb_) | ||||
|   { | ||||
|     if(k == 1) | ||||
|       ldb = n; | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     if(n == 1) | ||||
|       ldb = k; | ||||
|   } | ||||
|  | ||||
| #if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) | ||||
|   if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX)  && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) | ||||
|   { | ||||
|     int i_m = (int)m; | ||||
|     int i_n = (int)n; | ||||
|     int i_k = (int)k; | ||||
|     int i_lda = (int)lda; | ||||
|     int i_ldb = (int)ldb; | ||||
|     int i_ldc = (int)ldc; | ||||
|  | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|     extern void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc); | ||||
|     dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc); | ||||
| #else | ||||
|     extern void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc); | ||||
|     sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc); | ||||
| #endif | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|   { | ||||
|     long i, j, l; | ||||
|     if(!transa_ && !transb_) | ||||
|     { | ||||
|       real *a_ = a; | ||||
|       for(i = 0; i < m; i++) | ||||
|       { | ||||
|         real *b_ = b; | ||||
|         for(j = 0; j < n; j++) | ||||
|         { | ||||
|           real sum = 0; | ||||
|           for(l = 0; l < k; l++) | ||||
|             sum += a_[l*lda]*b_[l]; | ||||
|           b_ += ldb; | ||||
|           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; | ||||
|         } | ||||
|         a_++; | ||||
|       } | ||||
|     } | ||||
|     else if(transa_ && !transb_) | ||||
|     { | ||||
|       real *a_ = a; | ||||
|       for(i = 0; i < m; i++) | ||||
|       { | ||||
|         real *b_ = b; | ||||
|         for(j = 0; j < n; j++) | ||||
|         { | ||||
|           real sum = 0; | ||||
|           for(l = 0; l < k; l++) | ||||
|             sum += a_[l]*b_[l]; | ||||
|           b_ += ldb; | ||||
|           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; | ||||
|         } | ||||
|         a_ += lda; | ||||
|       } | ||||
|     } | ||||
|     else if(!transa_ && transb_) | ||||
|     { | ||||
|       real *a_ = a; | ||||
|       for(i = 0; i < m; i++) | ||||
|       { | ||||
|         real *b_ = b; | ||||
|         for(j = 0; j < n; j++) | ||||
|         { | ||||
|           real sum = 0; | ||||
|           for(l = 0; l < k; l++) | ||||
|             sum += a_[l*lda]*b_[l*ldb]; | ||||
|           b_++; | ||||
|           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; | ||||
|         } | ||||
|         a_++; | ||||
|       } | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       real *a_ = a; | ||||
|       for(i = 0; i < m; i++) | ||||
|       { | ||||
|         real *b_ = b; | ||||
|         for(j = 0; j < n; j++) | ||||
|         { | ||||
|           real sum = 0; | ||||
|           for(l = 0; l < k; l++) | ||||
|             sum += a_[l]*b_[l*ldb]; | ||||
|           b_++; | ||||
|           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; | ||||
|         } | ||||
|         a_ += lda; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										19
									
								
								lib/TH/generic/THBlas.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								lib/TH/generic/THBlas.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THBlas.h" | ||||
| #else | ||||
|  | ||||
| /* Level 1 */ | ||||
| void THBlas_(swap)(long n, real *x, long incx, real *y, long incy); | ||||
| void THBlas_(scal)(long n, real a, real *x, long incx); | ||||
| void THBlas_(copy)(long n, real *x, long incx, real *y, long incy); | ||||
| void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy); | ||||
| real THBlas_(dot)(long n, real *x, long incx, real *y, long incy); | ||||
|  | ||||
| /* Level 2 */ | ||||
| void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy); | ||||
| void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda); | ||||
|  | ||||
| /* Level 3 */ | ||||
| void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										66
									
								
								lib/TH/generic/THLapack.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								lib/TH/generic/THLapack.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THLapack.c" | ||||
| #else | ||||
|  | ||||
| void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info) | ||||
| { | ||||
| #ifdef __LAPACK__ | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|   extern void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); | ||||
|   dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); | ||||
| #else | ||||
|   extern void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); | ||||
|   sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); | ||||
| #endif | ||||
| #else | ||||
|   THError("gesv : Lapack library not found in compile time\n"); | ||||
| #endif | ||||
|   return; | ||||
| } | ||||
|  | ||||
| void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info) | ||||
| { | ||||
| #ifdef __LAPACK__ | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|   extern void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); | ||||
|   dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); | ||||
| #else | ||||
|   extern void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); | ||||
|   sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); | ||||
| #endif | ||||
| #else | ||||
|   THError("gels : Lapack library not found in compile time\n"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info) | ||||
| { | ||||
| #ifdef __LAPACK__ | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|   extern void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); | ||||
|   dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); | ||||
| #else | ||||
|   extern void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); | ||||
|   ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); | ||||
| #endif | ||||
| #else | ||||
|   THError("syev : Lapack library not found in compile time\n"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info) | ||||
| { | ||||
| #ifdef __LAPACK__ | ||||
| #if defined(TH_REAL_IS_DOUBLE) | ||||
|   extern void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info); | ||||
|   dgesvd_( &jobu,  &jobvt,  &m,  &n,  a,  &lda,  s,  u,  &ldu,  vt,  &ldvt,  work,  &lwork,  info); | ||||
| #else | ||||
|   extern void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info); | ||||
|   sgesvd_( &jobu,  &jobvt,  &m,  &n,  a,  &lda,  s,  u,  &ldu,  vt,  &ldvt,  work,  &lwork,  info); | ||||
| #endif | ||||
| #else | ||||
|   THError("gesvd : Lapack library not found in compile time\n"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										15
									
								
								lib/TH/generic/THLapack.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								lib/TH/generic/THLapack.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THLapack.h" | ||||
| #else | ||||
|  | ||||
|  | ||||
|  | ||||
| /* AX=B */ | ||||
| void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info); | ||||
| /* ||AX-B|| */ | ||||
| void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info); | ||||
| /* Eigenvals */ | ||||
| void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info); | ||||
| /* svd */ | ||||
| void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info); | ||||
| #endif | ||||
							
								
								
									
										259
									
								
								lib/TH/generic/THStorage.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										259
									
								
								lib/TH/generic/THStorage.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,259 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THStorage.c" | ||||
| #else | ||||
|  | ||||
| THStorage* THStorage_(new)(void) | ||||
| { | ||||
|   return THStorage_(newWithSize)(0); | ||||
| } | ||||
|  | ||||
| THStorage* THStorage_(newWithSize)(long size) | ||||
| { | ||||
|   THStorage *storage = THAlloc(sizeof(THStorage)); | ||||
|   storage->data = THAlloc(sizeof(real)*size); | ||||
|   storage->size = size; | ||||
|   storage->refcount = 1; | ||||
|   storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; | ||||
|   return storage; | ||||
| } | ||||
|  | ||||
| THStorage* THStorage_(newWithSize1)(real data0) | ||||
| { | ||||
|   THStorage *self = THStorage_(newWithSize)(1); | ||||
|   self->data[0] = data0; | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THStorage* THStorage_(newWithSize2)(real data0, real data1) | ||||
| { | ||||
|   THStorage *self = THStorage_(newWithSize)(2); | ||||
|   self->data[0] = data0; | ||||
|   self->data[1] = data1; | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THStorage* THStorage_(newWithSize3)(real data0, real data1, real data2) | ||||
| { | ||||
|   THStorage *self = THStorage_(newWithSize)(3); | ||||
|   self->data[0] = data0; | ||||
|   self->data[1] = data1; | ||||
|   self->data[2] = data2; | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THStorage* THStorage_(newWithSize4)(real data0, real data1, real data2, real data3) | ||||
| { | ||||
|   THStorage *self = THStorage_(newWithSize)(4); | ||||
|   self->data[0] = data0; | ||||
|   self->data[1] = data1; | ||||
|   self->data[2] = data2; | ||||
|   self->data[3] = data3; | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| #if defined(_WIN32) || defined(HAVE_MMAP) | ||||
|  | ||||
| THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared) | ||||
| { | ||||
|   THStorage *storage = THAlloc(sizeof(THStorage)); | ||||
|   long size; | ||||
|  | ||||
|   /* check size */ | ||||
|   FILE *f = fopen(fileName, "rb"); | ||||
|   if(f == NULL) | ||||
|     THError("unable to open file <%s> for mapping (read-only mode)", fileName); | ||||
|   fseek(f, 0, SEEK_END); | ||||
|   size = ftell(f); | ||||
|   fclose(f); | ||||
|   size /= sizeof(real); | ||||
|  | ||||
| #ifdef _WIN32 | ||||
|   { | ||||
|     HANDLE hfile; | ||||
|     HANDLE hmfile; | ||||
|     DWORD size_hi, size_lo; | ||||
|  | ||||
|     /* open file */ | ||||
|     if(isShared) | ||||
|     { | ||||
|       hfile = CreateFileA(fileName, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); | ||||
|       if (hfile == INVALID_HANDLE_VALUE) | ||||
|         THError("could not open file <%s> in read-write mode", fileName); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       hfile = CreateFileA(fileName, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); | ||||
|       if (hfile == INVALID_HANDLE_VALUE) | ||||
|         THError("could not open file <%s> in read-only mode", fileName); | ||||
|     } | ||||
|  | ||||
| #if SIZEOF_SIZE_T > 4 | ||||
|     size_hi = (DWORD)((size*sizeof(real)) >> 32); | ||||
|     size_lo = (DWORD)((size*sizeof(real)) & 0xFFFFFFFF); | ||||
| #else | ||||
|     size_hi = 0; | ||||
|     size_lo = (DWORD)(size*sizeof(real)); | ||||
| #endif | ||||
|  | ||||
|     /* get map handle */ | ||||
|     if(isShared) | ||||
|     { | ||||
|       if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, size_hi, size_lo, NULL)) == NULL ) | ||||
|         THError("could not create a map on file <%s>", fileName); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, size_hi, size_lo, NULL)) == NULL ) | ||||
|         THError("could not create a map on file <%s>", fileName); | ||||
|     } | ||||
|  | ||||
|     /* map the stuff */ | ||||
|     storage = THStorage_(new)(); | ||||
|     if(isShared) | ||||
|       storage->data = MapViewOfFile(hmfile, FILE_MAP_ALL_ACCESS, 0, 0, 0); | ||||
|     else | ||||
|       storage->data = MapViewOfFile(hmfile, FILE_MAP_COPY, 0, 0, 0); | ||||
|        | ||||
|     storage->size = size; | ||||
|     if(storage->data == NULL) | ||||
|     { | ||||
|       THStorage_(free)(storage); | ||||
|       THError("memory map failed on file <%s>", fileName); | ||||
|     } | ||||
|     CloseHandle(hfile);  | ||||
|     CloseHandle(hmfile);  | ||||
|   } | ||||
| #else | ||||
|   { | ||||
|     /* open file */ | ||||
|     int fd; | ||||
|     if(isShared) | ||||
|     { | ||||
|       fd = open(fileName, O_RDWR); | ||||
|       if(fd == -1) | ||||
|         THError("unable to open file <%s> in read-write mode", fileName); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       fd = open(fileName, O_RDONLY); | ||||
|       if(fd == -1) | ||||
|         THError("unable to open file <%s> in read-only mode", fileName); | ||||
|     } | ||||
|      | ||||
|     /* map it */ | ||||
|     storage = THStorage_(new)(); | ||||
|     if(isShared) | ||||
|       storage->data = mmap(NULL, size*sizeof(real), PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); | ||||
|     else | ||||
|       storage->data = mmap(NULL, size*sizeof(real), PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); | ||||
|  | ||||
|     storage->size = size; | ||||
|     if(storage->data == MAP_FAILED) | ||||
|     { | ||||
|       storage->data = NULL; /* let's be sure it is NULL before calling free() */ | ||||
|       THStorage_(free)(storage); | ||||
|       THError("memory map failed on file <%s>", fileName); | ||||
|     } | ||||
|     close (fd); | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   storage->refcount = 1; | ||||
|   storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_MAPPED | TH_STORAGE_FREEMEM;; | ||||
|   return storage; | ||||
| } | ||||
|  | ||||
| #else | ||||
|  | ||||
| THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared) | ||||
| { | ||||
|   THError("Mapped file Storages are not supported on your system"); | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| void THStorage_(setFlag)(THStorage *storage, const char flag) | ||||
| { | ||||
|   storage->flag |= flag; | ||||
| } | ||||
|  | ||||
| void THStorage_(clearFlag)(THStorage *storage, const char flag) | ||||
| { | ||||
|   storage->flag &= ~flag; | ||||
| } | ||||
|  | ||||
| void THStorage_(retain)(THStorage *storage) | ||||
| { | ||||
|   if(storage && (storage->flag & TH_STORAGE_REFCOUNTED)) | ||||
|     ++storage->refcount; | ||||
| } | ||||
|  | ||||
| void THStorage_(free)(THStorage *storage) | ||||
| { | ||||
|   if(!storage) | ||||
|     return; | ||||
|  | ||||
|   if((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount > 0)) | ||||
|   { | ||||
|     if(--storage->refcount == 0) | ||||
|     { | ||||
|       if(storage->flag & TH_STORAGE_FREEMEM) | ||||
|       { | ||||
| #if defined(_WIN32) || defined(HAVE_MMAP) | ||||
|         if(storage->flag & TH_STORAGE_MAPPED) | ||||
|         { | ||||
| #ifdef _WIN32 | ||||
|           if(!UnmapViewOfFile((LPINT)storage->data)) | ||||
| #else | ||||
|             if (munmap(storage->data, storage->size*sizeof(real))) | ||||
| #endif | ||||
|               THError("could not unmap the shared memory file"); | ||||
|         } | ||||
|         else | ||||
| #endif | ||||
|           THFree(storage->data); | ||||
|       } | ||||
|       THFree(storage); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| THStorage* THStorage_(newWithData)(real *data, long size) | ||||
| { | ||||
|   THStorage *storage = THAlloc(sizeof(THStorage)); | ||||
|   storage->data = data; | ||||
|   storage->size = size; | ||||
|   storage->refcount = 1; | ||||
|   storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; | ||||
|   return storage; | ||||
| } | ||||
|  | ||||
| void THStorage_(resize)(THStorage *storage, long size) | ||||
| { | ||||
|   if(storage->flag & TH_STORAGE_RESIZABLE) | ||||
|   { | ||||
|     storage->data = THRealloc(storage->data, sizeof(real)*size); | ||||
|     storage->size = size; | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THStorage_(fill)(THStorage *storage, real value) | ||||
| { | ||||
|   long i; | ||||
|   for(i = 0; i < storage->size; i++) | ||||
|     storage->data[i] = value; | ||||
| } | ||||
|  | ||||
| void THStorage_(set)(THStorage *self, long idx, real value) | ||||
| { | ||||
|   THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); | ||||
|   self->data[idx] = value; | ||||
| } | ||||
|  | ||||
| real THStorage_(get)(THStorage *self, long idx) | ||||
| { | ||||
|   THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); | ||||
|   return self->data[idx]; | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										59
									
								
								lib/TH/generic/THStorage.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								lib/TH/generic/THStorage.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THStorage.h" | ||||
| #else | ||||
|  | ||||
| /* on pourrait avoir un liste chainee | ||||
|    qui initialise math, lab structures (or more). | ||||
|    mouais -- complique. | ||||
|  | ||||
|    Pb: THMapStorage is kind of a class | ||||
|    THLab_()... comment je m'en sors? | ||||
|     | ||||
|    en template, faudrait que je les instancie toutes!!! oh boy! | ||||
|    Et comment je sais que c'est pour Cuda? Le type float est le meme dans les <> | ||||
|     | ||||
|    au bout du compte, ca serait sur des pointeurs float/double... etc... = facile. | ||||
|    primitives?? | ||||
|  */ | ||||
|  | ||||
| #define TH_STORAGE_REFCOUNTED 1 | ||||
| #define TH_STORAGE_RESIZABLE  2 | ||||
| #define TH_STORAGE_MAPPED     4 | ||||
| #define TH_STORAGE_FREEMEM    8 | ||||
|  | ||||
| typedef struct THStorage | ||||
| { | ||||
|     real *data; | ||||
|     long size; | ||||
|     int refcount; | ||||
|     char flag; | ||||
|  | ||||
| } THStorage; | ||||
|  | ||||
| TH_API real* THStorage_(data)(THStorage*); | ||||
| TH_API long THStorage_(size)(THStorage*); | ||||
|  | ||||
| /* slow access -- checks everything */ | ||||
| TH_API void THStorage_(set)(THStorage*, long, real); | ||||
| TH_API real THStorage_(get)(THStorage*, long); | ||||
|  | ||||
| TH_API THStorage* THStorage_(new)(void); | ||||
| TH_API THStorage* THStorage_(newWithSize)(long size); | ||||
| TH_API THStorage* THStorage_(newWithSize1)(real); | ||||
| TH_API THStorage* THStorage_(newWithSize2)(real, real); | ||||
| TH_API THStorage* THStorage_(newWithSize3)(real, real, real); | ||||
| TH_API THStorage* THStorage_(newWithSize4)(real, real, real, real); | ||||
| TH_API THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared); | ||||
| TH_API THStorage* THStorage_(newWithData)(real *data, long size); | ||||
|  | ||||
| /* should not differ with API */ | ||||
| TH_API void THStorage_(setFlag)(THStorage *storage, const char flag); | ||||
| TH_API void THStorage_(clearFlag)(THStorage *storage, const char flag); | ||||
| TH_API void THStorage_(retain)(THStorage *storage); | ||||
|  | ||||
| /* might differ with other API (like CUDA) */ | ||||
| TH_API void THStorage_(free)(THStorage *storage); | ||||
| TH_API void THStorage_(resize)(THStorage *storage, long size); | ||||
| TH_API void THStorage_(fill)(THStorage *storage, real value); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										36
									
								
								lib/TH/generic/THStorageCopy.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								lib/TH/generic/THStorageCopy.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THStorageCopy.c" | ||||
| #else | ||||
|  | ||||
| void THStorage_(rawCopy)(THStorage *storage, real *src) | ||||
| { | ||||
|   long i; | ||||
|   for(i = 0; i < storage->size; i++) | ||||
|     storage->data[i] = src[i]; | ||||
| } | ||||
|  | ||||
| void THStorage_(copy)(THStorage *storage, THStorage *src) | ||||
| { | ||||
|   THArgCheck(storage->size == src->size, 2, "size mismatch"); | ||||
|   THStorage_(rawCopy)(storage, src->data); | ||||
| } | ||||
|  | ||||
|  | ||||
| #define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \ | ||||
| void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ | ||||
| { \ | ||||
|   long i; \ | ||||
|   THArgCheck(storage->size == src->size, 2, "size mismatch"); \ | ||||
|   for(i = 0; i < storage->size; i++) \ | ||||
|     storage->data[i] = (real)src->data[i]; \ | ||||
| } | ||||
|  | ||||
| IMPLEMENT_THStorage_COPY(Byte) | ||||
| IMPLEMENT_THStorage_COPY(Char) | ||||
| IMPLEMENT_THStorage_COPY(Short) | ||||
| IMPLEMENT_THStorage_COPY(Int) | ||||
| IMPLEMENT_THStorage_COPY(Long) | ||||
| IMPLEMENT_THStorage_COPY(Float) | ||||
| IMPLEMENT_THStorage_COPY(Double) | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										17
									
								
								lib/TH/generic/THStorageCopy.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								lib/TH/generic/THStorageCopy.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THStorageCopy.h" | ||||
| #else | ||||
|  | ||||
| /* Support for copy between different Storage types */ | ||||
|  | ||||
| TH_API void THStorage_(rawCopy)(THStorage *storage, real *src); | ||||
| TH_API void THStorage_(copy)(THStorage *storage, THStorage *src); | ||||
| TH_API void THStorage_(copyByte)(THStorage *storage, struct THByteStorage *src); | ||||
| TH_API void THStorage_(copyChar)(THStorage *storage, struct THCharStorage *src); | ||||
| TH_API void THStorage_(copyShort)(THStorage *storage, struct THShortStorage *src); | ||||
| TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src); | ||||
| TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src); | ||||
| TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src); | ||||
| TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										728
									
								
								lib/TH/generic/THTensor.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										728
									
								
								lib/TH/generic/THTensor.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,728 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensor.c" | ||||
| #else | ||||
|  | ||||
| /**** access methods ****/ | ||||
| THStorage *THTensor_(storage)(THTensor *self) | ||||
| { | ||||
|   return self->storage; | ||||
| } | ||||
|  | ||||
| long THTensor_(storageOffset)(THTensor *self) | ||||
| { | ||||
|   return self->storageOffset; | ||||
| } | ||||
|  | ||||
| int THTensor_(nDimension)(THTensor *self) | ||||
| { | ||||
|   return self->nDimension; | ||||
| } | ||||
|  | ||||
| long THTensor_(size)(THTensor *self, int dim) | ||||
| { | ||||
|   THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range"); | ||||
|   return self->size[dim]; | ||||
| } | ||||
|  | ||||
| long THTensor_(stride)(THTensor *self, int dim) | ||||
| { | ||||
|   THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range"); | ||||
|   return self->stride[dim]; | ||||
| } | ||||
|  | ||||
| THLongStorage *THTensor_(newSizeOf)(THTensor *self) | ||||
| { | ||||
|   THLongStorage *size = THLongStorage_newWithSize(self->nDimension); | ||||
|   THLongStorage_rawCopy(size, self->size); | ||||
|   return size; | ||||
| } | ||||
|  | ||||
| THLongStorage *THTensor_(newStrideOf)(THTensor *self) | ||||
| { | ||||
|   THLongStorage *stride = THLongStorage_newWithSize(self->nDimension); | ||||
|   THLongStorage_rawCopy(stride, self->stride); | ||||
|   return stride; | ||||
| } | ||||
|  | ||||
| real *THTensor_(data)(THTensor *self) | ||||
| { | ||||
|   if(self->storage) | ||||
|     return (self->storage->data+self->storageOffset); | ||||
|   else | ||||
|     return NULL; | ||||
| } | ||||
|  | ||||
| void THTensor_(setFlag)(THTensor *self, const char flag) | ||||
| { | ||||
|   self->flag |= flag; | ||||
| } | ||||
|  | ||||
| void THTensor_(clearFlag)(THTensor *self, const char flag) | ||||
| { | ||||
|   self->flag &= ~flag; | ||||
| } | ||||
|  | ||||
| /**** creation methods ****/ | ||||
|  | ||||
| static void THTensor_(rawInit)(THTensor *self); | ||||
| static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride); | ||||
| static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride); | ||||
|  | ||||
|  | ||||
| /* Empty init */ | ||||
| THTensor *THTensor_(new)(void) | ||||
| { | ||||
|   THTensor *self = THAlloc(sizeof(THTensor)); | ||||
|   THTensor_(rawInit)(self); | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| /* Pointer-copy init */ | ||||
| THTensor *THTensor_(newWithTensor)(THTensor *tensor) | ||||
| { | ||||
|   THTensor *self = THAlloc(sizeof(THTensor)); | ||||
|   THTensor_(rawInit)(self); | ||||
|   THTensor_(rawSet)(self, | ||||
|                     tensor->storage, | ||||
|                     tensor->storageOffset, | ||||
|                     tensor->nDimension, | ||||
|                     tensor->size, | ||||
|                     tensor->stride); | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| /* Storage init */ | ||||
| THTensor *THTensor_(newWithStorage)(THStorage *storage, long storageOffset, THLongStorage *size, THLongStorage *stride) | ||||
| {   | ||||
|   THTensor *self = THAlloc(sizeof(THTensor)); | ||||
|   if(size && stride) | ||||
|     THArgCheck(size->size == stride->size, 4, "inconsistent size"); | ||||
|  | ||||
|   THTensor_(rawInit)(self);   | ||||
|   THTensor_(rawSet)(self, | ||||
|                     storage, | ||||
|                     storageOffset, | ||||
|                     (size ? size->size : (stride ? stride->size : 0)), | ||||
|                     (size ? size->data : NULL), | ||||
|                     (stride ? stride->data : NULL)); | ||||
|  | ||||
|   return self; | ||||
| } | ||||
| THTensor *THTensor_(newWithStorage1d)(THStorage *storage, long storageOffset, | ||||
|                                long size0, long stride0) | ||||
| { | ||||
|   return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, -1, -1,  -1, -1,  -1, -1); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithStorage2d)(THStorage *storage, long storageOffset, | ||||
|                                long size0, long stride0, | ||||
|                                long size1, long stride1) | ||||
| { | ||||
|   return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1,  -1, -1,  -1, -1); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithStorage3d)(THStorage *storage, long storageOffset, | ||||
|                                long size0, long stride0, | ||||
|                                long size1, long stride1, | ||||
|                                long size2, long stride2) | ||||
| { | ||||
|   return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1,  size2, stride2,  -1, -1); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithStorage4d)(THStorage *storage, long storageOffset, | ||||
|                                long size0, long stride0, | ||||
|                                long size1, long stride1, | ||||
|                                long size2, long stride2, | ||||
|                                long size3, long stride3) | ||||
| { | ||||
|   long size[4] = {size0, size1, size2, size3}; | ||||
|   long stride[4] = {stride0, stride1, stride2, stride3}; | ||||
|  | ||||
|   THTensor *self = THAlloc(sizeof(THTensor)); | ||||
|   THTensor_(rawInit)(self);   | ||||
|   THTensor_(rawSet)(self, storage, storageOffset, 4, size, stride); | ||||
|  | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithSize)(THLongStorage *size, THLongStorage *stride) | ||||
| { | ||||
|   return THTensor_(newWithStorage)(NULL, 0, size, stride); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithSize1d)(long size0) | ||||
| { | ||||
|   return THTensor_(newWithSize4d)(size0, -1, -1, -1); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithSize2d)(long size0, long size1) | ||||
| { | ||||
|   return THTensor_(newWithSize4d)(size0, size1, -1, -1); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithSize3d)(long size0, long size1, long size2) | ||||
| { | ||||
|   return THTensor_(newWithSize4d)(size0, size1, size2, -1); | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newWithSize4d)(long size0, long size1, long size2, long size3) | ||||
| { | ||||
|   long size[4] = {size0, size1, size2, size3}; | ||||
|  | ||||
|   THTensor *self = THAlloc(sizeof(THTensor)); | ||||
|   THTensor_(rawInit)(self);   | ||||
|   THTensor_(rawResize)(self, 4, size, NULL); | ||||
|  | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newClone)(THTensor *self) | ||||
| { | ||||
|   THTensor *tensor = THTensor_(new)(); | ||||
|   THTensor_(resizeAs)(tensor, self); | ||||
|   THTensor_(copy)(tensor, self); | ||||
|   return tensor; | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newContiguous)(THTensor *self) | ||||
| { | ||||
|   if(!THTensor_(isContiguous)(self)) | ||||
|     return THTensor_(newClone)(self); | ||||
|   else | ||||
|   { | ||||
|     THTensor_(retain)(self); | ||||
|     return self; | ||||
|   } | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_) | ||||
| { | ||||
|   THTensor *self = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(select)(self, NULL, dimension_, sliceIndex_); | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_) | ||||
| { | ||||
|   THTensor *self = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(narrow)(self, NULL, dimension_, firstIndex_, size_); | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_) | ||||
| { | ||||
|   THTensor *self = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(transpose)(self, NULL, dimension1_, dimension2_); | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_) | ||||
| { | ||||
|   THTensor *self = THTensor_(newWithTensor)(tensor); | ||||
|   THTensor_(unfold)(self, NULL, dimension_, size_, step_); | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| /* Resize */ | ||||
| void THTensor_(resize)(THTensor *self, THLongStorage *size, THLongStorage *stride) | ||||
| { | ||||
|   THArgCheck(size != NULL, 2, "invalid size"); | ||||
|   if(stride) | ||||
|     THArgCheck(stride->size == size->size, 3, "invalid stride"); | ||||
|  | ||||
|   THTensor_(rawResize)(self, size->size, size->data, (stride ? stride->data : NULL)); | ||||
| } | ||||
|  | ||||
| void THTensor_(resizeAs)(THTensor *self, THTensor *src) | ||||
| { | ||||
|   int isSame = 0; | ||||
|   int d; | ||||
|   if(self->nDimension == src->nDimension) | ||||
|   { | ||||
|     isSame = 1; | ||||
|     for(d = 0; d < self->nDimension; d++) | ||||
|     { | ||||
|       if(self->size[d] != src->size[d]) | ||||
|       { | ||||
|         isSame = 0; | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if(!isSame) | ||||
|     THTensor_(rawResize)(self, src->nDimension, src->size, NULL); | ||||
| } | ||||
|  | ||||
| void THTensor_(resize1d)(THTensor *tensor, long size0) | ||||
| { | ||||
|   THTensor_(resize4d)(tensor, size0, -1, -1, -1); | ||||
| } | ||||
|  | ||||
| void THTensor_(resize2d)(THTensor *tensor, long size0, long size1) | ||||
| { | ||||
|   THTensor_(resize4d)(tensor, size0, size1, -1, -1); | ||||
| } | ||||
|  | ||||
| void THTensor_(resize3d)(THTensor *tensor, long size0, long size1, long size2) | ||||
| { | ||||
|   THTensor_(resize4d)(tensor, size0, size1, size2, -1); | ||||
| } | ||||
|  | ||||
| void THTensor_(resize4d)(THTensor *self, long size0, long size1, long size2, long size3) | ||||
| { | ||||
|   long size[4] = {size0, size1, size2, size3}; | ||||
|  | ||||
|   THTensor_(rawResize)(self, 4, size, NULL); | ||||
| } | ||||
|  | ||||
| void THTensor_(resize5d)(THTensor *self, long size0, long size1, long size2, long size3, long size4) | ||||
| { | ||||
|     long size[5] = {size0, size1, size2, size3, size4}; | ||||
|  | ||||
|   THTensor_(rawResize)(self, 5, size, NULL); | ||||
| } | ||||
|  | ||||
| void THTensor_(set)(THTensor *self, THTensor *src) | ||||
| { | ||||
|   if(self != src) | ||||
|     THTensor_(rawSet)(self, | ||||
|                       src->storage, | ||||
|                       src->storageOffset, | ||||
|                       src->nDimension, | ||||
|                       src->size, | ||||
|                       src->stride); | ||||
| } | ||||
|  | ||||
| void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_) | ||||
| { | ||||
|   if(size_ && stride_) | ||||
|     THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes"); | ||||
|    | ||||
|   THTensor_(rawSet)(self,  | ||||
|                     storage_, | ||||
|                     storageOffset_, | ||||
|                     (size_ ? size_->size : (stride_ ? stride_->size : 0)), | ||||
|                     (size_ ? size_->data : NULL), | ||||
|                     (stride_ ? stride_->data : NULL)); | ||||
| } | ||||
|  | ||||
| void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                              long size0_, long stride0_) | ||||
| { | ||||
|   THTensor_(setStorage4d)(self, storage_, storageOffset_, | ||||
|                           size0_, stride0_, | ||||
|                           -1, -1, | ||||
|                           -1, -1, | ||||
|                           -1, -1); | ||||
| } | ||||
|  | ||||
| void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                              long size0_, long stride0_, | ||||
|                              long size1_, long stride1_) | ||||
| { | ||||
|   THTensor_(setStorage4d)(self, storage_, storageOffset_, | ||||
|                           size0_, stride0_, | ||||
|                           size1_, stride1_, | ||||
|                           -1, -1, | ||||
|                           -1, -1); | ||||
| } | ||||
|  | ||||
| void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                              long size0_, long stride0_, | ||||
|                              long size1_, long stride1_, | ||||
|                              long size2_, long stride2_) | ||||
| { | ||||
|   THTensor_(setStorage4d)(self, storage_, storageOffset_, | ||||
|                           size0_, stride0_, | ||||
|                           size1_, stride1_, | ||||
|                           size2_, stride2_, | ||||
|                           -1, -1); | ||||
| } | ||||
|  | ||||
| void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                              long size0_, long stride0_, | ||||
|                              long size1_, long stride1_, | ||||
|                              long size2_, long stride2_, | ||||
|                              long size3_, long stride3_) | ||||
| { | ||||
|  | ||||
|   long size[4] = {size0_, size1_, size2_, size3_}; | ||||
|   long stride[4] = {stride0_, stride1_, stride2_, stride3_}; | ||||
|  | ||||
|   THTensor_(rawSet)(self, storage_, storageOffset_, 4, size, stride);   | ||||
| } | ||||
|  | ||||
|  | ||||
| void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension, long firstIndex, long size) | ||||
| { | ||||
|   if(!src) | ||||
|     src = self; | ||||
|  | ||||
|   THArgCheck( (dimension >= 0) && (dimension < src->nDimension), 3, "out of range"); | ||||
|   THArgCheck( (firstIndex >= 0) && (firstIndex < src->size[dimension]), 4, "out of range"); | ||||
|   THArgCheck( (size > 0) && (firstIndex+size <= src->size[dimension]), 5, "out of range"); | ||||
|    | ||||
|   THTensor_(set)(self, src); | ||||
|  | ||||
|   if(firstIndex > 0) | ||||
|     self->storageOffset += firstIndex*self->stride[dimension]; | ||||
|    | ||||
|   self->size[dimension] = size; | ||||
| } | ||||
|  | ||||
| void THTensor_(select)(THTensor *self, THTensor *src, int dimension, long sliceIndex) | ||||
| { | ||||
|   int d; | ||||
|  | ||||
|   if(!src) | ||||
|     src = self; | ||||
|  | ||||
|   THArgCheck(src->nDimension > 1, 1, "cannot select on a vector"); | ||||
|   THArgCheck((dimension >= 0) && (dimension < src->nDimension), 3, "out of range"); | ||||
|   THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 4, "out of range"); | ||||
|  | ||||
|   THTensor_(set)(self, src); | ||||
|   THTensor_(narrow)(self, NULL, dimension, sliceIndex, 1); | ||||
|   for(d = dimension; d < self->nDimension-1; d++) | ||||
|   { | ||||
|     self->size[d] = self->size[d+1]; | ||||
|     self->stride[d] = self->stride[d+1]; | ||||
|   } | ||||
|   self->nDimension--; | ||||
| } | ||||
|  | ||||
| void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1, int dimension2) | ||||
| { | ||||
|   long z; | ||||
|  | ||||
|   if(!src) | ||||
|     src = self; | ||||
|  | ||||
|   THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 1, "out of range"); | ||||
|   THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 2, "out of range"); | ||||
|  | ||||
|   THTensor_(set)(self, src); | ||||
|  | ||||
|   if(dimension1 == dimension2) | ||||
| 	  return; | ||||
|   | ||||
|   z = self->stride[dimension1]; | ||||
|   self->stride[dimension1] = self->stride[dimension2]; | ||||
|   self->stride[dimension2] = z; | ||||
|   z = self->size[dimension1]; | ||||
|   self->size[dimension1] = self->size[dimension2]; | ||||
|   self->size[dimension2] = z; | ||||
| } | ||||
|  | ||||
| void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension, long size, long step) | ||||
| { | ||||
|   long *newSize; | ||||
|   long *newStride; | ||||
|   int d; | ||||
|  | ||||
|   if(!src) | ||||
|     src = self; | ||||
|  | ||||
|   THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor"); | ||||
|   THArgCheck(dimension < src->nDimension, 2, "out of range"); | ||||
|   THArgCheck(size <= src->size[dimension], 3, "out of range"); | ||||
|   THArgCheck(step > 0, 4, "invalid step"); | ||||
|  | ||||
|   THTensor_(set)(self, src); | ||||
|  | ||||
|   newSize = THAlloc(sizeof(long)*(self->nDimension+1)); | ||||
|   newStride = THAlloc(sizeof(long)*(self->nDimension+1)); | ||||
|  | ||||
|   newSize[self->nDimension] = size; | ||||
|   newStride[self->nDimension] = self->stride[dimension]; | ||||
|   for(d = 0; d < self->nDimension; d++) | ||||
|   { | ||||
|     if(d == dimension) | ||||
|     { | ||||
|       newSize[d] = (self->size[d] - size) / step + 1; | ||||
|       newStride[d] = step*self->stride[d]; | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|       newSize[d] = self->size[d]; | ||||
|       newStride[d] = self->stride[d]; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   THFree(self->size); | ||||
|   THFree(self->stride); | ||||
|  | ||||
|   self->size = newSize; | ||||
|   self->stride = newStride; | ||||
|   self->nDimension++; | ||||
| } | ||||
|  | ||||
| /* we have to handle the case where the result is a number */ | ||||
| void THTensor_(squeeze)(THTensor *self, THTensor *src) | ||||
| { | ||||
|   int ndim = 0; | ||||
|   int d; | ||||
|  | ||||
|   if(!src) | ||||
|     src = self; | ||||
|  | ||||
|   THTensor_(set)(self, src); | ||||
|  | ||||
|   for(d = 0; d < src->nDimension; d++) | ||||
|   { | ||||
|     if(src->size[d] != 1) | ||||
|     { | ||||
|       if(d != ndim) | ||||
|       { | ||||
|         self->size[ndim] = src->size[d]; | ||||
|         self->stride[ndim] = src->stride[d]; | ||||
|       } | ||||
|       ndim++; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   /* right now, we do not handle 0-dimension tensors */ | ||||
|   if(ndim == 0 && src->nDimension > 0) | ||||
|   { | ||||
|     self->size[0] = 1; | ||||
|     self->stride[0] = 1; | ||||
|     ndim = 1; | ||||
|   } | ||||
|   self->nDimension = ndim; | ||||
| } | ||||
|  | ||||
| void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension) | ||||
| { | ||||
|   int d; | ||||
|  | ||||
|   if(!src) | ||||
|     src = self; | ||||
|  | ||||
|   THArgCheck(dimension < src->nDimension, 3, "dimension out of range"); | ||||
|  | ||||
|   THTensor_(set)(self, src); | ||||
|  | ||||
|   if(src->size[dimension] == 1 && src->nDimension > 1) | ||||
|   { | ||||
|     for(d = dimension; d < self->nDimension-1; d++) | ||||
|     { | ||||
|       self->size[d] = self->size[d+1]; | ||||
|       self->stride[d] = self->stride[d+1]; | ||||
|     } | ||||
|     self->nDimension--; | ||||
|   } | ||||
| } | ||||
|  | ||||
| int THTensor_(isContiguous)(THTensor *self) | ||||
| { | ||||
|   long z = 1; | ||||
|   int d; | ||||
|   for(d = self->nDimension-1; d >= 0; d--) | ||||
|   { | ||||
|     if(self->size[d] != 1) | ||||
|     { | ||||
|       if(self->stride[d] == z) | ||||
|         z *= self->size[d]; | ||||
|       else | ||||
|         return 0; | ||||
|     } | ||||
|   } | ||||
|   return 1; | ||||
| } | ||||
|  | ||||
| long THTensor_(nElement)(THTensor *self) | ||||
| { | ||||
|   if(self->nDimension == 0) | ||||
|     return 0; | ||||
|   else | ||||
|   { | ||||
|     long nElement = 1; | ||||
|     int d; | ||||
|     for(d = 0; d < self->nDimension; d++) | ||||
|       nElement *= self->size[d]; | ||||
|     return nElement; | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THTensor_(retain)(THTensor *self) | ||||
| { | ||||
|   if(self->flag & TH_TENSOR_REFCOUNTED) | ||||
|     ++self->refcount; | ||||
| } | ||||
|  | ||||
| void THTensor_(free)(THTensor *self) | ||||
| { | ||||
|   if(!self) | ||||
|     return; | ||||
|  | ||||
|   if(self->flag & TH_TENSOR_REFCOUNTED) | ||||
|   { | ||||
|     if(--self->refcount == 0) | ||||
|     { | ||||
|       THFree(self->size); | ||||
|       THFree(self->stride); | ||||
|       if(self->storage) | ||||
|         THStorage_(free)(self->storage); | ||||
|       THFree(self); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst) | ||||
| { | ||||
|   if(self != dst) | ||||
|     THTensor_(copy)(dst, self); | ||||
|  | ||||
|   THTensor_(free)(self); | ||||
| } | ||||
|  | ||||
| /*******************************************************************************/ | ||||
|  | ||||
| static void THTensor_(rawInit)(THTensor *self) | ||||
| { | ||||
|   self->refcount = 1; | ||||
|   self->storage = NULL; | ||||
|   self->storageOffset = 0; | ||||
|   self->size = NULL; | ||||
|   self->stride = NULL; | ||||
|   self->nDimension = 0;     | ||||
|   self->flag = TH_TENSOR_REFCOUNTED; | ||||
| } | ||||
|  | ||||
| static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride) | ||||
| { | ||||
|   /* storage */ | ||||
|   if(self->storage != storage) | ||||
|   { | ||||
|     if(self->storage) | ||||
|       THStorage_(free)(self->storage); | ||||
|  | ||||
|     if(storage) | ||||
|     { | ||||
|       self->storage = storage; | ||||
|       THStorage_(retain)(self->storage); | ||||
|     } | ||||
|     else | ||||
|       self->storage = NULL; | ||||
|   } | ||||
|  | ||||
|   /* storageOffset */ | ||||
|   if(storageOffset < 0) | ||||
|     THError("Tensor: invalid storage offset"); | ||||
|   self->storageOffset = storageOffset; | ||||
|  | ||||
|   /* size and stride */ | ||||
|   THTensor_(rawResize)(self, nDimension, size, stride); | ||||
| } | ||||
|  | ||||
| static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride) | ||||
| { | ||||
|   int d; | ||||
|   int nDimension_; | ||||
|   long totalSize; | ||||
|  | ||||
|   nDimension_ = 0; | ||||
|   for(d = 0; d < nDimension; d++) | ||||
|   { | ||||
|     if(size[d] > 0) | ||||
|       nDimension_++; | ||||
|     else | ||||
|       break; | ||||
|   } | ||||
|   nDimension = nDimension_; | ||||
|  | ||||
|   if(nDimension > 0) | ||||
|   { | ||||
|     if(nDimension != self->nDimension) | ||||
|     { | ||||
|       self->size = THRealloc(self->size, sizeof(long)*nDimension); | ||||
|       self->stride = THRealloc(self->stride, sizeof(long)*nDimension); | ||||
|       self->nDimension = nDimension; | ||||
|     } | ||||
|    | ||||
|     totalSize = 1; | ||||
|     for(d = self->nDimension-1; d >= 0; d--) | ||||
|     { | ||||
|       self->size[d] = size[d]; | ||||
|       if(stride && (stride[d] >= 0) ) | ||||
|         self->stride[d] = stride[d]; | ||||
|       else | ||||
|       { | ||||
|         if(d == self->nDimension-1) | ||||
|           self->stride[d] = 1; | ||||
|         else | ||||
|           self->stride[d] = self->size[d+1]*self->stride[d+1]; | ||||
|       } | ||||
|       totalSize += (self->size[d]-1)*self->stride[d];       | ||||
|     } | ||||
|  | ||||
|     if(totalSize+self->storageOffset > 0) | ||||
|     { | ||||
|       if(!self->storage) | ||||
|         self->storage = THStorage_(new)();     | ||||
|       if(totalSize+self->storageOffset > self->storage->size) | ||||
|         THStorage_(resize)(self->storage, totalSize+self->storageOffset); | ||||
|     } | ||||
|   } | ||||
|   else | ||||
|     self->nDimension = 0; | ||||
| } | ||||
|  | ||||
| void THTensor_(set1d)(THTensor *tensor, long x0, real value) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension"); | ||||
|   THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range"); | ||||
|   THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0], value); | ||||
| } | ||||
|  | ||||
| real THTensor_(get1d)(THTensor *tensor, long x0) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension"); | ||||
|   THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range"); | ||||
|   return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]); | ||||
| } | ||||
|  | ||||
| void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions"); | ||||
|   THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range"); | ||||
|   THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1], value); | ||||
| } | ||||
|  | ||||
| real THTensor_(get2d)(THTensor *tensor, long x0, long x1) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions"); | ||||
|   THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range"); | ||||
|   return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]); | ||||
| } | ||||
|  | ||||
| void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions"); | ||||
|   THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range"); | ||||
|   THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2], value); | ||||
| } | ||||
|  | ||||
| real THTensor_(get3d)(THTensor *tensor, long x0, long x1, long x2) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions"); | ||||
|   THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range"); | ||||
|   return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]); | ||||
| } | ||||
|  | ||||
| void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions"); | ||||
|   THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range"); | ||||
|   THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3], value); | ||||
| } | ||||
|  | ||||
| real THTensor_(get4d)(THTensor *tensor, long x0, long x1, long x2, long x3) | ||||
| { | ||||
|   THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions"); | ||||
|   THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range"); | ||||
|   return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3]); | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										123
									
								
								lib/TH/generic/THTensor.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								lib/TH/generic/THTensor.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,123 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensor.h" | ||||
| #else | ||||
|  | ||||
| /* a la lua? dim, storageoffset, ...  et les methodes ? */ | ||||
|  | ||||
| #define TH_TENSOR_REFCOUNTED 1 | ||||
|  | ||||
| typedef struct THTensor | ||||
| { | ||||
|     long *size; | ||||
|     long *stride; | ||||
|     int nDimension; | ||||
|      | ||||
|     THStorage *storage; | ||||
|     long storageOffset; | ||||
|     int refcount; | ||||
|  | ||||
|     char flag; | ||||
|  | ||||
| } THTensor; | ||||
|  | ||||
|  | ||||
| /**** access methods ****/ | ||||
| TH_API THStorage* THTensor_(storage)(THTensor *self); | ||||
| TH_API long THTensor_(storageOffset)(THTensor *self); | ||||
| TH_API int THTensor_(nDimension)(THTensor *self); | ||||
| TH_API long THTensor_(size)(THTensor *self, int dim); | ||||
| TH_API long THTensor_(stride)(THTensor *self, int dim); | ||||
| TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self); | ||||
| TH_API THLongStorage *THTensor_(newStrideOf)(THTensor *self); | ||||
| TH_API real *THTensor_(data)(THTensor *self); | ||||
|  | ||||
| TH_API void THTensor_(setFlag)(THTensor *self, const char flag); | ||||
| TH_API void THTensor_(clearFlag)(THTensor *self, const char flag); | ||||
|  | ||||
|  | ||||
| /**** creation methods ****/ | ||||
| TH_API THTensor *THTensor_(new)(void); | ||||
| TH_API THTensor *THTensor_(newWithTensor)(THTensor *tensor); | ||||
| /* stride might be NULL */ | ||||
| TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_); | ||||
| TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, long storageOffset_, | ||||
|                                 long size0_, long stride0_); | ||||
| TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, long storageOffset_, | ||||
|                                 long size0_, long stride0_, | ||||
|                                 long size1_, long stride1_); | ||||
| TH_API THTensor *THTensor_(newWithStorage3d)(THStorage *storage_, long storageOffset_, | ||||
|                                 long size0_, long stride0_, | ||||
|                                 long size1_, long stride1_, | ||||
|                                 long size2_, long stride2_); | ||||
| TH_API THTensor *THTensor_(newWithStorage4d)(THStorage *storage_, long storageOffset_, | ||||
|                                 long size0_, long stride0_, | ||||
|                                 long size1_, long stride1_, | ||||
|                                 long size2_, long stride2_, | ||||
|                                 long size3_, long stride3_); | ||||
|  | ||||
| /* stride might be NULL */ | ||||
| TH_API THTensor *THTensor_(newWithSize)(THLongStorage *size_, THLongStorage *stride_); | ||||
| TH_API THTensor *THTensor_(newWithSize1d)(long size0_); | ||||
| TH_API THTensor *THTensor_(newWithSize2d)(long size0_, long size1_); | ||||
| TH_API THTensor *THTensor_(newWithSize3d)(long size0_, long size1_, long size2_); | ||||
| TH_API THTensor *THTensor_(newWithSize4d)(long size0_, long size1_, long size2_, long size3_); | ||||
|  | ||||
| TH_API THTensor *THTensor_(newClone)(THTensor *self); | ||||
| TH_API THTensor *THTensor_(newContiguous)(THTensor *tensor); | ||||
| TH_API THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_); | ||||
| TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_); | ||||
| TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_); | ||||
| TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_); | ||||
|    | ||||
| TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride); | ||||
| TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src); | ||||
| TH_API void THTensor_(resize1d)(THTensor *tensor, long size0_); | ||||
| TH_API void THTensor_(resize2d)(THTensor *tensor, long size0_, long size1_); | ||||
| TH_API void THTensor_(resize3d)(THTensor *tensor, long size0_, long size1_, long size2_); | ||||
| TH_API void THTensor_(resize4d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_); | ||||
| TH_API void THTensor_(resize5d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_); | ||||
|  | ||||
| TH_API void THTensor_(set)(THTensor *self, THTensor *src); | ||||
| TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_); | ||||
| TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                                     long size0_, long stride0_); | ||||
| TH_API void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                                     long size0_, long stride0_, | ||||
|                                     long size1_, long stride1_); | ||||
| TH_API void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                                     long size0_, long stride0_, | ||||
|                                     long size1_, long stride1_, | ||||
|                                     long size2_, long stride2_); | ||||
| TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_, | ||||
|                                     long size0_, long stride0_, | ||||
|                                     long size1_, long stride1_, | ||||
|                                     long size2_, long stride2_, | ||||
|                                     long size3_, long stride3_); | ||||
|  | ||||
| TH_API void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension_, long firstIndex_, long size_); | ||||
| TH_API void THTensor_(select)(THTensor *self, THTensor *src, int dimension_, long sliceIndex_); | ||||
| TH_API void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1_, int dimension2_); | ||||
| TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, long size_, long step_); | ||||
|  | ||||
| TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src); | ||||
| TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_); | ||||
|      | ||||
| TH_API int THTensor_(isContiguous)(THTensor *self); | ||||
| TH_API long THTensor_(nElement)(THTensor *self); | ||||
|  | ||||
| TH_API void THTensor_(retain)(THTensor *self); | ||||
| TH_API void THTensor_(free)(THTensor *self); | ||||
| TH_API void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst); | ||||
|  | ||||
| /* Slow access methods [check everything] */ | ||||
| TH_API void THTensor_(set1d)(THTensor *tensor, long x0, real value); | ||||
| TH_API void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value); | ||||
| TH_API void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value); | ||||
| TH_API void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value); | ||||
|  | ||||
| TH_API real THTensor_(get1d)(THTensor *tensor, long x0); | ||||
| TH_API real THTensor_(get2d)(THTensor *tensor, long x0, long x1); | ||||
| TH_API real THTensor_(get3d)(THTensor *tensor, long x0, long x1, long x2); | ||||
| TH_API real THTensor_(get4d)(THTensor *tensor, long x0, long x1, long x2, long x3); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										1489
									
								
								lib/TH/generic/THTensorConv.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1489
									
								
								lib/TH/generic/THTensorConv.c
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										78
									
								
								lib/TH/generic/THTensorConv.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								lib/TH/generic/THTensorConv.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,78 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorConv.h" | ||||
| #else | ||||
|  | ||||
|  | ||||
| TH_API void THTensor_(validXCorr2Dptr)(real *r_, | ||||
|                                     real alpha, | ||||
|                                     real *t_, long ir, long ic, | ||||
|                                     real *k_, long kr, long kc, | ||||
|                                     long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(validConv2Dptr)(real *r_, | ||||
|                                    real alpha, | ||||
|                                    real *t_, long ir, long ic, | ||||
|                                    real *k_, long kr, long kc, | ||||
|                                    long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(fullXCorr2Dptr)(real *r_, | ||||
|                                    real alpha, | ||||
|                                    real *t_, long ir, long ic, | ||||
|                                    real *k_, long kr, long kc, | ||||
|                                    long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(fullConv2Dptr)(real *r_, | ||||
|                                   real alpha, | ||||
|                                   real *t_, long ir, long ic, | ||||
|                                   real *k_, long kr, long kc, | ||||
|                                   long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(validXCorr2DRevptr)(real *r_, | ||||
|                                        real alpha, | ||||
|                                        real *t_, long ir, long ic, | ||||
|                                        real *k_, long kr, long kc, | ||||
|                                        long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(conv2DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol); | ||||
| TH_API void THTensor_(conv2Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); | ||||
| TH_API void THTensor_(conv2Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); | ||||
| TH_API void THTensor_(conv2Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); | ||||
| TH_API void THTensor_(conv2Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); | ||||
|  | ||||
| TH_API void THTensor_(validXCorr3Dptr)(real *r_, | ||||
|                                     real alpha, | ||||
|                                     real *t_, long it, long ir, long ic, | ||||
|                                     real *k_, long kt, long kr, long kc, | ||||
|                                     long st, long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(validConv3Dptr)(real *r_, | ||||
|                                    real alpha, | ||||
|                                    real *t_, long it, long ir, long ic, | ||||
|                                    real *k_, long kt, long kr, long kc, | ||||
|                                    long st, long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(fullXCorr3Dptr)(real *r_, | ||||
|                                    real alpha, | ||||
|                                    real *t_, long it, long ir, long ic, | ||||
|                                    real *k_, long kt, long kr, long kc, | ||||
|                                    long st, long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(fullConv3Dptr)(real *r_, | ||||
|                                   real alpha, | ||||
|                                   real *t_, long it, long ir, long ic, | ||||
|                                   real *k_, long kt, long kr, long kc, | ||||
|                                   long st, long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(validXCorr3DRevptr)(real *r_, | ||||
|                                        real alpha,  | ||||
|                                        real *t_, long it, long ir, long ic, | ||||
|                                        real *k_, long kt, long kr, long kc, | ||||
|                                        long st, long sr, long sc); | ||||
|  | ||||
| TH_API void THTensor_(conv3DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol); | ||||
| TH_API void THTensor_(conv3Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); | ||||
| TH_API void THTensor_(conv3Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); | ||||
| TH_API void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); | ||||
| TH_API void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										21
									
								
								lib/TH/generic/THTensorCopy.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								lib/TH/generic/THTensorCopy.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorCopy.c" | ||||
| #else | ||||
|  | ||||
| #define IMPLEMENT_THTensor_COPY(TYPENAMESRC, TYPE_SRC) \ | ||||
| void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \ | ||||
| { \ | ||||
|   TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \ | ||||
| } | ||||
|  | ||||
| IMPLEMENT_THTensor_COPY(, real) | ||||
|  | ||||
| IMPLEMENT_THTensor_COPY(Byte, unsigned char) | ||||
| IMPLEMENT_THTensor_COPY(Char, char) | ||||
| IMPLEMENT_THTensor_COPY(Short, short) | ||||
| IMPLEMENT_THTensor_COPY(Int, int) | ||||
| IMPLEMENT_THTensor_COPY(Long, long) | ||||
| IMPLEMENT_THTensor_COPY(Float, float) | ||||
| IMPLEMENT_THTensor_COPY(Double, double) | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										16
									
								
								lib/TH/generic/THTensorCopy.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								lib/TH/generic/THTensorCopy.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorCopy.h" | ||||
| #else | ||||
|  | ||||
| /* Support for copy between different Tensor types */ | ||||
|  | ||||
| TH_API void THTensor_(copy)(THTensor *tensor, THTensor *src); | ||||
| TH_API void THTensor_(copyByte)(THTensor *tensor, struct THByteTensor *src); | ||||
| TH_API void THTensor_(copyChar)(THTensor *tensor, struct THCharTensor *src); | ||||
| TH_API void THTensor_(copyShort)(THTensor *tensor, struct THShortTensor *src); | ||||
| TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src); | ||||
| TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src); | ||||
| TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src); | ||||
| TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										343
									
								
								lib/TH/generic/THTensorLapack.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								lib/TH/generic/THTensorLapack.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,343 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorLapack.c" | ||||
| #else | ||||
|  | ||||
| static int THTensor_(lapackClone)(THTensor *r_, THTensor *m, int forced) | ||||
| { | ||||
|   int clone; | ||||
|  | ||||
|   if (!forced && m->stride[0] == 1 && m->stride[1] == m->size[0]) | ||||
|   { | ||||
|     clone = 0; | ||||
|     THTensor_(set)(r_,m); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     clone = 1; | ||||
|     /* we need to copy */ | ||||
|     THTensor_(resize2d)(r_,m->size[1],m->size[0]); | ||||
|     THTensor_(transpose)(r_,NULL,0,1); | ||||
|     THTensor_(copy)(r_,m); | ||||
|   } | ||||
|   return clone; | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) | ||||
| { | ||||
|   int n, nrhs, lda, ldb, info; | ||||
|   THIntTensor *ipiv; | ||||
|   THTensor *ra__; | ||||
|   THTensor *rb__; | ||||
|  | ||||
|   int clonea; | ||||
|   int cloneb; | ||||
|   int destroya; | ||||
|   int destroyb; | ||||
|  | ||||
|    | ||||
|   if (a == NULL || ra_ == a) /* possibly destroy the inputs  */ | ||||
|   { | ||||
|     ra__ = THTensor_(new)(); | ||||
|     clonea = THTensor_(lapackClone)(ra__,ra_,0); | ||||
|     destroya = 1; | ||||
|   } | ||||
|   else /*we want to definitely clone and use ra_ and rb_ as computational space*/ | ||||
|   { | ||||
|     clonea = THTensor_(lapackClone)(ra_,a,1); | ||||
|     ra__ = ra_; | ||||
|     destroya = 0; | ||||
|   } | ||||
|   if (b == NULL || rb_ == b) /* possibly destroy the inputs  */ | ||||
|   { | ||||
|     rb__ = THTensor_(new)(); | ||||
|     cloneb = THTensor_(lapackClone)(rb__,rb_,0); | ||||
|     destroyb = 1; | ||||
|   } | ||||
|   else /*we want to definitely clone and use ra_ and rb_ as computational space*/ | ||||
|   { | ||||
|     cloneb = THTensor_(lapackClone)(rb_,b,1); | ||||
|     rb__ = rb_; | ||||
|     destroyb = 0; | ||||
|   } | ||||
|  | ||||
|   THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional"); | ||||
|   THArgCheck(rb__->nDimension == 2, 2, "b should be 2 dimensional"); | ||||
|   THArgCheck(ra__->size[0] == ra__->size[1], 1, "A should be square"); | ||||
|   THArgCheck(rb__->size[0] == ra__->size[0], 2, "A,b size incomptable"); | ||||
|  | ||||
|   n    = (int)ra__->size[0]; | ||||
|   nrhs = (int)rb__->size[1]; | ||||
|   lda  = n; | ||||
|   ldb  = n; | ||||
|  | ||||
|   ipiv = THIntTensor_newWithSize1d((long)n); | ||||
|   THLapack_(gesv)(n, nrhs,  | ||||
| 		  THTensor_(data)(ra__), lda, THIntTensor_data(ipiv), | ||||
| 		  THTensor_(data)(rb__), ldb, &info); | ||||
|  | ||||
|   /* clean up */ | ||||
|   if (destroya) | ||||
|   { | ||||
|     if (clonea) | ||||
|     { | ||||
|       THTensor_(copy)(ra_,ra__); | ||||
|     } | ||||
|     THTensor_(free)(ra__); | ||||
|   } | ||||
|   if (destroyb) | ||||
|   { | ||||
|     if (cloneb) | ||||
|     { | ||||
|       THTensor_(copy)(rb_,rb__); | ||||
|     } | ||||
|     THTensor_(free)(rb__); | ||||
|   } | ||||
|  | ||||
|   if (info < 0) | ||||
|   { | ||||
|     THError("Lapack gesv : Argument %d : illegal value", -info); | ||||
|   } | ||||
|   else if (info > 0) | ||||
|   { | ||||
|     THError("Lapack gesv : U(%d,%d) is zero, singular U.", info,info); | ||||
|   } | ||||
|  | ||||
|   THIntTensor_free(ipiv); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) | ||||
| { | ||||
|   int m, n, nrhs, lda, ldb, info, lwork; | ||||
|   char transpose; | ||||
|   THTensor *work = NULL; | ||||
|   real wkopt = 0; | ||||
|  | ||||
|   THTensor *ra__; | ||||
|   THTensor *rb__; | ||||
|  | ||||
|   int clonea; | ||||
|   int cloneb; | ||||
|   int destroya; | ||||
|   int destroyb; | ||||
|  | ||||
|    | ||||
|   if (a == NULL || ra_ == a) /* possibly destroy the inputs  */ | ||||
|   { | ||||
|     ra__ = THTensor_(new)(); | ||||
|     clonea = THTensor_(lapackClone)(ra__,ra_,0); | ||||
|     destroya = 1; | ||||
|   } | ||||
|   else /*we want to definitely clone and use ra_ and rb_ as computational space*/ | ||||
|   { | ||||
|     clonea = THTensor_(lapackClone)(ra_,a,1); | ||||
|     ra__ = ra_; | ||||
|     destroya = 0; | ||||
|   } | ||||
|   if (b == NULL || rb_ == b) /* possibly destroy the inputs  */ | ||||
|   { | ||||
|     rb__ = THTensor_(new)(); | ||||
|     cloneb = THTensor_(lapackClone)(rb__,rb_,0); | ||||
|     destroyb = 1; | ||||
|   } | ||||
|   else /*we want to definitely clone and use ra_ and rb_ as computational space*/ | ||||
|   { | ||||
|     cloneb = THTensor_(lapackClone)(rb_,b,1); | ||||
|     rb__ = rb_; | ||||
|     destroyb = 0; | ||||
|   } | ||||
|    | ||||
|   THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional"); | ||||
|   THArgCheck(ra_->size[0] == rb__->size[0], 2, "size incompatible A,b"); | ||||
|  | ||||
|   m = ra__->size[0]; | ||||
|   n = ra__->size[1]; | ||||
|   nrhs = rb__->size[1]; | ||||
|   lda = m; | ||||
|   ldb = m; | ||||
|   info = 0; | ||||
|  | ||||
|   // get optimal workspace size | ||||
|   THLapack_(gels)('N', m, n, nrhs, THTensor_(data)(ra__), lda,  | ||||
| 		  THTensor_(data)(rb__), ldb,  | ||||
| 		  &wkopt, -1, &info); | ||||
|   lwork = (int)wkopt; | ||||
|   work = THTensor_(newWithSize1d)(lwork); | ||||
|   THLapack_(gels)('N', m, n, nrhs, THTensor_(data)(ra__), lda,  | ||||
| 		  THTensor_(data)(rb__), ldb,  | ||||
| 		  THTensor_(data)(work), lwork, &info); | ||||
|  | ||||
|   //printf("lwork = %d,%g\n",lwork,THTensor_(data)(work)[0]); | ||||
|   if (info != 0) | ||||
|   { | ||||
|     THError("Lapack gels : Argument %d : illegal value", -info); | ||||
|   } | ||||
|   /* clean up */ | ||||
|   if (destroya) | ||||
|   { | ||||
|     if (clonea) | ||||
|     { | ||||
|       THTensor_(copy)(ra_,ra__); | ||||
|     } | ||||
|     THTensor_(free)(ra__); | ||||
|   } | ||||
|   if (destroyb) | ||||
|   { | ||||
|     if (cloneb) | ||||
|     { | ||||
|       THTensor_(copy)(rb_,rb__); | ||||
|     } | ||||
|     THTensor_(free)(rb__); | ||||
|   } | ||||
|   THTensor_(free)(work); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz, const char *uplo) | ||||
| { | ||||
|   int n, lda, lwork, info; | ||||
|   THTensor *work; | ||||
|   real wkopt; | ||||
|  | ||||
|   THTensor *rv__; | ||||
|  | ||||
|   int clonea; | ||||
|   int destroy; | ||||
|    | ||||
|   if (a == NULL) /* possibly destroy the inputs  */ | ||||
|   { | ||||
|     rv__ = THTensor_(new)(); | ||||
|     clonea = THTensor_(lapackClone)(rv__,rv_,0); | ||||
|     destroy = 1; | ||||
|   } | ||||
|   else /*we want to definitely clone and use ra_ and rb_ as computational space*/ | ||||
|   { | ||||
|     clonea = THTensor_(lapackClone)(rv_,a,1); | ||||
|     rv__ = rv_; | ||||
|     destroy = 0; | ||||
|   } | ||||
|  | ||||
|   THArgCheck(rv__->nDimension == 2, 2, "A should be 2 dimensional"); | ||||
|  | ||||
|   n = rv__->size[0]; | ||||
|   lda = n; | ||||
|  | ||||
|   THTensor_(resize1d)(re_,n); | ||||
|  | ||||
|   // get optimal workspace size | ||||
|   THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda, | ||||
| 		  THTensor_(data)(re_), &wkopt, -1, &info); | ||||
|   lwork = (int)wkopt; | ||||
|   work = THTensor_(newWithSize1d)(lwork); | ||||
|   THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda, | ||||
| 		  THTensor_(data)(re_), THTensor_(data)(work), lwork, &info); | ||||
|  | ||||
|   if (info > 0) | ||||
|   { | ||||
|     THError(" Lapack syev : Failed to converge. %d off-diagonal elements of an didn't converge to zero",info); | ||||
|   } | ||||
|   else if (info < 0) | ||||
|   { | ||||
|     THError("Lapack syev : Argument %d : illegal value", -info); | ||||
|   } | ||||
|   /* clean up */ | ||||
|   if (destroy) | ||||
|   { | ||||
|     if (clonea) | ||||
|     { | ||||
|       THTensor_(copy)(rv_,rv__); | ||||
|     } | ||||
|     THTensor_(free)(rv__); | ||||
|   } | ||||
|   THTensor_(free)(work); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char* jobu) | ||||
| { | ||||
|   THTensor *ra_ = THTensor_(new)(); | ||||
|   THTensor_(gesvd2)(ru_, rs_, rv_,  ra_, a, jobu); | ||||
|   THTensor_(free)(ra_); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char* jobu) | ||||
| { | ||||
|   int k,m, n, lda, ldu, ldvt, lwork, info; | ||||
|   THTensor *work; | ||||
|   real wkopt; | ||||
|  | ||||
|   THTensor *ra__; | ||||
|  | ||||
|   int clonea; | ||||
|   int destroy; | ||||
|  | ||||
|   if (a == NULL) /* possibly destroy the inputs  */ | ||||
|   { | ||||
|     ra__ = THTensor_(new)(); | ||||
|     clonea = THTensor_(lapackClone)(ra__,ra_,0); | ||||
|     destroy = 1; | ||||
|   } | ||||
|   else /*we want to definitely clone */ | ||||
|   { | ||||
|     clonea = THTensor_(lapackClone)(ra_,a,1); | ||||
|     ra__ = ra_; | ||||
|     destroy = 0; | ||||
|   } | ||||
|    | ||||
|   THArgCheck(ra__->nDimension == 2, 2, "A should be 2 dimensional"); | ||||
|  | ||||
|   m = ra__->size[0]; | ||||
|   n = ra__->size[1]; | ||||
|   k = (m < n ? m : n); | ||||
|  | ||||
|   lda = m; | ||||
|   ldu = m; | ||||
|   ldvt = n; | ||||
|   THTensor_(resize1d)(rs_,k); | ||||
|   THTensor_(resize2d)(rv_,ldvt,n); | ||||
|   if (*jobu == 'A') | ||||
|   { | ||||
|     THTensor_(resize2d)(ru_,m,ldu); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|     THTensor_(resize2d)(ru_,k,ldu); | ||||
|   } | ||||
|   THTensor_(transpose)(ru_,NULL,0,1); | ||||
|   THTensor_(transpose)(rv_,NULL,0,1); | ||||
|  | ||||
|   THLapack_(gesvd)(jobu[0],jobu[0], | ||||
| 		   m,n,THTensor_(data)(ra__),lda, | ||||
| 		   THTensor_(data)(rs_), | ||||
| 		   THTensor_(data)(ru_), | ||||
| 		   ldu, | ||||
| 		   THTensor_(data)(rv_), ldvt, | ||||
| 		   &wkopt, -1, &info); | ||||
|   lwork = (int)wkopt; | ||||
|   work = THTensor_(newWithSize1d)(lwork); | ||||
|   THLapack_(gesvd)(jobu[0],jobu[0], | ||||
| 		   m,n,THTensor_(data)(ra__),lda, | ||||
| 		   THTensor_(data)(rs_), | ||||
| 		   THTensor_(data)(ru_), | ||||
| 		   ldu, | ||||
| 		   THTensor_(data)(rv_), ldvt, | ||||
| 		   THTensor_(data)(work),lwork, &info); | ||||
|   if (info > 0) | ||||
|   { | ||||
|     THError(" Lapack gesvd : %d superdiagonals failed to converge.",info); | ||||
|   } | ||||
|   else if (info < 0) | ||||
|   { | ||||
|     THError("Lapack gesvd : Argument %d : illegal value", -info); | ||||
|   } | ||||
|  | ||||
|   /* clean up */ | ||||
|   if (destroy) | ||||
|   { | ||||
|     if (clonea) | ||||
|     { | ||||
|       THTensor_(copy)(ra_,ra__); | ||||
|     } | ||||
|     THTensor_(free)(ra__); | ||||
|   } | ||||
|   THTensor_(free)(work); | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										11
									
								
								lib/TH/generic/THTensorLapack.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								lib/TH/generic/THTensorLapack.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorLapack.h" | ||||
| #else | ||||
|  | ||||
| TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); | ||||
| TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); | ||||
| TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobz, const char *uplo); | ||||
| TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char *jobu); | ||||
| TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char *jobu); | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										1063
									
								
								lib/TH/generic/THTensorMath.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1063
									
								
								lib/TH/generic/THTensorMath.c
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										90
									
								
								lib/TH/generic/THTensorMath.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								lib/TH/generic/THTensorMath.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,90 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorMath.h" | ||||
| #else | ||||
|  | ||||
| TH_API void THTensor_(fill)(THTensor *r_, real value); | ||||
| TH_API void THTensor_(zero)(THTensor *r_); | ||||
|  | ||||
| TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); | ||||
|    | ||||
| TH_API real THTensor_(minall)(THTensor *t); | ||||
| TH_API real THTensor_(maxall)(THTensor *t); | ||||
| TH_API accreal THTensor_(sumall)(THTensor *t); | ||||
|  | ||||
| TH_API void THTensor_(add)(THTensor *r_, THTensor *t, real value); | ||||
| TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, real value); | ||||
| TH_API void THTensor_(div)(THTensor *r_, THTensor *t, real value); | ||||
|  | ||||
| TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src);   | ||||
| TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src); | ||||
| TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src); | ||||
|  | ||||
| TH_API void THTensor_(addcmul)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2); | ||||
| TH_API void THTensor_(addcdiv)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2); | ||||
|  | ||||
| TH_API void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat,  THTensor *vec); | ||||
| TH_API void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat1, THTensor *mat2); | ||||
| TH_API void THTensor_(addr)(THTensor *r_,  real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2); | ||||
|  | ||||
| TH_API long THTensor_(numel)(THTensor *t); | ||||
| TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(sign)(THTensor *r_, THTensor *t); | ||||
| TH_API accreal THTensor_(trace)(THTensor *t); | ||||
| TH_API void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension); | ||||
|  | ||||
| TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size); | ||||
| TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size); | ||||
| TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k); | ||||
| TH_API void THTensor_(eye)(THTensor *r_, long n, long m); | ||||
| TH_API void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step); | ||||
| TH_API void THTensor_(randperm)(THTensor *r_, long n); | ||||
|  | ||||
| TH_API void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size); | ||||
| TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder); | ||||
| TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k); | ||||
| TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k); | ||||
| TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension); | ||||
|  | ||||
| #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) | ||||
|  | ||||
| TH_API void THTensor_(log)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(exp)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(cos)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(acos)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(cosh)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(sin)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(asin)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(sinh)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(tan)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(atan)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(tanh)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, real value); | ||||
| TH_API void THTensor_(sqrt)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(ceil)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(floor)(THTensor *r_, THTensor *t); | ||||
| TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); | ||||
|  | ||||
| TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension); | ||||
| TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag); | ||||
| TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag); | ||||
| TH_API accreal THTensor_(norm)(THTensor *t, real value); | ||||
| TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, real value); | ||||
|  | ||||
| TH_API accreal THTensor_(meanall)(THTensor *self); | ||||
| TH_API accreal THTensor_(varall)(THTensor *self); | ||||
| TH_API accreal THTensor_(stdall)(THTensor *self); | ||||
|  | ||||
| TH_API void THTensor_(linspace)(THTensor *r_, real a, real b, long n); | ||||
| TH_API void THTensor_(logspace)(THTensor *r_, real a, real b, long n); | ||||
| TH_API void THTensor_(rand)(THTensor *r_, THLongStorage *size); | ||||
| TH_API void THTensor_(randn)(THTensor *r_, THLongStorage *size); | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										65
									
								
								lib/TH/generic/THTensorRandom.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								lib/TH/generic/THTensorRandom.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,65 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorRandom.c" | ||||
| #else | ||||
|  | ||||
| TH_API void THTensor_(random)(THTensor *self) | ||||
| { | ||||
| #if defined(TH_REAL_IS_BYTE) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (unsigned char)(THRandom_random() % (UCHAR_MAX+1));); | ||||
| #elif defined(TH_REAL_IS_CHAR) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (char)(THRandom_random() % (CHAR_MAX+1));); | ||||
| #elif defined(TH_REAL_IS_SHORT) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (short)(THRandom_random() % (SHRT_MAX+1));); | ||||
| #elif defined(TH_REAL_IS_INT) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (int)(THRandom_random() % (INT_MAX+1UL));); | ||||
| #elif defined(TH_REAL_IS_LONG) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (long)(THRandom_random() % (LONG_MAX+1UL));); | ||||
| #elif defined(TH_REAL_IS_FLOAT) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random() % ((1UL << FLT_MANT_DIG)+1));); | ||||
| #elif defined(TH_REAL_IS_DOUBLE) | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random() % ((1UL << DBL_MANT_DIG)+1));); | ||||
| #else | ||||
| #error "Unknown type" | ||||
| #endif | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(geometric)(THTensor *self, double p) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(p);); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(bernoulli)(THTensor *self, double p) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(p);); | ||||
| } | ||||
|  | ||||
| #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) | ||||
|  | ||||
| TH_API void THTensor_(uniform)(THTensor *self, double a, double b) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_uniform(a, b);); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_normal(mean, stdv);); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(exponential)(THTensor *self, double lambda) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(lambda);); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_cauchy(median, sigma);); | ||||
| } | ||||
|  | ||||
| TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv) | ||||
| { | ||||
|   TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_logNormal(mean, stdv);); | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										17
									
								
								lib/TH/generic/THTensorRandom.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								lib/TH/generic/THTensorRandom.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THTensorRandom.h" | ||||
| #else | ||||
|  | ||||
| TH_API void THTensor_(random)(THTensor *self); | ||||
| TH_API void THTensor_(geometric)(THTensor *self, double p); | ||||
| TH_API void THTensor_(bernoulli)(THTensor *self, double p); | ||||
|  | ||||
| #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) | ||||
| TH_API void THTensor_(uniform)(THTensor *self, double a, double b); | ||||
| TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv); | ||||
| TH_API void THTensor_(exponential)(THTensor *self, double lambda); | ||||
| TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma); | ||||
| TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv); | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										84
									
								
								lib/TH/generic/THVector.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								lib/TH/generic/THVector.c
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,84 @@ | ||||
| #ifndef TH_GENERIC_FILE | ||||
| #define TH_GENERIC_FILE "generic/THVector.c" | ||||
| #else | ||||
|  | ||||
| static inline void THVector_(fill)(real *x, const real c, const long n) { | ||||
|   long i = 0; | ||||
|  | ||||
|   for(; i < n-4; i += 4) | ||||
|   { | ||||
|     x[i] = c; | ||||
|     x[i+1] = c; | ||||
|     x[i+2] = c; | ||||
|     x[i+3] = c; | ||||
|   } | ||||
|  | ||||
|   for(; i < n; i++) | ||||
|     x[i] = c; | ||||
| } | ||||
|  | ||||
| static inline void THVector_(add)(real *y, const real *x, const real c, const long n) | ||||
| { | ||||
|   long i = 0; | ||||
|  | ||||
|   for(;i < n-4; i += 4) | ||||
|   { | ||||
|     y[i] += c * x[i]; | ||||
|     y[i+1] += c * x[i+1]; | ||||
|     y[i+2] += c * x[i+2]; | ||||
|     y[i+3] += c * x[i+3]; | ||||
|   } | ||||
|  | ||||
|   for(; i < n; i++) | ||||
|     y[i] += c * x[i]; | ||||
| } | ||||
|  | ||||
| static inline void THVector_(diff)(real *z, const real *x, const real *y, const long n) | ||||
| { | ||||
|   long i = 0; | ||||
|  | ||||
|   for(; i < n-4; i += 4) | ||||
|   { | ||||
|     z[i] = x[i] - y[i]; | ||||
|     z[i+1] + x[i+1] - y[i+1]; | ||||
|     z[i+2] = x[i+2] - y[i+2]; | ||||
|     z[i+3] = x[i+3] - y[i+3]; | ||||
|   } | ||||
|  | ||||
|   for(; i < n; i++) | ||||
|     z[i] = x[i] - y[i]; | ||||
| } | ||||
|  | ||||
| static inline void THVector_(scale)(real *y, const real c, const long n) | ||||
| { | ||||
|   long i = 0; | ||||
|  | ||||
|   for(; i < n-4; i +=4) | ||||
|   { | ||||
|     y[i] *= c; | ||||
|     y[i+1] *= c; | ||||
|     y[i+2] *= c; | ||||
|     y[i+3] *= c; | ||||
|   } | ||||
|  | ||||
|   for(; i < n; i++) | ||||
|     y[i] *= c; | ||||
| } | ||||
|  | ||||
| static inline void THVector_(mul)(real *y, const real *x, const long n) | ||||
| { | ||||
|   long i = 0; | ||||
|  | ||||
|   for(; i < n-4; i += 4) | ||||
|   { | ||||
|     y[i] *= x[i]; | ||||
|     y[i+1] *= x[i+1]; | ||||
|     y[i+2] *= x[i+2]; | ||||
|     y[i+3] *= x[i+3]; | ||||
|   } | ||||
|  | ||||
|   for(; i < n; i++) | ||||
|     y[i] *= x[i]; | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										28
									
								
								lib/luaT/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								lib/luaT/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | ||||
| # -*- cmake -*- | ||||
|  | ||||
| FIND_PACKAGE(Lua REQUIRED) | ||||
|  | ||||
| INCLUDE_DIRECTORIES(${LUA_INCLUDE_DIR}) | ||||
|  | ||||
| ADD_LIBRARY(luaT SHARED luaT.h luaT.c) | ||||
| TARGET_LINK_LIBRARIES(luaT ${LUA_LIBRARIES}) | ||||
|  | ||||
| INSTALL(TARGETS luaT | ||||
|           RUNTIME DESTINATION "${Torch_INSTALL_BIN_SUBDIR}" | ||||
|           LIBRARY DESTINATION "${Torch_INSTALL_LIB_SUBDIR}" | ||||
|           ARCHIVE DESTINATION "${Torch_INSTALL_LIB_SUBDIR}") | ||||
|  | ||||
| INSTALL(FILES luaT.h | ||||
|           DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}") | ||||
|  | ||||
| # Create luaT.cmake | ||||
| GET_TARGET_PROPERTY(LUAT_OUTPUT_NAME luaT LOCATION) | ||||
| GET_FILENAME_COMPONENT(LUAT_OUTPUT_NAME ${LUAT_OUTPUT_NAME} NAME) | ||||
| SET(LUAT_LIBRARIES "${Torch_INSTALL_LIB}/${LUAT_OUTPUT_NAME}") | ||||
| SET(LUAT_INCLUDE_DIR "${Torch_INSTALL_INCLUDE}") | ||||
| CONFIGURE_FILE(luaTConfig.cmake.in "${Torch_BINARY_DIR}/cmake-external/luaTConfig.cmake") | ||||
| INSTALL(FILES "${Torch_BINARY_DIR}/cmake-external/luaTConfig.cmake"  | ||||
|   DESTINATION "${Torch_INSTALL_CMAKE_SUBDIR}") | ||||
|  | ||||
| # luaT help | ||||
| ADD_TORCH_DOK(dok luaT "Torch C Libraries" "luaT" 5.1) | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user