[dynamo] Optimize COMPARE_OP (#122039)

Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 5.6 to 5.1s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122039
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
This commit is contained in:
Jason Ansel
2024-03-18 15:56:46 -07:00
committed by PyTorch MergeBot
parent e1706bba3b
commit 769ff86b91
3 changed files with 56 additions and 57 deletions

View File

@ -100,8 +100,8 @@ from .variables.misc import (
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
supported_comparison_ops,
supported_const_comparison_ops,
supported_tensor_comparison_ops,
SymNodeVariable,
TensorVariable,
)
@ -880,8 +880,7 @@ class InstructionTranslatorBase(
return self.stack.pop()
def popn(self, n: int) -> List[VariableTracker]:
assert n >= 0
return list(reversed([self.pop() for _ in range(n)]))
return [*reversed([self.pop() for _ in range(n)])]
def LOAD_FAST(self, inst):
name = inst.argval
@ -1203,14 +1202,26 @@ class InstructionTranslatorBase(
def COMPARE_OP(self, inst):
left, right = self.popn(2)
op = inst.argval
supported_any = dict(
itertools.chain(
supported_tensor_comparison_ops.items(),
supported_const_comparison_ops.items(),
if op == "in" or op == "not in":
self.push(right.call_method(self, "__contains__", [left], {}))
if op == "not in":
self.UNARY_NOT(inst)
return
if right.is_python_constant():
if left.is_python_constant():
# constant fold
return self.push(
ConstantVariable(
supported_comparison_ops[op](
left.as_python_constant(), right.as_python_constant()
),
)
)
if (
isinstance(
elif (
op in supported_const_comparison_ops
and right.as_python_constant() is None
and isinstance(
left,
(
TensorVariable,
@ -1222,37 +1233,13 @@ class InstructionTranslatorBase(
ConstDictVariable,
),
)
and isinstance(right, ConstantVariable)
and right.value is None
and op in supported_const_comparison_ops
):
# <non-None> is None
return self.push(
ConstantVariable(supported_const_comparison_ops[op](object(), None))
)
self.push(
ConstantVariable.create(
supported_const_comparison_ops[op](object(), right.value)
)
)
elif (
left.is_python_constant()
and right.is_python_constant()
and op in supported_any
):
# constant fold
self.push(
ConstantVariable.create(
supported_any[op](
left.as_python_constant(), right.as_python_constant()
),
)
)
elif op in ("in", "not in"):
self.push(right.call_method(self, "__contains__", [left], {}))
if op == "not in":
self.UNARY_NOT(inst)
else:
self.push(
BuiltinVariable(supported_any[op]).call_function(
BuiltinVariable(supported_comparison_ops[op]).call_function(
self, [left, right], {}
)
)

View File

@ -1526,8 +1526,8 @@ class BuiltinVariable(VariableTracker):
)
from .lists import SizeVariable
from .tensor import (
supported_const_comparison_ops,
supported_tensor_comparison_ops,
supported_const_comparison_op_values,
supported_tensor_comparison_op_values,
)
op = self.fn
@ -1540,7 +1540,7 @@ class BuiltinVariable(VariableTracker):
isinstance(x, (NNModuleVariable, ConstantVariable))
for x in [left, right]
)
and op in supported_const_comparison_ops.values()
and op in supported_const_comparison_op_values
):
left = (
tx.output.get_submodule(left.module_key)
@ -1555,7 +1555,7 @@ class BuiltinVariable(VariableTracker):
return ConstantVariable.create(op(left, right))
if isinstance(left, UserFunctionVariable):
if op not in supported_const_comparison_ops.values():
if op not in supported_const_comparison_op_values:
_unimplemented()
if not isinstance(right, UserFunctionVariable):
_unimplemented()
@ -1594,7 +1594,7 @@ class BuiltinVariable(VariableTracker):
else:
return ConstantVariable.create(not is_result)
if op not in supported_tensor_comparison_ops.values():
if op not in supported_tensor_comparison_op_values:
_unimplemented()
if (
isinstance(left, TensorVariable)
@ -1618,7 +1618,7 @@ class BuiltinVariable(VariableTracker):
)
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
if op not in supported_tensor_comparison_ops.values():
if op not in supported_tensor_comparison_op_values:
_unimplemented()
proxy = tx.output.create_proxy(

View File

@ -57,6 +57,7 @@ from .base import _is_top_level_scope, VariableTracker
from .constant import ConstantVariable
from .lists import SizeVariable
# Ops that allow tensor <op> tensor
supported_tensor_comparison_ops = {
">": operator.gt,
"<": operator.lt,
@ -65,12 +66,23 @@ supported_tensor_comparison_ops = {
"==": operator.eq,
"!=": operator.ne,
}
# Ops that allow tensor <op> None
supported_const_comparison_ops = {
"is": operator.is_,
"is not": operator.is_not,
"==": operator.eq,
"!=": operator.ne,
}
supported_comparison_ops = {
**supported_tensor_comparison_ops,
**supported_const_comparison_ops,
}
supported_tensor_comparison_op_values = dict.fromkeys(
supported_tensor_comparison_ops.values()
)
supported_const_comparison_op_values = dict.fromkeys(
supported_const_comparison_ops.values()
)
class TensorVariable(VariableTracker):