mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow running subset of tests
This commit is contained in:
54
Tester.lua
54
Tester.lua
@ -66,8 +66,11 @@ function Tester:pcall(f)
|
||||
return true, res, nerr == #self.errors
|
||||
end
|
||||
|
||||
function Tester:report()
|
||||
print('Completed ' .. #self.tests .. ' tests with ' .. #self.errors .. ' errors')
|
||||
function Tester:report(tests)
|
||||
if not tests then
|
||||
tests = self.tests
|
||||
end
|
||||
print('Completed ' .. #tests .. ' tests with ' .. #self.errors .. ' errors')
|
||||
print()
|
||||
print(string.rep('-',80))
|
||||
for i,v in ipairs(self.errors) do
|
||||
@ -76,13 +79,32 @@ function Tester:report()
|
||||
end
|
||||
end
|
||||
|
||||
function Tester:run()
|
||||
print('Running ' .. #self.tests .. ' tests')
|
||||
local statstr = string.rep('_',#self.tests)
|
||||
function Tester:run(run_tests)
|
||||
local tests, testnames
|
||||
tests = self.tests
|
||||
testnames = self.testnames
|
||||
if type(run_tests) == 'string' then
|
||||
run_tests = {run_tests}
|
||||
end
|
||||
if type(run_tests) == 'table' then
|
||||
tests = {}
|
||||
testnames = {}
|
||||
for i,fun in ipairs(self.tests) do
|
||||
for j,name in ipairs(run_tests) do
|
||||
if self.testnames[i] == name then
|
||||
tests[#tests+1] = self.tests[i]
|
||||
testnames[#testnames+1] = self.testnames[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
print('Running ' .. #tests .. ' tests')
|
||||
local statstr = string.rep('_',#tests)
|
||||
local pstr = ''
|
||||
io.write(statstr .. '\r')
|
||||
for i,v in ipairs(self.tests) do
|
||||
self.curtestname = self.testnames[i]
|
||||
for i,v in ipairs(tests) do
|
||||
self.curtestname = testnames[i]
|
||||
|
||||
--clear
|
||||
io.write('\r' .. string.rep(' ', pstr:len()))
|
||||
@ -95,17 +117,17 @@ function Tester:run()
|
||||
local stat, message, pass = self:pcall(v)
|
||||
|
||||
if pass then
|
||||
--io.write(string.format('\b_'))
|
||||
statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1)
|
||||
--io.write(string.format('\b_'))
|
||||
statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1)
|
||||
else
|
||||
statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1)
|
||||
--io.write(string.format('\b*'))
|
||||
statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1)
|
||||
--io.write(string.format('\b*'))
|
||||
end
|
||||
|
||||
if not stat then
|
||||
print()
|
||||
print('Function call failed: Test No ' .. i .. ' ' .. self.testnames[i])
|
||||
print(message)
|
||||
print()
|
||||
print('Function call failed: Test No ' .. i .. ' ' .. testnames[i])
|
||||
print(message)
|
||||
end
|
||||
collectgarbage()
|
||||
end
|
||||
@ -118,14 +140,14 @@ function Tester:run()
|
||||
io.flush()
|
||||
print()
|
||||
print()
|
||||
self:report()
|
||||
self:report(tests)
|
||||
end
|
||||
|
||||
function Tester:add(f,name)
|
||||
name = name or 'unknown'
|
||||
if type(f) == "table" then
|
||||
for i,v in pairs(f) do
|
||||
self:add(v,i)
|
||||
self:add(v,i)
|
||||
end
|
||||
elseif type(f) == "function" then
|
||||
self.tests[#self.tests+1] = f
|
||||
|
Reference in New Issue
Block a user