Added nll_loss_forward decomposition + some other minor decomps (#83235)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83235
Approved by: https://github.com/ngimel
This commit is contained in:
Horace He
2022-08-13 07:36:50 +00:00
committed by PyTorch MergeBot
parent 097951a967
commit f02f304657
5 changed files with 88 additions and 9 deletions

View File

@ -193,7 +193,6 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
@ -424,14 +423,16 @@ class TestDecomp(TestCase):
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
if orig is None:
assert decomp is None
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if orig is None:
assert decomp is None
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)