@ -7,18 +7,122 @@
(* * Equality over uninterpreted functions and linear rational arithmetic *)
type ' a term_map = ' a Map . M ( Term ) . t [ @@ deriving compare , equal , sexp ]
(* * Classification of Terms by Theory *)
let empty_map = Map . empty ( module Term )
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[ @@ deriving compare ]
type subst = Term . t term_map [ @@ deriving compare , equal , sexp ]
let classify e =
match ( e : Term . t ) with
| Add _ | Mul _ -> Interpreted
| Ap2 ( ( Eq | Dq ) , _ , _ ) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let pp_subst fs s =
(* * Solution Substitutions *)
module Subst : sig
type t [ @@ deriving compare , equal , sexp ]
val pp : t pp
val pp_sdiff : ? pre : string -> Format . formatter -> t -> t -> unit
val empty : t
val length : t -> int
val mem : t -> Term . t -> bool
val fold : t -> init : ' a -> f : ( key : Term . t -> data : Term . t -> ' a -> ' a ) -> ' a
val iteri : t -> f : ( key : Term . t -> data : Term . t -> unit ) -> unit
val for_alli : t -> f : ( key : Term . t -> data : Term . t -> bool ) -> bool
val apply : t -> Term . t -> Term . t
val norm : t -> Term . t -> Term . t
val compose : t -> t -> t
val compose1 : key : Term . t -> data : Term . t -> t -> t
val extend : Term . t -> t -> t option
val map_entries : f : ( Term . t -> Term . t ) -> t -> t
val to_alist : t -> ( Term . t * Term . t ) list
end = struct
type t = Term . t Term . Map . t [ @@ deriving compare , equal , sexp ]
let pp fs s =
Format . fprintf fs " @[<1>[%a]@] "
( List . pp " ,@ " ( fun fs ( k , v ) ->
Format . fprintf fs " @[%a ↦ %a@] " Term . pp k Term . pp v ) )
( Map . to_alist s )
let pp_sdiff ? ( pre = " " ) =
let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function
| k , ` Left v ->
Format . fprintf fs " -- [@[%a@ @<2>↦ %a@]] " pp_key k pp_val v
| k , ` Right v ->
Format . fprintf fs " ++ [@[%a@ @<2>↦ %a@]] " pp_key k pp_val v
| k , ` Unequal vv ->
Format . fprintf fs " [@[%a@ @<2>↦ %a@]] " pp_key k pp_sdiff_val vv
in
let pp_sdiff_map pp_elt_diff equal fs x y =
let sd =
Sequence . to_list ( Map . symmetric_diff ~ data_equal : equal x y )
in
if not ( List . is_empty sd ) then
Format . fprintf fs " %s[@[<hv>%a@]];@ " pre
( List . pp " ;@ " pp_elt_diff )
sd
in
let pp_sdiff_term fs ( u , v ) =
Format . fprintf fs " -- %a ++ %a " Term . pp u Term . pp v
in
pp_sdiff_map ( pp_sdiff_elt Term . pp Term . pp pp_sdiff_term ) Term . equal
let empty = Term . Map . empty
let length = Map . length
let mem = Map . mem
let fold = Map . fold
let iteri = Map . iteri
let for_alli = Map . for_alli
let to_alist = Map . to_alist ~ key_order : ` Increasing
(* * look up a term in a substitution *)
let apply s a = Map . find s a | > Option . value ~ default : a
(* * apply a substitution to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term . map ~ f : ( norm s ) a
| Simplified -> apply s ( Term . map ~ f : ( norm s ) a )
| Atomic | Uninterpreted -> apply s a
(* * compose two substitutions *)
let compose r s =
let r' = Map . map ~ f : ( norm s ) r in
Map . merge_skewed r' s ~ combine : ( fun ~ key v1 v2 ->
if Term . equal v1 v2 then v1
else fail " domains intersect: %a " Term . pp key () )
(* * compose a substitution with a mapping *)
let compose1 ~ key ~ data s =
if Term . equal key data then s
else compose s ( Map . set Term . Map . empty ~ key ~ data )
(* * add an identity entry if the term is not already present *)
let extend e s =
let exception Found in
match
Map . update s e ~ f : ( function
| Some _ -> Exn . raise_without_backtrace Found
| None -> e )
with
| exception Found -> None
| s -> Some s
(* * map over a subst, applying [f] to both domain and range, requires that
[ f ] is injective and for any set of terms [ E ] , [ f \ [ E \ ] ] is disjoint
from [ E ] * )
let map_entries ~ f s =
Map . fold s ~ init : s ~ f : ( fun ~ key ~ data s ->
let key' = f key in
let data' = f data in
if Term . equal key' key then
if Term . equal data' data then s else Map . set s ~ key ~ data : data'
else Map . remove s key | > Map . add_exn ~ key : key' ~ data : data' )
end
(* * Theory Solver *)
let rec is_constant e =
@ -32,16 +136,6 @@ let rec is_constant e =
Qset . for_all ~ f : ( fun arg _ -> is_constant arg ) args
| Label _ | Float _ | Integer _ -> true
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[ @@ deriving compare ]
let classify e =
match ( e : Term . t ) with
| Add _ | Mul _ -> Interpreted
| Ap2 ( ( Eq | Dq ) , _ , _ ) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let solve e f =
[ % Trace . call fun { pf } -> pf " %a@ %a " Term . pp e Term . pp f ]
;
@ -53,13 +147,13 @@ let solve e f =
match ( is_constant e , is_constant f ) with
(* orient equation to discretionarily prefer term that is constant
or compares smaller as class representative * )
| true , false -> Some ( Map. add_exn s ~ key : f ~ data : e )
| false , true -> Some ( Map. add_exn s ~ key : e ~ data : f )
| true , false -> Some ( Subst. compose1 ~ key : f ~ data : e s )
| false , true -> Some ( Subst. compose1 ~ key : e ~ data : f s )
| _ ->
let key , data =
if Term . compare e f > 0 then ( e , f ) else ( f , e )
in
Some ( Map. add_exn s ~ key ~ data ) )
Some ( Subst. compose1 ~ key ~ data s ) )
in
let concat_size args =
Vector . fold_until args ~ init : Term . zero
@ -73,7 +167,7 @@ let solve e f =
| ( Add _ | Mul _ | Integer _ ) , _ | _ , ( Add _ | Mul _ | Integer _ ) -> (
let e_f = Term . sub e f in
match Term . solve_zero_eq e_f with
| Some ( key , data ) -> Some ( Map. add_exn s ~ key ~ data )
| Some ( key , data ) -> Some ( Subst. compose1 ~ key ~ data s )
| None -> solve_uninterp e_f Term . zero )
| ApN ( Concat , ms ) , ApN ( Concat , ns ) -> (
match ( concat_size ms , concat_size ns ) with
@ -86,35 +180,32 @@ let solve e f =
| _ -> solve_uninterp e f )
| _ -> solve_uninterp e f
in
solve_ e f empty _map
solve_ e f Subst . empty
| >
[ % Trace . retn fun { pf } ->
function Some s -> pf " %a " pp_subst s | None -> pf " false " ]
function Some s -> pf " %a " Subst . pp s | None -> pf " false " ]
(* * Equality Relations *)
(* * see also [invariant] *)
type t =
{ sat : bool (* * [false] only if constraints are inconsistent *)
; rep : subs t
; rep : Subst . t
(* * functional set of oriented equations: map [a] to [a'],
indicating that [ a = a' ] holds , and that [ a' ] is the
' rep ( resentative ) ' of [ a ] * ) }
[ @@ deriving compare , equal , sexp ]
(* * apply a subst to a term *)
let apply s a = Map . find s a | > Option . value ~ default : a
let classes r =
let add key data cls =
if Term . equal key data then cls
else Map . add_multi cls ~ key : data ~ data : key
in
Map . fold r . rep ~ init : empty _map ~ f : ( fun ~ key ~ data cls ->
Subst . fold r . rep ~ init : Term . Map . empty ~ f : ( fun ~ key ~ data cls ->
match classify key with
| Interpreted | Atomic -> add key data cls
| Simplified | Uninterpreted ->
add ( Term . map ~ f : ( apply r . rep ) key ) data cls )
add ( Term . map ~ f : ( Subst . apply r . rep ) key ) data cls )
(* * Pretty-printing *)
@ -128,41 +219,20 @@ let pp fs {sat; rep} =
let pp_term_v fs ( k , v ) = if not ( Term . equal k v ) then Term . pp fs v in
Format . fprintf fs " @[{@[<hv>sat= %b;@ rep= %a@]}@] " sat
( pp_alist Term . pp pp_term_v )
( Map . to_alist rep )
( Subst . to_alist rep )
let pp_diff fs ( r , s ) =
let pp_sdiff_map pp_elt_diff equal nam fs x y =
let sd = Sequence . to_list ( Map . symmetric_diff ~ data_equal : equal x y ) in
if not ( List . is_empty sd ) then
Format . fprintf fs " %s= [@[<hv>%a@]];@ " nam
( List . pp " ;@ " pp_elt_diff )
sd
in
let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function
| k , ` Left v ->
Format . fprintf fs " -- [@[%a@ @<2>↦ %a@]] " pp_key k pp_val v
| k , ` Right v ->
Format . fprintf fs " ++ [@[%a@ @<2>↦ %a@]] " pp_key k pp_val v
| k , ` Unequal vv ->
Format . fprintf fs " [@[%a@ @<2>↦ %a@]] " pp_key k pp_sdiff_val vv
in
let pp_sdiff_term_map =
let pp_sdiff_term fs ( u , v ) =
Format . fprintf fs " -- %a ++ %a " Term . pp u Term . pp v
in
pp_sdiff_map ( pp_sdiff_elt Term . pp Term . pp pp_sdiff_term ) Term . equal
in
let pp_sat fs =
if not ( Bool . equal r . sat s . sat ) then
Format . fprintf fs " sat= @[-- %b@ ++ %b@];@ " r . sat s . sat
in
let pp_rep fs = pp_sdiff _term_map " rep " fs r . rep s . rep in
let pp_rep fs = Subst . pp_sdiff ~ pre : " rep= " fs r . rep s . rep in
Format . fprintf fs " @[{@[<hv>%t%t@]}@] " pp_sat pp_rep
(* * Invariant *)
(* * test membership in carrier *)
let in_car r e = Map . mem r . rep e
let in_car r e = Subst . mem r . rep e
let rec iter_max_solvables e ~ f =
match classify e with
@ -172,7 +242,7 @@ let rec iter_max_solvables e ~f =
let invariant r =
Invariant . invariant [ % here ] r [ % sexp_of : t ]
@@ fun () ->
Map . iteri r . rep ~ f : ( fun ~ key : a ~ data : _ ->
Subst . iteri r . rep ~ f : ( fun ~ key : a ~ data : _ ->
(* no interpreted terms in carrier *)
assert ( Poly . ( classify a < > Interpreted ) ) ;
(* carrier is closed under subterms *)
@ -184,26 +254,23 @@ let invariant r =
(* * Core operations *)
let true _ = { sat = true ; rep = empty_map } | > check invariant
(* * apply a subst to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term . map ~ f : ( norm s ) a
| Simplified -> apply s ( Term . map ~ f : ( norm s ) a )
| Atomic | Uninterpreted -> apply s a
let true _ = { sat = true ; rep = Subst . empty } | > check invariant
(* * terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term . equal ( Term . map ~ f : ( norm r . rep ) a ) ( Term . map ~ f : ( norm r . rep ) b )
Term . equal
( Term . map ~ f : ( Subst . norm r . rep ) a )
( Term . map ~ f : ( Subst . norm r . rep ) b )
(* * [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a =
With_return . with_return
@@ fun { return } ->
(* congruent specialized to assume [a] canonized and [b] non-interpreted *)
let semi_congruent r a b = Term . equal a ( Term . map ~ f : ( apply r . rep ) b ) in
Map . iteri r . rep ~ f : ( fun ~ key ~ data ->
let semi_congruent r a b =
Term . equal a ( Term . map ~ f : ( Subst . apply r . rep ) b )
in
Subst . iteri r . rep ~ f : ( fun ~ key ~ data ->
if semi_congruent r a key then return data ) ;
a
@ -213,35 +280,25 @@ let rec canon r a =
match classify a with
| Interpreted -> Term . map ~ f : ( canon r ) a
| Simplified | Uninterpreted -> lookup r ( Term . map ~ f : ( canon r ) a )
| Atomic -> apply r . rep a
| Atomic -> Subst . apply r . rep a
(* * add a term to the carrier *)
let rec extend a r =
match classify a with
| Interpreted | Simplified -> Term . fold ~ f : extend a ~ init : r
| Uninterpreted ->
Map . find_or_add r . rep a
~ if_found : ( fun _ -> r )
~ default : a
~ if_added : ( fun rep -> Term . fold ~ f : extend a ~ init : { r with rep } )
| Uninterpreted -> (
match Subst . extend a r . rep with
| Some rep -> Term . fold ~ f : extend a ~ init : { r with rep }
| None -> r )
| Atomic -> r
let extend a r = extend a r | > check invariant
let compose r s =
let rep = Map . map ~ f : ( norm s ) r . rep in
let rep =
Map . merge_skewed rep s ~ combine : ( fun ~ key v1 v2 ->
if Term . equal v1 v2 then v1
else fail " domains intersect: %a " Term . pp key () )
in
{ r with rep }
let merge a b r =
[ % Trace . call fun { pf } -> pf " %a@ %a@ %a " Term . pp a Term . pp b pp r ]
;
( match solve a b with
| Some s -> compose r s
| Some s -> { r with rep = Subst . compose r . rep s }
| None -> { r with sat = false } )
| >
[ % Trace . retn fun { pf } r' ->
@ -252,8 +309,8 @@ let merge a b r =
let find_missing r =
With_return . with_return
@@ fun { return } ->
Map . iteri r . rep ~ f : ( fun ~ key : a ~ data : a' ->
Map . iteri r . rep ~ f : ( fun ~ key : b ~ data : b' ->
Subst . iteri r . rep ~ f : ( fun ~ key : a ~ data : a' ->
Subst . iteri r . rep ~ f : ( fun ~ key : b ~ data : b' ->
if
Term . compare a b < 0
&& ( not ( Term . equal a' b' ) )
@ -295,13 +352,13 @@ let and_eq a b r =
invariant r' ]
let is_true { sat ; rep } =
sat && Map . for_alli rep ~ f : ( fun ~ key : a ~ data : a' -> Term . equal a a' )
sat && Subst . for_alli rep ~ f : ( fun ~ key : a ~ data : a' -> Term . equal a a' )
let is_false { sat } = not sat
let entails_eq r d e = Term . equal ( canon r d ) ( canon r e )
let entails r s =
Map . for_alli s . rep ~ f : ( fun ~ key : e ~ data : e' -> entails_eq r e e' )
Subst . for_alli s . rep ~ f : ( fun ~ key : e ~ data : e' -> entails_eq r e e' )
let normalize = canon
@ -328,9 +385,9 @@ let and_ r s =
else if not s . sat then s
else
let s , r =
if Map. length s . rep < = Map . length r . rep then ( s , r ) else ( r , s )
if Subst. length s . rep < = Subst . length r . rep then ( s , r ) else ( r , s )
in
Map . fold s . rep ~ init : r ~ f : ( fun ~ key : e ~ data : e' r -> and_eq e e' r )
Subst . fold s . rep ~ init : r ~ f : ( fun ~ key : e ~ data : e' r -> and_eq e e' r )
let or_ r s =
[ % Trace . call fun { pf } -> pf " @[<hv 1> %a@ @<2>∨ %a@] " pp r pp s ]
@ -355,30 +412,18 @@ let or_ r s =
| >
[ % Trace . retn fun { pf } -> pf " %a " pp ]
(* assumes that f is injective and for any set of terms E, f[E] is disjoint
from E * )
let map_terms ( { sat = _ ; rep } as r ) ~ f =
let rename r sub =
[ % Trace . call fun { pf } -> pf " %a " pp r ]
;
let map m =
Map . fold m ~ init : m ~ f : ( fun ~ key ~ data m ->
let key' = f key in
let data' = f data in
if Term . equal key' key then
if Term . equal data' data then m else Map . set m ~ key ~ data : data'
else Map . remove m key | > Map . add_exn ~ key : key' ~ data : data' )
in
let rep' = map rep in
( if rep' = = rep then r else { r with rep = rep' } )
let rep = Subst . map_entries ~ f : ( Term . rename sub ) r . rep in
( if rep = = r . rep then r else { r with rep } )
| >
[ % Trace . retn fun { pf } r' ->
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
let rename r sub = map_terms r ~ f : ( Term . rename sub )
let fold_terms r ~ init ~ f =
Map . fold r . rep ~ f : ( fun ~ key ~ data z -> f ( f z data ) key ) ~ init
Subst . fold r . rep ~ f : ( fun ~ key ~ data z -> f ( f z data ) key ) ~ init
let fold_vars r ~ init ~ f =
fold_terms r ~ init ~ f : ( fun init -> Term . fold_vars ~ f ~ init )