You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
526 lines
20 KiB
526 lines
20 KiB
5 months ago
|
import contextlib
|
||
|
import itertools
|
||
|
import re
|
||
|
import typing
|
||
|
from enum import Enum
|
||
|
from typing import Callable
|
||
|
|
||
|
import sympy
|
||
|
from sympy import Add, Implies, sqrt
|
||
|
from sympy.core import Mul, Pow
|
||
|
from sympy.core import (S, pi, symbols, Function, Rational, Integer,
|
||
|
Symbol, Eq, Ne, Le, Lt, Gt, Ge)
|
||
|
from sympy.functions import Piecewise, exp, sin, cos
|
||
|
from sympy.printing.smtlib import smtlib_code
|
||
|
from sympy.testing.pytest import raises, Failed
|
||
|
|
||
|
x, y, z = symbols('x,y,z')
|
||
|
|
||
|
|
||
|
class _W(Enum):
|
||
|
DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.I)
|
||
|
WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.I)
|
||
|
WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.I)
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _check_warns(expected: typing.Iterable[_W]):
|
||
|
warns: typing.List[str] = []
|
||
|
log_warn = warns.append
|
||
|
yield log_warn
|
||
|
|
||
|
errors = []
|
||
|
for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
|
||
|
if not e:
|
||
|
errors += [f"[{i}] Received unexpected warning `{w}`."]
|
||
|
elif not w:
|
||
|
errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
|
||
|
elif not e.value.match(w):
|
||
|
errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]
|
||
|
|
||
|
if errors: raise Failed('\n'.join(errors))
|
||
|
|
||
|
|
||
|
def test_Integer():
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
|
||
|
assert smtlib_code(Integer(67), log_warn=w) == "67"
|
||
|
assert smtlib_code(Integer(-1), log_warn=w) == "-1"
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(Integer(67)) == "67"
|
||
|
assert smtlib_code(Integer(-1)) == "-1"
|
||
|
|
||
|
|
||
|
def test_Rational():
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
|
||
|
assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
|
||
|
assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
|
||
|
assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
|
||
|
assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
|
||
|
assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
|
||
|
assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
|
||
|
"(* (/ 3 7) x)"
|
||
|
|
||
|
|
||
|
def test_Relational():
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
|
||
|
assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
|
||
|
assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
|
||
|
assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
|
||
|
assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
|
||
|
assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
|
||
|
assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
|
||
|
|
||
|
|
||
|
def test_Function():
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"
|
||
|
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
abs(x),
|
||
|
symbol_table={x: int, y: bool},
|
||
|
known_types={int: "INTEGER_TYPE"},
|
||
|
known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
|
||
|
log_warn=w
|
||
|
) == "(declare-const x INTEGER_TYPE)\n" \
|
||
|
"(ABSOLUTE_VALUE_OF x)"
|
||
|
|
||
|
my_fun1 = Function('f1')
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
my_fun1(x),
|
||
|
symbol_table={my_fun1: Callable[[bool], float]},
|
||
|
log_warn=w
|
||
|
) == "(declare-const x Bool)\n" \
|
||
|
"(declare-fun f1 (Bool) Real)\n" \
|
||
|
"(f1 x)"
|
||
|
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(
|
||
|
my_fun1(x),
|
||
|
symbol_table={my_fun1: Callable[[bool], bool]},
|
||
|
log_warn=w
|
||
|
) == "(declare-const x Bool)\n" \
|
||
|
"(declare-fun f1 (Bool) Bool)\n" \
|
||
|
"(assert (f1 x))"
|
||
|
|
||
|
assert smtlib_code(
|
||
|
Eq(my_fun1(x, z), y),
|
||
|
symbol_table={my_fun1: Callable[[int, bool], bool]},
|
||
|
log_warn=w
|
||
|
) == "(declare-const x Int)\n" \
|
||
|
"(declare-const y Bool)\n" \
|
||
|
"(declare-const z Bool)\n" \
|
||
|
"(declare-fun f1 (Int Bool) Bool)\n" \
|
||
|
"(assert (= (f1 x z) y))"
|
||
|
|
||
|
assert smtlib_code(
|
||
|
Eq(my_fun1(x, z), y),
|
||
|
symbol_table={my_fun1: Callable[[int, bool], bool]},
|
||
|
known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
|
||
|
log_warn=w
|
||
|
) == "(declare-const x Int)\n" \
|
||
|
"(declare-const y Bool)\n" \
|
||
|
"(declare-const z Bool)\n" \
|
||
|
"(assert (== (MY_KNOWN_FUN x z) y))"
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
|
||
|
assert smtlib_code(
|
||
|
Eq(my_fun1(x, z), y),
|
||
|
known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
|
||
|
log_warn=w
|
||
|
) == "(declare-const x Real)\n" \
|
||
|
"(declare-const y Real)\n" \
|
||
|
"(declare-const z Real)\n" \
|
||
|
"(assert (== (MY_KNOWN_FUN x z) y))"
|
||
|
|
||
|
|
||
|
def test_Pow():
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'
|
||
|
|
||
|
a = Symbol('a', integer=True)
|
||
|
b = Symbol('b', real=True)
|
||
|
c = Symbol('c')
|
||
|
|
||
|
def g(x): return 2 * x
|
||
|
|
||
|
# if x=1, y=2, then expr=2.333...
|
||
|
expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)
|
||
|
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(
|
||
|
[
|
||
|
Eq(a < 2, c),
|
||
|
Eq(b > a, c),
|
||
|
c & True,
|
||
|
Eq(expr, 2 + Rational(1, 3))
|
||
|
],
|
||
|
log_warn=w
|
||
|
) == '(declare-const a Int)\n' \
|
||
|
'(declare-const b Real)\n' \
|
||
|
'(declare-const c Bool)\n' \
|
||
|
'(assert (= (< a 2) c))\n' \
|
||
|
'(assert (= (> b a) c))\n' \
|
||
|
'(assert c)\n' \
|
||
|
'(assert (= ' \
|
||
|
'(* (pow (* 7. a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
|
||
|
'(/ 7 3)' \
|
||
|
'))'
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
|
||
|
log_warn=w
|
||
|
) == '(declare-const b Real)\n' \
|
||
|
'(declare-const c Real)\n' \
|
||
|
'(* -2 c (pow (* b b) -1))'
|
||
|
|
||
|
|
||
|
def test_basic_ops():
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"
|
||
|
|
||
|
# with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
|
||
|
# todo: implement re-write, currently does '(+ x (* -1 y))' instead
|
||
|
# assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"
|
||
|
|
||
|
|
||
|
def test_quantifier_extensions():
|
||
|
from sympy.logic.boolalg import Boolean
|
||
|
from sympy import Interval, Tuple, sympify
|
||
|
|
||
|
# start For-all quantifier class example
|
||
|
class ForAll(Boolean):
|
||
|
def _smtlib(self, printer):
|
||
|
bound_symbol_declarations = [
|
||
|
printer._s_expr(sym.name, [
|
||
|
printer._known_types[printer.symbol_table[sym]],
|
||
|
Interval(start, end)
|
||
|
]) for sym, start, end in self.limits
|
||
|
]
|
||
|
return printer._s_expr('forall', [
|
||
|
printer._s_expr('', bound_symbol_declarations),
|
||
|
self.function
|
||
|
])
|
||
|
|
||
|
@property
|
||
|
def bound_symbols(self):
|
||
|
return {s for s, _, _ in self.limits}
|
||
|
|
||
|
@property
|
||
|
def free_symbols(self):
|
||
|
bound_symbol_names = {s.name for s in self.bound_symbols}
|
||
|
return {
|
||
|
s for s in self.function.free_symbols
|
||
|
if s.name not in bound_symbol_names
|
||
|
}
|
||
|
|
||
|
def __new__(cls, *args):
|
||
|
limits = [sympify(a) for a in args if isinstance(a, tuple) or isinstance(a, Tuple)]
|
||
|
function = [sympify(a) for a in args if isinstance(a, Boolean)]
|
||
|
assert len(limits) + len(function) == len(args)
|
||
|
assert len(function) == 1
|
||
|
function = function[0]
|
||
|
|
||
|
if isinstance(function, ForAll): return ForAll.__new__(
|
||
|
ForAll, *(limits + function.limits), function.function
|
||
|
)
|
||
|
inst = Boolean.__new__(cls)
|
||
|
inst._args = tuple(limits + [function])
|
||
|
inst.limits = limits
|
||
|
inst.function = function
|
||
|
return inst
|
||
|
|
||
|
# end For-All Quantifier class example
|
||
|
|
||
|
f = Function('f')
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
|
||
|
assert smtlib_code(
|
||
|
ForAll((x, -42, +21), Eq(f(x), f(x))),
|
||
|
symbol_table={f: Callable[[float], float]},
|
||
|
log_warn=w
|
||
|
) == '(assert (forall ( (x Real [-42, 21])) true))'
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
|
||
|
assert smtlib_code(
|
||
|
ForAll(
|
||
|
(x, -42, +21), (y, -100, 3),
|
||
|
Implies(Eq(x, y), Eq(f(x), f(y)))
|
||
|
),
|
||
|
symbol_table={f: Callable[[float], float]},
|
||
|
log_warn=w
|
||
|
) == '(declare-fun f (Real) Real)\n' \
|
||
|
'(assert (' \
|
||
|
'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
|
||
|
'(=> (= x y) (= (f x) (f y)))' \
|
||
|
'))'
|
||
|
|
||
|
a = Symbol('a', integer=True)
|
||
|
b = Symbol('b', real=True)
|
||
|
c = Symbol('c')
|
||
|
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(
|
||
|
ForAll(
|
||
|
(a, 2, 100), ForAll(
|
||
|
(b, 2, 100),
|
||
|
Implies(a < b, sqrt(a) < b) | c
|
||
|
)),
|
||
|
log_warn=w
|
||
|
) == '(declare-const c Bool)\n' \
|
||
|
'(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
|
||
|
'(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
|
||
|
'))'
|
||
|
|
||
|
|
||
|
def test_mix_number_mult_symbols():
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
1 / pi,
|
||
|
known_constants={pi: "MY_PI"},
|
||
|
log_warn=w
|
||
|
) == '(pow MY_PI -1)'
|
||
|
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
[
|
||
|
Eq(pi, 3.14, evaluate=False),
|
||
|
1 / pi,
|
||
|
],
|
||
|
known_constants={pi: "MY_PI"},
|
||
|
log_warn=w
|
||
|
) == '(assert (= MY_PI 3.14))\n' \
|
||
|
'(pow MY_PI -1)'
|
||
|
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||
|
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||
|
known_constants={
|
||
|
S.Pi: 'p', S.GoldenRatio: 'g',
|
||
|
S.Exp1: 'e'
|
||
|
},
|
||
|
known_functions={
|
||
|
Add: 'plus',
|
||
|
exp: 'exp'
|
||
|
},
|
||
|
precision=3,
|
||
|
log_warn=w
|
||
|
) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'
|
||
|
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||
|
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||
|
known_constants={
|
||
|
S.Pi: 'p'
|
||
|
},
|
||
|
known_functions={
|
||
|
Add: 'plus',
|
||
|
exp: 'exp'
|
||
|
},
|
||
|
precision=3,
|
||
|
log_warn=w
|
||
|
) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'
|
||
|
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||
|
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||
|
known_functions={Add: 'plus'},
|
||
|
precision=3,
|
||
|
log_warn=w
|
||
|
) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'
|
||
|
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||
|
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||
|
known_constants={S.Exp1: 'e'},
|
||
|
known_functions={Add: 'plus'},
|
||
|
precision=3,
|
||
|
log_warn=w
|
||
|
) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'
|
||
|
|
||
|
|
||
|
def test_boolean():
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
|
||
|
'(declare-const y Bool)\n' \
|
||
|
'(assert (and x y))'
|
||
|
assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
|
||
|
'(declare-const y Bool)\n' \
|
||
|
'(assert (or x y))'
|
||
|
assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
|
||
|
'(assert (not x))'
|
||
|
assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
|
||
|
'(declare-const y Bool)\n' \
|
||
|
'(declare-const z Bool)\n' \
|
||
|
'(assert (and x y z))'
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
|
||
|
assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
|
||
|
'(declare-const y Bool)\n' \
|
||
|
'(declare-const z Real)\n' \
|
||
|
'(assert (or (> z 3) (and x (not y))))'
|
||
|
|
||
|
f = Function('f')
|
||
|
g = Function('g')
|
||
|
h = Function('h')
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
|
||
|
assert smtlib_code(
|
||
|
[Gt(f(x), y),
|
||
|
Lt(y, g(z))],
|
||
|
symbol_table={
|
||
|
f: Callable[[bool], int], g: Callable[[bool], int],
|
||
|
}, log_warn=w
|
||
|
) == '(declare-const x Bool)\n' \
|
||
|
'(declare-const y Real)\n' \
|
||
|
'(declare-const z Bool)\n' \
|
||
|
'(declare-fun f (Bool) Int)\n' \
|
||
|
'(declare-fun g (Bool) Int)\n' \
|
||
|
'(assert (> (f x) y))\n' \
|
||
|
'(assert (< y (g z)))'
|
||
|
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(
|
||
|
[Eq(f(x), y),
|
||
|
Lt(y, g(z))],
|
||
|
symbol_table={
|
||
|
f: Callable[[bool], int], g: Callable[[bool], int],
|
||
|
}, log_warn=w
|
||
|
) == '(declare-const x Bool)\n' \
|
||
|
'(declare-const y Int)\n' \
|
||
|
'(declare-const z Bool)\n' \
|
||
|
'(declare-fun f (Bool) Int)\n' \
|
||
|
'(declare-fun g (Bool) Int)\n' \
|
||
|
'(assert (= (f x) y))\n' \
|
||
|
'(assert (< y (g z)))'
|
||
|
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(
|
||
|
[Eq(f(x), y),
|
||
|
Eq(g(f(x)), z),
|
||
|
Eq(h(g(f(x))), x)],
|
||
|
symbol_table={
|
||
|
f: Callable[[float], int],
|
||
|
g: Callable[[int], bool],
|
||
|
h: Callable[[bool], float]
|
||
|
},
|
||
|
log_warn=w
|
||
|
) == '(declare-const x Real)\n' \
|
||
|
'(declare-const y Int)\n' \
|
||
|
'(declare-const z Bool)\n' \
|
||
|
'(declare-fun f (Real) Int)\n' \
|
||
|
'(declare-fun g (Int) Bool)\n' \
|
||
|
'(declare-fun h (Bool) Real)\n' \
|
||
|
'(assert (= (f x) y))\n' \
|
||
|
'(assert (= (g (f x)) z))\n' \
|
||
|
'(assert (= (h (g (f x))) x))'
|
||
|
|
||
|
|
||
|
# todo: make smtlib_code support arrays
|
||
|
# def test_containers():
|
||
|
# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
|
||
|
# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
|
||
|
# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
|
||
|
# assert julia_code([1]) == "Any[1]"
|
||
|
# assert julia_code((1,)) == "(1,)"
|
||
|
# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
|
||
|
# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
|
||
|
# # scalar, matrix, empty matrix and empty list
|
||
|
# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
|
||
|
|
||
|
def test_smtlib_piecewise():
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Piecewise((x, x < 1),
|
||
|
(x ** 2, True)),
|
||
|
auto_declare=False,
|
||
|
log_warn=w
|
||
|
) == '(ite (< x 1) x (pow x 2))'
|
||
|
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(
|
||
|
Piecewise((x ** 2, x < 1),
|
||
|
(x ** 3, x < 2),
|
||
|
(x ** 4, x < 3),
|
||
|
(x ** 5, True)),
|
||
|
auto_declare=False,
|
||
|
log_warn=w
|
||
|
) == '(ite (< x 1) (pow x 2) ' \
|
||
|
'(ite (< x 2) (pow x 3) ' \
|
||
|
'(ite (< x 3) (pow x 4) ' \
|
||
|
'(pow x 5))))'
|
||
|
|
||
|
# Check that Piecewise without a True (default) condition error
|
||
|
expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))
|
||
|
|
||
|
|
||
|
def test_smtlib_piecewise_times_const():
|
||
|
pw = Piecewise((x, x < 1), (x ** 2, True))
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'
|
||
|
|
||
|
|
||
|
# todo: make smtlib_code support arrays / matrices ?
|
||
|
# def test_smtlib_matrix_assign_to():
|
||
|
# A = Matrix([[1, 2, 3]])
|
||
|
# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
|
||
|
# A = Matrix([[1, 2], [3, 4]])
|
||
|
# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"
|
||
|
|
||
|
# def test_julia_matrix_1x1():
|
||
|
# A = Matrix([[3]])
|
||
|
# B = MatrixSymbol('B', 1, 1)
|
||
|
# C = MatrixSymbol('C', 1, 2)
|
||
|
# assert julia_code(A, assign_to=B) == "B = [3]"
|
||
|
# raises(ValueError, lambda: julia_code(A, assign_to=C))
|
||
|
|
||
|
# def test_julia_matrix_elements():
|
||
|
# A = Matrix([[x, 2, x * y]])
|
||
|
# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
|
||
|
# A = MatrixSymbol('AA', 1, 3)
|
||
|
# assert julia_code(A) == "AA"
|
||
|
# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
|
||
|
# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
|
||
|
# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
|
||
|
|
||
|
def test_smtlib_boolean():
|
||
|
with _check_warns([]) as w:
|
||
|
assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
|
||
|
assert smtlib_code(True, log_warn=w) == '(assert true)'
|
||
|
assert smtlib_code(S.true, log_warn=w) == '(assert true)'
|
||
|
assert smtlib_code(S.false, log_warn=w) == '(assert false)'
|
||
|
assert smtlib_code(False, log_warn=w) == '(assert false)'
|
||
|
assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'
|
||
|
|
||
|
|
||
|
def test_not_supported():
|
||
|
f = Function('f')
|
||
|
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||
|
raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
|
||
|
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||
|
raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))
|