mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Properly account for non-list instances in list comparison (#148470)
As title; this patch also removes an unused `list_compare` method. Fixes #148179. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148470 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7fe685be8
commit
c8cd8f68bd
@ -928,19 +928,38 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
[(1, -1, 3), (1, 2, 3), 13.33],
|
||||
]:
|
||||
if a != b:
|
||||
x += 1 * c
|
||||
x = x + 1 * c
|
||||
if a == b:
|
||||
x += 2 * c
|
||||
x = x + 2 * c
|
||||
if a < b:
|
||||
x += 4 * c
|
||||
x = x + 4 * c
|
||||
if a > b:
|
||||
x += 8 * c
|
||||
x = x + 8 * c
|
||||
if a <= b:
|
||||
x += 16 * c
|
||||
x = x + 16 * c
|
||||
if a >= b:
|
||||
x += 32 * c
|
||||
x = x + 32 * c
|
||||
return x
|
||||
|
||||
@make_test
|
||||
def test_list_compare_polyfill_non_lists(x):
|
||||
conds = []
|
||||
|
||||
# Non-list instances only work for eq and ne
|
||||
for a, b, c in [
|
||||
[(1, 2, 3), "(1, 2, 3)", 7.77],
|
||||
[143, (143,), 3.33],
|
||||
]:
|
||||
conds.append(a != b)
|
||||
if conds[-1]:
|
||||
x = x + 1 * c
|
||||
|
||||
conds.append(a == b)
|
||||
if conds[-1]:
|
||||
x = x + 2 * c
|
||||
|
||||
return x, conds
|
||||
|
||||
@make_test
|
||||
def test_promote_types(x):
|
||||
if x.dtype == torch.promote_types(torch.int32, torch.float32):
|
||||
|
@ -84,11 +84,16 @@ def accumulate_grad(x, new_grad):
|
||||
x.grad.add_(new_grad)
|
||||
|
||||
|
||||
# This mirrors
|
||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
|
||||
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
|
||||
"""emulate `(1,2,3) > (1,2)` etc"""
|
||||
# Apply `op` to the first pair that differ
|
||||
for a, b in zip(left, right):
|
||||
if a != b:
|
||||
return op(a, b)
|
||||
|
||||
# No more pairs to compare, so compare sizes.
|
||||
return op(len(left), len(right))
|
||||
|
||||
|
||||
|
@ -1034,6 +1034,16 @@ cmp_name_to_op_mapping = {
|
||||
}
|
||||
|
||||
|
||||
cmp_name_to_op_str_mapping = {
|
||||
"__eq__": "==",
|
||||
"__ne__": "!=",
|
||||
"__lt__": "<",
|
||||
"__le__": "<=",
|
||||
"__gt__": ">",
|
||||
"__ge__": ">=",
|
||||
}
|
||||
|
||||
|
||||
def is_wrapper_or_member_descriptor(value):
|
||||
return isinstance(
|
||||
value,
|
||||
|
@ -30,6 +30,7 @@ from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
|
||||
from ..source import AttrSource
|
||||
from ..utils import (
|
||||
cmp_name_to_op_mapping,
|
||||
cmp_name_to_op_str_mapping,
|
||||
get_fake_value,
|
||||
guard_if_dyn,
|
||||
iter_contains,
|
||||
@ -153,10 +154,28 @@ class BaseListVariable(VariableTracker):
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
left = self
|
||||
right = args[0]
|
||||
if not isinstance(left, BaseListVariable) and not isinstance(
|
||||
# TODO this type check logic mirrors the following
|
||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007
|
||||
# But we should probably move it up the stack to so that we don't
|
||||
# need to duplicate it for different VTs.
|
||||
if not isinstance(left, BaseListVariable) or not isinstance(
|
||||
right, BaseListVariable
|
||||
):
|
||||
return variables.ConstantVariable.create(NotImplemented)
|
||||
if name == "__eq__":
|
||||
return variables.BuiltinVariable(operator.is_).call_function(
|
||||
tx, (left, right), {}
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return variables.BuiltinVariable(operator.is_not).call_function(
|
||||
tx, (left, right), {}
|
||||
)
|
||||
else:
|
||||
op_str = cmp_name_to_op_str_mapping[name]
|
||||
left_ty = left.python_type_name()
|
||||
right_ty = right.python_type_name()
|
||||
msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'"
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
|
||||
tx,
|
||||
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
|
||||
@ -165,12 +184,6 @@ class BaseListVariable(VariableTracker):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@staticmethod
|
||||
def list_compare(tx: "InstructionTranslator", op, left, right):
|
||||
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
|
||||
tx, [variables.BuiltinVariable(op), left, right], {}
|
||||
)
|
||||
|
||||
|
||||
class RangeVariable(BaseListVariable):
|
||||
def __init__(self, items, **kwargs) -> None:
|
||||
|
Reference in New Issue
Block a user