Revert "Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)"

This reverts commit 70d80fb42480f4df7bd369f8f9f1500b58c5c603.

Reverted https://github.com/pytorch/pytorch/pull/77407 on behalf of https://github.com/malfet due to as it broke meta tests ( I guess due to landrace), see 70d80fb424
This commit is contained in:
PyTorch MergeBot
2022-05-20 02:31:57 +00:00
parent cecb2ad95e
commit 03546e9c07
4 changed files with 95 additions and 99 deletions

View File

@ -144,27 +144,19 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
return rtol, atol
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"Operation: {op}"
if orig.numel() == 0 or decomp.numel() == 0:
assert orig.numel() == decomp.numel()
return
assert orig.shape == decomp.shape, f"Operation: {op}"
tol_table = {
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
}
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()
decomp_diff = (decomp - ref).abs().max()
atol = tol_table.get((test_dtype, op), 1e-7)
atol = 1e-10
if decomp_diff > orig_diff + atol:
raise RuntimeError(
f"Difference from float64 is larger with decomposition {op.__name__}"
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f"atol = {atol}\n"
f"args = {args}\n"
f"kwargs = {kwargs}"
)
@ -174,7 +166,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
)
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
test_case.assertEqual(
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
# Before adding an entry to this table, make sure your decomposition is right :)
@ -186,7 +178,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
1e-3,
),
}
if (test_dtype, op) in tol_table:
if (decomp.dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
@ -281,6 +273,12 @@ CROSS_REF_EXCLUDE_SET = {
("cuda", torch.float64, "nn.functional.dropout"),
("cuda", torch.float32, "nn.functional.dropout"),
# decomp has problem even with opmath
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
("cuda", torch.float16, "nn.functional.layer_norm"),
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
("cuda", torch.float16, "nn.functional.batch_norm"),
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
("cuda", torch.float16, "nn.functional.instance_norm"),
# doesn't work
("cuda", torch.bfloat16, "nn.functional.embedding"),
@ -427,13 +425,13 @@ class TestDecomp(TestCase):
if orig is None:
assert decomp is None
continue
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
op_assert_ref(self, func, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if orig is None:
assert decomp is None
continue
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
op_assert_equal(self, func, orig, decomp, args, kwargs)
return real_out_unflat