Files
pytorch/torch/fx/experimental/unification/utils.py
Yuanyuan Chen a029675f6f More ruff SIM fixes (#164695)
This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695
Approved by: https://github.com/ezyang
2025-10-09 03:24:50 +00:00

109 lines
2.9 KiB
Python

# mypy: allow-untyped-defs
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
def hashable(x):
try:
hash(x)
return True
except TypeError:
return False
def transitive_get(key, d):
"""Transitive dict.get
>>> d = {1: 2, 2: 3, 3: 4}
>>> d.get(1)
2
>>> transitive_get(1, d)
4
"""
while hashable(key) and key in d:
key = d[key]
return key
def raises(err, lamda): # codespell:ignore lamda
try:
lamda() # codespell:ignore lamda
return False
except err:
return True
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
"""Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> # xdoctest: +SKIP
>>> _toposort({1: (2, 3), 2: (3,)})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = {v for v in edges if v not in incoming_edges}
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = {} # type: ignore[var-annotated]
for key in d:
for val in d[key]:
result[val] = result.get(val, ()) + (key,)
return result
def xfail(func):
try:
func()
raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002
except Exception:
pass
def freeze(d):
"""Freeze container to hashable form
>>> freeze(1)
1
>>> freeze([1, 2])
(1, 2)
>>> freeze({1: 2}) # doctest: +SKIP
frozenset([(1, 2)])
"""
if isinstance(d, dict):
return frozenset(map(freeze, d.items()))
if isinstance(d, set):
return frozenset(map(freeze, d))
if isinstance(d, (tuple, list)):
return tuple(map(freeze, d))
return d