You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
4.3 KiB

5 months ago
""" Generic Rules for SymPy
This file assumes knowledge of Basic and little else.
"""
from sympy.utilities.iterables import sift
from .util import new
# Functions that create rules
def rm_id(isid, new=new):
""" Create a rule to remove identities.
isid - fn :: x -> Bool --- whether or not this element is an identity.
Examples
========
>>> from sympy.strategies import rm_id
>>> from sympy import Basic, S
>>> remove_zeros = rm_id(lambda x: x==0)
>>> remove_zeros(Basic(S(1), S(0), S(2)))
Basic(1, 2)
>>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one
Basic(0)
See Also:
unpack
"""
def ident_remove(expr):
""" Remove identities """
ids = list(map(isid, expr.args))
if sum(ids) == 0: # No identities. Common case
return expr
elif sum(ids) != len(ids): # there is at least one non-identity
return new(expr.__class__,
*[arg for arg, x in zip(expr.args, ids) if not x])
else:
return new(expr.__class__, expr.args[0])
return ident_remove
def glom(key, count, combine):
""" Create a rule to conglomerate identical args.
Examples
========
>>> from sympy.strategies import glom
>>> from sympy import Add
>>> from sympy.abc import x
>>> key = lambda x: x.as_coeff_Mul()[1]
>>> count = lambda x: x.as_coeff_Mul()[0]
>>> combine = lambda cnt, arg: cnt * arg
>>> rl = glom(key, count, combine)
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
3*x + 5
Wait, how are key, count and combine supposed to work?
>>> key(2*x)
x
>>> count(2*x)
2
>>> combine(2, x)
2*x
"""
def conglomerate(expr):
""" Conglomerate together identical args x + x -> 2x """
groups = sift(expr.args, key)
counts = {k: sum(map(count, args)) for k, args in groups.items()}
newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
if set(newargs) != set(expr.args):
return new(type(expr), *newargs)
else:
return expr
return conglomerate
def sort(key, new=new):
""" Create a rule to sort by a key function.
Examples
========
>>> from sympy.strategies import sort
>>> from sympy import Basic, S
>>> sort_rl = sort(str)
>>> sort_rl(Basic(S(3), S(1), S(2)))
Basic(1, 2, 3)
"""
def sort_rl(expr):
return new(expr.__class__, *sorted(expr.args, key=key))
return sort_rl
def distribute(A, B):
""" Turns an A containing Bs into a B of As
where A, B are container types
>>> from sympy.strategies import distribute
>>> from sympy import Add, Mul, symbols
>>> x, y = symbols('x,y')
>>> dist = distribute(Mul, Add)
>>> expr = Mul(2, x+y, evaluate=False)
>>> expr
2*(x + y)
>>> dist(expr)
2*x + 2*y
"""
def distribute_rl(expr):
for i, arg in enumerate(expr.args):
if isinstance(arg, B):
first, b, tail = expr.args[:i], expr.args[i], expr.args[i + 1:]
return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
return expr
return distribute_rl
def subs(a, b):
""" Replace expressions exactly """
def subs_rl(expr):
if expr == a:
return b
else:
return expr
return subs_rl
# Functions that are rules
def unpack(expr):
""" Rule to unpack singleton args
>>> from sympy.strategies import unpack
>>> from sympy import Basic, S
>>> unpack(Basic(S(2)))
2
"""
if len(expr.args) == 1:
return expr.args[0]
else:
return expr
def flatten(expr, new=new):
""" Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
cls = expr.__class__
args = []
for arg in expr.args:
if arg.__class__ == cls:
args.extend(arg.args)
else:
args.append(arg)
return new(expr.__class__, *args)
def rebuild(expr):
""" Rebuild a SymPy tree.
Explanation
===========
This function recursively calls constructors in the expression tree.
This forces canonicalization and removes ugliness introduced by the use of
Basic.__new__
"""
if expr.is_Atom:
return expr
else:
return expr.func(*list(map(rebuild, expr.args)))