mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] add SymNode bitwise and/or (#138777)
Fixes [T203472723](https://www.internalfb.com/intern/tasks/?t=203472723) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138777 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a8c90e5140
commit
ee7eaad5c3
@ -944,6 +944,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
|
||||
r"""
|
||||
Return, if possible, the maximum value of the list.
|
||||
"""
|
||||
|
||||
zero = S.Infinity
|
||||
identity = S.NegativeInfinity
|
||||
|
||||
@ -1323,3 +1324,29 @@ OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
|
||||
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
|
||||
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
|
||||
OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
|
||||
|
||||
|
||||
def make_opaque_bitwise_fn(name, real_op_name):
|
||||
class BitwiseFn(sympy.Function):
|
||||
_torch_handler_name = name
|
||||
|
||||
@classmethod
|
||||
def eval(cls, a, b):
|
||||
if a.is_Boolean and b.is_Boolean:
|
||||
return getattr(operator, real_op_name)(a, b)
|
||||
if a.is_Boolean:
|
||||
a = sympy.Integer(1 if a else 0)
|
||||
if b.is_Boolean:
|
||||
b = sympy.Integer(1 if b else 0)
|
||||
if isinstance(a, (sympy.Integer, int)) and isinstance(
|
||||
b, (sympy.Integer, int)
|
||||
):
|
||||
return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
|
||||
return None
|
||||
|
||||
BitwiseFn.__name__ = "BitwiseFn_" + name
|
||||
return BitwiseFn
|
||||
|
||||
|
||||
BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
|
||||
BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
|
||||
|
Reference in New Issue
Block a user