[dynamo] Fix list comparison ops (#122559)

Fixes #122376

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122559
Approved by: https://github.com/Skylion007
This commit is contained in:
Jason Ansel
2024-03-24 10:41:43 -07:00
committed by PyTorch MergeBot
parent 5891c5b3a6
commit 069270db60
12 changed files with 51 additions and 54 deletions

View File

@ -558,6 +558,29 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
dtype = torch.get_autocast_gpu_dtype()
return x.type(dtype)
@make_test
def test_list_compare_polyfill(x):
for a, b, c in [
[(1, 2, 3), (1, 2, 3), 7.77],
[(1, 4, 3), (1, 2, 3), 3.33],
[(1, 2), (1, 2, 3), 5.55],
[(1, 2, 3), (1, 2), 11.11],
[(1, -1, 3), (1, 2, 3), 13.33],
]:
if a != b:
x += 1 * c
if a == b:
x += 2 * c
if a < b:
x += 4 * c
if a > b:
x += 8 * c
if a <= b:
x += 16 * c
if a >= b:
x += 32 * c
return x
@make_test
def test_promote_types(x):
if x.dtype == torch.promote_types(torch.int32, torch.float32):