Updates assertEqual to require atol and rtol, removes positional atol (#38872)

Summary:
This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument.

In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872

Differential Revision: D21740237

Pulled By: mruberry

fbshipit-source-id: acbc027aa1d7877a49664d94db9a5fff91a07042
This commit is contained in:
Mike Ruberry
2020-05-27 06:28:05 -07:00
committed by Facebook GitHub Bot
parent 9b95f757af
commit 13120bf677
29 changed files with 847 additions and 835 deletions

View File

@ -89,7 +89,7 @@ class TestNamedTensor(TestCase):
return
result = op(*args)
self.assertEqual(result.names, expected_names,
message='Name inference for {} on device {} failed'.format(
msg='Name inference for {} on device {} failed'.format(
op.__name__, device))
# TODO(rzou): Some form of this check should be added to self.assertEqual.
@ -147,8 +147,8 @@ class TestNamedTensor(TestCase):
prev_none_refcnt = sys.getrefcount(None)
scope()
self.assertEqual(sys.getrefcount(None), prev_none_refcnt,
message='Using tensor.names should not change '
'the refcount of Py_None')
msg='Using tensor.names should not change '
'the refcount of Py_None')
def test_has_names(self):
unnamed = torch.empty(2, 3)
@ -835,7 +835,7 @@ class TestNamedTensor(TestCase):
# Get a better error message by catching the error and asserting.
raise RuntimeError('{}: {}'.format(testcase.name, err))
self.assertEqual(out.names, tensor.names,
message=testcase.name)
msg=testcase.name)
def fn(name, *args, **kwargs):
return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]