mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Use polyfill to implement comparison operators (#144485)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144485 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1090e58687
commit
e2e265e27b
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_accuracy,46
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,24
|
||||
detectron2_fcos_r_50_fpn,pass,22
|
||||
|
||||
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,24
|
||||
detectron2_fcos_r_50_fpn,pass,22
|
||||
|
||||
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,24
|
||||
detectron2_fcos_r_50_fpn,pass,22
|
||||
|
||||
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,24
|
||||
detectron2_fcos_r_50_fpn,pass,22
|
||||
|
||||
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,22
|
||||
detectron2_fcos_r_50_fpn,pass,20
|
||||
|
||||
|
||||
|
||||
|
|
@ -11861,6 +11861,33 @@ fn
|
||||
_, ne = run(torch.ones(1))
|
||||
self.assertFalse(ne)
|
||||
|
||||
def test_ne_operator_with_custom_ne(self):
|
||||
class Foo:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.ne_called = False
|
||||
|
||||
def __ne__(self, other):
|
||||
# ne_called attr is later checked to ensure that overrideen
|
||||
# `__ne__` is traced
|
||||
self.ne_called = True
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.x == other.x
|
||||
|
||||
f1 = Foo(0)
|
||||
f2 = Foo(0)
|
||||
|
||||
@torch.compile(fullgraph=True, backend="eager")
|
||||
def run(x):
|
||||
# `x + 1` prevents Dynamo from skipping this frame.
|
||||
return x + 1, f1 != f2
|
||||
|
||||
_, ne = run(torch.ones(1))
|
||||
self.assertFalse(ne)
|
||||
self.assertTrue(f1.ne_called)
|
||||
|
||||
def test_ne_operator_with_custom_graphbreak_eq(self):
|
||||
counters.clear()
|
||||
|
||||
|
@ -6019,6 +6019,19 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
self.assertEqual(fn(config, x), opt_fn(config, x))
|
||||
self.assertEqual(cloned_config.baz, 4)
|
||||
|
||||
@unittest.skipIf(not HAS_OMEGACONG, "missing omegaconf package")
|
||||
def test_omegaconf_listconfig_contains(self):
|
||||
def fn(cfg, x):
|
||||
if 1 in cfg:
|
||||
return torch.sin(x)
|
||||
return torch.cos(x)
|
||||
|
||||
config = OmegaConf.create([1, 2, 3, {"key": "value"}])
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(config, x), opt_fn(config, x))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/136257
|
||||
def test_overwriting_params(self):
|
||||
class M(torch.nn.Module):
|
||||
|
@ -8,12 +8,15 @@ Python polyfills for common builtins.
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import types
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from itertools import repeat as _repeat
|
||||
from typing import Any, Callable, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import dict_keys
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Load by torch._dynamo.polyfills.loader
|
||||
@ -219,14 +222,52 @@ def predicate(obj: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def object_eq(self, other):
|
||||
# Mirrors CPython implementation:
|
||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L6228-L6233
|
||||
return self is other
|
||||
def cmp_eq(a, b):
|
||||
# Note that the commented `is` check should ideally be removed. This is a
|
||||
# CPython optimization that skips the __eq__ checks it the obj id's are
|
||||
# same. But, these lines adds many `is` nodes in the Fx graph for
|
||||
# SymNodeVariable. For now, we can just skip this check. This is STILL
|
||||
# correct because one of the __eq__ checks will pass later, just could be
|
||||
# slow in some corner cases.
|
||||
# if a is b:
|
||||
# return True
|
||||
result = a.__eq__(b)
|
||||
if result is NotImplemented:
|
||||
result = b.__eq__(a)
|
||||
return result is not NotImplemented and result
|
||||
|
||||
|
||||
def object_ne(self, other):
|
||||
# Mirrors CPython implementation:
|
||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L6235-L6255
|
||||
# Using `==` is important because `self` might have a user-defined `__eq__`.
|
||||
return not (self == other)
|
||||
def cmp_ne(a, b):
|
||||
# Check if __ne__ is overridden
|
||||
if isinstance(type(a).__ne__, types.FunctionType):
|
||||
return a.__ne__(b)
|
||||
return not cmp_eq(a, b)
|
||||
|
||||
|
||||
def cmp_lt(a, b):
|
||||
result = a.__lt__(b)
|
||||
if result is NotImplemented:
|
||||
raise TypeError(f"{type(a)} does not support the < operator")
|
||||
return result
|
||||
|
||||
|
||||
def cmp_le(a, b):
|
||||
# Check if __le__ is overridden
|
||||
if isinstance(type(a).__le__, types.FunctionType):
|
||||
return a.__le__(b)
|
||||
return cmp_eq(a, b) or cmp_lt(a, b)
|
||||
|
||||
|
||||
def cmp_gt(a, b):
|
||||
# Check if __gt__ is overridden
|
||||
if isinstance(type(a).__gt__, types.FunctionType):
|
||||
return a.__gt__(b)
|
||||
# a > b is equivalent to b < a
|
||||
return cmp_lt(b, a)
|
||||
|
||||
|
||||
def cmp_ge(a, b):
|
||||
# Check if __ge__ is overridden
|
||||
if isinstance(type(a).__ge__, types.FunctionType):
|
||||
return a.__ge__(b)
|
||||
return cmp_eq(a, b) or cmp_gt(a, b)
|
||||
|
@ -1008,6 +1008,16 @@ def is_function(value):
|
||||
)
|
||||
|
||||
|
||||
cmp_name_to_op_mapping = {
|
||||
"__eq__": operator.eq,
|
||||
"__ne__": operator.ne,
|
||||
"__lt__": operator.lt,
|
||||
"__le__": operator.le,
|
||||
"__gt__": operator.gt,
|
||||
"__ge__": operator.ge,
|
||||
}
|
||||
|
||||
|
||||
def is_wrapper_or_member_descriptor(value):
|
||||
return isinstance(
|
||||
value,
|
||||
|
@ -9,7 +9,7 @@ from ..current_scope_id import current_scope_id
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, Source
|
||||
from ..utils import istype
|
||||
from ..utils import cmp_name_to_op_mapping, istype
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -412,6 +412,29 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
and not kwargs
|
||||
):
|
||||
return self.var_getattr(tx, args[0].as_python_constant())
|
||||
elif (
|
||||
name in cmp_name_to_op_mapping
|
||||
and len(args) == 1
|
||||
and self.is_python_constant()
|
||||
and not tx.output.side_effects.has_pending_mutation(self)
|
||||
and not kwargs
|
||||
):
|
||||
# NB : Checking for mutation is necessary because we compare
|
||||
# constant values
|
||||
other = args[0]
|
||||
if not isinstance(self, type(other)):
|
||||
return variables.ConstantVariable.create(NotImplemented)
|
||||
if (
|
||||
not other.is_python_constant()
|
||||
or tx.output.side_effects.has_pending_mutation(other)
|
||||
):
|
||||
unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
||||
|
||||
return variables.ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](
|
||||
self.as_python_constant(), other.as_python_constant()
|
||||
)
|
||||
)
|
||||
unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
||||
|
||||
def set_name_hint(self, name):
|
||||
|
@ -40,6 +40,7 @@ from ..utils import (
|
||||
check_numpy_ndarray_args,
|
||||
check_unspec_or_constant_args,
|
||||
check_unspec_python_args,
|
||||
cmp_name_to_op_mapping,
|
||||
dict_methods,
|
||||
extract_fake_example_value,
|
||||
get_fake_value,
|
||||
@ -107,6 +108,14 @@ _HandlerCallback = Callable[
|
||||
["InstructionTranslator", typing.Any, typing.Any], VariableTracker
|
||||
]
|
||||
_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]]
|
||||
polyfill_fn_mapping = {
|
||||
operator.eq: polyfills.cmp_eq,
|
||||
operator.ne: polyfills.cmp_ne,
|
||||
operator.lt: polyfills.cmp_lt,
|
||||
operator.le: polyfills.cmp_le,
|
||||
operator.gt: polyfills.cmp_gt,
|
||||
operator.ge: polyfills.cmp_ge,
|
||||
}
|
||||
|
||||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
@ -288,7 +297,6 @@ class BuiltinVariable(VariableTracker):
|
||||
# combinations. Handlers are attempted in order, and will be used if the type checks
|
||||
# match. They are expected to have the signature:
|
||||
# fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
|
||||
from .dicts import DictKeysVariable, SetVariable
|
||||
from .functions import BaseUserFunctionVariable, UserFunctionVariable
|
||||
from .nn_module import NNModuleVariable
|
||||
from .tensor import supported_const_comparison_ops
|
||||
@ -511,9 +519,6 @@ class BuiltinVariable(VariableTracker):
|
||||
]
|
||||
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
||||
|
||||
size_or_tuple = (SizeVariable, TupleVariable)
|
||||
has_set_items = (SetVariable, DictKeysVariable)
|
||||
|
||||
def create_cmp_op_handlers(op):
|
||||
def compare_by_value(tx: "InstructionTranslator", a, b):
|
||||
return ConstantVariable(op(a.value, b.value))
|
||||
@ -528,29 +533,48 @@ class BuiltinVariable(VariableTracker):
|
||||
]
|
||||
] = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||
|
||||
if op in supported_const_comparison_ops.values():
|
||||
if op in polyfill_fn_mapping:
|
||||
# For constants, speedup the comparison instead of using
|
||||
# polyfill. Removing this line causes major regression for pr
|
||||
# time benchmark - add_loop_eager.
|
||||
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||
|
||||
op_var = BuiltinVariable(op)
|
||||
# Special handling of SymNode variable
|
||||
result.extend(
|
||||
[
|
||||
(
|
||||
(SymNodeVariable, VariableTracker),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
(
|
||||
(VariableTracker, SymNodeVariable),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def handler(tx, a, b):
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {}
|
||||
)
|
||||
|
||||
result.append(((VariableTracker, VariableTracker), handler))
|
||||
return result
|
||||
|
||||
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||
|
||||
if op in supported_const_comparison_ops.values() and op.__name__.startswith(
|
||||
"is_"
|
||||
):
|
||||
# Tensor is None, List is not None, etc
|
||||
none_result = op(object(), None)
|
||||
if op.__name__.startswith("is_"):
|
||||
|
||||
def never(tx: "InstructionTranslator", a, b):
|
||||
return ConstantVariable(none_result)
|
||||
def never(tx: "InstructionTranslator", a, b):
|
||||
return ConstantVariable(none_result)
|
||||
|
||||
obj_op_none = never
|
||||
none_op_obj = never
|
||||
else:
|
||||
|
||||
def obj_op_none(
|
||||
tx: "InstructionTranslator", a, b: ConstantVariable
|
||||
):
|
||||
if b.value is None or b.value is True or b.value is False:
|
||||
return ConstantVariable(none_result)
|
||||
|
||||
def none_op_obj(
|
||||
tx: "InstructionTranslator", a: ConstantVariable, b
|
||||
):
|
||||
if a.value is None or a.value is True or a.value is False:
|
||||
return ConstantVariable(none_result)
|
||||
obj_op_none = never
|
||||
none_op_obj = never
|
||||
|
||||
types_that_are_never_none = (
|
||||
TensorVariable,
|
||||
@ -575,96 +599,61 @@ class BuiltinVariable(VariableTracker):
|
||||
]
|
||||
)
|
||||
|
||||
def list_compare_nocheck(tx: "InstructionTranslator", left, right):
|
||||
return BaseListVariable.list_compare(tx, op, left, right)
|
||||
|
||||
def list_compare_check(tx: "InstructionTranslator", left, right):
|
||||
if type(left) is not type(
|
||||
right
|
||||
): # Mismatch in BaseListVariable subclasses
|
||||
unimplemented(f"{op.__name__}({left}, {right})")
|
||||
return BaseListVariable.list_compare(tx, op, left, right)
|
||||
|
||||
def compare_set_items(tx: "InstructionTranslator", left, right):
|
||||
return ConstantVariable(op(left.set_items, right.set_items))
|
||||
|
||||
def compare_via_method(tx: "InstructionTranslator", left, right):
|
||||
return left.call_method(tx, f"__{op.__name__}__", [right], {})
|
||||
|
||||
compare_user_defined: Callable[..., object]
|
||||
if op.__name__.startswith("is_"):
|
||||
compare_user_defined = compare_by_value
|
||||
else:
|
||||
compare_user_defined = compare_via_method
|
||||
|
||||
op_var = BuiltinVariable(op)
|
||||
result.extend(
|
||||
[
|
||||
(
|
||||
op_var = BuiltinVariable(op)
|
||||
result.extend(
|
||||
[
|
||||
(
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
(
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(
|
||||
op(
|
||||
a.fn,
|
||||
b.fn,
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
NNModuleVariable,
|
||||
NNModuleVariable,
|
||||
(
|
||||
NNModuleVariable,
|
||||
NNModuleVariable,
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(
|
||||
op(
|
||||
tx.output.get_submodule(a.module_key),
|
||||
tx.output.get_submodule(b.module_key),
|
||||
)
|
||||
),
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(
|
||||
op(
|
||||
tx.output.get_submodule(a.module_key),
|
||||
tx.output.get_submodule(b.module_key),
|
||||
)
|
||||
),
|
||||
),
|
||||
((size_or_tuple, size_or_tuple), list_compare_nocheck),
|
||||
(
|
||||
(variables.BaseListVariable, variables.BaseListVariable),
|
||||
list_compare_check,
|
||||
),
|
||||
((has_set_items, has_set_items), compare_set_items),
|
||||
(
|
||||
(UserDefinedObjectVariable, UserDefinedObjectVariable),
|
||||
compare_user_defined,
|
||||
),
|
||||
(
|
||||
(UserDefinedClassVariable, UserDefinedClassVariable),
|
||||
compare_user_defined,
|
||||
),
|
||||
(
|
||||
(
|
||||
(StreamVariable, EventVariable, ConstantVariable),
|
||||
(StreamVariable, EventVariable, ConstantVariable),
|
||||
(UserDefinedObjectVariable, UserDefinedObjectVariable),
|
||||
compare_by_value,
|
||||
),
|
||||
compare_by_value,
|
||||
),
|
||||
(
|
||||
(TensorVariable, VariableTracker),
|
||||
op_var._comparison_with_tensor,
|
||||
),
|
||||
(
|
||||
(VariableTracker, TensorVariable),
|
||||
op_var._comparison_with_tensor,
|
||||
),
|
||||
(
|
||||
(SymNodeVariable, VariableTracker),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
(
|
||||
(VariableTracker, SymNodeVariable),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if op.__name__.startswith("is_"):
|
||||
(
|
||||
(UserDefinedClassVariable, UserDefinedClassVariable),
|
||||
compare_by_value,
|
||||
),
|
||||
(
|
||||
(
|
||||
(StreamVariable, EventVariable, ConstantVariable),
|
||||
(StreamVariable, EventVariable, ConstantVariable),
|
||||
),
|
||||
compare_by_value,
|
||||
),
|
||||
(
|
||||
(TensorVariable, VariableTracker),
|
||||
op_var._comparison_with_tensor,
|
||||
),
|
||||
(
|
||||
(VariableTracker, TensorVariable),
|
||||
op_var._comparison_with_tensor,
|
||||
),
|
||||
(
|
||||
(SymNodeVariable, VariableTracker),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
(
|
||||
(VariableTracker, SymNodeVariable),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def handle_is(tx: "InstructionTranslator", left, right):
|
||||
# If the two objects are of different type, we can safely return False
|
||||
@ -1789,6 +1778,8 @@ class BuiltinVariable(VariableTracker):
|
||||
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
|
||||
) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member):
|
||||
return variables.TorchInGraphFunctionVariable(member, source=source)
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(obj, name, source=source)
|
||||
elif isinstance(obj, DummyModule):
|
||||
# TODO(mlazos) - Do we need this?
|
||||
if obj.is_torch or name not in obj.value.__dict__:
|
||||
|
@ -8,8 +8,8 @@ from torch._dynamo.source import AttrSource, GetItemSource
|
||||
|
||||
from .. import variables
|
||||
from ..exc import raise_observed_exception, unimplemented
|
||||
from ..utils import common_constant_types, istype, np
|
||||
from .base import typestr, VariableTracker
|
||||
from ..utils import cmp_name_to_op_mapping, common_constant_types, istype, np
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -192,8 +192,7 @@ its type to `common_constant_types`.
|
||||
search = args[0].as_python_constant()
|
||||
result = search in self.value
|
||||
return ConstantVariable.create(result)
|
||||
|
||||
unimplemented(f"const method call {typestr(self.value)}.{name}")
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
@ -229,6 +228,8 @@ class EnumVariable(VariableTracker):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
if not hasattr(self.value, name):
|
||||
raise NotImplementedError
|
||||
if name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(self, name)
|
||||
member = getattr(self.value, name)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return VariableTracker.build(tx, member, source=source)
|
||||
|
@ -1171,6 +1171,9 @@ class StreamVariable(VariableTracker):
|
||||
self.value = value
|
||||
self.device = device
|
||||
|
||||
def python_type(self):
|
||||
return torch.Stream
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
@ -1179,15 +1182,8 @@ class StreamVariable(VariableTracker):
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
assert hasattr(self.value, name), f"no stream method found named {name}"
|
||||
assert name in [
|
||||
"wait_stream",
|
||||
"synchronize",
|
||||
"query",
|
||||
"record_event",
|
||||
"wait_event",
|
||||
], f" unsupported stream method {name}"
|
||||
|
||||
from ..utils import proxy_args_kwargs
|
||||
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
if name in ("wait_stream", "synchronize", "wait_event"):
|
||||
@ -1211,8 +1207,17 @@ class StreamVariable(VariableTracker):
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
else:
|
||||
unimplemented(self.device + " stream method " + name + " unsupported")
|
||||
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
||||
# NB : Checking for mutation is necessary because we compare
|
||||
# constant values
|
||||
other = args[0]
|
||||
if not isinstance(other, StreamVariable):
|
||||
return variables.ConstantVariable.create(NotImplemented)
|
||||
return variables.ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.value, other.value)
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def as_proxy(self):
|
||||
return self.proxy
|
||||
|
@ -12,7 +12,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import raise_observed_exception, unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import is_from_local_source
|
||||
from ..utils import dict_keys, dict_values, specialize_symnode
|
||||
from ..utils import cmp_name_to_op_mapping, dict_keys, dict_values, specialize_symnode
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
@ -796,7 +796,9 @@ class DictKeySetVariable(SetVariable):
|
||||
return dict_keys
|
||||
|
||||
def as_python_constant(self):
|
||||
unimplemented("DictKeySetVariable.as_python_constant")
|
||||
return dict.fromkeys(
|
||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
||||
).keys()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
@ -882,6 +884,12 @@ class DictKeysVariable(DictViewVariable):
|
||||
) -> "VariableTracker":
|
||||
if name == "__contains__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
if name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
|
@ -20,6 +20,7 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import (
|
||||
check_constant_args,
|
||||
check_unspec_or_constant_args,
|
||||
cmp_name_to_op_mapping,
|
||||
counters,
|
||||
identity,
|
||||
is_function,
|
||||
@ -309,6 +310,8 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
return result
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||
if name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(self, name)
|
||||
return fn_var_getattr(tx, self.fn, self.source, name)
|
||||
|
||||
def call_obj_hasattr(
|
||||
@ -793,6 +796,9 @@ class SkipFunctionVariable(VariableTracker):
|
||||
return variables.ConstantVariable.create(hasattr(self.value, name))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||
if name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(self, name)
|
||||
|
||||
return fn_var_getattr(tx, self.value, self.source, name)
|
||||
|
||||
|
||||
@ -949,6 +955,9 @@ class FunctoolsWrapsVariable(UserFunctionVariable):
|
||||
|
||||
|
||||
class CollectionsNamedTupleFunction(UserFunctionVariable):
|
||||
def as_python_constant(self):
|
||||
return self.fn
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
|
@ -15,6 +15,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import raise_observed_exception, unimplemented
|
||||
from ..source import AttrSource
|
||||
from ..utils import (
|
||||
cmp_name_to_op_mapping,
|
||||
get_fake_value,
|
||||
guard_if_dyn,
|
||||
istype,
|
||||
@ -136,6 +137,18 @@ class BaseListVariable(VariableTracker):
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
)
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
left = self
|
||||
right = args[0]
|
||||
if not isinstance(left, BaseListVariable) and not isinstance(
|
||||
right, BaseListVariable
|
||||
):
|
||||
return variables.ConstantVariable.create(NotImplemented)
|
||||
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
|
||||
tx,
|
||||
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
|
||||
{},
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
@ -23,6 +23,7 @@ from ..mutation_guard import unpatched_nn_module_init
|
||||
from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource
|
||||
from ..utils import (
|
||||
check_unspec_or_constant_args,
|
||||
cmp_name_to_op_mapping,
|
||||
identity,
|
||||
is_tensor_base_attr_getter,
|
||||
proxy_args_kwargs,
|
||||
@ -952,6 +953,9 @@ class TypingVariable(VariableTracker):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(self, name)
|
||||
|
||||
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
|
||||
return tx.side_effects.load_attr(self, name)
|
||||
|
||||
@ -960,7 +964,7 @@ class TypingVariable(VariableTracker):
|
||||
attr_source = AttrSource(self.source, name)
|
||||
return VariableBuilder(tx, attr_source)(value)
|
||||
else:
|
||||
return SourcelessBuilder(tx, value)
|
||||
return SourcelessBuilder.create(tx, value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
@ -41,6 +41,7 @@ from ..utils import (
|
||||
build_checkpoint_variable,
|
||||
build_invoke_subgraph_variable,
|
||||
check_constant_args,
|
||||
cmp_name_to_op_mapping,
|
||||
dict_methods,
|
||||
get_custom_getattr,
|
||||
has_torch_function,
|
||||
@ -185,6 +186,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
except AttributeError:
|
||||
obj = None
|
||||
|
||||
if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType):
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
|
||||
if isinstance(obj, staticmethod):
|
||||
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||
elif isinstance(obj, classmethod):
|
||||
@ -795,14 +799,14 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
if is_standard_setattr(method) or isinstance(self.value, threading.local):
|
||||
return self.method_setattr_standard(tx, *args, **kwargs)
|
||||
|
||||
if len(args) == 1 and not kwargs:
|
||||
if method is object.__eq__:
|
||||
func_var = VariableTracker.build(tx, polyfills.object_eq)
|
||||
return func_var.call_function(tx, [self, *args], kwargs)
|
||||
if method is object.__eq__ and len(args) == 1 and not kwargs:
|
||||
other = args[0]
|
||||
if not isinstance(other, UserDefinedObjectVariable):
|
||||
return variables.ConstantVariable.create(NotImplemented)
|
||||
|
||||
if method is object.__ne__:
|
||||
func_var = VariableTracker.build(tx, polyfills.object_ne)
|
||||
return func_var.call_function(tx, [self, *args], kwargs)
|
||||
# TODO(anijain2305) - Identity checking should already be a part
|
||||
# of the cmp_eq polyfill function.
|
||||
return ConstantVariable.create(self.value is other.value)
|
||||
|
||||
# check for methods implemented in C++
|
||||
if isinstance(method, types.FunctionType):
|
||||
|
Reference in New Issue
Block a user