[dynamo] Properly account for non-list instances in list comparison (#148470)

As title; this patch also removes an unused `list_compare` method.

Fixes #148179.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148470
Approved by: https://github.com/anijain2305
This commit is contained in:
Ryan Guo
2025-03-04 11:33:46 -08:00
committed by PyTorch MergeBot
parent a7fe685be8
commit c8cd8f68bd
4 changed files with 61 additions and 14 deletions

View File

@ -928,19 +928,38 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
[(1, -1, 3), (1, 2, 3), 13.33],
]:
if a != b:
x += 1 * c
x = x + 1 * c
if a == b:
x += 2 * c
x = x + 2 * c
if a < b:
x += 4 * c
x = x + 4 * c
if a > b:
x += 8 * c
x = x + 8 * c
if a <= b:
x += 16 * c
x = x + 16 * c
if a >= b:
x += 32 * c
x = x + 32 * c
return x
@make_test
def test_list_compare_polyfill_non_lists(x):
conds = []
# Non-list instances only work for eq and ne
for a, b, c in [
[(1, 2, 3), "(1, 2, 3)", 7.77],
[143, (143,), 3.33],
]:
conds.append(a != b)
if conds[-1]:
x = x + 1 * c
conds.append(a == b)
if conds[-1]:
x = x + 2 * c
return x, conds
@make_test
def test_promote_types(x):
if x.dtype == torch.promote_types(torch.int32, torch.float32):

View File

@ -84,11 +84,16 @@ def accumulate_grad(x, new_grad):
x.grad.add_(new_grad)
# This mirrors
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
"""emulate `(1,2,3) > (1,2)` etc"""
# Apply `op` to the first pair that differ
for a, b in zip(left, right):
if a != b:
return op(a, b)
# No more pairs to compare, so compare sizes.
return op(len(left), len(right))

View File

@ -1034,6 +1034,16 @@ cmp_name_to_op_mapping = {
}
cmp_name_to_op_str_mapping = {
"__eq__": "==",
"__ne__": "!=",
"__lt__": "<",
"__le__": "<=",
"__gt__": ">",
"__ge__": ">=",
}
def is_wrapper_or_member_descriptor(value):
return isinstance(
value,

View File

@ -30,6 +30,7 @@ from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
from ..source import AttrSource
from ..utils import (
cmp_name_to_op_mapping,
cmp_name_to_op_str_mapping,
get_fake_value,
guard_if_dyn,
iter_contains,
@ -153,10 +154,28 @@ class BaseListVariable(VariableTracker):
elif name in cmp_name_to_op_mapping:
left = self
right = args[0]
if not isinstance(left, BaseListVariable) and not isinstance(
# TODO this type check logic mirrors the following
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007
# But we should probably move it up the stack to so that we don't
# need to duplicate it for different VTs.
if not isinstance(left, BaseListVariable) or not isinstance(
right, BaseListVariable
):
return variables.ConstantVariable.create(NotImplemented)
if name == "__eq__":
return variables.BuiltinVariable(operator.is_).call_function(
tx, (left, right), {}
)
elif name == "__ne__":
return variables.BuiltinVariable(operator.is_not).call_function(
tx, (left, right), {}
)
else:
op_str = cmp_name_to_op_str_mapping[name]
left_ty = left.python_type_name()
right_ty = right.python_type_name()
msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'"
raise_observed_exception(TypeError, tx, args=[msg])
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
tx,
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
@ -165,12 +184,6 @@ class BaseListVariable(VariableTracker):
return super().call_method(tx, name, args, kwargs)
@staticmethod
def list_compare(tx: "InstructionTranslator", op, left, right):
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
tx, [variables.BuiltinVariable(op), left, right], {}
)
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs) -> None: