You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
4.5 KiB
135 lines
4.5 KiB
from sympy.printing.dot import (purestr, styleof, attrprint, dotnode,
|
|
dotedges, dotprint)
|
|
from sympy.core.basic import Basic
|
|
from sympy.core.expr import Expr
|
|
from sympy.core.numbers import (Float, Integer)
|
|
from sympy.core.singleton import S
|
|
from sympy.core.symbol import (Symbol, symbols)
|
|
from sympy.printing.repr import srepr
|
|
from sympy.abc import x
|
|
|
|
|
|
def test_purestr():
|
|
assert purestr(Symbol('x')) == "Symbol('x')"
|
|
assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))"
|
|
assert purestr(Float(2)) == "Float('2.0', precision=53)"
|
|
|
|
assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ())
|
|
assert purestr(Basic(S(1), S(2)), with_args=True) == \
|
|
('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)'))
|
|
assert purestr(Float(2), with_args=True) == \
|
|
("Float('2.0', precision=53)", ())
|
|
|
|
|
|
def test_styleof():
|
|
styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
|
|
(Expr, {'color': 'black'})]
|
|
assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'}
|
|
|
|
assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'}
|
|
|
|
|
|
def test_attrprint():
|
|
assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \
|
|
'"color"="blue", "shape"="ellipse"'
|
|
|
|
def test_dotnode():
|
|
|
|
assert dotnode(x, repeat=False) == \
|
|
'"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];'
|
|
assert dotnode(x+2, repeat=False) == \
|
|
'"Add(Integer(2), Symbol(\'x\'))" ' \
|
|
'["color"="black", "label"="Add", "shape"="ellipse"];', \
|
|
dotnode(x+2,repeat=0)
|
|
|
|
assert dotnode(x + x**2, repeat=False) == \
|
|
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \
|
|
'["color"="black", "label"="Add", "shape"="ellipse"];'
|
|
assert dotnode(x + x**2, repeat=True) == \
|
|
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \
|
|
'["color"="black", "label"="Add", "shape"="ellipse"];'
|
|
|
|
def test_dotedges():
|
|
assert sorted(dotedges(x+2, repeat=False)) == [
|
|
'"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";',
|
|
'"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";'
|
|
]
|
|
assert sorted(dotedges(x + 2, repeat=True)) == [
|
|
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";',
|
|
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";'
|
|
]
|
|
|
|
def test_dotprint():
|
|
text = dotprint(x+2, repeat=False)
|
|
assert all(e in text for e in dotedges(x+2, repeat=False))
|
|
assert all(
|
|
n in text for n in [dotnode(expr, repeat=False)
|
|
for expr in (x, Integer(2), x+2)])
|
|
assert 'digraph' in text
|
|
|
|
text = dotprint(x+x**2, repeat=False)
|
|
assert all(e in text for e in dotedges(x+x**2, repeat=False))
|
|
assert all(
|
|
n in text for n in [dotnode(expr, repeat=False)
|
|
for expr in (x, Integer(2), x**2)])
|
|
assert 'digraph' in text
|
|
|
|
text = dotprint(x+x**2, repeat=True)
|
|
assert all(e in text for e in dotedges(x+x**2, repeat=True))
|
|
assert all(
|
|
n in text for n in [dotnode(expr, pos=())
|
|
for expr in [x + x**2]])
|
|
|
|
text = dotprint(x**x, repeat=True)
|
|
assert all(e in text for e in dotedges(x**x, repeat=True))
|
|
assert all(
|
|
n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))])
|
|
assert 'digraph' in text
|
|
|
|
def test_dotprint_depth():
|
|
text = dotprint(3*x+2, depth=1)
|
|
assert dotnode(3*x+2) in text
|
|
assert dotnode(x) not in text
|
|
text = dotprint(3*x+2)
|
|
assert "depth" not in text
|
|
|
|
def test_Matrix_and_non_basics():
|
|
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
|
n = Symbol('n')
|
|
assert dotprint(MatrixSymbol('X', n, n)) == \
|
|
"""digraph{
|
|
|
|
# Graph style
|
|
"ordering"="out"
|
|
"rankdir"="TD"
|
|
|
|
#########
|
|
# Nodes #
|
|
#########
|
|
|
|
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"];
|
|
"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"];
|
|
"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"];
|
|
"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"];
|
|
|
|
#########
|
|
# Edges #
|
|
#########
|
|
|
|
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)";
|
|
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)";
|
|
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)";
|
|
}"""
|
|
|
|
|
|
def test_labelfunc():
|
|
text = dotprint(x + 2, labelfunc=srepr)
|
|
assert "Symbol('x')" in text
|
|
assert "Integer(2)" in text
|
|
|
|
|
|
def test_commutative():
|
|
x, y = symbols('x y', commutative=False)
|
|
assert dotprint(x + y) == dotprint(y + x)
|
|
assert dotprint(x*y) != dotprint(y*x)
|