Updates assertEqual to require atol and rtol, removes positional atol (#38872)

Summary:
This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument.

In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872

Differential Revision: D21717199

Pulled By: mruberry

fbshipit-source-id: 9feb856f94eee911b44f6c7140a1d07c1b026d3a
This commit is contained in:
Mike Ruberry
2020-05-25 21:27:31 -07:00
committed by Facebook GitHub Bot
parent 341fd63ff6
commit 6ddca30b2d
29 changed files with 843 additions and 831 deletions

View File

@ -102,7 +102,7 @@ def test_forward(unit_test_class, test_params):
# Check that forward outputs are equal
unit_test_class.assertEqual(
python_output, cpp_output,
message=generate_error_msg("forward output", cpp_output, python_output))
msg=generate_error_msg("forward output", cpp_output, python_output))
run_cpp_test_fn_and_check_output()
@ -131,7 +131,7 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
if 'cpp_function_call' in test_params_dict:
return test_params_dict['cpp_function_call']
elif 'cpp_options_args' in test_params_dict:
cpp_forward_args_symbols = [arg_name for arg_name, _ in
cpp_forward_args_symbols = [arg_name for arg_name, _ in
arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args']]
return 'F::{}({}, {})'.format(
functional_name, ", ".join(cpp_forward_args_symbols), test_params_dict['cpp_options_args'])

View File

@ -152,12 +152,12 @@ def test_forward_backward(unit_test_class, test_params):
# Check that forward outputs are equal
unit_test_class.assertEqual(python_output, cpp_output,
message=generate_error_msg("forward output", cpp_output, python_output))
msg=generate_error_msg("forward output", cpp_output, python_output))
# Check that module parameter gradients are equal after backward pass
unit_test_class.assertEqual(
len(python_grad_dict), len(cpp_grad_dict),
message=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
msg=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
for key in python_grad_dict:
param_name = None
for suffix in ['_grad', '_grad_indices', '_grad_values']:
@ -174,7 +174,7 @@ def test_forward_backward(unit_test_class, test_params):
False, True))
unit_test_class.assertEqual(
python_grad_dict[key], cpp_grad_dict[key],
message=generate_error_msg(
msg=generate_error_msg(
"`{}`'s {} gradient (`{}`)".format(param_name, sparsity_str, key),
cpp_grad_dict[key], python_grad_dict[key]))