@ -13,6 +13,86 @@ let empty_map = Map.empty (module Term)
type subst = Term . t term_map [ @@ deriving compare , equal , sexp ]
let pp_subst fs s =
Format . fprintf fs " @[<1>[%a]@] "
( List . pp " ,@ " ( fun fs ( k , v ) ->
Format . fprintf fs " @[%a ↦ %a@] " Term . pp k Term . pp v ) )
( Map . to_alist s )
(* * Theory Solver *)
let rec is_constant e =
match ( e : Term . t ) with
| Var _ | Nondet _ -> false
| Ap1 ( _ , x ) -> is_constant x
| Ap2 ( _ , x , y ) -> is_constant x && is_constant y
| Ap3 ( _ , x , y , z ) -> is_constant x && is_constant y && is_constant z
| ApN ( _ , xs ) | RecN ( _ , xs ) -> Vector . for_all ~ f : is_constant xs
| Add args | Mul args ->
Qset . for_all ~ f : ( fun arg _ -> is_constant arg ) args
| Label _ | Float _ | Integer _ -> true
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[ @@ deriving compare ]
let classify e =
match ( e : Term . t ) with
| Add _ | Mul _ -> Interpreted
| Ap2 ( ( Eq | Dq ) , _ , _ ) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let solve e f =
[ % Trace . call fun { pf } -> pf " %a@ %a " Term . pp e Term . pp f ]
;
let rec solve_ e f s =
let solve_uninterp e f =
match ( ( e : Term . t ) , ( f : Term . t ) ) with
| Integer { data = m } , Integer { data = n } when not ( Z . equal m n ) -> None
| _ -> (
match ( is_constant e , is_constant f ) with
(* orient equation to discretionarily prefer term that is constant
or compares smaller as class representative * )
| true , false -> Some ( Map . add_exn s ~ key : f ~ data : e )
| false , true -> Some ( Map . add_exn s ~ key : e ~ data : f )
| _ ->
let key , data =
if Term . compare e f > 0 then ( e , f ) else ( f , e )
in
Some ( Map . add_exn s ~ key ~ data ) )
in
let concat_size args =
Vector . fold_until args ~ init : Term . zero
~ f : ( fun sum m ->
match ( m : Term . t ) with
| Ap2 ( Memory , siz , _ ) -> Continue ( Term . add siz sum )
| _ -> Stop None )
~ finish : ( fun _ -> None )
in
match ( ( e : Term . t ) , ( f : Term . t ) ) with
| ( Add _ | Mul _ | Integer _ ) , _ | _ , ( Add _ | Mul _ | Integer _ ) -> (
let e_f = Term . sub e f in
match Term . solve_zero_eq e_f with
| Some ( key , data ) -> Some ( Map . add_exn s ~ key ~ data )
| None -> solve_uninterp e_f Term . zero )
| ApN ( Concat , ms ) , ApN ( Concat , ns ) -> (
match ( concat_size ms , concat_size ns ) with
| Some p , Some q -> solve_uninterp e f > > = solve_ p q
| _ -> solve_uninterp e f )
| Ap2 ( Memory , m , _ ) , ApN ( Concat , ns )
| ApN ( Concat , ns ) , Ap2 ( Memory , m , _ ) -> (
match concat_size ns with
| Some p -> solve_uninterp e f > > = solve_ p m
| _ -> solve_uninterp e f )
| _ -> solve_uninterp e f
in
solve_ e f empty_map
| >
[ % Trace . retn fun { pf } ->
function Some s -> pf " %a " pp_subst s | None -> pf " false " ]
(* * Equality Relations *)
(* * see also [invariant] *)
type t =
{ sat : bool (* * [false] only if constraints are inconsistent *)
@ -31,7 +111,7 @@ let classes r =
else Map . add_multi cls ~ key : data ~ data : key
in
Map . fold r . rep ~ init : empty_map ~ f : ( fun ~ key ~ data cls ->
match Term . classify key with
match classify key with
| Interpreted | Atomic -> add key data cls
| Simplified | Uninterpreted ->
add ( Term . map ~ f : ( apply r . rep ) key ) data cls )
@ -85,7 +165,7 @@ let pp_diff fs (r, s) =
let in_car r e = Map . mem r . rep e
let rec iter_max_solvables e ~ f =
match Term . classify e with
match classify e with
| Interpreted -> Term . iter ~ f : ( iter_max_solvables ~ f ) e
| _ -> f e
@ -94,7 +174,7 @@ let invariant r =
@@ fun () ->
Map . iteri r . rep ~ f : ( fun ~ key : a ~ data : _ ->
(* no interpreted terms in carrier *)
assert ( Poly . ( Term . classify a < > Interpreted ) ) ;
assert ( Poly . ( classify a < > Interpreted ) ) ;
(* carrier is closed under subterms *)
iter_max_solvables a ~ f : ( fun b ->
assert (
@ -108,7 +188,7 @@ let true_ = {sat= true; rep= empty_map} |> check invariant
(* * apply a subst to maximal non-interpreted subterms *)
let rec norm s a =
match Term . classify a with
match classify a with
| Interpreted -> Term . map ~ f : ( norm s ) a
| Simplified -> apply s ( Term . map ~ f : ( norm s ) a )
| Atomic | Uninterpreted -> apply s a
@ -130,14 +210,14 @@ let lookup r a =
(* * rewrite a term into canonical form using rep and, for uninterpreted
terms , congruence composed with rep * )
let rec canon r a =
match Term . classify a with
match classify a with
| Interpreted -> Term . map ~ f : ( canon r ) a
| Simplified | Uninterpreted -> lookup r ( Term . map ~ f : ( canon r ) a )
| Atomic -> apply r . rep a
(* * add a term to the carrier *)
let rec extend a r =
match Term . classify a with
match classify a with
| Interpreted | Simplified -> Term . fold ~ f : extend a ~ init : r
| Uninterpreted ->
Map . find_or_add r . rep a
@ -160,7 +240,7 @@ let compose r s =
let merge a b r =
[ % Trace . call fun { pf } -> pf " %a@ %a@ %a " Term . pp a Term . pp b pp r ]
;
( match Term . solve a b with
( match solve a b with
| Some s -> compose r s
| None -> { r with sat = false } )
| >