mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding assertTableEq function to test table equality.
* also added some minor spacing to help readability.
This commit is contained in:
12
Tester.lua
12
Tester.lua
@ -16,33 +16,45 @@ function Tester:assert_sub (condition, message)
|
||||
self.errors[#self.errors+1] = self.curtestname .. '\n' .. message .. '\n' .. ss .. '\n'
|
||||
end
|
||||
end
|
||||
|
||||
function Tester:assert (condition, message)
|
||||
self:assert_sub(condition,string.format('%s\n%s condition=%s',message,' BOOL violation ', tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertlt (val, condition, message)
|
||||
self:assert_sub(val<condition,string.format('%s\n%s val=%s, condition=%s',message,' LT(<) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertgt (val, condition, message)
|
||||
self:assert_sub(val>condition,string.format('%s\n%s val=%s, condition=%s',message,' GT(>) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertle (val, condition, message)
|
||||
self:assert_sub(val<=condition,string.format('%s\n%s val=%s, condition=%s',message,' LE(<=) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertge (val, condition, message)
|
||||
self:assert_sub(val>=condition,string.format('%s\n%s val=%s, condition=%s',message,' GE(>=) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:asserteq (val, condition, message)
|
||||
self:assert_sub(val==condition,string.format('%s\n%s val=%s, condition=%s',message,' EQ(==) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertne (val, condition, message)
|
||||
self:assert_sub(val~=condition,string.format('%s\n%s val=%s, condition=%s',message,' NE(~=) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertTensorEq(ta, tb, condition, message)
|
||||
local diff = ta-tb
|
||||
local err = diff:abs():max()
|
||||
self:assert_sub(err<condition,string.format('%s\n%s val=%s, condition=%s',message,' TensorEQ(~=) violation ', tostring(err), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:assertTableEq(ta, condition, message)
|
||||
self:assert_sub(unpack(ta) == unpack(condition), string.format('%s\n%s val=%s, condition=%s',message,' TensorEQ(~=) violation ', tostring(err), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:pcall(f)
|
||||
local nerr = #self.errors
|
||||
local res = f()
|
||||
|
Reference in New Issue
Block a user