mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Refactor COMPARE_OP and comparison builtins (#122043)
This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure. This change regresses overheads a bit, but this is fixed in the next PR. New test skips are variants of `type(e) is np.ndarray` previously falling back to eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043 Approved by: https://github.com/anijain2305 ghstack dependencies: #122039
This commit is contained in:
committed by
PyTorch MergeBot
parent
769ff86b91
commit
07caea5c12
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,35
|
||||
detectron2_fcos_r_50_fpn,pass,94
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,47 +54,47 @@ densenet121,pass,0
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_c4,pass,51
|
||||
detectron2_fasterrcnn_r_101_c4,pass,164
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_dc5,pass,51
|
||||
detectron2_fasterrcnn_r_101_dc5,pass,163
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_fpn,pass,55
|
||||
detectron2_fasterrcnn_r_101_fpn,pass,172
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_c4,pass,51
|
||||
detectron2_fasterrcnn_r_50_c4,pass,113
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_dc5,pass,51
|
||||
detectron2_fasterrcnn_r_50_dc5,pass,112
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_fpn,pass,55
|
||||
detectron2_fasterrcnn_r_50_fpn,pass,121
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,38
|
||||
detectron2_fcos_r_50_fpn,pass,97
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
|
||||
detectron2_maskrcnn_r_101_c4,pass,182
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_101_fpn,pass,73
|
||||
detectron2_maskrcnn_r_101_fpn,pass,192
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_50_c4,pass,66
|
||||
detectron2_maskrcnn_r_50_c4,pass,131
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_50_fpn,pass,73
|
||||
detectron2_maskrcnn_r_50_fpn,pass,141
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,35
|
||||
detectron2_fcos_r_50_fpn,pass,94
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ densenet121,pass,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,38
|
||||
detectron2_fcos_r_50_fpn,pass,97
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,36
|
||||
detectron2_fcos_r_50_fpn,pass,95
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,35
|
||||
detectron2_fcos_r_50_fpn,pass,94
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,36
|
||||
detectron2_fcos_r_50_fpn,pass,95
|
||||
|
||||
|
||||
|
||||
|
|
@ -1459,6 +1459,26 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
par_mul = functools.partial(udf_mul, torch.ones(10, 10))
|
||||
return par_mul(x)
|
||||
|
||||
@make_test
|
||||
def test_list_add_then_mutate(x):
|
||||
my_list = [1, x]
|
||||
y = x / 4.0
|
||||
my_list = my_list + [x / 2.0, 4]
|
||||
my_list.append(y)
|
||||
return sum(my_list)
|
||||
|
||||
@make_test
|
||||
def test_list_expand_lhs(x):
|
||||
return sum(4 * [x])
|
||||
|
||||
@make_test
|
||||
def test_in_not_in(x):
|
||||
mylist = [1, 2, 3, 4, 5, x]
|
||||
myotherlist = [1, 2, 3, 4, 5]
|
||||
assert 3 in mylist
|
||||
assert 6 not in myotherlist
|
||||
return sum(mylist)
|
||||
|
||||
@make_test
|
||||
def test_partials_udf_kwarg(x):
|
||||
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
|
||||
|
0
test/dynamo_skips/TestSqueeze.test_squeeze_type
Normal file
0
test/dynamo_skips/TestSqueeze.test_squeeze_type
Normal file
@ -99,17 +99,11 @@ from .variables.misc import (
|
||||
UnknownVariable,
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
supported_comparison_ops,
|
||||
supported_const_comparison_ops,
|
||||
SymNodeVariable,
|
||||
TensorVariable,
|
||||
)
|
||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
|
||||
from .variables.user_defined import (
|
||||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedVariable,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -117,6 +111,17 @@ graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
||||
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
|
||||
tls = threading.local()
|
||||
compare_op_handlers: Dict[str, Any] = {
|
||||
k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
|
||||
}
|
||||
handle_contains = BuiltinVariable(operator.contains).call_function
|
||||
handle_not = BuiltinVariable(operator.not_).call_function
|
||||
compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
|
||||
tx, [*reversed(args)], {}
|
||||
)
|
||||
compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
|
||||
tx, [handle_contains(tx, [*reversed(args)], {})], {}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -1200,49 +1205,7 @@ class InstructionTranslatorBase(
|
||||
unimplemented(f"FOR_ITER {typestr(it)}")
|
||||
|
||||
def COMPARE_OP(self, inst):
|
||||
left, right = self.popn(2)
|
||||
op = inst.argval
|
||||
if op == "in" or op == "not in":
|
||||
self.push(right.call_method(self, "__contains__", [left], {}))
|
||||
if op == "not in":
|
||||
self.UNARY_NOT(inst)
|
||||
return
|
||||
|
||||
if right.is_python_constant():
|
||||
if left.is_python_constant():
|
||||
# constant fold
|
||||
return self.push(
|
||||
ConstantVariable(
|
||||
supported_comparison_ops[op](
|
||||
left.as_python_constant(), right.as_python_constant()
|
||||
),
|
||||
)
|
||||
)
|
||||
elif (
|
||||
op in supported_const_comparison_ops
|
||||
and right.as_python_constant() is None
|
||||
and isinstance(
|
||||
left,
|
||||
(
|
||||
TensorVariable,
|
||||
SymNodeVariable,
|
||||
NNModuleVariable,
|
||||
BaseListVariable,
|
||||
UserDefinedVariable,
|
||||
BaseUserFunctionVariable,
|
||||
ConstDictVariable,
|
||||
),
|
||||
)
|
||||
):
|
||||
# <non-None> is None
|
||||
return self.push(
|
||||
ConstantVariable(supported_const_comparison_ops[op](object(), None))
|
||||
)
|
||||
self.push(
|
||||
BuiltinVariable(supported_comparison_ops[op]).call_function(
|
||||
self, [left, right], {}
|
||||
)
|
||||
)
|
||||
self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
|
||||
|
||||
def GET_ITER(self, inst):
|
||||
self.call_function(BuiltinVariable(iter), [self.pop()], {})
|
||||
|
@ -38,7 +38,7 @@ from ..utils import (
|
||||
proxy_args_kwargs,
|
||||
tensortype_to_dtype,
|
||||
)
|
||||
from .base import MutableLocal, typestr, VariableTracker
|
||||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import EventVariable, StreamVariable
|
||||
from .dicts import (
|
||||
@ -58,6 +58,7 @@ from .lists import (
|
||||
)
|
||||
from .tensor import (
|
||||
FakeItemVariable,
|
||||
supported_comparison_ops,
|
||||
SymNodeVariable,
|
||||
TensorVariable,
|
||||
UnspecializedPythonVariable,
|
||||
@ -167,6 +168,9 @@ class BuiltinVariable(VariableTracker):
|
||||
operator.ior,
|
||||
operator.index,
|
||||
}
|
||||
from .tensor import supported_comparison_ops
|
||||
|
||||
fns.update(supported_comparison_ops.values())
|
||||
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
|
||||
return fns
|
||||
|
||||
@ -262,7 +266,17 @@ class BuiltinVariable(VariableTracker):
|
||||
# Multiple dispatch mechanism defining custom binop behavior for certain type
|
||||
# combinations. Handlers are attempted in order, and will be used if the type checks
|
||||
# match. They are expected to have the signature:
|
||||
# fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
|
||||
# fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
|
||||
from .dicts import DictKeys, SetVariable
|
||||
from .functions import BaseUserFunctionVariable, UserFunctionVariable
|
||||
from .nn_module import NNModuleVariable
|
||||
from .tensor import supported_const_comparison_ops
|
||||
from .torch import BaseTorchVariable
|
||||
from .user_defined import (
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedVariable,
|
||||
)
|
||||
|
||||
# Override table contains: op_fn -> [list of handlers]
|
||||
op_handlers = {}
|
||||
@ -280,7 +294,7 @@ class BuiltinVariable(VariableTracker):
|
||||
tx,
|
||||
a,
|
||||
b,
|
||||
options,
|
||||
*,
|
||||
forward_name=forward_name,
|
||||
reverse_name=reverse_name,
|
||||
):
|
||||
@ -310,9 +324,7 @@ class BuiltinVariable(VariableTracker):
|
||||
((VariableTracker, UserDefinedVariable), user_defined_handler)
|
||||
)
|
||||
|
||||
def user_defined_inplace_handler(
|
||||
tx, a, b, options, forward_name=inplace_name
|
||||
):
|
||||
def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name):
|
||||
return a.call_method(tx, forward_name, [b], {})
|
||||
|
||||
op_handlers[in_place_op].append(
|
||||
@ -323,7 +335,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
# Dynamic shape args
|
||||
def dynamic_handler(tx, a, b, options, fn=op):
|
||||
def dynamic_handler(tx, a, b, *, fn=op):
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
return wrap_fx_proxy(
|
||||
@ -331,7 +343,6 @@ class BuiltinVariable(VariableTracker):
|
||||
tx.output.create_proxy(
|
||||
"call_function", fn, *proxy_args_kwargs([a, b], {})
|
||||
),
|
||||
**options,
|
||||
)
|
||||
|
||||
op_handlers[op].append(
|
||||
@ -352,11 +363,11 @@ class BuiltinVariable(VariableTracker):
|
||||
# Special cases - lower precedence but still prefer these over constant folding
|
||||
|
||||
# List-like addition (e.g. [1, 2] + [3, 4])
|
||||
def tuple_add_handler(tx, a, b, options):
|
||||
return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
|
||||
def tuple_add_handler(tx, a, b):
|
||||
return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||
|
||||
def size_add_handler(tx, a, b, options):
|
||||
return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
|
||||
def size_add_handler(tx, a, b):
|
||||
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||
|
||||
list_like_addition_handlers = [
|
||||
# NB: Prefer the tuple-specific logic over base logic because of
|
||||
@ -376,18 +387,27 @@ class BuiltinVariable(VariableTracker):
|
||||
),
|
||||
(
|
||||
(ConstantVariable, TupleVariable),
|
||||
lambda tx, a, b, options: TupleVariable(
|
||||
list(a.unpack_var_sequence(tx)) + b.items, **options
|
||||
lambda tx, a, b: TupleVariable(
|
||||
[*a.unpack_var_sequence(tx), *b.items],
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
ListVariable,
|
||||
(BaseListVariable, ConstantVariable, ListIteratorVariable),
|
||||
),
|
||||
lambda tx, a, b: ListVariable(
|
||||
[*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
|
||||
),
|
||||
),
|
||||
(
|
||||
(BaseListVariable, BaseListVariable),
|
||||
lambda tx, a, b, options: type(a)(a.items + b.items, **options),
|
||||
lambda tx, a, b: type(a)([*a.items, *b.items]),
|
||||
),
|
||||
]
|
||||
op_handlers[operator.add].extend(list_like_addition_handlers)
|
||||
|
||||
def list_iadd_handler(tx, a, b, _):
|
||||
def list_iadd_handler(tx, a, b):
|
||||
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
|
||||
# Handler doesn't apply
|
||||
return None
|
||||
@ -414,30 +434,169 @@ class BuiltinVariable(VariableTracker):
|
||||
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
|
||||
|
||||
# List-like expansion (e.g. [1, 2, 3] * 3)
|
||||
def expand_list_like(tx, lst, const, options):
|
||||
def expand_list_like(tx, lst, const):
|
||||
if isinstance(lst, ConstantVariable):
|
||||
lst, const = const, lst
|
||||
return lst.__class__(
|
||||
items=lst.items * const.as_python_constant(),
|
||||
mutable_local=MutableLocal(),
|
||||
**options,
|
||||
)
|
||||
|
||||
list_like_expansion_handlers = [
|
||||
((ListVariable, ConstantVariable), expand_list_like),
|
||||
((TupleVariable, ConstantVariable), expand_list_like),
|
||||
(
|
||||
(ConstantVariable, ListVariable),
|
||||
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
|
||||
),
|
||||
(
|
||||
(ConstantVariable, TupleVariable),
|
||||
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
|
||||
),
|
||||
((ConstantVariable, ListVariable), expand_list_like),
|
||||
((ConstantVariable, TupleVariable), expand_list_like),
|
||||
]
|
||||
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
||||
|
||||
for key in op_handlers.keys():
|
||||
# insert a mutable cache into each entry
|
||||
op_handlers[key] = (op_handlers[key], dict())
|
||||
size_or_tuple = (SizeVariable, TupleVariable)
|
||||
has_set_items = (SetVariable, DictKeys)
|
||||
|
||||
def create_cmp_op_handlers(op):
|
||||
def compare_by_value(tx, a, b):
|
||||
return ConstantVariable(op(a.value, b.value))
|
||||
|
||||
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||
|
||||
if op in supported_const_comparison_ops.values():
|
||||
# Tensor is None, List is not None, etc
|
||||
none_result = op(object(), None)
|
||||
if op.__name__.startswith("is_"):
|
||||
|
||||
def never(tx, a, b):
|
||||
return ConstantVariable(none_result)
|
||||
|
||||
obj_op_none = never
|
||||
none_op_obj = never
|
||||
else:
|
||||
|
||||
def obj_op_none(tx, a, b: ConstantVariable):
|
||||
if b.value is None or b.value is True or b.value is False:
|
||||
return ConstantVariable(none_result)
|
||||
|
||||
def none_op_obj(tx, a: ConstantVariable, b):
|
||||
if a.value is None or a.value is True or a.value is False:
|
||||
return ConstantVariable(none_result)
|
||||
|
||||
types_that_are_never_none = (
|
||||
TensorVariable,
|
||||
SymNodeVariable,
|
||||
NNModuleVariable,
|
||||
BaseListVariable,
|
||||
UserDefinedVariable,
|
||||
BaseUserFunctionVariable,
|
||||
ConstDictVariable,
|
||||
BaseTorchVariable,
|
||||
)
|
||||
result.extend(
|
||||
[
|
||||
(
|
||||
(types_that_are_never_none, ConstantVariable),
|
||||
obj_op_none,
|
||||
),
|
||||
(
|
||||
(ConstantVariable, types_that_are_never_none),
|
||||
none_op_obj,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def list_compare_nocheck(tx, left, right):
|
||||
return BaseListVariable.list_compare(tx, op, left, right)
|
||||
|
||||
def list_compare_check(tx, left, right):
|
||||
if type(left) is not type(
|
||||
right
|
||||
): # Mismatch in BaseListVariable subclasses
|
||||
unimplemented(f"{op.__name__}({left}, {right})")
|
||||
return BaseListVariable.list_compare(tx, op, left, right)
|
||||
|
||||
def compare_set_items(tx, left, right):
|
||||
return ConstantVariable(op(left.set_items, right.set_items))
|
||||
|
||||
op_var = BuiltinVariable(op)
|
||||
result.extend(
|
||||
[
|
||||
(
|
||||
(
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
|
||||
),
|
||||
(
|
||||
(
|
||||
NNModuleVariable,
|
||||
NNModuleVariable,
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(
|
||||
op(
|
||||
tx.output.get_submodule(a.module_key),
|
||||
tx.output.get_submodule(b.module_key),
|
||||
)
|
||||
),
|
||||
),
|
||||
((size_or_tuple, size_or_tuple), list_compare_nocheck),
|
||||
(
|
||||
(variables.BaseListVariable, variables.BaseListVariable),
|
||||
list_compare_check,
|
||||
),
|
||||
((has_set_items, has_set_items), compare_set_items),
|
||||
# TODO(jansel): UserDefinedObjectVariable is wrong and could invoke user code
|
||||
(
|
||||
(UserDefinedObjectVariable, UserDefinedObjectVariable),
|
||||
compare_by_value,
|
||||
),
|
||||
(
|
||||
(UserDefinedClassVariable, UserDefinedClassVariable),
|
||||
compare_by_value,
|
||||
),
|
||||
(
|
||||
(
|
||||
(StreamVariable, EventVariable, ConstantVariable),
|
||||
(StreamVariable, EventVariable, ConstantVariable),
|
||||
),
|
||||
compare_by_value,
|
||||
),
|
||||
(
|
||||
(TensorVariable, VariableTracker),
|
||||
op_var._comparison_with_tensor,
|
||||
),
|
||||
(
|
||||
(VariableTracker, TensorVariable),
|
||||
op_var._comparison_with_tensor,
|
||||
),
|
||||
(
|
||||
(SymNodeVariable, VariableTracker),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
(
|
||||
(VariableTracker, SymNodeVariable),
|
||||
op_var._comparison_with_symnode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if op.__name__.startswith("is_"):
|
||||
|
||||
def handle_is(tx, left, right):
|
||||
# If the two objects are of different type, we can safely return False
|
||||
# and True for `is` and `is not`, respectively
|
||||
if type(left) is not type(right):
|
||||
return ConstantVariable.create(op.__name__ != "is_")
|
||||
|
||||
result.append(((VariableTracker, VariableTracker), handle_is))
|
||||
|
||||
return result
|
||||
|
||||
for op in supported_comparison_ops.values():
|
||||
assert callable(op)
|
||||
assert op not in op_handlers
|
||||
op_handlers[op] = create_cmp_op_handlers(op)
|
||||
|
||||
for op in op_handlers.keys():
|
||||
op_handlers[op] = (op_handlers[op], dict())
|
||||
return op_handlers
|
||||
|
||||
@staticmethod
|
||||
@ -452,13 +611,26 @@ class BuiltinVariable(VariableTracker):
|
||||
if hit is not False:
|
||||
return hit
|
||||
|
||||
# Return first handler that matches the type checks
|
||||
a_type = type(a)
|
||||
b_type = type(b)
|
||||
matches = []
|
||||
for (type1, type2), handler in handlers:
|
||||
if isinstance(a, type1) and isinstance(b, type2):
|
||||
cache[cache_key] = handler
|
||||
return handler
|
||||
cache[cache_key] = None
|
||||
return None
|
||||
if issubclass(a_type, type1) and issubclass(b_type, type2):
|
||||
matches.append(handler)
|
||||
|
||||
if not matches:
|
||||
result = None
|
||||
elif len(matches) == 1:
|
||||
result = matches[0]
|
||||
else:
|
||||
|
||||
def result(*args):
|
||||
for fn in matches:
|
||||
rv = fn(*args)
|
||||
if rv:
|
||||
return rv
|
||||
|
||||
return result
|
||||
|
||||
def can_insert_in_graph(self):
|
||||
return self.fn in self._fx_graph_functions()
|
||||
@ -666,7 +838,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# Try to find a handler for the arg types; otherwise, fall through to constant handler
|
||||
binop_handler = BuiltinVariable._find_binop_handler(fn, args[0], args[1])
|
||||
if binop_handler:
|
||||
res = binop_handler(tx, args[0], args[1], {})
|
||||
res = binop_handler(tx, args[0], args[1])
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
@ -1511,150 +1683,63 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_deepcopy(self, tx, x):
|
||||
unimplemented(f"copy.deepcopy {repr(x)}")
|
||||
|
||||
def _comparison(self, tx, left, right):
|
||||
"""
|
||||
Used to implement comparison operators for different types.
|
||||
For example, list1 < list2 is implemented differently from tensor1 < tensor2
|
||||
"""
|
||||
from . import (
|
||||
BaseListVariable,
|
||||
ConstantVariable,
|
||||
NNModuleVariable,
|
||||
TensorVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserFunctionVariable,
|
||||
)
|
||||
from .lists import SizeVariable
|
||||
from .tensor import (
|
||||
supported_const_comparison_op_values,
|
||||
supported_tensor_comparison_op_values,
|
||||
)
|
||||
def _comparison_with_tensor(self, tx, left, right):
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
from .tensor import supported_tensor_comparison_op_values
|
||||
|
||||
op = self.fn
|
||||
|
||||
def _unimplemented():
|
||||
unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
|
||||
|
||||
if (
|
||||
all(
|
||||
isinstance(x, (NNModuleVariable, ConstantVariable))
|
||||
for x in [left, right]
|
||||
)
|
||||
and op in supported_const_comparison_op_values
|
||||
):
|
||||
left = (
|
||||
tx.output.get_submodule(left.module_key)
|
||||
if isinstance(left, NNModuleVariable)
|
||||
else left.as_python_constant()
|
||||
)
|
||||
right = (
|
||||
tx.output.get_submodule(right.module_key)
|
||||
if isinstance(right, NNModuleVariable)
|
||||
else right.as_python_constant()
|
||||
)
|
||||
return ConstantVariable.create(op(left, right))
|
||||
|
||||
if isinstance(left, UserFunctionVariable):
|
||||
if op not in supported_const_comparison_op_values:
|
||||
_unimplemented()
|
||||
if not isinstance(right, UserFunctionVariable):
|
||||
_unimplemented()
|
||||
return ConstantVariable.create(op(left.fn, right.fn))
|
||||
|
||||
# Note, we have a rare BaseListVariable subtype mismatch with valid comparison
|
||||
# x = torch.randn([3, 3])
|
||||
# x.size() == (3, 3) # True
|
||||
# (3, 3) == x.size() # True
|
||||
if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
|
||||
right, (TupleVariable, SizeVariable)
|
||||
):
|
||||
return BaseListVariable.list_compare(tx, op, left, right)
|
||||
|
||||
if isinstance(left, BaseListVariable):
|
||||
if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
|
||||
_unimplemented()
|
||||
return BaseListVariable.list_compare(tx, op, left, right)
|
||||
|
||||
# If they implement set semantics (e.g. SetVariable or DictKeys)
|
||||
if hasattr(left, "set_items") and hasattr(right, "set_items"):
|
||||
return ConstantVariable.create(op(left.set_items, right.set_items))
|
||||
|
||||
if isinstance(left, TensorVariable) or isinstance(right, TensorVariable):
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
if op in [operator.is_, operator.is_not]:
|
||||
is_result = (
|
||||
isinstance(left, TensorVariable)
|
||||
and isinstance(right, TensorVariable)
|
||||
and id(extract_fake_example_value(left.as_proxy().node))
|
||||
== id(extract_fake_example_value(right.as_proxy().node))
|
||||
)
|
||||
if op is operator.is_:
|
||||
return ConstantVariable.create(is_result)
|
||||
else:
|
||||
return ConstantVariable.create(not is_result)
|
||||
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
_unimplemented()
|
||||
if (
|
||||
if op in [operator.is_, operator.is_not]:
|
||||
is_result = (
|
||||
isinstance(left, TensorVariable)
|
||||
and isinstance(right, TensorVariable)
|
||||
and (left.size and right.size) is not None
|
||||
and left.size != right.size
|
||||
):
|
||||
try:
|
||||
torch.broadcast_shapes(left.size, right.size)
|
||||
except RuntimeError:
|
||||
# not broadcastable, can't be compared
|
||||
_unimplemented()
|
||||
tensor_cls = left if isinstance(left, TensorVariable) else right
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||
)
|
||||
return wrap_fx_proxy_cls(
|
||||
type(tensor_cls), # handle Ndarrays and Tensors
|
||||
tx,
|
||||
proxy,
|
||||
and id(extract_fake_example_value(left.as_proxy().node))
|
||||
== id(extract_fake_example_value(right.as_proxy().node))
|
||||
)
|
||||
if op is operator.is_:
|
||||
return ConstantVariable.create(is_result)
|
||||
else:
|
||||
return ConstantVariable.create(not is_result)
|
||||
|
||||
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
_unimplemented()
|
||||
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||
)
|
||||
return SymNodeVariable.create(
|
||||
tx,
|
||||
proxy,
|
||||
sym_num=None,
|
||||
)
|
||||
|
||||
if isinstance(left, UserDefinedObjectVariable) and isinstance(
|
||||
right, UserDefinedObjectVariable
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
unimplemented(f"{op.__name__}({left}, {right})")
|
||||
if (
|
||||
isinstance(left, TensorVariable)
|
||||
and isinstance(right, TensorVariable)
|
||||
and (left.size and right.size) is not None
|
||||
and left.size != right.size
|
||||
):
|
||||
return ConstantVariable.create(op(left.value, right.value))
|
||||
try:
|
||||
torch.broadcast_shapes(left.size, right.size)
|
||||
except RuntimeError:
|
||||
# not broadcastable, can't be compared
|
||||
unimplemented(f"{op.__name__}({left}, {right})")
|
||||
tensor_cls = left if isinstance(left, TensorVariable) else right
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||
)
|
||||
return wrap_fx_proxy_cls(
|
||||
type(tensor_cls), # handle Ndarrays and Tensors
|
||||
tx,
|
||||
proxy,
|
||||
)
|
||||
|
||||
if isinstance(left, (StreamVariable, EventVariable)) or isinstance(
|
||||
right, (StreamVariable, EventVariable)
|
||||
):
|
||||
if type(left) == type(right) and op is operator.eq:
|
||||
return ConstantVariable(op(left.value, right.value))
|
||||
def _comparison_with_symnode(self, tx, left, right):
|
||||
from .tensor import supported_tensor_comparison_op_values
|
||||
|
||||
if isinstance(right, ConstantVariable) or isinstance(
|
||||
left, ConstantVariable
|
||||
):
|
||||
return ConstantVariable(op(left.value, right.value))
|
||||
op = self.fn
|
||||
|
||||
if op.__name__.startswith("is_"):
|
||||
# If the two objects are of different type, we can safely return False and True for `is` and `is not`, respectively
|
||||
if type(left) is not type(right):
|
||||
return ConstantVariable.create(op.__name__ != "is_")
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
unimplemented(f"{op.__name__}({left}, {right})")
|
||||
|
||||
if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable):
|
||||
return ConstantVariable.create(op(left.fn, right.fn))
|
||||
|
||||
_unimplemented()
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||
)
|
||||
return SymNodeVariable.create(
|
||||
tx,
|
||||
proxy,
|
||||
sym_num=None,
|
||||
)
|
||||
|
||||
def call_and_(self, tx, a, b):
|
||||
# Rely on constant_handler
|
||||
@ -1711,14 +1796,8 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
return None
|
||||
|
||||
call_eq = _comparison
|
||||
call_gt = _comparison
|
||||
call_lt = _comparison
|
||||
call_ge = _comparison
|
||||
call_le = _comparison
|
||||
call_ne = _comparison
|
||||
call_is_ = _comparison
|
||||
call_is_not = _comparison
|
||||
def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
|
||||
return a.call_method(tx, "__contains__", [b], {})
|
||||
|
||||
call_all = _polyfill_call_impl("all")
|
||||
call_any = _polyfill_call_impl("any")
|
||||
|
Reference in New Issue
Block a user