OpInfo: Sample input cleanup (4/n) (#86324)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86324
Approved by: https://github.com/mruberry
This commit is contained in:
Peter Bell
2022-10-19 17:00:52 +01:00
committed by PyTorch MergeBot
parent c141f28b64
commit 6eeeb88172
5 changed files with 64 additions and 56 deletions

View File

@ -200,6 +200,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
# Exceeds tolerances on CUDA, likely due to fma
(torch.float32, torch.ops.aten.mv.default) : (1e-5, 3e-5),
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 1e-6),
(torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
}
if (test_dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
@ -294,8 +295,11 @@ CROSS_REF_EXCLUDE_SET = {
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {
# Backward formula is not as precise as the custom CUDA kernel
# Decomposed backward formula is not as precise
("cuda", torch.float16, "nn.functional.embedding"),
("cuda", torch.bfloat16, "nn.functional.embedding"),
("cpu", torch.bfloat16, "nn.functional.hardswish"),
("cuda", torch.float16, "nn.functional.cross_entropy"),
}
all_decomposed = set()