mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updates assertEqual to require atol and rtol, removes positional atol (#38872)
Summary: This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument. In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872 Differential Revision: D21740237 Pulled By: mruberry fbshipit-source-id: acbc027aa1d7877a49664d94db9a5fff91a07042
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9b95f757af
commit
13120bf677
@ -89,7 +89,7 @@ class TestNamedTensor(TestCase):
|
||||
return
|
||||
result = op(*args)
|
||||
self.assertEqual(result.names, expected_names,
|
||||
message='Name inference for {} on device {} failed'.format(
|
||||
msg='Name inference for {} on device {} failed'.format(
|
||||
op.__name__, device))
|
||||
|
||||
# TODO(rzou): Some form of this check should be added to self.assertEqual.
|
||||
@ -147,8 +147,8 @@ class TestNamedTensor(TestCase):
|
||||
prev_none_refcnt = sys.getrefcount(None)
|
||||
scope()
|
||||
self.assertEqual(sys.getrefcount(None), prev_none_refcnt,
|
||||
message='Using tensor.names should not change '
|
||||
'the refcount of Py_None')
|
||||
msg='Using tensor.names should not change '
|
||||
'the refcount of Py_None')
|
||||
|
||||
def test_has_names(self):
|
||||
unnamed = torch.empty(2, 3)
|
||||
@ -835,7 +835,7 @@ class TestNamedTensor(TestCase):
|
||||
# Get a better error message by catching the error and asserting.
|
||||
raise RuntimeError('{}: {}'.format(testcase.name, err))
|
||||
self.assertEqual(out.names, tensor.names,
|
||||
message=testcase.name)
|
||||
msg=testcase.name)
|
||||
|
||||
def fn(name, *args, **kwargs):
|
||||
return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]
|
||||
|
Reference in New Issue
Block a user