Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)

Originally, when these were written, they simply used the naive strategy of "upcast all inputs to floats, and downcast all inputs back". In addition to being... not quite what the kernels did, they also didn't capture some additional semantics. Namely, that the norms (except for layer norm on CPU! cc: @ngimel) return fp32 for the mean and rstd values.

Also, folks didn't like that I wrote `native_layer_norm` in terms of `native_batch_norm`. Which is fair - so I refactored the common logic into a `normalize` function.

cc: @jansel / @bertmaher , who've been looking at lowering layer norm/batch norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77407
Approved by: https://github.com/bertmaher
This commit is contained in:
Horace He
2022-05-19 17:11:46 +00:00
committed by PyTorch MergeBot
parent 00a187c373
commit 70d80fb424
4 changed files with 98 additions and 94 deletions

View File

@ -144,19 +144,27 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
return rtol, atol
def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"Operation: {op}"
if orig.numel() == 0 or decomp.numel() == 0:
assert orig.numel() == decomp.numel()
return
assert orig.shape == decomp.shape, f"Operation: {op}"
tol_table = {
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
}
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()
decomp_diff = (decomp - ref).abs().max()
atol = 1e-10
atol = tol_table.get((test_dtype, op), 1e-7)
if decomp_diff > orig_diff + atol:
raise RuntimeError(
f"Difference from float64 is larger with decomposition {op.__name__}"
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f"atol = {atol}\n"
f"args = {args}\n"
f"kwargs = {kwargs}"
)
@ -166,7 +174,7 @@ def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
)
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
test_case.assertEqual(
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
# Before adding an entry to this table, make sure your decomposition is right :)
@ -178,7 +186,7 @@ def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
1e-3,
),
}
if (decomp.dtype, op) in tol_table:
if (test_dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
@ -273,12 +281,6 @@ CROSS_REF_EXCLUDE_SET = {
("cuda", torch.float64, "nn.functional.dropout"),
("cuda", torch.float32, "nn.functional.dropout"),
# decomp has problem even with opmath
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
("cuda", torch.float16, "nn.functional.layer_norm"),
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
("cuda", torch.float16, "nn.functional.batch_norm"),
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
("cuda", torch.float16, "nn.functional.instance_norm"),
# doesn't work
("cuda", torch.bfloat16, "nn.functional.embedding"),
@ -425,13 +427,13 @@ class TestDecomp(TestCase):
if orig is None:
assert decomp is None
continue
op_assert_ref(self, func, orig, decomp, ref, args, kwargs)
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if orig is None:
assert decomp is None
continue
op_assert_equal(self, func, orig, decomp, args, kwargs)
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
return real_out_unflat