You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
1.6 KiB
53 lines
1.6 KiB
""" Optimizations of the expression tree representation for better CSE
|
|
opportunities.
|
|
"""
|
|
from sympy.core import Add, Basic, Mul
|
|
from sympy.core.singleton import S
|
|
from sympy.core.sorting import default_sort_key
|
|
from sympy.core.traversal import preorder_traversal
|
|
|
|
|
|
def sub_pre(e):
|
|
""" Replace y - x with -(x - y) if -1 can be extracted from y - x.
|
|
"""
|
|
# replacing Add, A, from which -1 can be extracted with -1*-A
|
|
adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
|
|
reps = {}
|
|
ignore = set()
|
|
for a in adds:
|
|
na = -a
|
|
if na.is_Mul: # e.g. MatExpr
|
|
ignore.add(a)
|
|
continue
|
|
reps[a] = Mul._from_args([S.NegativeOne, na])
|
|
|
|
e = e.xreplace(reps)
|
|
|
|
# repeat again for persisting Adds but mark these with a leading 1, -1
|
|
# e.g. y - x -> 1*-1*(x - y)
|
|
if isinstance(e, Basic):
|
|
negs = {}
|
|
for a in sorted(e.atoms(Add), key=default_sort_key):
|
|
if a in ignore:
|
|
continue
|
|
if a in reps:
|
|
negs[a] = reps[a]
|
|
elif a.could_extract_minus_sign():
|
|
negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
|
|
e = e.xreplace(negs)
|
|
return e
|
|
|
|
|
|
def sub_post(e):
|
|
""" Replace 1*-1*x with -x.
|
|
"""
|
|
replacements = []
|
|
for node in preorder_traversal(e):
|
|
if isinstance(node, Mul) and \
|
|
node.args[0] is S.One and node.args[1] is S.NegativeOne:
|
|
replacements.append((node, -Mul._from_args(node.args[2:])))
|
|
for node, replacement in replacements:
|
|
e = e.xreplace({node: replacement})
|
|
|
|
return e
|