mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Register std_mean ref as a decomposition
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
523c9c2ac2
commit
eee2aa14a6
@ -144,12 +144,12 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
|
||||
return rtol, atol
|
||||
|
||||
|
||||
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"Operation: {op}"
|
||||
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
|
||||
if orig.numel() == 0 or decomp.numel() == 0:
|
||||
assert orig.numel() == decomp.numel()
|
||||
return
|
||||
assert orig.shape == decomp.shape, f"Operation: {op}"
|
||||
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
|
||||
tol_table = {
|
||||
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
@ -163,7 +163,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
||||
if decomp_diff > orig_diff + atol:
|
||||
raise RuntimeError(
|
||||
f"Difference from float64 is larger with decomposition {op.__name__}"
|
||||
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
||||
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
||||
f"atol = {atol}\n"
|
||||
f"args = {args}\n"
|
||||
f"kwargs = {kwargs}"
|
||||
@ -414,11 +414,11 @@ class TestDecomp(TestCase):
|
||||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
|
||||
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
continue
|
||||
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
|
||||
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if orig is None:
|
||||
|
||||
Reference in New Issue
Block a user