mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
test_decomp.py: Skip tests for embedding_backward bf16 (#84554)
`embedding_backward`'s decomposition is less accurate for bf16. Currently bfloat16 is skipped in both forward and backward, but the forward decomposition matches 1-1 with the ATen implementation so this re-enables the test for the forwards decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84554 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2e9b9ec4a
commit
29c78266c0
@ -283,9 +283,6 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
(None, None, "new_empty"),
|
||||
(None, None, "empty_like"),
|
||||
(None, None, "empty"),
|
||||
# decomp has problem even with opmath
|
||||
# doesn't work
|
||||
("cuda", torch.bfloat16, "nn.functional.embedding"),
|
||||
|
||||
# CompositeAutogradImplicit
|
||||
# See https://github.com/pytorch/pytorch/issues/81669
|
||||
@ -293,6 +290,11 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
(None, None, "meshgrid"),
|
||||
}
|
||||
|
||||
CROSS_REF_BACKWARD_EXCLUDE_SET = {
|
||||
# Backward formula is not as precise as the custom CUDA kernel
|
||||
("cuda", torch.bfloat16, "nn.functional.embedding"),
|
||||
}
|
||||
|
||||
all_decomposed = set()
|
||||
all_called = defaultdict(int)
|
||||
|
||||
@ -367,13 +369,15 @@ class TestDecomp(TestCase):
|
||||
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
def do_cross_ref(self, device, dtype, op, *, run_all):
|
||||
if (torch.device(device).type, dtype, op.name) in CROSS_REF_EXCLUDE_SET or (
|
||||
None,
|
||||
dtype,
|
||||
op.name,
|
||||
) in CROSS_REF_EXCLUDE_SET or (None, None, op.name) in CROSS_REF_EXCLUDE_SET:
|
||||
test_keys = [
|
||||
(torch.device(device).type, dtype, op.name),
|
||||
(None, dtype, op.name),
|
||||
(None, None, op.name),
|
||||
]
|
||||
if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
|
||||
self.skipTest(f"{op.name} in {dtype} not supported")
|
||||
|
||||
skip_decomp_vjp = any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys)
|
||||
test_dtype = dtype
|
||||
|
||||
# We check the correctness of each decomposition right after running it.
|
||||
@ -491,7 +495,7 @@ class TestDecomp(TestCase):
|
||||
if aten_name in decomposition_names:
|
||||
check_decomposed(aten_name)
|
||||
|
||||
if op.aten_backward_name in decomposition_names or run_all:
|
||||
if not skip_decomp_vjp and (op.aten_backward_name in decomposition_names or run_all):
|
||||
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
|
||||
|
||||
decomposed.clear()
|
||||
|
Reference in New Issue
Block a user