@ -177,33 +177,35 @@ end
(* * Theory Solver *)
(* * orient equations s.t. Var < Memory < Extract < Concat < others, then
using height of aggregate nesting , and then using Term . compare * )
let orient e f =
let compare e f =
let rank e =
match ( e : Term . t ) with
| Var _ -> 0
| Ap2 ( Memory , _ , _ ) -> 1
| Ap3 ( Extract , _ , _ , _ ) -> 2
| ApN ( Concat , _ ) -> 3
| _ -> 4
in
let rec height e =
match ( e : Term . t ) with
| Ap2 ( Memory , _ , x ) -> 1 + height x
| Ap3 ( Extract , x , _ , _ ) -> 1 + height x
| ApN ( Concat , xs ) ->
1 + IArray . fold ~ init : 0 ~ f : ( fun h x -> max h ( height x ) ) xs
| _ -> 0
in
let o = compare ( rank e ) ( rank f ) in
if o < > 0 then o
else
let o = compare ( height e ) ( height f ) in
if o < > 0 then o else Term . compare e f
(* * prefer representative terms that are minimal in the order s.t. Var <
Memory < Extract < Concat < others , then using height of aggregate
nesting , and then using Term . compare * )
let prefer e f =
let rank e =
match ( e : Term . t ) with
| Var _ -> 0
| Ap2 ( Memory , _ , _ ) -> 1
| Ap3 ( Extract , _ , _ , _ ) -> 2
| ApN ( Concat , _ ) -> 3
| _ -> 4
in
match Ordering . of_int ( compare e f ) with
let rec height e =
match ( e : Term . t ) with
| Ap2 ( Memory , _ , x ) -> 1 + height x
| Ap3 ( Extract , x , _ , _ ) -> 1 + height x
| ApN ( Concat , xs ) ->
1 + IArray . fold ~ init : 0 ~ f : ( fun h x -> max h ( height x ) ) xs
| _ -> 0
in
let o = compare ( rank e ) ( rank f ) in
if o < > 0 then o
else
let o = compare ( height e ) ( height f ) in
if o < > 0 then o else Term . compare e f
(* * orient equations based on representative preference *)
let orient e f =
match Ordering . of_int ( prefer e f ) with
| Less -> Some ( e , f )
| Equal -> None
| Greater -> Some ( f , e )
@ -402,23 +404,44 @@ let pp_diff_clss =
Term . Map . pp_diff ~ data_equal : ( List . equal Term . equal ) Term . pp pp_cls
pp_diff_cls
(* * Invariant *)
(* * Basic queries *)
(* * test membership in carrier *)
let in_car r e = Subst . mem r . rep e
let invariant r =
(* * terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term . equal
( Term . map ~ f : ( Subst . norm r . rep ) a )
( Term . map ~ f : ( Subst . norm r . rep ) b )
(* * Invariant *)
let pre_invariant r =
Invariant . invariant [ % here ] r [ % sexp_of : t ]
@@ fun () ->
Subst . iteri r . rep ~ f : ( fun ~ key : a ~ data : _ ->
Subst . iteri r . rep ~ f : ( fun ~ key : trm ~ data : _ ->
(* no interpreted terms in carrier *)
assert ( non_interpreted a ) ;
assert ( non_interpreted trm | | f ail " non-interp %a " Term . pp trm () ) ;
(* carrier is closed under subterms *)
iter_max_solvables a ~ f : ( fun b ->
iter_max_solvables trm ~ f : ( fun su btrm ->
assert (
in_car r b
| | fail " @[subterm %a of %a not in carrier of@ %a@] " Term . pp b
Term . pp a pp r () ) ) )
in_car r subtrm
| | fail " @[subterm %a of %a not in carrier of@ %a@] " Term . pp
subtrm Term . pp trm pp r () ) ) )
let invariant r =
Invariant . invariant [ % here ] r [ % sexp_of : t ]
@@ fun () ->
pre_invariant r ;
assert (
( not r . sat )
| | Subst . for_alli r . rep ~ f : ( fun ~ key : a ~ data : a' ->
Subst . for_alli r . rep ~ f : ( fun ~ key : b ~ data : b' ->
Term . compare a b > = 0
| | congruent r a b = = > Term . equal a' b'
| | fail " not congruent %a@ %a@ in@ %a " Term . pp a Term . pp b pp
r () ) ) )
(* * Core operations *)
@ -427,15 +450,9 @@ let true_ =
let false _ = { true _ with sat = false }
(* * terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term . equal
( Term . map ~ f : ( Subst . norm r . rep ) a )
( Term . map ~ f : ( Subst . norm r . rep ) b )
(* * [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a =
[ % Trace . call fun { pf } -> pf " %a @ %a " Term . pp a pp r ]
[ % Trace . call fun { pf } -> pf " %a " Term . pp a ]
;
( with_return
@@ fun { return } ->
@ -452,7 +469,7 @@ let lookup r a =
(* * rewrite a term into canonical form using rep and, for non-interpreted
terms , congruence composed with rep * )
let rec canon r a =
[ % Trace . call fun { pf } -> pf " %a @ %a " Term . pp a pp r ]
[ % Trace . call fun { pf } -> pf " %a " Term . pp a ]
;
( match classify a with
| Atomic -> Subst . apply r . rep a
@ -478,7 +495,7 @@ let rec extend_ a r =
(* * add a term to the carrier *)
let extend a r =
let rep = extend_ a r . rep in
if rep = = r . rep then r else { r with rep } | > check invariant
if rep = = r . rep then r else { r with rep } | > check pre_ invariant
let merge us a b r =
[ % Trace . call fun { pf } -> pf " %a@ %a@ %a " Term . pp a Term . pp b pp r ]
@ -490,7 +507,7 @@ let merge us a b r =
| >
[ % Trace . retn fun { pf } r' ->
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
pre_ invariant r' ]
(* * find an unproved equation between congruent terms *)
let find_missing r =
@ -521,7 +538,7 @@ let close us r =
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
let and_eq us a b r =
let and_eq _ us a b r =
if not r . sat then r
else
let a' = canon r a in
@ -538,7 +555,13 @@ let is_true {sat; rep} =
sat && Subst . for_alli rep ~ f : ( fun ~ key : a ~ data : a' -> Term . equal a a' )
let is_false { sat } = not sat
let entails_eq r d e = Term . is_true ( canon r ( Term . eq d e ) )
let entails_eq r d e =
[ % Trace . call fun { pf } -> pf " %a = %a@ %a " Term . pp d Term . pp e pp r ]
;
Term . is_true ( canon r ( Term . eq d e ) )
| >
[ % Trace . retn fun { pf } -> pf " %b " ]
let entails r s =
Subst . for_alli s . rep ~ f : ( fun ~ key : e ~ data : e' -> entails_eq r e e' )
@ -584,21 +607,29 @@ let apply_subst us s r =
let rep' = Subst . subst s rep in
List . fold cls ~ init : r ~ f : ( fun r trm ->
let trm' = Subst . subst s trm in
and_eq us trm' rep' r ) )
and_eq _ us trm' rep' r ) )
| > extract_xs
| >
[ % Trace . retn fun { pf } ( xs , r' ) -> pf " %a%a " Var . Set . pp_xs xs pp r' ]
[ % Trace . retn fun { pf } ( xs , r' ) ->
pf " %a%a " Var . Set . pp_xs xs pp_diff ( r , r' ) ;
invariant r' ]
let and_ us r s =
[ % Trace . call fun { pf } -> pf " @[<hv 1> %a@ @<2>∧ %a@] " pp r pp s ]
;
( if not r . sat then r
else if not s . sat then s
else
let s , r =
if Subst . length s . rep < = Subst . length r . rep then ( s , r ) else ( r , s )
in
Subst . fold s . rep ~ init : r ~ f : ( fun ~ key : e ~ data : e' r -> and_eq us e e' r )
Subst . fold s . rep ~ init : r ~ f : ( fun ~ key : e ~ data : e' r -> and_eq _ us e e' r )
)
| > extract_xs
| >
[ % Trace . retn fun { pf } ( _ , r' ) ->
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
let or_ us r s =
[ % Trace . call fun { pf } -> pf " @[<hv 1> %a@ @<2>∨ %a@] " pp r pp s ]
@ -612,7 +643,7 @@ let or_ us r s =
~ init : ( [ rep ] , rs )
~ f : ( fun ( reps , rs ) exp ->
match List . find ~ f : ( entails_eq r exp ) reps with
| Some rep -> ( reps , and_eq us exp rep rs )
| Some rep -> ( reps , and_eq _ us exp rep rs )
| None -> ( exp :: reps , rs ) )
| > snd )
in
@ -622,7 +653,9 @@ let or_ us r s =
rs )
| > extract_xs
| >
[ % Trace . retn fun { pf } ( _ , r ) -> pf " %a " pp r ]
[ % Trace . retn fun { pf } ( _ , r' ) ->
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
let orN us rs =
match rs with
@ -630,21 +663,28 @@ let orN us rs =
| r :: rs -> List . fold ~ f : ( fun ( us , s ) r -> or_ us s r ) ~ init : ( us , r ) rs
let rec and_term_ us e r =
let eq_false b r = and_eq us b Term . false_ r in
let eq_false b r = and_eq _ us b Term . false_ r in
match ( e : Term . t ) with
| Integer { data } -> if Z . is_false data then false _ else r
| Ap2 ( And , a , b ) -> and_term_ us a ( and_term_ us b r )
| Ap2 ( Eq , a , b ) -> and_eq us a b r
| Ap2 ( Eq , a , b ) -> and_eq _ us a b r
| Ap2 ( Xor , Integer { data } , a ) when Z . is_true data -> eq_false a r
| Ap2 ( Xor , a , Integer { data } ) when Z . is_true data -> eq_false a r
| _ -> r
let and_term us e r = and_term_ us e r | > extract_xs
let and_term us e r =
[ % Trace . call fun { pf } -> pf " %a@ %a " Term . pp e pp r ]
;
and_term_ us e r | > extract_xs
| >
[ % Trace . retn fun { pf } ( _ , r' ) ->
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
let and_eq us a b r =
[ % Trace . call fun { pf } -> pf " %a = %a@ %a " Term . pp a Term . pp b pp r ]
;
and_eq us a b r | > extract_xs
and_eq _ us a b r | > extract_xs
| >
[ % Trace . retn fun { pf } ( _ , r' ) ->
pf " %a " pp_diff ( r , r' ) ;
@ -710,6 +750,8 @@ type 'a zom = Zero | One of 'a | Many
[ fv kill ⊈ us ] ; solve [ p = q ] for [ kill ] ; extend subst mapping [ kill ]
to the solution * )
let solve_poly_eq us p' q' subst =
[ % Trace . call fun { pf } -> pf " %a = %a " Term . pp p' Term . pp q' ]
;
let diff = Term . sub p' q' in
let max_solvables_not_ito_us =
fold_max_solvables diff ~ init : Zero ~ f : ( fun solvable_subterm -> function
@ -718,11 +760,14 @@ let solve_poly_eq us p' q' subst =
| One _ -> Many
| Zero -> One solvable_subterm )
in
match max_solvables_not_ito_us with
( match max_solvables_not_ito_us with
| One kill ->
let + kill , keep = Term . solve_zero_eq diff ~ for_ : kill in
Subst . compose1 ~ key : kill ~ data : keep subst
| Many | Zero -> None
| Many | Zero -> None )
| >
[ % Trace . retn fun { pf } subst' ->
pf " @[%a@] " Subst . pp_diff ( subst , Option . value subst' ~ default : subst ) ]
let solve_memory_eq us e' f' subst =
[ % Trace . call fun { pf } -> pf " %a = %a " Term . pp e' Term . pp f' ]
@ -1001,7 +1046,9 @@ let pp_vss fs vss =
[ fv u ⊆ ⋃ ⱼ ₌ ₁ ⁱ ⁻ ¹ v ⱼ ] if possible and otherwise
[ fv u ⊆ ⋃ ⱼ ₌ ₁ ⁱ v ⱼ ] * )
let solve_for_vars vss r =
[ % Trace . call fun { pf } -> pf " %a@ @[%a@] " pp_vss vss pp_classes r ]
[ % Trace . call fun { pf } ->
pf " %a@ @[%a@]@ @[%a@] " pp_vss vss pp_classes r pp r ;
invariant r ]
;
let us , vss =
match vss with us :: vss -> ( us , vss ) | [] -> ( Var . Set . empty , vss )
@ -1014,8 +1061,8 @@ let solve_for_vars vss r =
Subst . iteri subst ~ f : ( fun ~ key ~ data ->
assert (
entails_eq r key data
| | fail " @[%a = %a not entailed by@ %a@]" Term . pp key Term . pp data
pp_classes r () ) ;
| | fail " @[%a @ = %a@ not entailed by@ @[ %a@]@] " Term . pp key
Term . pp data pp_classes r () ) ;
assert (
List . fold_until vss ~ init : us
~ f : ( fun us xs ->