mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)
Originally, when these were written, they simply used the naive strategy of "upcast all inputs to floats, and downcast all inputs back". In addition to being... not quite what the kernels did, they also didn't capture some additional semantics. Namely, that the norms (except for layer norm on CPU! cc: @ngimel) return fp32 for the mean and rstd values. Also, folks didn't like that I wrote `native_layer_norm` in terms of `native_batch_norm`. Which is fair - so I refactored the common logic into a `normalize` function. cc: @jansel / @bertmaher , who've been looking at lowering layer norm/batch norm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77407 Approved by: https://github.com/bertmaher
This commit is contained in:
committed by
PyTorch MergeBot
parent
00a187c373
commit
70d80fb424
@ -144,19 +144,27 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
|
||||
return rtol, atol
|
||||
|
||||
|
||||
def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
|
||||
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"Operation: {op}"
|
||||
if orig.numel() == 0 or decomp.numel() == 0:
|
||||
assert orig.numel() == decomp.numel()
|
||||
return
|
||||
assert orig.shape == decomp.shape, f"Operation: {op}"
|
||||
tol_table = {
|
||||
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
|
||||
}
|
||||
if ref.is_floating_point():
|
||||
orig_diff = (orig - ref).abs().max()
|
||||
decomp_diff = (decomp - ref).abs().max()
|
||||
atol = 1e-10
|
||||
atol = tol_table.get((test_dtype, op), 1e-7)
|
||||
if decomp_diff > orig_diff + atol:
|
||||
raise RuntimeError(
|
||||
f"Difference from float64 is larger with decomposition {op.__name__}"
|
||||
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
||||
f"atol = {atol}\n"
|
||||
f"args = {args}\n"
|
||||
f"kwargs = {kwargs}"
|
||||
)
|
||||
@ -166,7 +174,7 @@ def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
|
||||
)
|
||||
|
||||
|
||||
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
|
||||
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
test_case.assertEqual(
|
||||
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
|
||||
# Before adding an entry to this table, make sure your decomposition is right :)
|
||||
@ -178,7 +186,7 @@ def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
|
||||
1e-3,
|
||||
),
|
||||
}
|
||||
if (decomp.dtype, op) in tol_table:
|
||||
if (test_dtype, op) in tol_table:
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
@ -273,12 +281,6 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
("cuda", torch.float64, "nn.functional.dropout"),
|
||||
("cuda", torch.float32, "nn.functional.dropout"),
|
||||
# decomp has problem even with opmath
|
||||
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
|
||||
("cuda", torch.float16, "nn.functional.layer_norm"),
|
||||
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
|
||||
("cuda", torch.float16, "nn.functional.batch_norm"),
|
||||
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
|
||||
("cuda", torch.float16, "nn.functional.instance_norm"),
|
||||
# doesn't work
|
||||
("cuda", torch.bfloat16, "nn.functional.embedding"),
|
||||
|
||||
@ -425,13 +427,13 @@ class TestDecomp(TestCase):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
continue
|
||||
op_assert_ref(self, func, orig, decomp, ref, args, kwargs)
|
||||
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
continue
|
||||
op_assert_equal(self, func, orig, decomp, args, kwargs)
|
||||
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
|
||||
|
||||
return real_out_unflat
|
||||
|
||||
|
Reference in New Issue
Block a user