You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

53 lines
1.6 KiB

""" Optimizations of the expression tree representation for better CSE
opportunities.
"""
from sympy.core import Add, Basic, Mul
from sympy.core.singleton import S
from sympy.core.sorting import default_sort_key
from sympy.core.traversal import preorder_traversal
def sub_pre(e):
""" Replace y - x with -(x - y) if -1 can be extracted from y - x.
"""
# replacing Add, A, from which -1 can be extracted with -1*-A
adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
reps = {}
ignore = set()
for a in adds:
na = -a
if na.is_Mul: # e.g. MatExpr
ignore.add(a)
continue
reps[a] = Mul._from_args([S.NegativeOne, na])
e = e.xreplace(reps)
# repeat again for persisting Adds but mark these with a leading 1, -1
# e.g. y - x -> 1*-1*(x - y)
if isinstance(e, Basic):
negs = {}
for a in sorted(e.atoms(Add), key=default_sort_key):
if a in ignore:
continue
if a in reps:
negs[a] = reps[a]
elif a.could_extract_minus_sign():
negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
e = e.xreplace(negs)
return e
def sub_post(e):
""" Replace 1*-1*x with -x.
"""
replacements = []
for node in preorder_traversal(e):
if isinstance(node, Mul) and \
node.args[0] is S.One and node.args[1] is S.NegativeOne:
replacements.append((node, -Mul._from_args(node.args[2:])))
for node, replacement in replacements:
e = e.xreplace({node: replacement})
return e