mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reproduce command: ```bash ghstack checkout https://github.com/pytorch/pytorch/pull/138443 git checkout HEAD~1 torch/ lintrunner -a --take "PYFMT" --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443 Approved by: https://github.com/ezyang
420 lines
10 KiB
Python
420 lines
10 KiB
Python
# mypy: allow-untyped-defs
|
|
import collections
|
|
import operator
|
|
from collections.abc import Mapping
|
|
from functools import reduce
|
|
|
|
|
|
__all__ = [
|
|
"merge",
|
|
"merge_with",
|
|
"valmap",
|
|
"keymap",
|
|
"itemmap",
|
|
"valfilter",
|
|
"keyfilter",
|
|
"itemfilter",
|
|
"assoc",
|
|
"dissoc",
|
|
"assoc_in",
|
|
"update_in",
|
|
"get_in",
|
|
]
|
|
|
|
|
|
def _get_factory(f, kwargs):
|
|
factory = kwargs.pop("factory", dict)
|
|
if kwargs:
|
|
raise TypeError(
|
|
f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'"
|
|
)
|
|
return factory
|
|
|
|
|
|
def merge(*dicts, **kwargs):
|
|
"""Merge a collection of dictionaries
|
|
|
|
>>> merge({1: "one"}, {2: "two"})
|
|
{1: 'one', 2: 'two'}
|
|
|
|
Later dictionaries have precedence
|
|
|
|
>>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
|
|
{1: 2, 3: 3, 4: 4}
|
|
|
|
See Also:
|
|
merge_with
|
|
"""
|
|
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
|
dicts = dicts[0]
|
|
factory = _get_factory(merge, kwargs)
|
|
|
|
rv = factory()
|
|
for d in dicts:
|
|
rv.update(d)
|
|
return rv
|
|
|
|
|
|
def merge_with(func, *dicts, **kwargs):
|
|
"""Merge dictionaries and apply function to combined values
|
|
|
|
A key may occur in more than one dict, and all values mapped from the key
|
|
will be passed to the function as a list, such as func([val1, val2, ...]).
|
|
|
|
>>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
|
|
{1: 11, 2: 22}
|
|
|
|
>>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
|
|
{1: 1, 2: 2, 3: 30}
|
|
|
|
See Also:
|
|
merge
|
|
"""
|
|
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
|
dicts = dicts[0]
|
|
factory = _get_factory(merge_with, kwargs)
|
|
|
|
result = factory()
|
|
for d in dicts:
|
|
for k, v in d.items():
|
|
if k not in result:
|
|
result[k] = [v]
|
|
else:
|
|
result[k].append(v)
|
|
return valmap(func, result, factory)
|
|
|
|
|
|
def valmap(func, d, factory=dict):
|
|
"""Apply function to values of dictionary
|
|
|
|
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
|
>>> valmap(sum, bills) # doctest: +SKIP
|
|
{'Alice': 65, 'Bob': 45}
|
|
|
|
See Also:
|
|
keymap
|
|
itemmap
|
|
"""
|
|
rv = factory()
|
|
rv.update(zip(d.keys(), map(func, d.values())))
|
|
return rv
|
|
|
|
|
|
def keymap(func, d, factory=dict):
|
|
"""Apply function to keys of dictionary
|
|
|
|
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
|
>>> keymap(str.lower, bills) # doctest: +SKIP
|
|
{'alice': [20, 15, 30], 'bob': [10, 35]}
|
|
|
|
See Also:
|
|
valmap
|
|
itemmap
|
|
"""
|
|
rv = factory()
|
|
rv.update(zip(map(func, d.keys()), d.values()))
|
|
return rv
|
|
|
|
|
|
def itemmap(func, d, factory=dict):
|
|
"""Apply function to items of dictionary
|
|
|
|
>>> accountids = {"Alice": 10, "Bob": 20}
|
|
>>> itemmap(reversed, accountids) # doctest: +SKIP
|
|
{10: "Alice", 20: "Bob"}
|
|
|
|
See Also:
|
|
keymap
|
|
valmap
|
|
"""
|
|
rv = factory()
|
|
rv.update(map(func, d.items()))
|
|
return rv
|
|
|
|
|
|
def valfilter(predicate, d, factory=dict):
|
|
"""Filter items in dictionary by value
|
|
|
|
>>> iseven = lambda x: x % 2 == 0
|
|
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
|
>>> valfilter(iseven, d)
|
|
{1: 2, 3: 4}
|
|
|
|
See Also:
|
|
keyfilter
|
|
itemfilter
|
|
valmap
|
|
"""
|
|
rv = factory()
|
|
for k, v in d.items():
|
|
if predicate(v):
|
|
rv[k] = v
|
|
return rv
|
|
|
|
|
|
def keyfilter(predicate, d, factory=dict):
|
|
"""Filter items in dictionary by key
|
|
|
|
>>> iseven = lambda x: x % 2 == 0
|
|
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
|
>>> keyfilter(iseven, d)
|
|
{2: 3, 4: 5}
|
|
|
|
See Also:
|
|
valfilter
|
|
itemfilter
|
|
keymap
|
|
"""
|
|
rv = factory()
|
|
for k, v in d.items():
|
|
if predicate(k):
|
|
rv[k] = v
|
|
return rv
|
|
|
|
|
|
def itemfilter(predicate, d, factory=dict):
|
|
"""Filter items in dictionary by item
|
|
|
|
>>> def isvalid(item):
|
|
... k, v = item
|
|
... return k % 2 == 0 and v < 4
|
|
|
|
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
|
>>> itemfilter(isvalid, d)
|
|
{2: 3}
|
|
|
|
See Also:
|
|
keyfilter
|
|
valfilter
|
|
itemmap
|
|
"""
|
|
rv = factory()
|
|
for item in d.items():
|
|
if predicate(item):
|
|
k, v = item
|
|
rv[k] = v
|
|
return rv
|
|
|
|
|
|
def assoc(d, key, value, factory=dict):
|
|
"""Return a new dict with new key value pair
|
|
|
|
New dict has d[key] set to value. Does not modify the initial dictionary.
|
|
|
|
>>> assoc({"x": 1}, "x", 2)
|
|
{'x': 2}
|
|
>>> assoc({"x": 1}, "y", 3) # doctest: +SKIP
|
|
{'x': 1, 'y': 3}
|
|
"""
|
|
d2 = factory()
|
|
d2.update(d)
|
|
d2[key] = value
|
|
return d2
|
|
|
|
|
|
def dissoc(d, *keys, **kwargs):
|
|
"""Return a new dict with the given key(s) removed.
|
|
|
|
New dict has d[key] deleted for each supplied key.
|
|
Does not modify the initial dictionary.
|
|
|
|
>>> dissoc({"x": 1, "y": 2}, "y")
|
|
{'x': 1}
|
|
>>> dissoc({"x": 1, "y": 2}, "y", "x")
|
|
{}
|
|
>>> dissoc({"x": 1}, "y") # Ignores missing keys
|
|
{'x': 1}
|
|
"""
|
|
factory = _get_factory(dissoc, kwargs)
|
|
d2 = factory()
|
|
|
|
if len(keys) < len(d) * 0.6:
|
|
d2.update(d)
|
|
for key in keys:
|
|
if key in d2:
|
|
del d2[key]
|
|
else:
|
|
remaining = set(d)
|
|
remaining.difference_update(keys)
|
|
for k in remaining:
|
|
d2[k] = d[k]
|
|
return d2
|
|
|
|
|
|
def assoc_in(d, keys, value, factory=dict):
|
|
"""Return a new dict with new, potentially nested, key value pair
|
|
|
|
>>> purchase = {
|
|
... "name": "Alice",
|
|
... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
|
|
... "credit card": "5555-1234-1234-1234",
|
|
... }
|
|
>>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP
|
|
{'credit card': '5555-1234-1234-1234',
|
|
'name': 'Alice',
|
|
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
|
|
"""
|
|
return update_in(d, keys, lambda x: value, value, factory)
|
|
|
|
|
|
def update_in(d, keys, func, default=None, factory=dict):
|
|
"""Update value in a (potentially) nested dictionary
|
|
|
|
inputs:
|
|
d - dictionary on which to operate
|
|
keys - list or tuple giving the location of the value to be changed in d
|
|
func - function to operate on that value
|
|
|
|
If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
|
|
original dictionary with v replaced by func(v), but does not mutate the
|
|
original dictionary.
|
|
|
|
If k0 is not a key in d, update_in creates nested dictionaries to the depth
|
|
specified by the keys, with the innermost value set to func(default).
|
|
|
|
>>> inc = lambda x: x + 1
|
|
>>> update_in({"a": 0}, ["a"], inc)
|
|
{'a': 1}
|
|
|
|
>>> transaction = {
|
|
... "name": "Alice",
|
|
... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
|
|
... "credit card": "5555-1234-1234-1234",
|
|
... }
|
|
>>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP
|
|
{'credit card': '5555-1234-1234-1234',
|
|
'name': 'Alice',
|
|
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
|
|
|
|
>>> # updating a value when k0 is not in d
|
|
>>> update_in({}, [1, 2, 3], str, default="bar")
|
|
{1: {2: {3: 'bar'}}}
|
|
>>> update_in({1: "foo"}, [2, 3, 4], inc, 0)
|
|
{1: 'foo', 2: {3: {4: 1}}}
|
|
"""
|
|
ks = iter(keys)
|
|
k = next(ks)
|
|
|
|
rv = inner = factory()
|
|
rv.update(d)
|
|
|
|
for key in ks:
|
|
if k in d:
|
|
d = d[k]
|
|
dtemp = factory()
|
|
dtemp.update(d)
|
|
else:
|
|
d = dtemp = factory()
|
|
|
|
inner[k] = inner = dtemp
|
|
k = key
|
|
|
|
if k in d:
|
|
inner[k] = func(d[k])
|
|
else:
|
|
inner[k] = func(default)
|
|
return rv
|
|
|
|
|
|
def get_in(keys, coll, default=None, no_default=False):
|
|
"""Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
|
|
|
|
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
|
|
``no_default`` is specified, then it raises KeyError or IndexError.
|
|
|
|
``get_in`` is a generalization of ``operator.getitem`` for nested data
|
|
structures such as dictionaries and lists.
|
|
|
|
>>> transaction = {
|
|
... "name": "Alice",
|
|
... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
|
|
... "credit card": "5555-1234-1234-1234",
|
|
... }
|
|
>>> get_in(["purchase", "items", 0], transaction)
|
|
'Apple'
|
|
>>> get_in(["name"], transaction)
|
|
'Alice'
|
|
>>> get_in(["purchase", "total"], transaction)
|
|
>>> get_in(["purchase", "items", "apple"], transaction)
|
|
>>> get_in(["purchase", "items", 10], transaction)
|
|
>>> get_in(["purchase", "total"], transaction, 0)
|
|
0
|
|
>>> get_in(["y"], {}, no_default=True)
|
|
Traceback (most recent call last):
|
|
...
|
|
KeyError: 'y'
|
|
|
|
See Also:
|
|
itertoolz.get
|
|
operator.getitem
|
|
"""
|
|
try:
|
|
return reduce(operator.getitem, keys, coll)
|
|
except (KeyError, IndexError, TypeError):
|
|
if no_default:
|
|
raise
|
|
return default
|
|
|
|
|
|
def getter(index):
|
|
if isinstance(index, list):
|
|
if len(index) == 1:
|
|
index = index[0]
|
|
return lambda x: (x[index],)
|
|
elif index:
|
|
return operator.itemgetter(*index)
|
|
else:
|
|
return lambda x: ()
|
|
else:
|
|
return operator.itemgetter(index)
|
|
|
|
|
|
def groupby(key, seq):
|
|
"""Group a collection by a key function
|
|
|
|
>>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
|
|
>>> groupby(len, names) # doctest: +SKIP
|
|
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
|
|
|
>>> iseven = lambda x: x % 2 == 0
|
|
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
|
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
|
|
|
Non-callable keys imply grouping on a member.
|
|
|
|
>>> groupby(
|
|
... "gender",
|
|
... [
|
|
... {"name": "Alice", "gender": "F"},
|
|
... {"name": "Bob", "gender": "M"},
|
|
... {"name": "Charlie", "gender": "M"},
|
|
... ],
|
|
... ) # doctest:+SKIP
|
|
{'F': [{'gender': 'F', 'name': 'Alice'}],
|
|
'M': [{'gender': 'M', 'name': 'Bob'},
|
|
{'gender': 'M', 'name': 'Charlie'}]}
|
|
|
|
Not to be confused with ``itertools.groupby``
|
|
|
|
See Also:
|
|
countby
|
|
"""
|
|
if not callable(key):
|
|
key = getter(key)
|
|
d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
|
|
for item in seq:
|
|
d[key(item)](item)
|
|
rv = {}
|
|
for k, v in d.items():
|
|
rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
|
|
return rv
|
|
|
|
|
|
def first(seq):
|
|
"""The first element in a sequence
|
|
|
|
>>> first("ABC")
|
|
'A'
|
|
"""
|
|
return next(iter(seq))
|