mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)"
This reverts commit 70d80fb42480f4df7bd369f8f9f1500b58c5c603.
Reverted https://github.com/pytorch/pytorch/pull/77407 on behalf of https://github.com/malfet due to as it broke meta tests ( I guess due to landrace), see 70d80fb424
This commit is contained in:
@ -144,27 +144,19 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
|
||||
return rtol, atol
|
||||
|
||||
|
||||
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
||||
def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"Operation: {op}"
|
||||
if orig.numel() == 0 or decomp.numel() == 0:
|
||||
assert orig.numel() == decomp.numel()
|
||||
return
|
||||
assert orig.shape == decomp.shape, f"Operation: {op}"
|
||||
tol_table = {
|
||||
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
|
||||
}
|
||||
if ref.is_floating_point():
|
||||
orig_diff = (orig - ref).abs().max()
|
||||
decomp_diff = (decomp - ref).abs().max()
|
||||
atol = tol_table.get((test_dtype, op), 1e-7)
|
||||
atol = 1e-10
|
||||
if decomp_diff > orig_diff + atol:
|
||||
raise RuntimeError(
|
||||
f"Difference from float64 is larger with decomposition {op.__name__}"
|
||||
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
||||
f"atol = {atol}\n"
|
||||
f"args = {args}\n"
|
||||
f"kwargs = {kwargs}"
|
||||
)
|
||||
@ -174,7 +166,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
||||
)
|
||||
|
||||
|
||||
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
|
||||
test_case.assertEqual(
|
||||
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
|
||||
# Before adding an entry to this table, make sure your decomposition is right :)
|
||||
@ -186,7 +178,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
1e-3,
|
||||
),
|
||||
}
|
||||
if (test_dtype, op) in tol_table:
|
||||
if (decomp.dtype, op) in tol_table:
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
@ -281,6 +273,12 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
("cuda", torch.float64, "nn.functional.dropout"),
|
||||
("cuda", torch.float32, "nn.functional.dropout"),
|
||||
# decomp has problem even with opmath
|
||||
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
|
||||
("cuda", torch.float16, "nn.functional.layer_norm"),
|
||||
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
|
||||
("cuda", torch.float16, "nn.functional.batch_norm"),
|
||||
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
|
||||
("cuda", torch.float16, "nn.functional.instance_norm"),
|
||||
# doesn't work
|
||||
("cuda", torch.bfloat16, "nn.functional.embedding"),
|
||||
|
||||
@ -427,13 +425,13 @@ class TestDecomp(TestCase):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
continue
|
||||
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
|
||||
op_assert_ref(self, func, orig, decomp, ref, args, kwargs)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
continue
|
||||
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
|
||||
op_assert_equal(self, func, orig, decomp, args, kwargs)
|
||||
|
||||
return real_out_unflat
|
||||
|
||||
|
Reference in New Issue
Block a user