[dynamo] add SymNode bitwise and/or (#138777)

Fixes [T203472723](https://www.internalfb.com/intern/tasks/?t=203472723)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138777
Approved by: https://github.com/ezyang
This commit is contained in:
William Wen
2024-11-22 11:00:51 -08:00
committed by PyTorch MergeBot
parent a8c90e5140
commit ee7eaad5c3
13 changed files with 324 additions and 17 deletions

View File

@ -944,6 +944,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
r"""
Return, if possible, the maximum value of the list.
"""
zero = S.Infinity
identity = S.NegativeInfinity
@ -1323,3 +1324,29 @@ OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
def make_opaque_bitwise_fn(name, real_op_name):
class BitwiseFn(sympy.Function):
_torch_handler_name = name
@classmethod
def eval(cls, a, b):
if a.is_Boolean and b.is_Boolean:
return getattr(operator, real_op_name)(a, b)
if a.is_Boolean:
a = sympy.Integer(1 if a else 0)
if b.is_Boolean:
b = sympy.Integer(1 if b else 0)
if isinstance(a, (sympy.Integer, int)) and isinstance(
b, (sympy.Integer, int)
):
return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
return None
BitwiseFn.__name__ = "BitwiseFn_" + name
return BitwiseFn
BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")