mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Refactor COMPARE_OP and comparison builtins (#122043)
This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure. This change regresses overheads a bit, but this is fixed in the next PR. New test skips are variants of `type(e) is np.ndarray` previously falling back to eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043 Approved by: https://github.com/anijain2305 ghstack dependencies: #122039
This commit is contained in:
committed by
PyTorch MergeBot
parent
769ff86b91
commit
07caea5c12
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,35
|
detectron2_fcos_r_50_fpn,pass,94
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -54,47 +54,47 @@ densenet121,pass,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fasterrcnn_r_101_c4,pass,51
|
detectron2_fasterrcnn_r_101_c4,pass,164
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fasterrcnn_r_101_dc5,pass,51
|
detectron2_fasterrcnn_r_101_dc5,pass,163
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fasterrcnn_r_101_fpn,pass,55
|
detectron2_fasterrcnn_r_101_fpn,pass,172
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fasterrcnn_r_50_c4,pass,51
|
detectron2_fasterrcnn_r_50_c4,pass,113
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fasterrcnn_r_50_dc5,pass,51
|
detectron2_fasterrcnn_r_50_dc5,pass,112
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fasterrcnn_r_50_fpn,pass,55
|
detectron2_fasterrcnn_r_50_fpn,pass,121
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,38
|
detectron2_fcos_r_50_fpn,pass,97
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
|
detectron2_maskrcnn_r_101_c4,pass,182
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_maskrcnn_r_101_fpn,pass,73
|
detectron2_maskrcnn_r_101_fpn,pass,192
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_maskrcnn_r_50_c4,pass,66
|
detectron2_maskrcnn_r_50_c4,pass,131
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_maskrcnn_r_50_fpn,pass,73
|
detectron2_maskrcnn_r_50_fpn,pass,141
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,35
|
detectron2_fcos_r_50_fpn,pass,94
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -54,7 +54,7 @@ densenet121,pass,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,38
|
detectron2_fcos_r_50_fpn,pass,97
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,36
|
detectron2_fcos_r_50_fpn,pass,95
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,35
|
detectron2_fcos_r_50_fpn,pass,94
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,36
|
detectron2_fcos_r_50_fpn,pass,95
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -1459,6 +1459,26 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
par_mul = functools.partial(udf_mul, torch.ones(10, 10))
|
par_mul = functools.partial(udf_mul, torch.ones(10, 10))
|
||||||
return par_mul(x)
|
return par_mul(x)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_list_add_then_mutate(x):
|
||||||
|
my_list = [1, x]
|
||||||
|
y = x / 4.0
|
||||||
|
my_list = my_list + [x / 2.0, 4]
|
||||||
|
my_list.append(y)
|
||||||
|
return sum(my_list)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_list_expand_lhs(x):
|
||||||
|
return sum(4 * [x])
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_in_not_in(x):
|
||||||
|
mylist = [1, 2, 3, 4, 5, x]
|
||||||
|
myotherlist = [1, 2, 3, 4, 5]
|
||||||
|
assert 3 in mylist
|
||||||
|
assert 6 not in myotherlist
|
||||||
|
return sum(mylist)
|
||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_partials_udf_kwarg(x):
|
def test_partials_udf_kwarg(x):
|
||||||
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
|
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
|
||||||
|
|||||||
0
test/dynamo_skips/TestSqueeze.test_squeeze_type
Normal file
0
test/dynamo_skips/TestSqueeze.test_squeeze_type
Normal file
@ -99,17 +99,11 @@ from .variables.misc import (
|
|||||||
UnknownVariable,
|
UnknownVariable,
|
||||||
)
|
)
|
||||||
from .variables.nn_module import NNModuleVariable
|
from .variables.nn_module import NNModuleVariable
|
||||||
from .variables.tensor import (
|
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
|
||||||
supported_comparison_ops,
|
|
||||||
supported_const_comparison_ops,
|
|
||||||
SymNodeVariable,
|
|
||||||
TensorVariable,
|
|
||||||
)
|
|
||||||
from .variables.user_defined import (
|
from .variables.user_defined import (
|
||||||
RemovableHandleVariable,
|
RemovableHandleVariable,
|
||||||
UserDefinedClassVariable,
|
UserDefinedClassVariable,
|
||||||
UserDefinedObjectVariable,
|
UserDefinedObjectVariable,
|
||||||
UserDefinedVariable,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -117,6 +111,17 @@ graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
|||||||
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
||||||
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
|
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
|
||||||
tls = threading.local()
|
tls = threading.local()
|
||||||
|
compare_op_handlers: Dict[str, Any] = {
|
||||||
|
k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
|
||||||
|
}
|
||||||
|
handle_contains = BuiltinVariable(operator.contains).call_function
|
||||||
|
handle_not = BuiltinVariable(operator.not_).call_function
|
||||||
|
compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
|
||||||
|
tx, [*reversed(args)], {}
|
||||||
|
)
|
||||||
|
compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
|
||||||
|
tx, [handle_contains(tx, [*reversed(args)], {})], {}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -1200,49 +1205,7 @@ class InstructionTranslatorBase(
|
|||||||
unimplemented(f"FOR_ITER {typestr(it)}")
|
unimplemented(f"FOR_ITER {typestr(it)}")
|
||||||
|
|
||||||
def COMPARE_OP(self, inst):
|
def COMPARE_OP(self, inst):
|
||||||
left, right = self.popn(2)
|
self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
|
||||||
op = inst.argval
|
|
||||||
if op == "in" or op == "not in":
|
|
||||||
self.push(right.call_method(self, "__contains__", [left], {}))
|
|
||||||
if op == "not in":
|
|
||||||
self.UNARY_NOT(inst)
|
|
||||||
return
|
|
||||||
|
|
||||||
if right.is_python_constant():
|
|
||||||
if left.is_python_constant():
|
|
||||||
# constant fold
|
|
||||||
return self.push(
|
|
||||||
ConstantVariable(
|
|
||||||
supported_comparison_ops[op](
|
|
||||||
left.as_python_constant(), right.as_python_constant()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
op in supported_const_comparison_ops
|
|
||||||
and right.as_python_constant() is None
|
|
||||||
and isinstance(
|
|
||||||
left,
|
|
||||||
(
|
|
||||||
TensorVariable,
|
|
||||||
SymNodeVariable,
|
|
||||||
NNModuleVariable,
|
|
||||||
BaseListVariable,
|
|
||||||
UserDefinedVariable,
|
|
||||||
BaseUserFunctionVariable,
|
|
||||||
ConstDictVariable,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
# <non-None> is None
|
|
||||||
return self.push(
|
|
||||||
ConstantVariable(supported_const_comparison_ops[op](object(), None))
|
|
||||||
)
|
|
||||||
self.push(
|
|
||||||
BuiltinVariable(supported_comparison_ops[op]).call_function(
|
|
||||||
self, [left, right], {}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def GET_ITER(self, inst):
|
def GET_ITER(self, inst):
|
||||||
self.call_function(BuiltinVariable(iter), [self.pop()], {})
|
self.call_function(BuiltinVariable(iter), [self.pop()], {})
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from ..utils import (
|
|||||||
proxy_args_kwargs,
|
proxy_args_kwargs,
|
||||||
tensortype_to_dtype,
|
tensortype_to_dtype,
|
||||||
)
|
)
|
||||||
from .base import MutableLocal, typestr, VariableTracker
|
from .base import MutableLocal, VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
from .ctx_manager import EventVariable, StreamVariable
|
from .ctx_manager import EventVariable, StreamVariable
|
||||||
from .dicts import (
|
from .dicts import (
|
||||||
@ -58,6 +58,7 @@ from .lists import (
|
|||||||
)
|
)
|
||||||
from .tensor import (
|
from .tensor import (
|
||||||
FakeItemVariable,
|
FakeItemVariable,
|
||||||
|
supported_comparison_ops,
|
||||||
SymNodeVariable,
|
SymNodeVariable,
|
||||||
TensorVariable,
|
TensorVariable,
|
||||||
UnspecializedPythonVariable,
|
UnspecializedPythonVariable,
|
||||||
@ -167,6 +168,9 @@ class BuiltinVariable(VariableTracker):
|
|||||||
operator.ior,
|
operator.ior,
|
||||||
operator.index,
|
operator.index,
|
||||||
}
|
}
|
||||||
|
from .tensor import supported_comparison_ops
|
||||||
|
|
||||||
|
fns.update(supported_comparison_ops.values())
|
||||||
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
|
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
|
||||||
return fns
|
return fns
|
||||||
|
|
||||||
@ -262,7 +266,17 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# Multiple dispatch mechanism defining custom binop behavior for certain type
|
# Multiple dispatch mechanism defining custom binop behavior for certain type
|
||||||
# combinations. Handlers are attempted in order, and will be used if the type checks
|
# combinations. Handlers are attempted in order, and will be used if the type checks
|
||||||
# match. They are expected to have the signature:
|
# match. They are expected to have the signature:
|
||||||
# fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
|
# fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
|
||||||
|
from .dicts import DictKeys, SetVariable
|
||||||
|
from .functions import BaseUserFunctionVariable, UserFunctionVariable
|
||||||
|
from .nn_module import NNModuleVariable
|
||||||
|
from .tensor import supported_const_comparison_ops
|
||||||
|
from .torch import BaseTorchVariable
|
||||||
|
from .user_defined import (
|
||||||
|
UserDefinedClassVariable,
|
||||||
|
UserDefinedObjectVariable,
|
||||||
|
UserDefinedVariable,
|
||||||
|
)
|
||||||
|
|
||||||
# Override table contains: op_fn -> [list of handlers]
|
# Override table contains: op_fn -> [list of handlers]
|
||||||
op_handlers = {}
|
op_handlers = {}
|
||||||
@ -280,7 +294,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
tx,
|
tx,
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
options,
|
*,
|
||||||
forward_name=forward_name,
|
forward_name=forward_name,
|
||||||
reverse_name=reverse_name,
|
reverse_name=reverse_name,
|
||||||
):
|
):
|
||||||
@ -310,9 +324,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
((VariableTracker, UserDefinedVariable), user_defined_handler)
|
((VariableTracker, UserDefinedVariable), user_defined_handler)
|
||||||
)
|
)
|
||||||
|
|
||||||
def user_defined_inplace_handler(
|
def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name):
|
||||||
tx, a, b, options, forward_name=inplace_name
|
|
||||||
):
|
|
||||||
return a.call_method(tx, forward_name, [b], {})
|
return a.call_method(tx, forward_name, [b], {})
|
||||||
|
|
||||||
op_handlers[in_place_op].append(
|
op_handlers[in_place_op].append(
|
||||||
@ -323,7 +335,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Dynamic shape args
|
# Dynamic shape args
|
||||||
def dynamic_handler(tx, a, b, options, fn=op):
|
def dynamic_handler(tx, a, b, *, fn=op):
|
||||||
from .builder import wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
|
|
||||||
return wrap_fx_proxy(
|
return wrap_fx_proxy(
|
||||||
@ -331,7 +343,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
tx.output.create_proxy(
|
tx.output.create_proxy(
|
||||||
"call_function", fn, *proxy_args_kwargs([a, b], {})
|
"call_function", fn, *proxy_args_kwargs([a, b], {})
|
||||||
),
|
),
|
||||||
**options,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
op_handlers[op].append(
|
op_handlers[op].append(
|
||||||
@ -352,11 +363,11 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# Special cases - lower precedence but still prefer these over constant folding
|
# Special cases - lower precedence but still prefer these over constant folding
|
||||||
|
|
||||||
# List-like addition (e.g. [1, 2] + [3, 4])
|
# List-like addition (e.g. [1, 2] + [3, 4])
|
||||||
def tuple_add_handler(tx, a, b, options):
|
def tuple_add_handler(tx, a, b):
|
||||||
return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
|
return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||||
|
|
||||||
def size_add_handler(tx, a, b, options):
|
def size_add_handler(tx, a, b):
|
||||||
return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
|
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||||
|
|
||||||
list_like_addition_handlers = [
|
list_like_addition_handlers = [
|
||||||
# NB: Prefer the tuple-specific logic over base logic because of
|
# NB: Prefer the tuple-specific logic over base logic because of
|
||||||
@ -376,18 +387,27 @@ class BuiltinVariable(VariableTracker):
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
(ConstantVariable, TupleVariable),
|
(ConstantVariable, TupleVariable),
|
||||||
lambda tx, a, b, options: TupleVariable(
|
lambda tx, a, b: TupleVariable(
|
||||||
list(a.unpack_var_sequence(tx)) + b.items, **options
|
[*a.unpack_var_sequence(tx), *b.items],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
ListVariable,
|
||||||
|
(BaseListVariable, ConstantVariable, ListIteratorVariable),
|
||||||
|
),
|
||||||
|
lambda tx, a, b: ListVariable(
|
||||||
|
[*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
(BaseListVariable, BaseListVariable),
|
(BaseListVariable, BaseListVariable),
|
||||||
lambda tx, a, b, options: type(a)(a.items + b.items, **options),
|
lambda tx, a, b: type(a)([*a.items, *b.items]),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
op_handlers[operator.add].extend(list_like_addition_handlers)
|
op_handlers[operator.add].extend(list_like_addition_handlers)
|
||||||
|
|
||||||
def list_iadd_handler(tx, a, b, _):
|
def list_iadd_handler(tx, a, b):
|
||||||
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
|
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
|
||||||
# Handler doesn't apply
|
# Handler doesn't apply
|
||||||
return None
|
return None
|
||||||
@ -414,30 +434,169 @@ class BuiltinVariable(VariableTracker):
|
|||||||
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
|
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
|
||||||
|
|
||||||
# List-like expansion (e.g. [1, 2, 3] * 3)
|
# List-like expansion (e.g. [1, 2, 3] * 3)
|
||||||
def expand_list_like(tx, lst, const, options):
|
def expand_list_like(tx, lst, const):
|
||||||
|
if isinstance(lst, ConstantVariable):
|
||||||
|
lst, const = const, lst
|
||||||
return lst.__class__(
|
return lst.__class__(
|
||||||
items=lst.items * const.as_python_constant(),
|
items=lst.items * const.as_python_constant(),
|
||||||
mutable_local=MutableLocal(),
|
mutable_local=MutableLocal(),
|
||||||
**options,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
list_like_expansion_handlers = [
|
list_like_expansion_handlers = [
|
||||||
((ListVariable, ConstantVariable), expand_list_like),
|
((ListVariable, ConstantVariable), expand_list_like),
|
||||||
((TupleVariable, ConstantVariable), expand_list_like),
|
((TupleVariable, ConstantVariable), expand_list_like),
|
||||||
(
|
((ConstantVariable, ListVariable), expand_list_like),
|
||||||
(ConstantVariable, ListVariable),
|
((ConstantVariable, TupleVariable), expand_list_like),
|
||||||
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
(ConstantVariable, TupleVariable),
|
|
||||||
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
||||||
|
|
||||||
for key in op_handlers.keys():
|
size_or_tuple = (SizeVariable, TupleVariable)
|
||||||
# insert a mutable cache into each entry
|
has_set_items = (SetVariable, DictKeys)
|
||||||
op_handlers[key] = (op_handlers[key], dict())
|
|
||||||
|
def create_cmp_op_handlers(op):
|
||||||
|
def compare_by_value(tx, a, b):
|
||||||
|
return ConstantVariable(op(a.value, b.value))
|
||||||
|
|
||||||
|
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||||
|
|
||||||
|
if op in supported_const_comparison_ops.values():
|
||||||
|
# Tensor is None, List is not None, etc
|
||||||
|
none_result = op(object(), None)
|
||||||
|
if op.__name__.startswith("is_"):
|
||||||
|
|
||||||
|
def never(tx, a, b):
|
||||||
|
return ConstantVariable(none_result)
|
||||||
|
|
||||||
|
obj_op_none = never
|
||||||
|
none_op_obj = never
|
||||||
|
else:
|
||||||
|
|
||||||
|
def obj_op_none(tx, a, b: ConstantVariable):
|
||||||
|
if b.value is None or b.value is True or b.value is False:
|
||||||
|
return ConstantVariable(none_result)
|
||||||
|
|
||||||
|
def none_op_obj(tx, a: ConstantVariable, b):
|
||||||
|
if a.value is None or a.value is True or a.value is False:
|
||||||
|
return ConstantVariable(none_result)
|
||||||
|
|
||||||
|
types_that_are_never_none = (
|
||||||
|
TensorVariable,
|
||||||
|
SymNodeVariable,
|
||||||
|
NNModuleVariable,
|
||||||
|
BaseListVariable,
|
||||||
|
UserDefinedVariable,
|
||||||
|
BaseUserFunctionVariable,
|
||||||
|
ConstDictVariable,
|
||||||
|
BaseTorchVariable,
|
||||||
|
)
|
||||||
|
result.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
(types_that_are_never_none, ConstantVariable),
|
||||||
|
obj_op_none,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(ConstantVariable, types_that_are_never_none),
|
||||||
|
none_op_obj,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_compare_nocheck(tx, left, right):
|
||||||
|
return BaseListVariable.list_compare(tx, op, left, right)
|
||||||
|
|
||||||
|
def list_compare_check(tx, left, right):
|
||||||
|
if type(left) is not type(
|
||||||
|
right
|
||||||
|
): # Mismatch in BaseListVariable subclasses
|
||||||
|
unimplemented(f"{op.__name__}({left}, {right})")
|
||||||
|
return BaseListVariable.list_compare(tx, op, left, right)
|
||||||
|
|
||||||
|
def compare_set_items(tx, left, right):
|
||||||
|
return ConstantVariable(op(left.set_items, right.set_items))
|
||||||
|
|
||||||
|
op_var = BuiltinVariable(op)
|
||||||
|
result.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(UserFunctionVariable, BuiltinVariable),
|
||||||
|
(UserFunctionVariable, BuiltinVariable),
|
||||||
|
),
|
||||||
|
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
NNModuleVariable,
|
||||||
|
NNModuleVariable,
|
||||||
|
),
|
||||||
|
lambda tx, a, b: ConstantVariable(
|
||||||
|
op(
|
||||||
|
tx.output.get_submodule(a.module_key),
|
||||||
|
tx.output.get_submodule(b.module_key),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
((size_or_tuple, size_or_tuple), list_compare_nocheck),
|
||||||
|
(
|
||||||
|
(variables.BaseListVariable, variables.BaseListVariable),
|
||||||
|
list_compare_check,
|
||||||
|
),
|
||||||
|
((has_set_items, has_set_items), compare_set_items),
|
||||||
|
# TODO(jansel): UserDefinedObjectVariable is wrong and could invoke user code
|
||||||
|
(
|
||||||
|
(UserDefinedObjectVariable, UserDefinedObjectVariable),
|
||||||
|
compare_by_value,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(UserDefinedClassVariable, UserDefinedClassVariable),
|
||||||
|
compare_by_value,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(StreamVariable, EventVariable, ConstantVariable),
|
||||||
|
(StreamVariable, EventVariable, ConstantVariable),
|
||||||
|
),
|
||||||
|
compare_by_value,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(TensorVariable, VariableTracker),
|
||||||
|
op_var._comparison_with_tensor,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(VariableTracker, TensorVariable),
|
||||||
|
op_var._comparison_with_tensor,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(SymNodeVariable, VariableTracker),
|
||||||
|
op_var._comparison_with_symnode,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(VariableTracker, SymNodeVariable),
|
||||||
|
op_var._comparison_with_symnode,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if op.__name__.startswith("is_"):
|
||||||
|
|
||||||
|
def handle_is(tx, left, right):
|
||||||
|
# If the two objects are of different type, we can safely return False
|
||||||
|
# and True for `is` and `is not`, respectively
|
||||||
|
if type(left) is not type(right):
|
||||||
|
return ConstantVariable.create(op.__name__ != "is_")
|
||||||
|
|
||||||
|
result.append(((VariableTracker, VariableTracker), handle_is))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
for op in supported_comparison_ops.values():
|
||||||
|
assert callable(op)
|
||||||
|
assert op not in op_handlers
|
||||||
|
op_handlers[op] = create_cmp_op_handlers(op)
|
||||||
|
|
||||||
|
for op in op_handlers.keys():
|
||||||
|
op_handlers[op] = (op_handlers[op], dict())
|
||||||
return op_handlers
|
return op_handlers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -452,13 +611,26 @@ class BuiltinVariable(VariableTracker):
|
|||||||
if hit is not False:
|
if hit is not False:
|
||||||
return hit
|
return hit
|
||||||
|
|
||||||
# Return first handler that matches the type checks
|
a_type = type(a)
|
||||||
|
b_type = type(b)
|
||||||
|
matches = []
|
||||||
for (type1, type2), handler in handlers:
|
for (type1, type2), handler in handlers:
|
||||||
if isinstance(a, type1) and isinstance(b, type2):
|
if issubclass(a_type, type1) and issubclass(b_type, type2):
|
||||||
cache[cache_key] = handler
|
matches.append(handler)
|
||||||
return handler
|
|
||||||
cache[cache_key] = None
|
if not matches:
|
||||||
return None
|
result = None
|
||||||
|
elif len(matches) == 1:
|
||||||
|
result = matches[0]
|
||||||
|
else:
|
||||||
|
|
||||||
|
def result(*args):
|
||||||
|
for fn in matches:
|
||||||
|
rv = fn(*args)
|
||||||
|
if rv:
|
||||||
|
return rv
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def can_insert_in_graph(self):
|
def can_insert_in_graph(self):
|
||||||
return self.fn in self._fx_graph_functions()
|
return self.fn in self._fx_graph_functions()
|
||||||
@ -666,7 +838,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# Try to find a handler for the arg types; otherwise, fall through to constant handler
|
# Try to find a handler for the arg types; otherwise, fall through to constant handler
|
||||||
binop_handler = BuiltinVariable._find_binop_handler(fn, args[0], args[1])
|
binop_handler = BuiltinVariable._find_binop_handler(fn, args[0], args[1])
|
||||||
if binop_handler:
|
if binop_handler:
|
||||||
res = binop_handler(tx, args[0], args[1], {})
|
res = binop_handler(tx, args[0], args[1])
|
||||||
if res is not None:
|
if res is not None:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -1511,150 +1683,63 @@ class BuiltinVariable(VariableTracker):
|
|||||||
def call_deepcopy(self, tx, x):
|
def call_deepcopy(self, tx, x):
|
||||||
unimplemented(f"copy.deepcopy {repr(x)}")
|
unimplemented(f"copy.deepcopy {repr(x)}")
|
||||||
|
|
||||||
def _comparison(self, tx, left, right):
|
def _comparison_with_tensor(self, tx, left, right):
|
||||||
"""
|
from .builder import wrap_fx_proxy_cls
|
||||||
Used to implement comparison operators for different types.
|
from .tensor import supported_tensor_comparison_op_values
|
||||||
For example, list1 < list2 is implemented differently from tensor1 < tensor2
|
|
||||||
"""
|
|
||||||
from . import (
|
|
||||||
BaseListVariable,
|
|
||||||
ConstantVariable,
|
|
||||||
NNModuleVariable,
|
|
||||||
TensorVariable,
|
|
||||||
UserDefinedObjectVariable,
|
|
||||||
UserFunctionVariable,
|
|
||||||
)
|
|
||||||
from .lists import SizeVariable
|
|
||||||
from .tensor import (
|
|
||||||
supported_const_comparison_op_values,
|
|
||||||
supported_tensor_comparison_op_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
op = self.fn
|
op = self.fn
|
||||||
|
|
||||||
def _unimplemented():
|
if op in [operator.is_, operator.is_not]:
|
||||||
unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
|
is_result = (
|
||||||
|
|
||||||
if (
|
|
||||||
all(
|
|
||||||
isinstance(x, (NNModuleVariable, ConstantVariable))
|
|
||||||
for x in [left, right]
|
|
||||||
)
|
|
||||||
and op in supported_const_comparison_op_values
|
|
||||||
):
|
|
||||||
left = (
|
|
||||||
tx.output.get_submodule(left.module_key)
|
|
||||||
if isinstance(left, NNModuleVariable)
|
|
||||||
else left.as_python_constant()
|
|
||||||
)
|
|
||||||
right = (
|
|
||||||
tx.output.get_submodule(right.module_key)
|
|
||||||
if isinstance(right, NNModuleVariable)
|
|
||||||
else right.as_python_constant()
|
|
||||||
)
|
|
||||||
return ConstantVariable.create(op(left, right))
|
|
||||||
|
|
||||||
if isinstance(left, UserFunctionVariable):
|
|
||||||
if op not in supported_const_comparison_op_values:
|
|
||||||
_unimplemented()
|
|
||||||
if not isinstance(right, UserFunctionVariable):
|
|
||||||
_unimplemented()
|
|
||||||
return ConstantVariable.create(op(left.fn, right.fn))
|
|
||||||
|
|
||||||
# Note, we have a rare BaseListVariable subtype mismatch with valid comparison
|
|
||||||
# x = torch.randn([3, 3])
|
|
||||||
# x.size() == (3, 3) # True
|
|
||||||
# (3, 3) == x.size() # True
|
|
||||||
if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
|
|
||||||
right, (TupleVariable, SizeVariable)
|
|
||||||
):
|
|
||||||
return BaseListVariable.list_compare(tx, op, left, right)
|
|
||||||
|
|
||||||
if isinstance(left, BaseListVariable):
|
|
||||||
if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
|
|
||||||
_unimplemented()
|
|
||||||
return BaseListVariable.list_compare(tx, op, left, right)
|
|
||||||
|
|
||||||
# If they implement set semantics (e.g. SetVariable or DictKeys)
|
|
||||||
if hasattr(left, "set_items") and hasattr(right, "set_items"):
|
|
||||||
return ConstantVariable.create(op(left.set_items, right.set_items))
|
|
||||||
|
|
||||||
if isinstance(left, TensorVariable) or isinstance(right, TensorVariable):
|
|
||||||
from .builder import wrap_fx_proxy_cls
|
|
||||||
|
|
||||||
if op in [operator.is_, operator.is_not]:
|
|
||||||
is_result = (
|
|
||||||
isinstance(left, TensorVariable)
|
|
||||||
and isinstance(right, TensorVariable)
|
|
||||||
and id(extract_fake_example_value(left.as_proxy().node))
|
|
||||||
== id(extract_fake_example_value(right.as_proxy().node))
|
|
||||||
)
|
|
||||||
if op is operator.is_:
|
|
||||||
return ConstantVariable.create(is_result)
|
|
||||||
else:
|
|
||||||
return ConstantVariable.create(not is_result)
|
|
||||||
|
|
||||||
if op not in supported_tensor_comparison_op_values:
|
|
||||||
_unimplemented()
|
|
||||||
if (
|
|
||||||
isinstance(left, TensorVariable)
|
isinstance(left, TensorVariable)
|
||||||
and isinstance(right, TensorVariable)
|
and isinstance(right, TensorVariable)
|
||||||
and (left.size and right.size) is not None
|
and id(extract_fake_example_value(left.as_proxy().node))
|
||||||
and left.size != right.size
|
== id(extract_fake_example_value(right.as_proxy().node))
|
||||||
):
|
|
||||||
try:
|
|
||||||
torch.broadcast_shapes(left.size, right.size)
|
|
||||||
except RuntimeError:
|
|
||||||
# not broadcastable, can't be compared
|
|
||||||
_unimplemented()
|
|
||||||
tensor_cls = left if isinstance(left, TensorVariable) else right
|
|
||||||
proxy = tx.output.create_proxy(
|
|
||||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
|
||||||
)
|
|
||||||
return wrap_fx_proxy_cls(
|
|
||||||
type(tensor_cls), # handle Ndarrays and Tensors
|
|
||||||
tx,
|
|
||||||
proxy,
|
|
||||||
)
|
)
|
||||||
|
if op is operator.is_:
|
||||||
|
return ConstantVariable.create(is_result)
|
||||||
|
else:
|
||||||
|
return ConstantVariable.create(not is_result)
|
||||||
|
|
||||||
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
|
if op not in supported_tensor_comparison_op_values:
|
||||||
if op not in supported_tensor_comparison_op_values:
|
unimplemented(f"{op.__name__}({left}, {right})")
|
||||||
_unimplemented()
|
if (
|
||||||
|
isinstance(left, TensorVariable)
|
||||||
proxy = tx.output.create_proxy(
|
and isinstance(right, TensorVariable)
|
||||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
and (left.size and right.size) is not None
|
||||||
)
|
and left.size != right.size
|
||||||
return SymNodeVariable.create(
|
|
||||||
tx,
|
|
||||||
proxy,
|
|
||||||
sym_num=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(left, UserDefinedObjectVariable) and isinstance(
|
|
||||||
right, UserDefinedObjectVariable
|
|
||||||
):
|
):
|
||||||
return ConstantVariable.create(op(left.value, right.value))
|
try:
|
||||||
|
torch.broadcast_shapes(left.size, right.size)
|
||||||
|
except RuntimeError:
|
||||||
|
# not broadcastable, can't be compared
|
||||||
|
unimplemented(f"{op.__name__}({left}, {right})")
|
||||||
|
tensor_cls = left if isinstance(left, TensorVariable) else right
|
||||||
|
proxy = tx.output.create_proxy(
|
||||||
|
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||||
|
)
|
||||||
|
return wrap_fx_proxy_cls(
|
||||||
|
type(tensor_cls), # handle Ndarrays and Tensors
|
||||||
|
tx,
|
||||||
|
proxy,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(left, (StreamVariable, EventVariable)) or isinstance(
|
def _comparison_with_symnode(self, tx, left, right):
|
||||||
right, (StreamVariable, EventVariable)
|
from .tensor import supported_tensor_comparison_op_values
|
||||||
):
|
|
||||||
if type(left) == type(right) and op is operator.eq:
|
|
||||||
return ConstantVariable(op(left.value, right.value))
|
|
||||||
|
|
||||||
if isinstance(right, ConstantVariable) or isinstance(
|
op = self.fn
|
||||||
left, ConstantVariable
|
|
||||||
):
|
|
||||||
return ConstantVariable(op(left.value, right.value))
|
|
||||||
|
|
||||||
if op.__name__.startswith("is_"):
|
if op not in supported_tensor_comparison_op_values:
|
||||||
# If the two objects are of different type, we can safely return False and True for `is` and `is not`, respectively
|
unimplemented(f"{op.__name__}({left}, {right})")
|
||||||
if type(left) is not type(right):
|
|
||||||
return ConstantVariable.create(op.__name__ != "is_")
|
|
||||||
|
|
||||||
if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable):
|
proxy = tx.output.create_proxy(
|
||||||
return ConstantVariable.create(op(left.fn, right.fn))
|
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||||
|
)
|
||||||
_unimplemented()
|
return SymNodeVariable.create(
|
||||||
|
tx,
|
||||||
|
proxy,
|
||||||
|
sym_num=None,
|
||||||
|
)
|
||||||
|
|
||||||
def call_and_(self, tx, a, b):
|
def call_and_(self, tx, a, b):
|
||||||
# Rely on constant_handler
|
# Rely on constant_handler
|
||||||
@ -1711,14 +1796,8 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
call_eq = _comparison
|
def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
|
||||||
call_gt = _comparison
|
return a.call_method(tx, "__contains__", [b], {})
|
||||||
call_lt = _comparison
|
|
||||||
call_ge = _comparison
|
|
||||||
call_le = _comparison
|
|
||||||
call_ne = _comparison
|
|
||||||
call_is_ = _comparison
|
|
||||||
call_is_not = _comparison
|
|
||||||
|
|
||||||
call_all = _polyfill_call_impl("all")
|
call_all = _polyfill_call_impl("all")
|
||||||
call_any = _polyfill_call_impl("any")
|
call_any = _polyfill_call_impl("any")
|
||||||
|
|||||||
Reference in New Issue
Block a user