Assorted decomposition fixes (#87183)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87183
Approved by: https://github.com/ngimel
This commit is contained in:
lezcano
2023-01-16 13:34:02 +00:00
committed by PyTorch MergeBot
parent da58f9eb8f
commit d162c8f92b
3 changed files with 23 additions and 22 deletions

View File

@ -294,15 +294,26 @@ CROSS_REF_EXCLUDE_SET = {
# (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
("cuda", torch.bfloat16, "nn.functional.bilinear"),
# randomness
("cuda", torch.float16, "nn.functional.dropout"),
("cuda", torch.bfloat16, "nn.functional.dropout"),
("cuda", torch.float64, "nn.functional.dropout"),
("cuda", torch.float32, "nn.functional.dropout"),
(None, None, "special.ndtr"), # aten.special_ndtr was not decomposed
(None, None, "new_empty"),
(None, None, "empty_like"),
(None, None, "empty"),
# It's the only in-place op without an out-of-place equivalent in the Python API
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
(None, None, "zero_"),
# No idea what's going on here
# In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
# in the test, but it seems to pass when tested locally and in the logsumexp test
(None, torch.float32, "masked.logsumexp"),
(None, torch.float64, "masked.logsumexp"),
# exp_vml_cpu not implemented for Half
(torch.cpu, torch.float16, "signal.windows.exponential"),
(torch.cpu, torch.float16, "signal.windows.gaussian"),
# sin_vml_cpu not implemented for Half
(torch.cpu, torch.float16, "signal.windows.cosine"),
# CompositeAutogradImplicit
# See https://github.com/pytorch/pytorch/issues/81669
(None, None, "nn.functional.relu6"),
@ -448,7 +459,11 @@ class TestDecomp(TestCase):
# non-deterministic ops
torch.ops.aten.empty.memory_format,
torch.ops.aten.empty_like.default,
torch.ops.aten.new_empty.default
torch.ops.aten.new_empty.default,
torch.ops.aten.empty_strided.default,
torch.ops.aten.new_empty_strided.default,
torch.ops.aten.randn.default,
torch.ops.aten.native_dropout.default,
] or any_unsupported(args, kwargs):
return func(*args, **kwargs)