mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
81 lines
2.0 KiB
Lua
81 lines
2.0 KiB
Lua
|
|
-- We are using paths.require to appease mkl
|
|
require "paths"
|
|
paths.require "libtorch"
|
|
|
|
--- package stuff
|
|
function torch.packageLuaPath(name)
|
|
if not name then
|
|
local ret = string.match(torch.packageLuaPath('torch'), '(.*)/')
|
|
if not ret then --windows?
|
|
ret = string.match(torch.packageLuaPath('torch'), '(.*)\\')
|
|
end
|
|
return ret
|
|
end
|
|
for path in string.gmatch(package.path, "(.-);") do
|
|
path = string.gsub(path, "%?", name)
|
|
local f = io.open(path)
|
|
if f then
|
|
f:close()
|
|
local ret = string.match(path, "(.*)/")
|
|
if not ret then --windows?
|
|
ret = string.match(path, "(.*)\\")
|
|
end
|
|
return ret
|
|
end
|
|
end
|
|
end
|
|
|
|
function include(file, depth)
|
|
paths.dofile(file, 3 + (depth or 0))
|
|
end
|
|
|
|
function torch.include(package, file)
|
|
dofile(torch.packageLuaPath(package) .. '/' .. file)
|
|
end
|
|
|
|
function torch.class(tname, parenttname)
|
|
|
|
local function constructor(...)
|
|
local self = {}
|
|
torch.setmetatable(self, tname)
|
|
if self.__init then
|
|
self:__init(...)
|
|
end
|
|
return self
|
|
end
|
|
|
|
local function factory()
|
|
local self = {}
|
|
torch.setmetatable(self, tname)
|
|
return self
|
|
end
|
|
|
|
local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory)
|
|
local mpt
|
|
if parenttname then
|
|
mpt = torch.getmetatable(parenttname)
|
|
end
|
|
return mt, mpt
|
|
end
|
|
|
|
function torch.setdefaulttensortype(typename)
|
|
assert(type(typename) == 'string', 'string expected')
|
|
if torch.getconstructortable(typename) then
|
|
torch.Tensor = torch.getconstructortable(typename)
|
|
torch.Storage = torch.getconstructortable(torch.typename(torch.Tensor(1):storage()))
|
|
else
|
|
error(string.format("<%s> is not a string describing a torch object", typename))
|
|
end
|
|
end
|
|
|
|
torch.setdefaulttensortype('torch.DoubleTensor')
|
|
|
|
include('Tensor.lua')
|
|
include('File.lua')
|
|
include('CmdLine.lua')
|
|
include('Tester.lua')
|
|
include('test.lua')
|
|
|
|
return torch
|