from sympy.printing.dot import (purestr, styleof, attrprint, dotnode, dotedges, dotprint) from sympy.core.basic import Basic from sympy.core.expr import Expr from sympy.core.numbers import (Float, Integer) from sympy.core.singleton import S from sympy.core.symbol import (Symbol, symbols) from sympy.printing.repr import srepr from sympy.abc import x def test_purestr(): assert purestr(Symbol('x')) == "Symbol('x')" assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))" assert purestr(Float(2)) == "Float('2.0', precision=53)" assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ()) assert purestr(Basic(S(1), S(2)), with_args=True) == \ ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)')) assert purestr(Float(2), with_args=True) == \ ("Float('2.0', precision=53)", ()) def test_styleof(): styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), (Expr, {'color': 'black'})] assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'} assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'} def test_attrprint(): assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \ '"color"="blue", "shape"="ellipse"' def test_dotnode(): assert dotnode(x, repeat=False) == \ '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];' assert dotnode(x+2, repeat=False) == \ '"Add(Integer(2), Symbol(\'x\'))" ' \ '["color"="black", "label"="Add", "shape"="ellipse"];', \ dotnode(x+2,repeat=0) assert dotnode(x + x**2, repeat=False) == \ '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \ '["color"="black", "label"="Add", "shape"="ellipse"];' assert dotnode(x + x**2, repeat=True) == \ '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \ '["color"="black", "label"="Add", "shape"="ellipse"];' def test_dotedges(): assert sorted(dotedges(x+2, repeat=False)) == [ '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";', '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";' ] assert sorted(dotedges(x + 2, repeat=True)) == [ '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";', '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";' ] def test_dotprint(): text = dotprint(x+2, repeat=False) assert all(e in text for e in dotedges(x+2, repeat=False)) assert all( n in text for n in [dotnode(expr, repeat=False) for expr in (x, Integer(2), x+2)]) assert 'digraph' in text text = dotprint(x+x**2, repeat=False) assert all(e in text for e in dotedges(x+x**2, repeat=False)) assert all( n in text for n in [dotnode(expr, repeat=False) for expr in (x, Integer(2), x**2)]) assert 'digraph' in text text = dotprint(x+x**2, repeat=True) assert all(e in text for e in dotedges(x+x**2, repeat=True)) assert all( n in text for n in [dotnode(expr, pos=()) for expr in [x + x**2]]) text = dotprint(x**x, repeat=True) assert all(e in text for e in dotedges(x**x, repeat=True)) assert all( n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))]) assert 'digraph' in text def test_dotprint_depth(): text = dotprint(3*x+2, depth=1) assert dotnode(3*x+2) in text assert dotnode(x) not in text text = dotprint(3*x+2) assert "depth" not in text def test_Matrix_and_non_basics(): from sympy.matrices.expressions.matexpr import MatrixSymbol n = Symbol('n') assert dotprint(MatrixSymbol('X', n, n)) == \ """digraph{ # Graph style "ordering"="out" "rankdir"="TD" ######### # Nodes # ######### "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"]; "Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"]; "Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"]; "Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"]; ######### # Edges # ######### "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)"; "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)"; "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)"; }""" def test_labelfunc(): text = dotprint(x + 2, labelfunc=srepr) assert "Symbol('x')" in text assert "Integer(2)" in text def test_commutative(): x, y = symbols('x y', commutative=False) assert dotprint(x + y) == dotprint(y + x) assert dotprint(x*y) != dotprint(y*x)