Files
pytorch/torch/fx/experimental/unification/more.py
Yuanyuan Chen 70925bdf82 [1/N] Use "is" in python type comparison (#165037)
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
2025-10-10 12:36:50 +00:00

131 lines
3.1 KiB
Python

# mypy: allow-untyped-defs
from .core import ( # type: ignore[attr-defined]
_reify as core_reify,
_unify as core_unify,
reify,
unify,
)
from .dispatch import dispatch
__all__ = ["unifiable", "reify_object", "unify_object"]
def unifiable(cls):
"""Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
>>> # xdoctest: +SKIP
>>> class A(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> unifiable(A)
<class 'unification.more.A'>
>>> x = var("x")
>>> a = A(1, 2)
>>> b = A(1, x)
>>> unify(a, b, {})
{~x: 2}
"""
core_unify.add((cls, cls, dict), unify_object) # type: ignore[attr-defined]
core_reify.add((cls, dict), reify_object) # type: ignore[attr-defined]
return cls
#########
# Reify #
#########
def reify_object(o, s):
"""Reify a Python object with a substitution
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
...
... def __str__(self):
... return "Foo(%s, %s)" % (str(self.a), str(self.b))
>>> x = var("x")
>>> f = Foo(1, x)
>>> print(f)
Foo(1, ~x)
>>> print(reify_object(f, {x: 2}))
Foo(1, 2)
"""
if hasattr(o, "__slots__"):
return _reify_object_slots(o, s)
else:
return _reify_object_dict(o, s)
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
return obj
def _reify_object_slots(o, s):
attrs = [getattr(o, attr) for attr in o.__slots__]
new_attrs = reify(attrs, s)
if attrs == new_attrs:
return o
else:
newobj = object.__new__(type(o))
for slot, attr in zip(o.__slots__, new_attrs):
setattr(newobj, slot, attr)
return newobj
@dispatch(slice, dict)
def _reify(o, s):
"""Reify a Python ``slice`` object"""
return slice(*reify((o.start, o.stop, o.step), s))
#########
# Unify #
#########
def unify_object(u, v, s):
"""Unify two Python objects
Unifies their type and ``__dict__`` attributes
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
...
... def __str__(self):
... return "Foo(%s, %s)" % (str(self.a), str(self.b))
>>> x = var("x")
>>> f = Foo(1, x)
>>> g = Foo(1, 2)
>>> unify_object(f, g, {})
{~x: 2}
"""
if type(u) is not type(v):
return False
if hasattr(u, "__slots__"):
return unify(
[getattr(u, slot) for slot in u.__slots__],
[getattr(v, slot) for slot in v.__slots__],
s,
)
else:
return unify(u.__dict__, v.__dict__, s)
@dispatch(slice, slice, dict)
def _unify(u, v, s):
"""Unify a Python ``slice`` object"""
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)