mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] add SymNode bitwise and/or (#138777)
Fixes [T203472723](https://www.internalfb.com/intern/tasks/?t=203472723) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138777 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a8c90e5140
commit
ee7eaad5c3
@ -6387,6 +6387,34 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
res = f(t, [1, 2])
|
||||
self.assertEqual(t + 1, res)
|
||||
|
||||
def test_symint_bitwise(self):
|
||||
def fn(x):
|
||||
z = x.shape[0]
|
||||
z |= z >> 1
|
||||
z |= z << 1
|
||||
z &= z | (z > 1)
|
||||
y = (z > 1) | (z <= 1)
|
||||
# test composition with non-bitwise ops
|
||||
z = (z | z) % 6
|
||||
return y, z
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True)
|
||||
inp = torch.randn(3, 3)
|
||||
self.assertEqual(fn(inp), opt_fn(inp))
|
||||
|
||||
def test_bitwise_op_guard(self):
|
||||
# attempt evaluating a guard with BitwiseFn_bitwise_[and/or]
|
||||
def fn(x):
|
||||
if x.shape[0] | x.shape[1] > 4:
|
||||
x = x + 1
|
||||
if x.shape[0] & x.shape[1] > 2:
|
||||
return x + 1
|
||||
return x - 1
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True)
|
||||
inp = torch.randn(3, 3)
|
||||
self.assertEqual(fn(inp), opt_fn(inp))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
||||
@ -371,6 +371,39 @@ class TestPySymInt(TestCase):
|
||||
z = y.expand((y.shape[1],))
|
||||
z = y.expand(y.shape[1])
|
||||
|
||||
def test_symint_bitwise_and(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 0b1100)
|
||||
b0 = create_symint(shape_env, 0b1010)
|
||||
res_and = a0 & b0
|
||||
self.assertEqual(res_and, 0b1000)
|
||||
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)"""
|
||||
)
|
||||
|
||||
a1 = create_symint(shape_env, 3)
|
||||
b1 = create_symbool(shape_env, True)
|
||||
self.assertEqual(a1 & b1, 1)
|
||||
|
||||
a2 = create_symint(shape_env, 0b1100)
|
||||
self.assertEqual(a2 & 0b1010, 0b1000)
|
||||
|
||||
a3 = create_symbool(shape_env, True)
|
||||
b3 = create_symbool(shape_env, True)
|
||||
self.assertEqual(a3 & b3, True)
|
||||
|
||||
def test_symint_bitwise_or(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 0b1100)
|
||||
b0 = create_symint(shape_env, 0b1010)
|
||||
res_or = a0 | b0
|
||||
self.assertEqual(res_or, 0b1110)
|
||||
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)"""
|
||||
)
|
||||
|
||||
def test_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
||||
@ -1405,6 +1438,9 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
if second_type == "float" and fn in ["mod"]:
|
||||
self.skipTest(f"{fn} only handles int")
|
||||
|
||||
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
|
||||
self.skipTest(f"{fn} is a bitwise op, only handles int")
|
||||
|
||||
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
|
||||
# Second argument is ignored for unary function. So only run for one type
|
||||
if is_unary_fn and second_type == "float":
|
||||
|
||||
@ -1747,7 +1747,7 @@ if TEST_Z3:
|
||||
import torch._dynamo.config
|
||||
|
||||
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod, BitwiseFn_bitwise_and
|
||||
|
||||
class TestTranslationValidation(TestCase):
|
||||
def _prepare_for_translation_validation(self):
|
||||
@ -1801,6 +1801,8 @@ if TEST_Z3:
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Bitwise operations.
|
||||
(BitwiseFn_bitwise_and(s0, s1), z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64))),
|
||||
# Other operations.
|
||||
(
|
||||
s0 - s1,
|
||||
@ -1847,6 +1849,18 @@ if TEST_Z3:
|
||||
|
||||
validator.validate()
|
||||
|
||||
def test_sat_bitwise(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64)) == 5)
|
||||
validator.add_source_expr(z0 == 0b110101)
|
||||
|
||||
validator.validate()
|
||||
|
||||
def test_unsat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
|
||||
@ -3,9 +3,9 @@
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Callable, List, Tuple, Type
|
||||
import pickle
|
||||
|
||||
import sympy
|
||||
|
||||
@ -20,7 +20,11 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_Z3,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
|
||||
from torch.utils._sympy.functions import (
|
||||
FloorDiv,
|
||||
OpaqueUnaryFn_cos,
|
||||
simple_floordiv_gcd,
|
||||
)
|
||||
from torch.utils._sympy.interp import sympy_interp
|
||||
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
|
||||
from torch.utils._sympy.reference import (
|
||||
@ -31,7 +35,6 @@ from torch.utils._sympy.reference import (
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.functions import OpaqueUnaryFn_cos
|
||||
|
||||
|
||||
UNARY_OPS = [
|
||||
@ -58,6 +61,12 @@ BINARY_OPS = [
|
||||
"minimum",
|
||||
"maximum",
|
||||
"mod",
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
]
|
||||
BITWISE_OPS = [
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
]
|
||||
|
||||
UNARY_BOOL_OPS = ["not_"]
|
||||
@ -231,6 +240,10 @@ class TestValueRanges(TestCase):
|
||||
@parametrize("dtype", ("int", "float"))
|
||||
def test_binary_ref(self, fn, dtype):
|
||||
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
|
||||
# Don't test bitwise methods since value range analysis on a singleton
|
||||
# range may not return a singleton result.
|
||||
if fn in BITWISE_OPS:
|
||||
return
|
||||
# Don't test float on int only methods
|
||||
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
|
||||
return
|
||||
@ -280,7 +293,7 @@ class TestValueRanges(TestCase):
|
||||
else:
|
||||
self.assertEqual(len(unique), 2)
|
||||
|
||||
@parametrize("fn", BINARY_BOOL_OPS)
|
||||
@parametrize("fn", BINARY_BOOL_OPS + BITWISE_OPS)
|
||||
def test_binary_bool_ref_range(self, fn):
|
||||
vals = [sympy.false, sympy.true]
|
||||
for a, b in itertools.product(generate_range(vals), repeat=2):
|
||||
@ -338,6 +351,38 @@ class TestValueRanges(TestCase):
|
||||
if r.is_finite:
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
# stronger test specially for bitwise ops
|
||||
@parametrize("fn", BITWISE_OPS)
|
||||
def test_bitwise_ref_range(self, fn):
|
||||
# N^4 complexity
|
||||
vals = range(-4, 5)
|
||||
for a, b in itertools.product(generate_range(vals), repeat=2):
|
||||
with self.subTest(a=a, b=b):
|
||||
for a0, b0 in itertools.product(vals, repeat=2):
|
||||
if a0 not in a or b0 not in b:
|
||||
continue
|
||||
with self.subTest(a0=a0, b0=b0):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
r = getattr(ReferenceAnalysis, fn)(a0, b0)
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
# test that bitwise ops can take bool arguments
|
||||
bool_vals = [
|
||||
(3, sympy.true),
|
||||
(3, sympy.false),
|
||||
(sympy.true, 3),
|
||||
(sympy.false, 3),
|
||||
(sympy.true, sympy.true),
|
||||
(sympy.true, sympy.false),
|
||||
(sympy.false, sympy.true),
|
||||
(sympy.false, sympy.false),
|
||||
]
|
||||
for a, b in bool_vals:
|
||||
with self.subTest(a=a, b=b):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
r = getattr(ReferenceAnalysis, fn)(a, b)
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
|
||||
class TestSympyInterp(TestCase):
|
||||
@parametrize(
|
||||
@ -358,6 +403,8 @@ class TestSympyInterp(TestCase):
|
||||
vals = CONSTANTS
|
||||
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
||||
vals = [True, False]
|
||||
elif fn in BITWISE_OPS:
|
||||
vals = vals + [True, False]
|
||||
arity = 1
|
||||
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
||||
arity = 2
|
||||
@ -395,6 +442,8 @@ class TestSympyInterp(TestCase):
|
||||
vals = CONSTANTS
|
||||
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
||||
vals = [True, False]
|
||||
elif fn in BITWISE_OPS:
|
||||
vals = vals + [True, False]
|
||||
|
||||
arity = 1
|
||||
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
||||
@ -472,6 +521,8 @@ class TestSympyInterp(TestCase):
|
||||
vals = CONSTANTS
|
||||
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
||||
vals = [True, False]
|
||||
elif fn in BITWISE_OPS:
|
||||
vals = vals + [True, False]
|
||||
|
||||
arity = 1
|
||||
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
||||
@ -815,7 +866,7 @@ class TestSympySolve(TestCase):
|
||||
|
||||
class TestSympyFunctions(TestCase):
|
||||
def test_pickle(self):
|
||||
x = OpaqueUnaryFn_cos(sympy.Symbol('a'))
|
||||
x = OpaqueUnaryFn_cos(sympy.Symbol("a"))
|
||||
r = pickle.loads(pickle.dumps(x))
|
||||
self.assertEqual(x, r)
|
||||
|
||||
|
||||
@ -536,6 +536,12 @@ class SymInt:
|
||||
def __rsub__(self, other: "IntLikeType") -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __and__(self, other) -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __or__(self, other) -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __repr__(self):
|
||||
return self.node._graph_repr()
|
||||
|
||||
@ -922,6 +928,7 @@ for __name in (
|
||||
__fn.__qualname__ = __fn.__name__ = __sym_name
|
||||
globals()[__sym_name] = __fn
|
||||
|
||||
|
||||
del __fn, __name, __sym_name, _get_sym_math_fn
|
||||
|
||||
# Adding temporary shortcut
|
||||
|
||||
@ -2031,6 +2031,8 @@ class BuiltinVariable(VariableTracker):
|
||||
return SetVariable(list(a.set_items & b.set_items))
|
||||
# None no-ops this handler and lets the driving function proceed
|
||||
|
||||
call_iand = call_and_
|
||||
|
||||
def call_or_(self, tx: "InstructionTranslator", a, b):
|
||||
# Rely on constant_handler
|
||||
if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
|
||||
@ -2050,6 +2052,8 @@ class BuiltinVariable(VariableTracker):
|
||||
# None no-ops this handler and lets the driving function proceed
|
||||
return None
|
||||
|
||||
call_ior = call_or_
|
||||
|
||||
def call_not_(self, tx: "InstructionTranslator", a):
|
||||
if isinstance(a, SymNodeVariable):
|
||||
return SymNodeVariable.create(
|
||||
|
||||
@ -420,6 +420,13 @@ class SymNode:
|
||||
def sym_and(self, other):
|
||||
return self.and_(other)
|
||||
|
||||
# Integer bitwise ops
|
||||
def bitwise_and(self, other):
|
||||
return self._bitwise_and(other) # type: ignore[attr-defined]
|
||||
|
||||
def bitwise_or(self, other):
|
||||
return self._bitwise_or(other) # type: ignore[attr-defined]
|
||||
|
||||
# There is no int_truediv available from C++
|
||||
def truediv(self, other):
|
||||
return self.float_truediv(other)
|
||||
@ -582,6 +589,7 @@ METHOD_TO_OPERATOR = {
|
||||
"abs": operator.abs,
|
||||
"add": operator.add,
|
||||
"and": operator.and_,
|
||||
"bitwise_and": operator.and_,
|
||||
"ceil": math.ceil,
|
||||
"eq": operator.eq,
|
||||
"floor": math.floor,
|
||||
@ -598,6 +606,7 @@ METHOD_TO_OPERATOR = {
|
||||
"ne": operator.ne,
|
||||
"neg": operator.neg,
|
||||
"or": operator.or_,
|
||||
"bitwise_or": operator.or_,
|
||||
"float_pow": operator.pow,
|
||||
"pow_by_natural": operator.pow,
|
||||
"round": builtins.round,
|
||||
@ -676,6 +685,11 @@ only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
|
||||
|
||||
|
||||
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
|
||||
# remap necessary because an op name can have a bitwise and boolean implementation
|
||||
bitwise_ops = {
|
||||
"bitwise_and": "and",
|
||||
"bitwise_or": "or",
|
||||
}
|
||||
|
||||
|
||||
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
|
||||
@ -849,6 +863,18 @@ def _optimized_add(
|
||||
return (_is_symbols_binary_summation(result), result)
|
||||
|
||||
|
||||
def _bitwise_and(a, b):
|
||||
from torch.utils._sympy.functions import BitwiseFn_bitwise_and
|
||||
|
||||
return BitwiseFn_bitwise_and(a, b)
|
||||
|
||||
|
||||
def _bitwise_or(a, b):
|
||||
from torch.utils._sympy.functions import BitwiseFn_bitwise_or
|
||||
|
||||
return BitwiseFn_bitwise_or(a, b)
|
||||
|
||||
|
||||
reflectable_magic_methods = {
|
||||
"add": _optimized_add,
|
||||
"sub": operator.sub,
|
||||
@ -857,7 +883,9 @@ reflectable_magic_methods = {
|
||||
"pow_by_natural": _sympy_pow_by_natural,
|
||||
"float_pow": _sympy_float_pow,
|
||||
"and": _sympy_and,
|
||||
"bitwise_and": _bitwise_and,
|
||||
"or": _sympy_or,
|
||||
"bitwise_or": _bitwise_or,
|
||||
"float_truediv": _sympy_float_truediv,
|
||||
"int_truediv": _sympy_int_truediv,
|
||||
"int_floordiv": _sympy_floordiv,
|
||||
@ -1682,9 +1710,12 @@ def _make_user_magic(method, user_type):
|
||||
|
||||
setattr(user_type, f"__{method}__", round_magic_impl)
|
||||
else:
|
||||
setattr(user_type, f"__{method}__", binary_magic_impl)
|
||||
method_name = method
|
||||
if method in bitwise_ops:
|
||||
method_name = bitwise_ops[method]
|
||||
setattr(user_type, f"__{method_name}__", binary_magic_impl)
|
||||
if method in reflectable_magic_methods:
|
||||
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
|
||||
setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
|
||||
|
||||
|
||||
for method, func in magic_methods.items(): # type: ignore[assignment]
|
||||
@ -1697,7 +1728,8 @@ for method, func in magic_methods.items(): # type: ignore[assignment]
|
||||
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
|
||||
_make_user_magic(method, SymBool)
|
||||
_make_user_magic(method, SymInt)
|
||||
_make_user_magic(method, SymFloat)
|
||||
if method not in bitwise_ops:
|
||||
_make_user_magic(method, SymFloat)
|
||||
|
||||
del method
|
||||
del func
|
||||
|
||||
@ -2025,6 +2025,8 @@ SYMPY_INTERP = {
|
||||
"OpaqueUnaryFn_tanh": math.tanh,
|
||||
"OpaqueUnaryFn_atan": math.atan,
|
||||
"OpaqueUnaryFn_sqrt": math.sqrt,
|
||||
"BitwiseFn_bitwise_and": operator.and_,
|
||||
"BitwiseFn_bitwise_or": operator.or_,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -139,6 +139,22 @@ try:
|
||||
string = op + " " + " ".join(args)
|
||||
return f"({string.rstrip()})"
|
||||
|
||||
# We need to convert to/from BitVec in order to use z3 bitwise ops.
|
||||
# We assume that integers are 64 bit.
|
||||
# If all args are boolean, then use the boolean bitwise op implementation instead, if provided.
|
||||
def _bitwise_op(bitwise_func, bool_func):
|
||||
@functools.wraps(bitwise_func)
|
||||
def wrapper(self, *args):
|
||||
if bool_func is not None and all(
|
||||
isinstance(arg, z3.BoolRef) for arg in args
|
||||
):
|
||||
return bool_func(*args)
|
||||
|
||||
wrapped_args = tuple(z3.Int2BV(a, 64) for a in args)
|
||||
return z3.BV2Int(bitwise_func(*wrapped_args))
|
||||
|
||||
return wrapper
|
||||
|
||||
# Implementation of Python semantics as Z3 expressions.
|
||||
#
|
||||
# Z3 Real-Int theory has operators with semantics that differ that of
|
||||
@ -237,6 +253,11 @@ try:
|
||||
self.floor(number + 0.5),
|
||||
)
|
||||
|
||||
bitwise_and = _bitwise_op(operator.and_, z3.And)
|
||||
bitwise_or = _bitwise_op(operator.or_, z3.Or)
|
||||
lshift = _bitwise_op(operator.lshift, None)
|
||||
rshift = _bitwise_op(operator.rshift, None)
|
||||
|
||||
# Lifts a callable to be used in Z3.
|
||||
#
|
||||
# This function replaces the given 'op' by a function that:
|
||||
@ -250,7 +271,7 @@ try:
|
||||
# This is needed because the argument of some FX nodes were
|
||||
# literal integers, instead of booleans. So, whenever this flag
|
||||
# is set, we also convert ints to booleans.
|
||||
boolean_ops = {operator.not_, operator.and_, operator.or_}
|
||||
boolean_ops = {operator.not_}
|
||||
as_bool = op in boolean_ops
|
||||
|
||||
# Lifts the function into 'z3.ExprRef' domain.
|
||||
@ -284,8 +305,10 @@ try:
|
||||
replacement_map = {
|
||||
# Operator module.
|
||||
operator.not_: lift(z3.Not),
|
||||
operator.and_: lift(z3.And),
|
||||
operator.or_: lift(z3.Or),
|
||||
operator.and_: lift(ops.bitwise_and),
|
||||
operator.or_: lift(ops.bitwise_or),
|
||||
operator.lshift: lift(ops.lshift),
|
||||
operator.rshift: lift(ops.rshift),
|
||||
operator.floordiv: lift(ops.floordiv),
|
||||
operator.truediv: lift(ops.div),
|
||||
operator.mod: lift(ops.mod),
|
||||
@ -420,6 +443,10 @@ try:
|
||||
"and_": z3.And,
|
||||
"or_": z3.Or,
|
||||
"not_": z3.Not,
|
||||
"bitwise_and": self._ops.bitwise_and,
|
||||
"bitwise_or": self._ops.bitwise_or,
|
||||
"lshift": self._ops.lshift,
|
||||
"rshift": self._ops.rshift,
|
||||
"floor": self._ops.floor,
|
||||
"ceil": self._ops.ceil,
|
||||
"minimum": self._ops.min,
|
||||
|
||||
@ -944,6 +944,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
|
||||
r"""
|
||||
Return, if possible, the maximum value of the list.
|
||||
"""
|
||||
|
||||
zero = S.Infinity
|
||||
identity = S.NegativeInfinity
|
||||
|
||||
@ -1323,3 +1324,29 @@ OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
|
||||
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
|
||||
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
|
||||
OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
|
||||
|
||||
|
||||
def make_opaque_bitwise_fn(name, real_op_name):
|
||||
class BitwiseFn(sympy.Function):
|
||||
_torch_handler_name = name
|
||||
|
||||
@classmethod
|
||||
def eval(cls, a, b):
|
||||
if a.is_Boolean and b.is_Boolean:
|
||||
return getattr(operator, real_op_name)(a, b)
|
||||
if a.is_Boolean:
|
||||
a = sympy.Integer(1 if a else 0)
|
||||
if b.is_Boolean:
|
||||
b = sympy.Integer(1 if b else 0)
|
||||
if isinstance(a, (sympy.Integer, int)) and isinstance(
|
||||
b, (sympy.Integer, int)
|
||||
):
|
||||
return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
|
||||
return None
|
||||
|
||||
BitwiseFn.__name__ = "BitwiseFn_" + name
|
||||
return BitwiseFn
|
||||
|
||||
|
||||
BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
|
||||
BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
|
||||
|
||||
@ -18,6 +18,8 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
|
||||
import torch
|
||||
|
||||
from .functions import (
|
||||
BitwiseFn_bitwise_and,
|
||||
BitwiseFn_bitwise_or,
|
||||
CeilToInt,
|
||||
CleanDiv,
|
||||
FloatPow,
|
||||
@ -104,6 +106,8 @@ def handlers():
|
||||
RoundDecimal: "round_decimal",
|
||||
# TODO: do the rest of the opaque unary functions...
|
||||
OpaqueUnaryFn_log2: "log2",
|
||||
BitwiseFn_bitwise_and: "bitwise_and",
|
||||
BitwiseFn_bitwise_or: "bitwise_or",
|
||||
}
|
||||
# TODO: This is kind of pointless, we shouldn't be generating sympy.sin
|
||||
# for these functions, they should be Opaque instead
|
||||
|
||||
@ -8,6 +8,8 @@ import sympy
|
||||
import torch
|
||||
from torch.utils._sympy.functions import (
|
||||
_keep_float,
|
||||
BitwiseFn_bitwise_and,
|
||||
BitwiseFn_bitwise_or,
|
||||
FloatPow,
|
||||
FloatTrueDiv,
|
||||
FloorDiv,
|
||||
@ -195,6 +197,14 @@ class ReferenceAnalysis:
|
||||
def round_decimal(a, b):
|
||||
return RoundDecimal(a, b)
|
||||
|
||||
@staticmethod
|
||||
def bitwise_and(a, b):
|
||||
return BitwiseFn_bitwise_and(a, b)
|
||||
|
||||
@staticmethod
|
||||
def bitwise_or(a, b):
|
||||
return BitwiseFn_bitwise_or(a, b)
|
||||
|
||||
|
||||
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
||||
# Python types and is FX traceable. Inheritance here is purely for code
|
||||
@ -307,6 +317,14 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
|
||||
def round_decimal(a, b):
|
||||
return round(a, ndigits=b)
|
||||
|
||||
@staticmethod
|
||||
def bitwise_and(a, b):
|
||||
return a & b
|
||||
|
||||
@staticmethod
|
||||
def bitwise_or(a, b):
|
||||
return a | b
|
||||
|
||||
|
||||
# Like PythonReferenceAnalysis, but some export-unfriendly choices of
|
||||
# operators to make things faster
|
||||
@ -358,6 +376,14 @@ class TensorReferenceAnalysis:
|
||||
def and_(a, b):
|
||||
return torch.ops.aten.logical_and.default(a, b)
|
||||
|
||||
@staticmethod
|
||||
def bitwise_and(a, b):
|
||||
return torch.ops.aten.bitwise_and(a, b)
|
||||
|
||||
@staticmethod
|
||||
def bitwise_or(a, b):
|
||||
return torch.ops.aten.bitwise_or(a, b)
|
||||
|
||||
@staticmethod
|
||||
def eq(a, b):
|
||||
return torch.ops.aten.eq.Tensor(a, b)
|
||||
|
||||
@ -500,6 +500,53 @@ class SymPyValueRangeAnalysis:
|
||||
def and_(a, b):
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
|
||||
|
||||
@staticmethod
|
||||
def _bool_to_int(x):
|
||||
if x.is_singleton():
|
||||
return ValueRanges.wrap(sympy.Integer(1 if x.lower else 0))
|
||||
else:
|
||||
return ValueRanges(sympy.Integer(0), sympy.Integer(1))
|
||||
|
||||
@classmethod
|
||||
def bitwise_and(cls, a, b):
|
||||
a, b = ValueRanges.wrap(a), ValueRanges.wrap(b)
|
||||
if a.is_bool and b.is_bool:
|
||||
return cls.and_(a, b)
|
||||
if a.is_bool:
|
||||
a = cls._bool_to_int(a)
|
||||
if b.is_bool:
|
||||
b = cls._bool_to_int(b)
|
||||
lower = min(a.lower, b.lower)
|
||||
if lower < 0 and lower != -int_oo:
|
||||
# If both lower bounds are negative, then bits start like
|
||||
# 1...10..., so the smallest possible value is 1...101...1.
|
||||
# Thus, we need to find the next smallest power of 2 (inclusive).
|
||||
lower = -(1 << int(-lower - 1).bit_length())
|
||||
else:
|
||||
lower = 0
|
||||
return ValueRanges(lower, max(a.upper, b.upper))
|
||||
|
||||
@classmethod
|
||||
def bitwise_or(cls, a, b):
|
||||
a, b = ValueRanges.wrap(a), ValueRanges.wrap(b)
|
||||
if a.is_bool and b.is_bool:
|
||||
return cls.or_(a, b)
|
||||
if a.is_bool:
|
||||
a = cls._bool_to_int(a)
|
||||
if b.is_bool:
|
||||
b = cls._bool_to_int(b)
|
||||
upper = max(a.upper, b.upper)
|
||||
if upper == 0:
|
||||
upper = 0
|
||||
elif upper > 0 and upper != int_oo:
|
||||
# If both upper bounds are positive, then the largest
|
||||
# possible value is 01...1, so we need to find
|
||||
# next largest power of 2 (exclusive), minus 1
|
||||
upper = (1 << int(upper).bit_length()) - 1
|
||||
elif upper < 0:
|
||||
upper = -1
|
||||
return ValueRanges(min(a.lower, b.lower), upper)
|
||||
|
||||
@staticmethod
|
||||
def eq(a, b):
|
||||
a = ValueRanges.wrap(a)
|
||||
@ -1061,12 +1108,14 @@ def bound_sympy(
|
||||
"bound_sympy(%s)%s",
|
||||
expr,
|
||||
LazyString(
|
||||
lambda: "\n"
|
||||
+ "\n".join(
|
||||
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
|
||||
lambda: (
|
||||
"\n"
|
||||
+ "\n".join(
|
||||
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
|
||||
)
|
||||
if ranges
|
||||
else ""
|
||||
)
|
||||
if ranges
|
||||
else ""
|
||||
),
|
||||
)
|
||||
if isinstance(expr, sympy.Number):
|
||||
|
||||
Reference in New Issue
Block a user