[dynamo] add SymNode bitwise and/or (#138777)

Fixes [T203472723](https://www.internalfb.com/intern/tasks/?t=203472723)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138777
Approved by: https://github.com/ezyang
This commit is contained in:
William Wen
2024-11-22 11:00:51 -08:00
committed by PyTorch MergeBot
parent a8c90e5140
commit ee7eaad5c3
13 changed files with 324 additions and 17 deletions

View File

@ -6387,6 +6387,34 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
res = f(t, [1, 2])
self.assertEqual(t + 1, res)
def test_symint_bitwise(self):
def fn(x):
z = x.shape[0]
z |= z >> 1
z |= z << 1
z &= z | (z > 1)
y = (z > 1) | (z <= 1)
# test composition with non-bitwise ops
z = (z | z) % 6
return y, z
opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True)
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))
def test_bitwise_op_guard(self):
# attempt evaluating a guard with BitwiseFn_bitwise_[and/or]
def fn(x):
if x.shape[0] | x.shape[1] > 4:
x = x + 1
if x.shape[0] & x.shape[1] > 2:
return x + 1
return x - 1
opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True)
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))
instantiate_parametrized_tests(ReproTests)

View File

@ -371,6 +371,39 @@ class TestPySymInt(TestCase):
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_symint_bitwise_and(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 0b1100)
b0 = create_symint(shape_env, 0b1010)
res_and = a0 & b0
self.assertEqual(res_and, 0b1000)
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)"""
)
a1 = create_symint(shape_env, 3)
b1 = create_symbool(shape_env, True)
self.assertEqual(a1 & b1, 1)
a2 = create_symint(shape_env, 0b1100)
self.assertEqual(a2 & 0b1010, 0b1000)
a3 = create_symbool(shape_env, True)
b3 = create_symbool(shape_env, True)
self.assertEqual(a3 & b3, True)
def test_symint_bitwise_or(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 0b1100)
b0 = create_symint(shape_env, 0b1010)
res_or = a0 | b0
self.assertEqual(res_or, 0b1110)
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)"""
)
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
@ -1405,6 +1438,9 @@ class TestSymNumberMagicMethods(TestCase):
if second_type == "float" and fn in ["mod"]:
self.skipTest(f"{fn} only handles int")
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
self.skipTest(f"{fn} is a bitwise op, only handles int")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":

View File

@ -1747,7 +1747,7 @@ if TEST_Z3:
import torch._dynamo.config
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
from torch.utils._sympy.functions import FloorDiv, Mod
from torch.utils._sympy.functions import FloorDiv, Mod, BitwiseFn_bitwise_and
class TestTranslationValidation(TestCase):
def _prepare_for_translation_validation(self):
@ -1801,6 +1801,8 @@ if TEST_Z3:
(sympy.Ge, operator.ge),
)
],
# Bitwise operations.
(BitwiseFn_bitwise_and(s0, s1), z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64))),
# Other operations.
(
s0 - s1,
@ -1847,6 +1849,18 @@ if TEST_Z3:
validator.validate()
def test_sat_bitwise(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64)) == 5)
validator.add_source_expr(z0 == 0b110101)
validator.validate()
def test_unsat(self):
(
(s0, s1, s2),

View File

@ -3,9 +3,9 @@
import functools
import itertools
import math
import pickle
import sys
from typing import Callable, List, Tuple, Type
import pickle
import sympy
@ -20,7 +20,11 @@ from torch.testing._internal.common_utils import (
TEST_Z3,
TestCase,
)
from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
from torch.utils._sympy.functions import (
FloorDiv,
OpaqueUnaryFn_cos,
simple_floordiv_gcd,
)
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from torch.utils._sympy.reference import (
@ -31,7 +35,6 @@ from torch.utils._sympy.reference import (
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.functions import OpaqueUnaryFn_cos
UNARY_OPS = [
@ -58,6 +61,12 @@ BINARY_OPS = [
"minimum",
"maximum",
"mod",
"bitwise_and",
"bitwise_or",
]
BITWISE_OPS = [
"bitwise_and",
"bitwise_or",
]
UNARY_BOOL_OPS = ["not_"]
@ -231,6 +240,10 @@ class TestValueRanges(TestCase):
@parametrize("dtype", ("int", "float"))
def test_binary_ref(self, fn, dtype):
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
# Don't test bitwise methods since value range analysis on a singleton
# range may not return a singleton result.
if fn in BITWISE_OPS:
return
# Don't test float on int only methods
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
return
@ -280,7 +293,7 @@ class TestValueRanges(TestCase):
else:
self.assertEqual(len(unique), 2)
@parametrize("fn", BINARY_BOOL_OPS)
@parametrize("fn", BINARY_BOOL_OPS + BITWISE_OPS)
def test_binary_bool_ref_range(self, fn):
vals = [sympy.false, sympy.true]
for a, b in itertools.product(generate_range(vals), repeat=2):
@ -338,6 +351,38 @@ class TestValueRanges(TestCase):
if r.is_finite:
self.assertIn(r, ref_r)
# stronger test specially for bitwise ops
@parametrize("fn", BITWISE_OPS)
def test_bitwise_ref_range(self, fn):
# N^4 complexity
vals = range(-4, 5)
for a, b in itertools.product(generate_range(vals), repeat=2):
with self.subTest(a=a, b=b):
for a0, b0 in itertools.product(vals, repeat=2):
if a0 not in a or b0 not in b:
continue
with self.subTest(a0=a0, b0=b0):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
r = getattr(ReferenceAnalysis, fn)(a0, b0)
self.assertIn(r, ref_r)
# test that bitwise ops can take bool arguments
bool_vals = [
(3, sympy.true),
(3, sympy.false),
(sympy.true, 3),
(sympy.false, 3),
(sympy.true, sympy.true),
(sympy.true, sympy.false),
(sympy.false, sympy.true),
(sympy.false, sympy.false),
]
for a, b in bool_vals:
with self.subTest(a=a, b=b):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
r = getattr(ReferenceAnalysis, fn)(a, b)
self.assertIn(r, ref_r)
class TestSympyInterp(TestCase):
@parametrize(
@ -358,6 +403,8 @@ class TestSympyInterp(TestCase):
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
elif fn in BITWISE_OPS:
vals = vals + [True, False]
arity = 1
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
arity = 2
@ -395,6 +442,8 @@ class TestSympyInterp(TestCase):
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
elif fn in BITWISE_OPS:
vals = vals + [True, False]
arity = 1
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
@ -472,6 +521,8 @@ class TestSympyInterp(TestCase):
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
elif fn in BITWISE_OPS:
vals = vals + [True, False]
arity = 1
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
@ -815,7 +866,7 @@ class TestSympySolve(TestCase):
class TestSympyFunctions(TestCase):
def test_pickle(self):
x = OpaqueUnaryFn_cos(sympy.Symbol('a'))
x = OpaqueUnaryFn_cos(sympy.Symbol("a"))
r = pickle.loads(pickle.dumps(x))
self.assertEqual(x, r)

View File

@ -536,6 +536,12 @@ class SymInt:
def __rsub__(self, other: "IntLikeType") -> "SymInt":
raise TypeError("type stub not overridden")
def __and__(self, other) -> "SymInt":
raise TypeError("type stub not overridden")
def __or__(self, other) -> "SymInt":
raise TypeError("type stub not overridden")
def __repr__(self):
return self.node._graph_repr()
@ -922,6 +928,7 @@ for __name in (
__fn.__qualname__ = __fn.__name__ = __sym_name
globals()[__sym_name] = __fn
del __fn, __name, __sym_name, _get_sym_math_fn
# Adding temporary shortcut

View File

@ -2031,6 +2031,8 @@ class BuiltinVariable(VariableTracker):
return SetVariable(list(a.set_items & b.set_items))
# None no-ops this handler and lets the driving function proceed
call_iand = call_and_
def call_or_(self, tx: "InstructionTranslator", a, b):
# Rely on constant_handler
if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
@ -2050,6 +2052,8 @@ class BuiltinVariable(VariableTracker):
# None no-ops this handler and lets the driving function proceed
return None
call_ior = call_or_
def call_not_(self, tx: "InstructionTranslator", a):
if isinstance(a, SymNodeVariable):
return SymNodeVariable.create(

View File

@ -420,6 +420,13 @@ class SymNode:
def sym_and(self, other):
return self.and_(other)
# Integer bitwise ops
def bitwise_and(self, other):
return self._bitwise_and(other) # type: ignore[attr-defined]
def bitwise_or(self, other):
return self._bitwise_or(other) # type: ignore[attr-defined]
# There is no int_truediv available from C++
def truediv(self, other):
return self.float_truediv(other)
@ -582,6 +589,7 @@ METHOD_TO_OPERATOR = {
"abs": operator.abs,
"add": operator.add,
"and": operator.and_,
"bitwise_and": operator.and_,
"ceil": math.ceil,
"eq": operator.eq,
"floor": math.floor,
@ -598,6 +606,7 @@ METHOD_TO_OPERATOR = {
"ne": operator.ne,
"neg": operator.neg,
"or": operator.or_,
"bitwise_or": operator.or_,
"float_pow": operator.pow,
"pow_by_natural": operator.pow,
"round": builtins.round,
@ -676,6 +685,11 @@ only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
# remap necessary because an op name can have a bitwise and boolean implementation
bitwise_ops = {
"bitwise_and": "and",
"bitwise_or": "or",
}
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
@ -849,6 +863,18 @@ def _optimized_add(
return (_is_symbols_binary_summation(result), result)
def _bitwise_and(a, b):
from torch.utils._sympy.functions import BitwiseFn_bitwise_and
return BitwiseFn_bitwise_and(a, b)
def _bitwise_or(a, b):
from torch.utils._sympy.functions import BitwiseFn_bitwise_or
return BitwiseFn_bitwise_or(a, b)
reflectable_magic_methods = {
"add": _optimized_add,
"sub": operator.sub,
@ -857,7 +883,9 @@ reflectable_magic_methods = {
"pow_by_natural": _sympy_pow_by_natural,
"float_pow": _sympy_float_pow,
"and": _sympy_and,
"bitwise_and": _bitwise_and,
"or": _sympy_or,
"bitwise_or": _bitwise_or,
"float_truediv": _sympy_float_truediv,
"int_truediv": _sympy_int_truediv,
"int_floordiv": _sympy_floordiv,
@ -1682,9 +1710,12 @@ def _make_user_magic(method, user_type):
setattr(user_type, f"__{method}__", round_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
method_name = method
if method in bitwise_ops:
method_name = bitwise_ops[method]
setattr(user_type, f"__{method_name}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
for method, func in magic_methods.items(): # type: ignore[assignment]
@ -1697,7 +1728,8 @@ for method, func in magic_methods.items(): # type: ignore[assignment]
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
_make_user_magic(method, SymBool)
_make_user_magic(method, SymInt)
_make_user_magic(method, SymFloat)
if method not in bitwise_ops:
_make_user_magic(method, SymFloat)
del method
del func

View File

@ -2025,6 +2025,8 @@ SYMPY_INTERP = {
"OpaqueUnaryFn_tanh": math.tanh,
"OpaqueUnaryFn_atan": math.atan,
"OpaqueUnaryFn_sqrt": math.sqrt,
"BitwiseFn_bitwise_and": operator.and_,
"BitwiseFn_bitwise_or": operator.or_,
}

View File

@ -139,6 +139,22 @@ try:
string = op + " " + " ".join(args)
return f"({string.rstrip()})"
# We need to convert to/from BitVec in order to use z3 bitwise ops.
# We assume that integers are 64 bit.
# If all args are boolean, then use the boolean bitwise op implementation instead, if provided.
def _bitwise_op(bitwise_func, bool_func):
@functools.wraps(bitwise_func)
def wrapper(self, *args):
if bool_func is not None and all(
isinstance(arg, z3.BoolRef) for arg in args
):
return bool_func(*args)
wrapped_args = tuple(z3.Int2BV(a, 64) for a in args)
return z3.BV2Int(bitwise_func(*wrapped_args))
return wrapper
# Implementation of Python semantics as Z3 expressions.
#
# Z3 Real-Int theory has operators with semantics that differ that of
@ -237,6 +253,11 @@ try:
self.floor(number + 0.5),
)
bitwise_and = _bitwise_op(operator.and_, z3.And)
bitwise_or = _bitwise_op(operator.or_, z3.Or)
lshift = _bitwise_op(operator.lshift, None)
rshift = _bitwise_op(operator.rshift, None)
# Lifts a callable to be used in Z3.
#
# This function replaces the given 'op' by a function that:
@ -250,7 +271,7 @@ try:
# This is needed because the argument of some FX nodes were
# literal integers, instead of booleans. So, whenever this flag
# is set, we also convert ints to booleans.
boolean_ops = {operator.not_, operator.and_, operator.or_}
boolean_ops = {operator.not_}
as_bool = op in boolean_ops
# Lifts the function into 'z3.ExprRef' domain.
@ -284,8 +305,10 @@ try:
replacement_map = {
# Operator module.
operator.not_: lift(z3.Not),
operator.and_: lift(z3.And),
operator.or_: lift(z3.Or),
operator.and_: lift(ops.bitwise_and),
operator.or_: lift(ops.bitwise_or),
operator.lshift: lift(ops.lshift),
operator.rshift: lift(ops.rshift),
operator.floordiv: lift(ops.floordiv),
operator.truediv: lift(ops.div),
operator.mod: lift(ops.mod),
@ -420,6 +443,10 @@ try:
"and_": z3.And,
"or_": z3.Or,
"not_": z3.Not,
"bitwise_and": self._ops.bitwise_and,
"bitwise_or": self._ops.bitwise_or,
"lshift": self._ops.lshift,
"rshift": self._ops.rshift,
"floor": self._ops.floor,
"ceil": self._ops.ceil,
"minimum": self._ops.min,

View File

@ -944,6 +944,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
r"""
Return, if possible, the maximum value of the list.
"""
zero = S.Infinity
identity = S.NegativeInfinity
@ -1323,3 +1324,29 @@ OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
def make_opaque_bitwise_fn(name, real_op_name):
class BitwiseFn(sympy.Function):
_torch_handler_name = name
@classmethod
def eval(cls, a, b):
if a.is_Boolean and b.is_Boolean:
return getattr(operator, real_op_name)(a, b)
if a.is_Boolean:
a = sympy.Integer(1 if a else 0)
if b.is_Boolean:
b = sympy.Integer(1 if b else 0)
if isinstance(a, (sympy.Integer, int)) and isinstance(
b, (sympy.Integer, int)
):
return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
return None
BitwiseFn.__name__ = "BitwiseFn_" + name
return BitwiseFn
BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")

View File

@ -18,6 +18,8 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from .functions import (
BitwiseFn_bitwise_and,
BitwiseFn_bitwise_or,
CeilToInt,
CleanDiv,
FloatPow,
@ -104,6 +106,8 @@ def handlers():
RoundDecimal: "round_decimal",
# TODO: do the rest of the opaque unary functions...
OpaqueUnaryFn_log2: "log2",
BitwiseFn_bitwise_and: "bitwise_and",
BitwiseFn_bitwise_or: "bitwise_or",
}
# TODO: This is kind of pointless, we shouldn't be generating sympy.sin
# for these functions, they should be Opaque instead

View File

@ -8,6 +8,8 @@ import sympy
import torch
from torch.utils._sympy.functions import (
_keep_float,
BitwiseFn_bitwise_and,
BitwiseFn_bitwise_or,
FloatPow,
FloatTrueDiv,
FloorDiv,
@ -195,6 +197,14 @@ class ReferenceAnalysis:
def round_decimal(a, b):
return RoundDecimal(a, b)
@staticmethod
def bitwise_and(a, b):
return BitwiseFn_bitwise_and(a, b)
@staticmethod
def bitwise_or(a, b):
return BitwiseFn_bitwise_or(a, b)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
# Python types and is FX traceable. Inheritance here is purely for code
@ -307,6 +317,14 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
def round_decimal(a, b):
return round(a, ndigits=b)
@staticmethod
def bitwise_and(a, b):
return a & b
@staticmethod
def bitwise_or(a, b):
return a | b
# Like PythonReferenceAnalysis, but some export-unfriendly choices of
# operators to make things faster
@ -358,6 +376,14 @@ class TensorReferenceAnalysis:
def and_(a, b):
return torch.ops.aten.logical_and.default(a, b)
@staticmethod
def bitwise_and(a, b):
return torch.ops.aten.bitwise_and(a, b)
@staticmethod
def bitwise_or(a, b):
return torch.ops.aten.bitwise_or(a, b)
@staticmethod
def eq(a, b):
return torch.ops.aten.eq.Tensor(a, b)

View File

@ -500,6 +500,53 @@ class SymPyValueRangeAnalysis:
def and_(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
@staticmethod
def _bool_to_int(x):
if x.is_singleton():
return ValueRanges.wrap(sympy.Integer(1 if x.lower else 0))
else:
return ValueRanges(sympy.Integer(0), sympy.Integer(1))
@classmethod
def bitwise_and(cls, a, b):
a, b = ValueRanges.wrap(a), ValueRanges.wrap(b)
if a.is_bool and b.is_bool:
return cls.and_(a, b)
if a.is_bool:
a = cls._bool_to_int(a)
if b.is_bool:
b = cls._bool_to_int(b)
lower = min(a.lower, b.lower)
if lower < 0 and lower != -int_oo:
# If both lower bounds are negative, then bits start like
# 1...10..., so the smallest possible value is 1...101...1.
# Thus, we need to find the next smallest power of 2 (inclusive).
lower = -(1 << int(-lower - 1).bit_length())
else:
lower = 0
return ValueRanges(lower, max(a.upper, b.upper))
@classmethod
def bitwise_or(cls, a, b):
a, b = ValueRanges.wrap(a), ValueRanges.wrap(b)
if a.is_bool and b.is_bool:
return cls.or_(a, b)
if a.is_bool:
a = cls._bool_to_int(a)
if b.is_bool:
b = cls._bool_to_int(b)
upper = max(a.upper, b.upper)
if upper == 0:
upper = 0
elif upper > 0 and upper != int_oo:
# If both upper bounds are positive, then the largest
# possible value is 01...1, so we need to find
# next largest power of 2 (exclusive), minus 1
upper = (1 << int(upper).bit_length()) - 1
elif upper < 0:
upper = -1
return ValueRanges(min(a.lower, b.lower), upper)
@staticmethod
def eq(a, b):
a = ValueRanges.wrap(a)
@ -1061,12 +1108,14 @@ def bound_sympy(
"bound_sympy(%s)%s",
expr,
LazyString(
lambda: "\n"
+ "\n".join(
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
lambda: (
"\n"
+ "\n".join(
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
)
if ranges
else ""
)
if ranges
else ""
),
)
if isinstance(expr, sympy.Number):