[dynamo] Refactor COMPARE_OP and comparison builtins (#122043)

This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure.  This change regresses overheads a bit, but this is fixed in the next PR.

New test skips are variants of `type(e) is np.ndarray` previously falling back to eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039
This commit is contained in:
Jason Ansel
2024-03-18 15:56:47 -07:00
committed by PyTorch MergeBot
parent 769ff86b91
commit 07caea5c12
13 changed files with 305 additions and 243 deletions

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,35
detectron2_fcos_r_50_fpn,pass,94

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -54,47 +54,47 @@ densenet121,pass,0
detectron2_fasterrcnn_r_101_c4,pass,51
detectron2_fasterrcnn_r_101_c4,pass,164
detectron2_fasterrcnn_r_101_dc5,pass,51
detectron2_fasterrcnn_r_101_dc5,pass,163
detectron2_fasterrcnn_r_101_fpn,pass,55
detectron2_fasterrcnn_r_101_fpn,pass,172
detectron2_fasterrcnn_r_50_c4,pass,51
detectron2_fasterrcnn_r_50_c4,pass,113
detectron2_fasterrcnn_r_50_dc5,pass,51
detectron2_fasterrcnn_r_50_dc5,pass,112
detectron2_fasterrcnn_r_50_fpn,pass,55
detectron2_fasterrcnn_r_50_fpn,pass,121
detectron2_fcos_r_50_fpn,pass,38
detectron2_fcos_r_50_fpn,pass,97
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
detectron2_maskrcnn_r_101_c4,pass,182
detectron2_maskrcnn_r_101_fpn,pass,73
detectron2_maskrcnn_r_101_fpn,pass,192
detectron2_maskrcnn_r_50_c4,pass,66
detectron2_maskrcnn_r_50_c4,pass,131
detectron2_maskrcnn_r_50_fpn,pass,73
detectron2_maskrcnn_r_50_fpn,pass,141

1 name accuracy graph_breaks
54 nvidia_deeprecommender pass 0
55 opacus_cifar10 pass 0
56 phlippe_densenet pass 0
57 phlippe_resnet pass 0
58 pyhpc_equation_of_state pass 0
59 pyhpc_isoneutral_mixing pass 0
60 pyhpc_turbulent_kinetic_energy pass 0
61 pytorch_CycleGAN_and_pix2pix pass 0
62 pytorch_stargan pass 0
63 pytorch_unet pass 0
64 resnet152 pass 0
65 resnet18 pass 0
66 resnet50 pass 0
67 resnet50_quantized_qat pass 2
68 resnext50_32x4d pass 0
69 shufflenet_v2_x1_0 pass 0
70 soft_actor_critic pass 0
71 speech_transformer pass 10
72 squeezenet1_1 pass 0
73 stable_diffusion_unet pass_due_to_skip 0
74 timm_efficientdet model_fail_to_load 0
75 timm_efficientnet pass 0
76 timm_nfnet pass 0
77 timm_regnet pass 0
78 timm_resnest pass 0
79 timm_vision_transformer pass 0
80 timm_vision_transformer_large pass_due_to_skip 0
81 timm_vovnet pass 0
82 torch_multimodal_clip pass 0
83 tts_angular pass 2
84 vgg16 pass 0
85 vision_maskrcnn pass 28
86 yolov3 pass 2
87
88
89
90
91
92
93
94
95
96
97
98
99
100

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,35
detectron2_fcos_r_50_fpn,pass,94

1 name accuracy graph_breaks
86 timm_regnet pass 0
87 timm_resnest pass 0
88 timm_vision_transformer pass 0
89 timm_vision_transformer_large pass_due_to_skip 0
90 timm_vovnet pass 0
91 torch_multimodal_clip pass 0
92 tts_angular pass 2

View File

@ -54,7 +54,7 @@ densenet121,pass,0
detectron2_fcos_r_50_fpn,pass,38
detectron2_fcos_r_50_fpn,pass,97

1 name accuracy graph_breaks
54 resnet152 pass 0
55 resnet18 pass 0
56 resnet50 pass 0
57 resnet50_quantized_qat pass 2
58 resnext50_32x4d pass 0
59 shufflenet_v2_x1_0 pass 0
60 soft_actor_critic pass 0

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,36
detectron2_fcos_r_50_fpn,pass,95

1 name accuracy graph_breaks
86 timm_regnet pass 0
87 timm_resnest pass 0
88 timm_vision_transformer pass 0
89 timm_vision_transformer_large pass_due_to_skip 0
90 timm_vovnet pass 0
91 torch_multimodal_clip pass 0
92 tts_angular pass 2

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,35
detectron2_fcos_r_50_fpn,pass,94

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,36
detectron2_fcos_r_50_fpn,pass,95

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -1459,6 +1459,26 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
par_mul = functools.partial(udf_mul, torch.ones(10, 10))
return par_mul(x)
@make_test
def test_list_add_then_mutate(x):
my_list = [1, x]
y = x / 4.0
my_list = my_list + [x / 2.0, 4]
my_list.append(y)
return sum(my_list)
@make_test
def test_list_expand_lhs(x):
return sum(4 * [x])
@make_test
def test_in_not_in(x):
mylist = [1, 2, 3, 4, 5, x]
myotherlist = [1, 2, 3, 4, 5]
assert 3 in mylist
assert 6 not in myotherlist
return sum(mylist)
@make_test
def test_partials_udf_kwarg(x):
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))

View File

@ -99,17 +99,11 @@ from .variables.misc import (
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
supported_comparison_ops,
supported_const_comparison_ops,
SymNodeVariable,
TensorVariable,
)
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
UserDefinedVariable,
)
log = logging.getLogger(__name__)
@ -117,6 +111,17 @@ graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
tls = threading.local()
compare_op_handlers: Dict[str, Any] = {
k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
}
handle_contains = BuiltinVariable(operator.contains).call_function
handle_not = BuiltinVariable(operator.not_).call_function
compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
tx, [*reversed(args)], {}
)
compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
tx, [handle_contains(tx, [*reversed(args)], {})], {}
)
@dataclasses.dataclass
@ -1200,49 +1205,7 @@ class InstructionTranslatorBase(
unimplemented(f"FOR_ITER {typestr(it)}")
def COMPARE_OP(self, inst):
left, right = self.popn(2)
op = inst.argval
if op == "in" or op == "not in":
self.push(right.call_method(self, "__contains__", [left], {}))
if op == "not in":
self.UNARY_NOT(inst)
return
if right.is_python_constant():
if left.is_python_constant():
# constant fold
return self.push(
ConstantVariable(
supported_comparison_ops[op](
left.as_python_constant(), right.as_python_constant()
),
)
)
elif (
op in supported_const_comparison_ops
and right.as_python_constant() is None
and isinstance(
left,
(
TensorVariable,
SymNodeVariable,
NNModuleVariable,
BaseListVariable,
UserDefinedVariable,
BaseUserFunctionVariable,
ConstDictVariable,
),
)
):
# <non-None> is None
return self.push(
ConstantVariable(supported_const_comparison_ops[op](object(), None))
)
self.push(
BuiltinVariable(supported_comparison_ops[op]).call_function(
self, [left, right], {}
)
)
self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
def GET_ITER(self, inst):
self.call_function(BuiltinVariable(iter), [self.pop()], {})

View File

@ -38,7 +38,7 @@ from ..utils import (
proxy_args_kwargs,
tensortype_to_dtype,
)
from .base import MutableLocal, typestr, VariableTracker
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable
from .dicts import (
@ -58,6 +58,7 @@ from .lists import (
)
from .tensor import (
FakeItemVariable,
supported_comparison_ops,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
@ -167,6 +168,9 @@ class BuiltinVariable(VariableTracker):
operator.ior,
operator.index,
}
from .tensor import supported_comparison_ops
fns.update(supported_comparison_ops.values())
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
return fns
@ -262,7 +266,17 @@ class BuiltinVariable(VariableTracker):
# Multiple dispatch mechanism defining custom binop behavior for certain type
# combinations. Handlers are attempted in order, and will be used if the type checks
# match. They are expected to have the signature:
# fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
# fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
from .dicts import DictKeys, SetVariable
from .functions import BaseUserFunctionVariable, UserFunctionVariable
from .nn_module import NNModuleVariable
from .tensor import supported_const_comparison_ops
from .torch import BaseTorchVariable
from .user_defined import (
UserDefinedClassVariable,
UserDefinedObjectVariable,
UserDefinedVariable,
)
# Override table contains: op_fn -> [list of handlers]
op_handlers = {}
@ -280,7 +294,7 @@ class BuiltinVariable(VariableTracker):
tx,
a,
b,
options,
*,
forward_name=forward_name,
reverse_name=reverse_name,
):
@ -310,9 +324,7 @@ class BuiltinVariable(VariableTracker):
((VariableTracker, UserDefinedVariable), user_defined_handler)
)
def user_defined_inplace_handler(
tx, a, b, options, forward_name=inplace_name
):
def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name):
return a.call_method(tx, forward_name, [b], {})
op_handlers[in_place_op].append(
@ -323,7 +335,7 @@ class BuiltinVariable(VariableTracker):
)
# Dynamic shape args
def dynamic_handler(tx, a, b, options, fn=op):
def dynamic_handler(tx, a, b, *, fn=op):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
@ -331,7 +343,6 @@ class BuiltinVariable(VariableTracker):
tx.output.create_proxy(
"call_function", fn, *proxy_args_kwargs([a, b], {})
),
**options,
)
op_handlers[op].append(
@ -352,11 +363,11 @@ class BuiltinVariable(VariableTracker):
# Special cases - lower precedence but still prefer these over constant folding
# List-like addition (e.g. [1, 2] + [3, 4])
def tuple_add_handler(tx, a, b, options):
return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
def tuple_add_handler(tx, a, b):
return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
def size_add_handler(tx, a, b, options):
return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
def size_add_handler(tx, a, b):
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
list_like_addition_handlers = [
# NB: Prefer the tuple-specific logic over base logic because of
@ -376,18 +387,27 @@ class BuiltinVariable(VariableTracker):
),
(
(ConstantVariable, TupleVariable),
lambda tx, a, b, options: TupleVariable(
list(a.unpack_var_sequence(tx)) + b.items, **options
lambda tx, a, b: TupleVariable(
[*a.unpack_var_sequence(tx), *b.items],
),
),
(
(
ListVariable,
(BaseListVariable, ConstantVariable, ListIteratorVariable),
),
lambda tx, a, b: ListVariable(
[*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
),
),
(
(BaseListVariable, BaseListVariable),
lambda tx, a, b, options: type(a)(a.items + b.items, **options),
lambda tx, a, b: type(a)([*a.items, *b.items]),
),
]
op_handlers[operator.add].extend(list_like_addition_handlers)
def list_iadd_handler(tx, a, b, _):
def list_iadd_handler(tx, a, b):
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
# Handler doesn't apply
return None
@ -414,30 +434,169 @@ class BuiltinVariable(VariableTracker):
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
# List-like expansion (e.g. [1, 2, 3] * 3)
def expand_list_like(tx, lst, const, options):
def expand_list_like(tx, lst, const):
if isinstance(lst, ConstantVariable):
lst, const = const, lst
return lst.__class__(
items=lst.items * const.as_python_constant(),
mutable_local=MutableLocal(),
**options,
)
list_like_expansion_handlers = [
((ListVariable, ConstantVariable), expand_list_like),
((TupleVariable, ConstantVariable), expand_list_like),
(
(ConstantVariable, ListVariable),
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
),
(
(ConstantVariable, TupleVariable),
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
),
((ConstantVariable, ListVariable), expand_list_like),
((ConstantVariable, TupleVariable), expand_list_like),
]
op_handlers[operator.mul].extend(list_like_expansion_handlers)
for key in op_handlers.keys():
# insert a mutable cache into each entry
op_handlers[key] = (op_handlers[key], dict())
size_or_tuple = (SizeVariable, TupleVariable)
has_set_items = (SetVariable, DictKeys)
def create_cmp_op_handlers(op):
def compare_by_value(tx, a, b):
return ConstantVariable(op(a.value, b.value))
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
if op in supported_const_comparison_ops.values():
# Tensor is None, List is not None, etc
none_result = op(object(), None)
if op.__name__.startswith("is_"):
def never(tx, a, b):
return ConstantVariable(none_result)
obj_op_none = never
none_op_obj = never
else:
def obj_op_none(tx, a, b: ConstantVariable):
if b.value is None or b.value is True or b.value is False:
return ConstantVariable(none_result)
def none_op_obj(tx, a: ConstantVariable, b):
if a.value is None or a.value is True or a.value is False:
return ConstantVariable(none_result)
types_that_are_never_none = (
TensorVariable,
SymNodeVariable,
NNModuleVariable,
BaseListVariable,
UserDefinedVariable,
BaseUserFunctionVariable,
ConstDictVariable,
BaseTorchVariable,
)
result.extend(
[
(
(types_that_are_never_none, ConstantVariable),
obj_op_none,
),
(
(ConstantVariable, types_that_are_never_none),
none_op_obj,
),
]
)
def list_compare_nocheck(tx, left, right):
return BaseListVariable.list_compare(tx, op, left, right)
def list_compare_check(tx, left, right):
if type(left) is not type(
right
): # Mismatch in BaseListVariable subclasses
unimplemented(f"{op.__name__}({left}, {right})")
return BaseListVariable.list_compare(tx, op, left, right)
def compare_set_items(tx, left, right):
return ConstantVariable(op(left.set_items, right.set_items))
op_var = BuiltinVariable(op)
result.extend(
[
(
(
(UserFunctionVariable, BuiltinVariable),
(UserFunctionVariable, BuiltinVariable),
),
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
),
(
(
NNModuleVariable,
NNModuleVariable,
),
lambda tx, a, b: ConstantVariable(
op(
tx.output.get_submodule(a.module_key),
tx.output.get_submodule(b.module_key),
)
),
),
((size_or_tuple, size_or_tuple), list_compare_nocheck),
(
(variables.BaseListVariable, variables.BaseListVariable),
list_compare_check,
),
((has_set_items, has_set_items), compare_set_items),
# TODO(jansel): UserDefinedObjectVariable is wrong and could invoke user code
(
(UserDefinedObjectVariable, UserDefinedObjectVariable),
compare_by_value,
),
(
(UserDefinedClassVariable, UserDefinedClassVariable),
compare_by_value,
),
(
(
(StreamVariable, EventVariable, ConstantVariable),
(StreamVariable, EventVariable, ConstantVariable),
),
compare_by_value,
),
(
(TensorVariable, VariableTracker),
op_var._comparison_with_tensor,
),
(
(VariableTracker, TensorVariable),
op_var._comparison_with_tensor,
),
(
(SymNodeVariable, VariableTracker),
op_var._comparison_with_symnode,
),
(
(VariableTracker, SymNodeVariable),
op_var._comparison_with_symnode,
),
]
)
if op.__name__.startswith("is_"):
def handle_is(tx, left, right):
# If the two objects are of different type, we can safely return False
# and True for `is` and `is not`, respectively
if type(left) is not type(right):
return ConstantVariable.create(op.__name__ != "is_")
result.append(((VariableTracker, VariableTracker), handle_is))
return result
for op in supported_comparison_ops.values():
assert callable(op)
assert op not in op_handlers
op_handlers[op] = create_cmp_op_handlers(op)
for op in op_handlers.keys():
op_handlers[op] = (op_handlers[op], dict())
return op_handlers
@staticmethod
@ -452,13 +611,26 @@ class BuiltinVariable(VariableTracker):
if hit is not False:
return hit
# Return first handler that matches the type checks
a_type = type(a)
b_type = type(b)
matches = []
for (type1, type2), handler in handlers:
if isinstance(a, type1) and isinstance(b, type2):
cache[cache_key] = handler
return handler
cache[cache_key] = None
return None
if issubclass(a_type, type1) and issubclass(b_type, type2):
matches.append(handler)
if not matches:
result = None
elif len(matches) == 1:
result = matches[0]
else:
def result(*args):
for fn in matches:
rv = fn(*args)
if rv:
return rv
return result
def can_insert_in_graph(self):
return self.fn in self._fx_graph_functions()
@ -666,7 +838,7 @@ class BuiltinVariable(VariableTracker):
# Try to find a handler for the arg types; otherwise, fall through to constant handler
binop_handler = BuiltinVariable._find_binop_handler(fn, args[0], args[1])
if binop_handler:
res = binop_handler(tx, args[0], args[1], {})
res = binop_handler(tx, args[0], args[1])
if res is not None:
return res
@ -1511,150 +1683,63 @@ class BuiltinVariable(VariableTracker):
def call_deepcopy(self, tx, x):
unimplemented(f"copy.deepcopy {repr(x)}")
def _comparison(self, tx, left, right):
"""
Used to implement comparison operators for different types.
For example, list1 < list2 is implemented differently from tensor1 < tensor2
"""
from . import (
BaseListVariable,
ConstantVariable,
NNModuleVariable,
TensorVariable,
UserDefinedObjectVariable,
UserFunctionVariable,
)
from .lists import SizeVariable
from .tensor import (
supported_const_comparison_op_values,
supported_tensor_comparison_op_values,
)
def _comparison_with_tensor(self, tx, left, right):
from .builder import wrap_fx_proxy_cls
from .tensor import supported_tensor_comparison_op_values
op = self.fn
def _unimplemented():
unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
if (
all(
isinstance(x, (NNModuleVariable, ConstantVariable))
for x in [left, right]
)
and op in supported_const_comparison_op_values
):
left = (
tx.output.get_submodule(left.module_key)
if isinstance(left, NNModuleVariable)
else left.as_python_constant()
)
right = (
tx.output.get_submodule(right.module_key)
if isinstance(right, NNModuleVariable)
else right.as_python_constant()
)
return ConstantVariable.create(op(left, right))
if isinstance(left, UserFunctionVariable):
if op not in supported_const_comparison_op_values:
_unimplemented()
if not isinstance(right, UserFunctionVariable):
_unimplemented()
return ConstantVariable.create(op(left.fn, right.fn))
# Note, we have a rare BaseListVariable subtype mismatch with valid comparison
# x = torch.randn([3, 3])
# x.size() == (3, 3) # True
# (3, 3) == x.size() # True
if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
right, (TupleVariable, SizeVariable)
):
return BaseListVariable.list_compare(tx, op, left, right)
if isinstance(left, BaseListVariable):
if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
_unimplemented()
return BaseListVariable.list_compare(tx, op, left, right)
# If they implement set semantics (e.g. SetVariable or DictKeys)
if hasattr(left, "set_items") and hasattr(right, "set_items"):
return ConstantVariable.create(op(left.set_items, right.set_items))
if isinstance(left, TensorVariable) or isinstance(right, TensorVariable):
from .builder import wrap_fx_proxy_cls
if op in [operator.is_, operator.is_not]:
is_result = (
isinstance(left, TensorVariable)
and isinstance(right, TensorVariable)
and id(extract_fake_example_value(left.as_proxy().node))
== id(extract_fake_example_value(right.as_proxy().node))
)
if op is operator.is_:
return ConstantVariable.create(is_result)
else:
return ConstantVariable.create(not is_result)
if op not in supported_tensor_comparison_op_values:
_unimplemented()
if (
if op in [operator.is_, operator.is_not]:
is_result = (
isinstance(left, TensorVariable)
and isinstance(right, TensorVariable)
and (left.size and right.size) is not None
and left.size != right.size
):
try:
torch.broadcast_shapes(left.size, right.size)
except RuntimeError:
# not broadcastable, can't be compared
_unimplemented()
tensor_cls = left if isinstance(left, TensorVariable) else right
proxy = tx.output.create_proxy(
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
)
return wrap_fx_proxy_cls(
type(tensor_cls), # handle Ndarrays and Tensors
tx,
proxy,
and id(extract_fake_example_value(left.as_proxy().node))
== id(extract_fake_example_value(right.as_proxy().node))
)
if op is operator.is_:
return ConstantVariable.create(is_result)
else:
return ConstantVariable.create(not is_result)
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
if op not in supported_tensor_comparison_op_values:
_unimplemented()
proxy = tx.output.create_proxy(
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
)
return SymNodeVariable.create(
tx,
proxy,
sym_num=None,
)
if isinstance(left, UserDefinedObjectVariable) and isinstance(
right, UserDefinedObjectVariable
if op not in supported_tensor_comparison_op_values:
unimplemented(f"{op.__name__}({left}, {right})")
if (
isinstance(left, TensorVariable)
and isinstance(right, TensorVariable)
and (left.size and right.size) is not None
and left.size != right.size
):
return ConstantVariable.create(op(left.value, right.value))
try:
torch.broadcast_shapes(left.size, right.size)
except RuntimeError:
# not broadcastable, can't be compared
unimplemented(f"{op.__name__}({left}, {right})")
tensor_cls = left if isinstance(left, TensorVariable) else right
proxy = tx.output.create_proxy(
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
)
return wrap_fx_proxy_cls(
type(tensor_cls), # handle Ndarrays and Tensors
tx,
proxy,
)
if isinstance(left, (StreamVariable, EventVariable)) or isinstance(
right, (StreamVariable, EventVariable)
):
if type(left) == type(right) and op is operator.eq:
return ConstantVariable(op(left.value, right.value))
def _comparison_with_symnode(self, tx, left, right):
from .tensor import supported_tensor_comparison_op_values
if isinstance(right, ConstantVariable) or isinstance(
left, ConstantVariable
):
return ConstantVariable(op(left.value, right.value))
op = self.fn
if op.__name__.startswith("is_"):
# If the two objects are of different type, we can safely return False and True for `is` and `is not`, respectively
if type(left) is not type(right):
return ConstantVariable.create(op.__name__ != "is_")
if op not in supported_tensor_comparison_op_values:
unimplemented(f"{op.__name__}({left}, {right})")
if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable):
return ConstantVariable.create(op(left.fn, right.fn))
_unimplemented()
proxy = tx.output.create_proxy(
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
)
return SymNodeVariable.create(
tx,
proxy,
sym_num=None,
)
def call_and_(self, tx, a, b):
# Rely on constant_handler
@ -1711,14 +1796,8 @@ class BuiltinVariable(VariableTracker):
return None
call_eq = _comparison
call_gt = _comparison
call_lt = _comparison
call_ge = _comparison
call_le = _comparison
call_ne = _comparison
call_is_ = _comparison
call_is_not = _comparison
def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
return a.call_method(tx, "__contains__", [b], {})
call_all = _polyfill_call_impl("all")
call_any = _polyfill_call_impl("any")