mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
280 lines
7.8 KiB
Lua
280 lines
7.8 KiB
Lua
-- additional methods for Storage
|
|
local Storage = {}
|
|
|
|
-- additional methods for Tensor
|
|
local Tensor = {}
|
|
|
|
-- types
|
|
local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
|
|
|
|
-- tostring() functions for Tensor and Storage
|
|
local function Storage__printformat(self)
|
|
local intMode = true
|
|
local type = torch.typename(self)
|
|
-- if type == 'torch.FloatStorage' or type == 'torch.DoubleStorage' then
|
|
for i=1,self:size() do
|
|
if self[i] ~= math.ceil(self[i]) then
|
|
intMode = false
|
|
break
|
|
end
|
|
end
|
|
-- end
|
|
local tensor = torch.DoubleTensor(torch.DoubleStorage(self:size()):copy(self), 1, self:size()):abs()
|
|
local expMin = tensor:minall()
|
|
if expMin ~= 0 then
|
|
expMin = math.floor(math.log10(expMin)) + 1
|
|
end
|
|
local expMax = tensor:maxall()
|
|
if expMax ~= 0 then
|
|
expMax = math.floor(math.log10(expMax)) + 1
|
|
end
|
|
|
|
local format
|
|
local scale
|
|
local sz
|
|
if intMode then
|
|
if expMax > 9 then
|
|
format = "%11.4e"
|
|
sz = 11
|
|
else
|
|
format = "%SZd"
|
|
sz = expMax + 1
|
|
end
|
|
else
|
|
if expMax-expMin > 4 then
|
|
format = "%SZ.4e"
|
|
sz = 11
|
|
if math.abs(expMax) > 99 or math.abs(expMin) > 99 then
|
|
sz = sz + 1
|
|
end
|
|
else
|
|
if expMax > 5 or expMax < 0 then
|
|
format = "%SZ.4f"
|
|
sz = 7
|
|
scale = math.pow(10, expMax-1)
|
|
else
|
|
format = "%SZ.4f"
|
|
if expMax == 0 then
|
|
sz = 7
|
|
else
|
|
sz = expMax+6
|
|
end
|
|
end
|
|
end
|
|
end
|
|
format = string.gsub(format, 'SZ', sz)
|
|
if scale == 1 then
|
|
scale = nil
|
|
end
|
|
return format, scale, sz
|
|
end
|
|
|
|
function Storage.__tostring__(self)
|
|
local strt = {'\n'}
|
|
local format,scale = Storage__printformat(self)
|
|
if format:sub(2,4) == 'nan' then format = '%f' end
|
|
if scale then
|
|
table.insert(strt, string.format('%g', scale) .. ' *\n')
|
|
for i = 1,self:size() do
|
|
table.insert(strt, string.format(format, self[i]/scale) .. '\n')
|
|
end
|
|
else
|
|
for i = 1,self:size() do
|
|
table.insert(strt, string.format(format, self[i]) .. '\n')
|
|
end
|
|
end
|
|
table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. self:size() .. ']\n')
|
|
str = table.concat(strt)
|
|
return str
|
|
end
|
|
|
|
for _,type in ipairs(types) do
|
|
local metatable = torch.getmetatable('torch.' .. type .. 'Storage')
|
|
for funcname, func in pairs(Storage) do
|
|
rawset(metatable, funcname, func)
|
|
end
|
|
end
|
|
|
|
local function Tensor__printMatrix(self, indent)
|
|
local format,scale,sz = Storage__printformat(self:storage())
|
|
if format:sub(2,4) == 'nan' then format = '%f' end
|
|
-- print('format = ' .. format)
|
|
scale = scale or 1
|
|
indent = indent or ''
|
|
local strt = {indent}
|
|
local nColumnPerLine = math.floor((80-#indent)/(sz+1))
|
|
-- print('sz = ' .. sz .. ' and nColumnPerLine = ' .. nColumnPerLine)
|
|
local firstColumn = 1
|
|
local lastColumn = -1
|
|
while firstColumn <= self:size(2) do
|
|
if firstColumn + nColumnPerLine - 1 <= self:size(2) then
|
|
lastColumn = firstColumn + nColumnPerLine - 1
|
|
else
|
|
lastColumn = self:size(2)
|
|
end
|
|
if nColumnPerLine < self:size(2) then
|
|
if firstColumn ~= 1 then
|
|
table.insert(strt, '\n')
|
|
end
|
|
table.insert(strt, 'Columns ' .. firstColumn .. ' to ' .. lastColumn .. '\n' .. indent)
|
|
end
|
|
if scale ~= 1 then
|
|
table.insert(strt, string.format('%g', scale) .. ' *\n ' .. indent)
|
|
end
|
|
for l=1,self:size(1) do
|
|
local row = self:select(1, l)
|
|
for c=firstColumn,lastColumn do
|
|
table.insert(strt, string.format(format, row[c]/scale))
|
|
if c == lastColumn then
|
|
table.insert(strt, '\n')
|
|
if l~=self:size(1) then
|
|
if scale ~= 1 then
|
|
table.insert(strt, indent .. ' ')
|
|
else
|
|
table.insert(strt, indent)
|
|
end
|
|
end
|
|
else
|
|
table.insert(strt, ' ')
|
|
end
|
|
end
|
|
end
|
|
firstColumn = lastColumn + 1
|
|
end
|
|
local str = table.concat(strt)
|
|
return str
|
|
end
|
|
|
|
local function Tensor__printTensor(self)
|
|
local counter = torch.LongStorage(self:nDimension()-2)
|
|
local strt = {''}
|
|
local finished
|
|
counter:fill(1)
|
|
counter[1] = 0
|
|
while true do
|
|
for i=1,self:nDimension()-2 do
|
|
counter[i] = counter[i] + 1
|
|
if counter[i] > self:size(i) then
|
|
if i == self:nDimension()-2 then
|
|
finished = true
|
|
break
|
|
end
|
|
counter[i] = 1
|
|
else
|
|
break
|
|
end
|
|
end
|
|
if finished then
|
|
break
|
|
end
|
|
-- print(counter)
|
|
if #strt > 1 then
|
|
table.insert(strt, '\n')
|
|
end
|
|
table.insert(strt, '(')
|
|
local tensor = self
|
|
for i=1,self:nDimension()-2 do
|
|
tensor = tensor:select(1, counter[i])
|
|
table.insert(strt, counter[i] .. ',')
|
|
end
|
|
table.insert(strt, '.,.) = \n')
|
|
table.insert(strt, Tensor__printMatrix(tensor, ' '))
|
|
end
|
|
local str = table.concat(strt)
|
|
return str
|
|
end
|
|
|
|
function Tensor.__tostring__(self)
|
|
local str = '\n'
|
|
local strt = {''}
|
|
if self:nDimension() == 0 then
|
|
table.insert(strt, '[' .. torch.typename(self) .. ' with no dimension]\n')
|
|
else
|
|
local tensor = torch.DoubleTensor():resize(self:size()):copy(self)
|
|
if tensor:nDimension() == 1 then
|
|
local format,scale,sz = Storage__printformat(tensor:storage())
|
|
if format:sub(2,4) == 'nan' then format = '%f' end
|
|
if scale then
|
|
table.insert(strt, string.format('%g', scale) .. ' *\n')
|
|
for i = 1,tensor:size(1) do
|
|
table.insert(strt, string.format(format, tensor[i]/scale) .. '\n')
|
|
end
|
|
else
|
|
for i = 1,tensor:size(1) do
|
|
table.insert(strt, string.format(format, tensor[i]) .. '\n')
|
|
end
|
|
end
|
|
table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. ']\n')
|
|
elseif tensor:nDimension() == 2 then
|
|
table.insert(strt, Tensor__printMatrix(tensor))
|
|
table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. 'x' .. tensor:size(2) .. ']\n')
|
|
else
|
|
table.insert(strt, Tensor__printTensor(tensor))
|
|
table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ')
|
|
for i=1,tensor:nDimension() do
|
|
table.insert(strt, tensor:size(i))
|
|
if i ~= tensor:nDimension() then
|
|
table.insert(strt, 'x')
|
|
end
|
|
end
|
|
table.insert(strt, ']\n')
|
|
end
|
|
end
|
|
local str = table.concat(strt)
|
|
return str
|
|
end
|
|
|
|
function Tensor.type(self,type)
|
|
local current = torch.typename(self)
|
|
if not type then return current end
|
|
if type ~= current then
|
|
local new = torch.getmetatable(type).new()
|
|
if self:nElement() > 0 then
|
|
new:resize(self:size()):copy(self)
|
|
end
|
|
return new
|
|
else
|
|
return self
|
|
end
|
|
end
|
|
|
|
function Tensor.typeAs(self,tensor)
|
|
return self:type(tensor:type())
|
|
end
|
|
|
|
function Tensor.byte(self,type)
|
|
return self:type('torch.ByteTensor')
|
|
end
|
|
|
|
function Tensor.char(self,type)
|
|
return self:type('torch.CharTensor')
|
|
end
|
|
|
|
function Tensor.short(self,type)
|
|
return self:type('torch.ShortTensor')
|
|
end
|
|
|
|
function Tensor.int(self,type)
|
|
return self:type('torch.IntTensor')
|
|
end
|
|
|
|
function Tensor.long(self,type)
|
|
return self:type('torch.LongTensor')
|
|
end
|
|
|
|
function Tensor.float(self,type)
|
|
return self:type('torch.FloatTensor')
|
|
end
|
|
|
|
function Tensor.double(self,type)
|
|
return self:type('torch.DoubleTensor')
|
|
end
|
|
|
|
|
|
for _,type in ipairs(types) do
|
|
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
|
|
for funcname, func in pairs(Tensor) do
|
|
rawset(metatable, funcname, func)
|
|
end
|
|
end
|