test_decomp.py: Skip tests for embedding_backward bf16 (#84554)

`embedding_backward`'s decomposition is less accurate for bf16.
Currently bfloat16 is skipped in both forward and backward, but the
forward decomposition matches 1-1 with the ATen implementation so this
re-enables the test for the forwards decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84554
Approved by: https://github.com/albanD
This commit is contained in:
Peter Bell
2022-09-27 23:54:50 +01:00
committed by PyTorch MergeBot
parent c2e9b9ec4a
commit 29c78266c0

View File

@ -283,9 +283,6 @@ CROSS_REF_EXCLUDE_SET = {
(None, None, "new_empty"),
(None, None, "empty_like"),
(None, None, "empty"),
# decomp has problem even with opmath
# doesn't work
("cuda", torch.bfloat16, "nn.functional.embedding"),
# CompositeAutogradImplicit
# See https://github.com/pytorch/pytorch/issues/81669
@ -293,6 +290,11 @@ CROSS_REF_EXCLUDE_SET = {
(None, None, "meshgrid"),
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {
# Backward formula is not as precise as the custom CUDA kernel
("cuda", torch.bfloat16, "nn.functional.embedding"),
}
all_decomposed = set()
all_called = defaultdict(int)
@ -367,13 +369,15 @@ class TestDecomp(TestCase):
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def do_cross_ref(self, device, dtype, op, *, run_all):
if (torch.device(device).type, dtype, op.name) in CROSS_REF_EXCLUDE_SET or (
None,
dtype,
op.name,
) in CROSS_REF_EXCLUDE_SET or (None, None, op.name) in CROSS_REF_EXCLUDE_SET:
test_keys = [
(torch.device(device).type, dtype, op.name),
(None, dtype, op.name),
(None, None, op.name),
]
if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
self.skipTest(f"{op.name} in {dtype} not supported")
skip_decomp_vjp = any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys)
test_dtype = dtype
# We check the correctness of each decomposition right after running it.
@ -491,7 +495,7 @@ class TestDecomp(TestCase):
if aten_name in decomposition_names:
check_decomposed(aten_name)
if op.aten_backward_name in decomposition_names or run_all:
if not skip_decomp_vjp and (op.aten_backward_name in decomposition_names or run_all):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
decomposed.clear()