You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
3.0 KiB
106 lines
3.0 KiB
5 months ago
|
from collections import OrderedDict
|
||
|
|
||
|
|
||
|
def expand_tuples(L):
|
||
|
"""
|
||
|
>>> from sympy.multipledispatch.utils import expand_tuples
|
||
|
>>> expand_tuples([1, (2, 3)])
|
||
|
[(1, 2), (1, 3)]
|
||
|
|
||
|
>>> expand_tuples([1, 2])
|
||
|
[(1, 2)]
|
||
|
"""
|
||
|
if not L:
|
||
|
return [()]
|
||
|
elif not isinstance(L[0], tuple):
|
||
|
rest = expand_tuples(L[1:])
|
||
|
return [(L[0],) + t for t in rest]
|
||
|
else:
|
||
|
rest = expand_tuples(L[1:])
|
||
|
return [(item,) + t for t in rest for item in L[0]]
|
||
|
|
||
|
|
||
|
# Taken from theano/theano/gof/sched.py
|
||
|
# Avoids licensing issues because this was written by Matthew Rocklin
|
||
|
def _toposort(edges):
|
||
|
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
||
|
|
||
|
inputs:
|
||
|
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
||
|
outputs:
|
||
|
L - an ordered list of nodes that satisfy the dependencies of edges
|
||
|
|
||
|
>>> from sympy.multipledispatch.utils import _toposort
|
||
|
>>> _toposort({1: (2, 3), 2: (3, )})
|
||
|
[1, 2, 3]
|
||
|
|
||
|
Closely follows the wikipedia page [2]
|
||
|
|
||
|
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
||
|
Communications of the ACM
|
||
|
[2] https://en.wikipedia.org/wiki/Toposort#Algorithms
|
||
|
"""
|
||
|
incoming_edges = reverse_dict(edges)
|
||
|
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
|
||
|
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
|
||
|
L = []
|
||
|
|
||
|
while S:
|
||
|
n, _ = S.popitem()
|
||
|
L.append(n)
|
||
|
for m in edges.get(n, ()):
|
||
|
assert n in incoming_edges[m]
|
||
|
incoming_edges[m].remove(n)
|
||
|
if not incoming_edges[m]:
|
||
|
S[m] = None
|
||
|
if any(incoming_edges.get(v, None) for v in edges):
|
||
|
raise ValueError("Input has cycles")
|
||
|
return L
|
||
|
|
||
|
|
||
|
def reverse_dict(d):
|
||
|
"""Reverses direction of dependence dict
|
||
|
|
||
|
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
||
|
>>> reverse_dict(d) # doctest: +SKIP
|
||
|
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
||
|
|
||
|
:note: dict order are not deterministic. As we iterate on the
|
||
|
input dict, it make the output of this function depend on the
|
||
|
dict order. So this function output order should be considered
|
||
|
as undeterministic.
|
||
|
|
||
|
"""
|
||
|
result = {}
|
||
|
for key in d:
|
||
|
for val in d[key]:
|
||
|
result[val] = result.get(val, ()) + (key, )
|
||
|
return result
|
||
|
|
||
|
|
||
|
# Taken from toolz
|
||
|
# Avoids licensing issues because this version was authored by Matthew Rocklin
|
||
|
def groupby(func, seq):
|
||
|
""" Group a collection by a key function
|
||
|
|
||
|
>>> from sympy.multipledispatch.utils import groupby
|
||
|
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
||
|
>>> groupby(len, names) # doctest: +SKIP
|
||
|
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
||
|
|
||
|
>>> iseven = lambda x: x % 2 == 0
|
||
|
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
||
|
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
||
|
|
||
|
See Also:
|
||
|
``countby``
|
||
|
"""
|
||
|
|
||
|
d = {}
|
||
|
for item in seq:
|
||
|
key = func(item)
|
||
|
if key not in d:
|
||
|
d[key] = []
|
||
|
d[key].append(item)
|
||
|
return d
|