mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
241 lines
7.4 KiB
Lua
241 lines
7.4 KiB
Lua
local File = torch.getmetatable('torch.File')
|
|
|
|
function File:writeBool(value)
|
|
if value then
|
|
self:writeInt(1)
|
|
else
|
|
self:writeInt(0)
|
|
end
|
|
end
|
|
|
|
function File:readBool()
|
|
return (self:readInt() == 1)
|
|
end
|
|
|
|
local TYPE_NIL = 0
|
|
local TYPE_NUMBER = 1
|
|
local TYPE_STRING = 2
|
|
local TYPE_TABLE = 3
|
|
local TYPE_TORCH = 4
|
|
local TYPE_BOOLEAN = 5
|
|
local TYPE_FUNCTION = 6
|
|
|
|
function File:isWritableObject(object)
|
|
local typename = type(object)
|
|
local typeidx
|
|
if type(object) ~= 'boolean' and not object then
|
|
typeidx = TYPE_NIL
|
|
elseif torch.typename(object) and torch.factory(torch.typename(object)) then
|
|
typeidx = TYPE_TORCH
|
|
elseif typename == 'table' then
|
|
typeidx = TYPE_TABLE
|
|
elseif typename == 'number' then
|
|
typeidx = TYPE_NUMBER
|
|
elseif typename == 'string' then
|
|
typeidx = TYPE_STRING
|
|
elseif typename == 'boolean' then
|
|
typeidx = TYPE_BOOLEAN
|
|
elseif typename == 'function' and pcall(string.dump, object) then
|
|
typeidx = TYPE_FUNCTION
|
|
end
|
|
return typeidx
|
|
end
|
|
|
|
function File:writeObject(object)
|
|
-- we use an environment to keep a record of written objects
|
|
if not torch.getenv(self).writeObjects then
|
|
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
|
|
end
|
|
|
|
-- if nil object, only write the type and return
|
|
if type(object) ~= 'boolean' and not object then
|
|
self:writeInt(TYPE_NIL)
|
|
return
|
|
end
|
|
|
|
-- check the type we are dealing with
|
|
local typeidx = self:isWritableObject(object)
|
|
if not typeidx then
|
|
error(string.format('Unwritable object <%s>', type(object)))
|
|
end
|
|
self:writeInt(typeidx)
|
|
|
|
if typeidx == TYPE_NUMBER then
|
|
self:writeDouble(object)
|
|
elseif typeidx == TYPE_BOOLEAN then
|
|
self:writeBool(object)
|
|
elseif typeidx == TYPE_STRING then
|
|
local stringStorage = torch.CharStorage():string(object)
|
|
self:writeInt(#stringStorage)
|
|
self:writeChar(stringStorage)
|
|
elseif typeidx == TYPE_FUNCTION then
|
|
local upvalues = {}
|
|
while true do
|
|
local name,value = debug.getupvalue(object, #upvalues+1)
|
|
if not name then break end
|
|
table.insert(upvalues, value)
|
|
end
|
|
local dumped = string.dump(object)
|
|
local stringStorage = torch.CharStorage():string(dumped)
|
|
self:writeInt(#stringStorage)
|
|
self:writeChar(stringStorage)
|
|
self:writeObject(upvalues)
|
|
elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then
|
|
-- check it exists already (we look at the pointer!)
|
|
local objects = torch.getenv(self).writeObjects
|
|
local objectsRef = torch.getenv(self).writeObjectsRef
|
|
local index = objects[torch.pointer(object)]
|
|
|
|
if index then
|
|
-- if already exists, write only its index
|
|
self:writeInt(index)
|
|
else
|
|
-- else write the object itself
|
|
index = objects.nWriteObject or 0
|
|
index = index + 1
|
|
objects[torch.pointer(object)] = index
|
|
objectsRef[object] = index -- we make sure the object is not going to disappear
|
|
self:writeInt(index)
|
|
objects.nWriteObject = index
|
|
|
|
if typeidx == TYPE_TORCH then
|
|
local version = torch.CharStorage():string('V ' .. torch.version(object))
|
|
local className = torch.CharStorage():string(torch.typename(object))
|
|
self:writeInt(#version)
|
|
self:writeChar(version)
|
|
self:writeInt(#className)
|
|
self:writeChar(className)
|
|
if object.write then
|
|
object:write(self)
|
|
elseif type(object) == 'table' then
|
|
local var = {}
|
|
for k,v in pairs(object) do
|
|
if self:isWritableObject(v) then
|
|
var[k] = v
|
|
else
|
|
print(string.format('$ Warning: cannot write object field <%s>', k))
|
|
end
|
|
end
|
|
self:writeObject(var)
|
|
else
|
|
error(string.format('<%s> is a non-serializable Torch object', torch.typename(object)))
|
|
end
|
|
else -- it is a table
|
|
local size = 0; for k,v in pairs(object) do size = size + 1 end
|
|
self:writeInt(size)
|
|
for k,v in pairs(object) do
|
|
self:writeObject(k)
|
|
self:writeObject(v)
|
|
end
|
|
end
|
|
end
|
|
else
|
|
error('Unwritable object')
|
|
end
|
|
end
|
|
|
|
function File:readObject()
|
|
-- we use an environment to keep a record of read objects
|
|
if not torch.getenv(self).writeObjects then
|
|
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
|
|
end
|
|
|
|
-- read the typeidx
|
|
local typeidx = self:readInt()
|
|
|
|
-- is it nil?
|
|
if typeidx == TYPE_NIL then
|
|
return nil
|
|
end
|
|
|
|
if typeidx == TYPE_NUMBER then
|
|
return self:readDouble()
|
|
elseif typeidx == TYPE_BOOLEAN then
|
|
return self:readBool()
|
|
elseif typeidx == TYPE_STRING then
|
|
local size = self:readInt()
|
|
return self:readChar(size):string()
|
|
elseif typeidx == TYPE_FUNCTION then
|
|
local size = self:readInt()
|
|
local dumped = self:readChar(size):string()
|
|
local func = loadstring(dumped)
|
|
local upvalues = self:readObject()
|
|
for index,upvalue in ipairs(upvalues) do
|
|
debug.setupvalue(func, index, upvalue)
|
|
end
|
|
return func
|
|
elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then
|
|
-- read the index
|
|
local index = self:readInt()
|
|
|
|
-- check it is loaded already
|
|
local objects = torch.getenv(self).readObjects
|
|
if objects[index] then
|
|
return objects[index]
|
|
end
|
|
|
|
-- otherwise read it
|
|
if typeidx == TYPE_TORCH then
|
|
local version, className, versionNumber
|
|
version = self:readChar(self:readInt()):string()
|
|
versionNumber = tonumber(string.match(version, '^V (.*)$'))
|
|
if not versionNumber then
|
|
className = version
|
|
versionNumber = 0 -- file created before existence of versioning system
|
|
else
|
|
className = self:readChar(self:readInt()):string()
|
|
end
|
|
if not torch.factory(className) then
|
|
error(string.format('unknown Torch class <%s>' .. className))
|
|
end
|
|
local object = torch.factory(className)()
|
|
objects[index] = object
|
|
if object.read then
|
|
object:read(self, versionNumber)
|
|
elseif type(object) == 'table' then
|
|
local var = self:readObject(var)
|
|
for k,v in pairs(var) do
|
|
object[k] = v
|
|
end
|
|
else
|
|
error(string.format('Cannot load object class <%s>', className))
|
|
end
|
|
return object
|
|
else -- it is a table
|
|
local size = self:readInt()
|
|
local object = {}
|
|
objects[index] = object
|
|
for i = 1,size do
|
|
local k = self:readObject()
|
|
local v = self:readObject()
|
|
object[k] = v
|
|
end
|
|
return object
|
|
end
|
|
else
|
|
error('unknown object')
|
|
end
|
|
end
|
|
|
|
-- simple helpers to save/load arbitrary objects/tables
|
|
function torch.save(filename, object, mode)
|
|
mode = mode or 'binary'
|
|
local file = torch.DiskFile(filename, 'w')
|
|
file[mode](file)
|
|
file:writeObject(object)
|
|
file:close()
|
|
end
|
|
|
|
function torch.load(filename, mode)
|
|
mode = mode or 'binary'
|
|
local file = torch.DiskFile(filename, 'r')
|
|
file[mode](file)
|
|
local object = file:readObject()
|
|
file:close()
|
|
return object
|
|
end
|
|
|
|
-- public API (saveobj/loadobj are safe for global import)
|
|
torch.saveobj = torch.save
|
|
torch.loadobj = torch.load
|