mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added nll_loss_forward decomposition + some other minor decomps (#83235)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83235 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
097951a967
commit
f02f304657
@ -193,7 +193,6 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
|
||||
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
|
||||
|
||||
|
||||
@ -424,14 +423,16 @@ class TestDecomp(TestCase):
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user