mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
[dynamo] Fix list comparison ops (#122559)
Fixes #122376 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122559 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
5891c5b3a6
commit
069270db60
@ -558,6 +558,29 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
return x.type(dtype)
|
||||
|
||||
@make_test
|
||||
def test_list_compare_polyfill(x):
|
||||
for a, b, c in [
|
||||
[(1, 2, 3), (1, 2, 3), 7.77],
|
||||
[(1, 4, 3), (1, 2, 3), 3.33],
|
||||
[(1, 2), (1, 2, 3), 5.55],
|
||||
[(1, 2, 3), (1, 2), 11.11],
|
||||
[(1, -1, 3), (1, 2, 3), 13.33],
|
||||
]:
|
||||
if a != b:
|
||||
x += 1 * c
|
||||
if a == b:
|
||||
x += 2 * c
|
||||
if a < b:
|
||||
x += 4 * c
|
||||
if a > b:
|
||||
x += 8 * c
|
||||
if a <= b:
|
||||
x += 16 * c
|
||||
if a >= b:
|
||||
x += 32 * c
|
||||
return x
|
||||
|
||||
@make_test
|
||||
def test_promote_types(x):
|
||||
if x.dtype == torch.promote_types(torch.int32, torch.float32):
|
||||
|
||||
Reference in New Issue
Block a user