mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updates assertEqual to require atol and rtol, removes positional atol (#38872)
Summary: This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument. In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872 Differential Revision: D21717199 Pulled By: mruberry fbshipit-source-id: 9feb856f94eee911b44f6c7140a1d07c1b026d3a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
341fd63ff6
commit
6ddca30b2d
@ -102,7 +102,7 @@ def test_forward(unit_test_class, test_params):
|
||||
# Check that forward outputs are equal
|
||||
unit_test_class.assertEqual(
|
||||
python_output, cpp_output,
|
||||
message=generate_error_msg("forward output", cpp_output, python_output))
|
||||
msg=generate_error_msg("forward output", cpp_output, python_output))
|
||||
|
||||
run_cpp_test_fn_and_check_output()
|
||||
|
||||
@ -131,7 +131,7 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
|
||||
if 'cpp_function_call' in test_params_dict:
|
||||
return test_params_dict['cpp_function_call']
|
||||
elif 'cpp_options_args' in test_params_dict:
|
||||
cpp_forward_args_symbols = [arg_name for arg_name, _ in
|
||||
cpp_forward_args_symbols = [arg_name for arg_name, _ in
|
||||
arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args']]
|
||||
return 'F::{}({}, {})'.format(
|
||||
functional_name, ", ".join(cpp_forward_args_symbols), test_params_dict['cpp_options_args'])
|
||||
|
||||
@ -152,12 +152,12 @@ def test_forward_backward(unit_test_class, test_params):
|
||||
|
||||
# Check that forward outputs are equal
|
||||
unit_test_class.assertEqual(python_output, cpp_output,
|
||||
message=generate_error_msg("forward output", cpp_output, python_output))
|
||||
msg=generate_error_msg("forward output", cpp_output, python_output))
|
||||
|
||||
# Check that module parameter gradients are equal after backward pass
|
||||
unit_test_class.assertEqual(
|
||||
len(python_grad_dict), len(cpp_grad_dict),
|
||||
message=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
|
||||
msg=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
|
||||
for key in python_grad_dict:
|
||||
param_name = None
|
||||
for suffix in ['_grad', '_grad_indices', '_grad_values']:
|
||||
@ -174,7 +174,7 @@ def test_forward_backward(unit_test_class, test_params):
|
||||
False, True))
|
||||
unit_test_class.assertEqual(
|
||||
python_grad_dict[key], cpp_grad_dict[key],
|
||||
message=generate_error_msg(
|
||||
msg=generate_error_msg(
|
||||
"`{}`'s {} gradient (`{}`)".format(param_name, sparsity_str, key),
|
||||
cpp_grad_dict[key], python_grad_dict[key]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user