Register std_mean ref as a decomposition

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468

Approved by: https://github.com/ngimel
This commit is contained in:
Edward Z. Yang
2022-05-31 07:26:03 -07:00
committed by PyTorch MergeBot
parent 523c9c2ac2
commit eee2aa14a6
4 changed files with 14 additions and 11 deletions

View File

@ -144,12 +144,12 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
return rtol, atol
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"Operation: {op}"
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
if orig.numel() == 0 or decomp.numel() == 0:
assert orig.numel() == decomp.numel()
return
assert orig.shape == decomp.shape, f"Operation: {op}"
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
tol_table = {
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
@ -163,7 +163,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
if decomp_diff > orig_diff + atol:
raise RuntimeError(
f"Difference from float64 is larger with decomposition {op.__name__}"
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f"atol = {atol}\n"
f"args = {args}\n"
f"kwargs = {kwargs}"
@ -414,11 +414,11 @@ class TestDecomp(TestCase):
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
if orig is None:
assert decomp is None
continue
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if orig is None: