mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Optimize COMPARE_OP (#122039)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py` from 5.6 to 5.1s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122039 Approved by: https://github.com/Skylion007, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
e1706bba3b
commit
769ff86b91
@ -100,8 +100,8 @@ from .variables.misc import (
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
supported_comparison_ops,
|
||||
supported_const_comparison_ops,
|
||||
supported_tensor_comparison_ops,
|
||||
SymNodeVariable,
|
||||
TensorVariable,
|
||||
)
|
||||
@ -880,8 +880,7 @@ class InstructionTranslatorBase(
|
||||
return self.stack.pop()
|
||||
|
||||
def popn(self, n: int) -> List[VariableTracker]:
|
||||
assert n >= 0
|
||||
return list(reversed([self.pop() for _ in range(n)]))
|
||||
return [*reversed([self.pop() for _ in range(n)])]
|
||||
|
||||
def LOAD_FAST(self, inst):
|
||||
name = inst.argval
|
||||
@ -1203,14 +1202,26 @@ class InstructionTranslatorBase(
|
||||
def COMPARE_OP(self, inst):
|
||||
left, right = self.popn(2)
|
||||
op = inst.argval
|
||||
supported_any = dict(
|
||||
itertools.chain(
|
||||
supported_tensor_comparison_ops.items(),
|
||||
supported_const_comparison_ops.items(),
|
||||
if op == "in" or op == "not in":
|
||||
self.push(right.call_method(self, "__contains__", [left], {}))
|
||||
if op == "not in":
|
||||
self.UNARY_NOT(inst)
|
||||
return
|
||||
|
||||
if right.is_python_constant():
|
||||
if left.is_python_constant():
|
||||
# constant fold
|
||||
return self.push(
|
||||
ConstantVariable(
|
||||
supported_comparison_ops[op](
|
||||
left.as_python_constant(), right.as_python_constant()
|
||||
),
|
||||
)
|
||||
)
|
||||
if (
|
||||
isinstance(
|
||||
elif (
|
||||
op in supported_const_comparison_ops
|
||||
and right.as_python_constant() is None
|
||||
and isinstance(
|
||||
left,
|
||||
(
|
||||
TensorVariable,
|
||||
@ -1222,37 +1233,13 @@ class InstructionTranslatorBase(
|
||||
ConstDictVariable,
|
||||
),
|
||||
)
|
||||
and isinstance(right, ConstantVariable)
|
||||
and right.value is None
|
||||
and op in supported_const_comparison_ops
|
||||
):
|
||||
# <non-None> is None
|
||||
return self.push(
|
||||
ConstantVariable(supported_const_comparison_ops[op](object(), None))
|
||||
)
|
||||
self.push(
|
||||
ConstantVariable.create(
|
||||
supported_const_comparison_ops[op](object(), right.value)
|
||||
)
|
||||
)
|
||||
|
||||
elif (
|
||||
left.is_python_constant()
|
||||
and right.is_python_constant()
|
||||
and op in supported_any
|
||||
):
|
||||
# constant fold
|
||||
self.push(
|
||||
ConstantVariable.create(
|
||||
supported_any[op](
|
||||
left.as_python_constant(), right.as_python_constant()
|
||||
),
|
||||
)
|
||||
)
|
||||
elif op in ("in", "not in"):
|
||||
self.push(right.call_method(self, "__contains__", [left], {}))
|
||||
if op == "not in":
|
||||
self.UNARY_NOT(inst)
|
||||
else:
|
||||
self.push(
|
||||
BuiltinVariable(supported_any[op]).call_function(
|
||||
BuiltinVariable(supported_comparison_ops[op]).call_function(
|
||||
self, [left, right], {}
|
||||
)
|
||||
)
|
||||
|
@ -1526,8 +1526,8 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
from .lists import SizeVariable
|
||||
from .tensor import (
|
||||
supported_const_comparison_ops,
|
||||
supported_tensor_comparison_ops,
|
||||
supported_const_comparison_op_values,
|
||||
supported_tensor_comparison_op_values,
|
||||
)
|
||||
|
||||
op = self.fn
|
||||
@ -1540,7 +1540,7 @@ class BuiltinVariable(VariableTracker):
|
||||
isinstance(x, (NNModuleVariable, ConstantVariable))
|
||||
for x in [left, right]
|
||||
)
|
||||
and op in supported_const_comparison_ops.values()
|
||||
and op in supported_const_comparison_op_values
|
||||
):
|
||||
left = (
|
||||
tx.output.get_submodule(left.module_key)
|
||||
@ -1555,7 +1555,7 @@ class BuiltinVariable(VariableTracker):
|
||||
return ConstantVariable.create(op(left, right))
|
||||
|
||||
if isinstance(left, UserFunctionVariable):
|
||||
if op not in supported_const_comparison_ops.values():
|
||||
if op not in supported_const_comparison_op_values:
|
||||
_unimplemented()
|
||||
if not isinstance(right, UserFunctionVariable):
|
||||
_unimplemented()
|
||||
@ -1594,7 +1594,7 @@ class BuiltinVariable(VariableTracker):
|
||||
else:
|
||||
return ConstantVariable.create(not is_result)
|
||||
|
||||
if op not in supported_tensor_comparison_ops.values():
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
_unimplemented()
|
||||
if (
|
||||
isinstance(left, TensorVariable)
|
||||
@ -1618,7 +1618,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
|
||||
if op not in supported_tensor_comparison_ops.values():
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
_unimplemented()
|
||||
|
||||
proxy = tx.output.create_proxy(
|
||||
|
@ -57,6 +57,7 @@ from .base import _is_top_level_scope, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
# Ops that allow tensor <op> tensor
|
||||
supported_tensor_comparison_ops = {
|
||||
">": operator.gt,
|
||||
"<": operator.lt,
|
||||
@ -65,12 +66,23 @@ supported_tensor_comparison_ops = {
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
}
|
||||
# Ops that allow tensor <op> None
|
||||
supported_const_comparison_ops = {
|
||||
"is": operator.is_,
|
||||
"is not": operator.is_not,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
}
|
||||
supported_comparison_ops = {
|
||||
**supported_tensor_comparison_ops,
|
||||
**supported_const_comparison_ops,
|
||||
}
|
||||
supported_tensor_comparison_op_values = dict.fromkeys(
|
||||
supported_tensor_comparison_ops.values()
|
||||
)
|
||||
supported_const_comparison_op_values = dict.fromkeys(
|
||||
supported_const_comparison_ops.values()
|
||||
)
|
||||
|
||||
|
||||
class TensorVariable(VariableTracker):
|
||||
|
Reference in New Issue
Block a user