diff --git a/sledge/src/symbheap/equality.ml b/sledge/src/symbheap/equality.ml index 37a50f6a0..718cb7ffc 100644 --- a/sledge/src/symbheap/equality.ml +++ b/sledge/src/symbheap/equality.ml @@ -7,17 +7,121 @@ (** Equality over uninterpreted functions and linear rational arithmetic *) -type 'a term_map = 'a Map.M(Term).t [@@deriving compare, equal, sexp] +(** Classification of Terms by Theory *) -let empty_map = Map.empty (module Term) +type kind = Interpreted | Simplified | Atomic | Uninterpreted +[@@deriving compare] + +let classify e = + match (e : Term.t) with + | Add _ | Mul _ -> Interpreted + | Ap2 ((Eq | Dq), _, _) -> Simplified + | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted + | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic + +(** Solution Substitutions *) +module Subst : sig + type t [@@deriving compare, equal, sexp] + + val pp : t pp + val pp_sdiff : ?pre:string -> Format.formatter -> t -> t -> unit + val empty : t + val length : t -> int + val mem : t -> Term.t -> bool + val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a + val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit + val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool + val apply : t -> Term.t -> Term.t + val norm : t -> Term.t -> Term.t + val compose : t -> t -> t + val compose1 : key:Term.t -> data:Term.t -> t -> t + val extend : Term.t -> t -> t option + val map_entries : f:(Term.t -> Term.t) -> t -> t + val to_alist : t -> (Term.t * Term.t) list +end = struct + type t = Term.t Term.Map.t [@@deriving compare, equal, sexp] + + let pp fs s = + Format.fprintf fs "@[<1>[%a]@]" + (List.pp ",@ " (fun fs (k, v) -> + Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v )) + (Map.to_alist s) + + let pp_sdiff ?(pre = "") = + let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function + | k, `Left v -> + Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Right v -> + Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Unequal vv -> + Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv + in + let pp_sdiff_map pp_elt_diff equal fs x y = + let sd = + Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y) + in + if not (List.is_empty sd) then + Format.fprintf fs "%s[@[%a@]];@ " pre + (List.pp ";@ " pp_elt_diff) + sd + in + let pp_sdiff_term fs (u, v) = + Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v + in + pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal -type subst = Term.t term_map [@@deriving compare, equal, sexp] + let empty = Term.Map.empty + let length = Map.length + let mem = Map.mem + let fold = Map.fold + let iteri = Map.iteri + let for_alli = Map.for_alli + let to_alist = Map.to_alist ~key_order:`Increasing + + (** look up a term in a substitution *) + let apply s a = Map.find s a |> Option.value ~default:a + + (** apply a substitution to maximal non-interpreted subterms *) + let rec norm s a = + match classify a with + | Interpreted -> Term.map ~f:(norm s) a + | Simplified -> apply s (Term.map ~f:(norm s) a) + | Atomic | Uninterpreted -> apply s a + + (** compose two substitutions *) + let compose r s = + let r' = Map.map ~f:(norm s) r in + Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> + if Term.equal v1 v2 then v1 + else fail "domains intersect: %a" Term.pp key () ) -let pp_subst fs s = - Format.fprintf fs "@[<1>[%a]@]" - (List.pp ",@ " (fun fs (k, v) -> - Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v )) - (Map.to_alist s) + (** compose a substitution with a mapping *) + let compose1 ~key ~data s = + if Term.equal key data then s + else compose s (Map.set Term.Map.empty ~key ~data) + + (** add an identity entry if the term is not already present *) + let extend e s = + let exception Found in + match + Map.update s e ~f:(function + | Some _ -> Exn.raise_without_backtrace Found + | None -> e ) + with + | exception Found -> None + | s -> Some s + + (** map over a subst, applying [f] to both domain and range, requires that + [f] is injective and for any set of terms [E], [f\[E\]] is disjoint + from [E] *) + let map_entries ~f s = + Map.fold s ~init:s ~f:(fun ~key ~data s -> + let key' = f key in + let data' = f data in + if Term.equal key' key then + if Term.equal data' data then s else Map.set s ~key ~data:data' + else Map.remove s key |> Map.add_exn ~key:key' ~data:data' ) +end (** Theory Solver *) @@ -32,16 +136,6 @@ let rec is_constant e = Qset.for_all ~f:(fun arg _ -> is_constant arg) args | Label _ | Float _ | Integer _ -> true -type kind = Interpreted | Simplified | Atomic | Uninterpreted -[@@deriving compare] - -let classify e = - match (e : Term.t) with - | Add _ | Mul _ -> Interpreted - | Ap2 ((Eq | Dq), _, _) -> Simplified - | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted - | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic - let solve e f = [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e Term.pp f] ; @@ -53,13 +147,13 @@ let solve e f = match (is_constant e, is_constant f) with (* orient equation to discretionarily prefer term that is constant or compares smaller as class representative *) - | true, false -> Some (Map.add_exn s ~key:f ~data:e) - | false, true -> Some (Map.add_exn s ~key:e ~data:f) + | true, false -> Some (Subst.compose1 ~key:f ~data:e s) + | false, true -> Some (Subst.compose1 ~key:e ~data:f s) | _ -> let key, data = if Term.compare e f > 0 then (e, f) else (f, e) in - Some (Map.add_exn s ~key ~data) ) + Some (Subst.compose1 ~key ~data s) ) in let concat_size args = Vector.fold_until args ~init:Term.zero @@ -73,7 +167,7 @@ let solve e f = | (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> ( let e_f = Term.sub e f in match Term.solve_zero_eq e_f with - | Some (key, data) -> Some (Map.add_exn s ~key ~data) + | Some (key, data) -> Some (Subst.compose1 ~key ~data s) | None -> solve_uninterp e_f Term.zero ) | ApN (Concat, ms), ApN (Concat, ns) -> ( match (concat_size ms, concat_size ns) with @@ -86,35 +180,32 @@ let solve e f = | _ -> solve_uninterp e f ) | _ -> solve_uninterp e f in - solve_ e f empty_map + solve_ e f Subst.empty |> [%Trace.retn fun {pf} -> - function Some s -> pf "%a" pp_subst s | None -> pf "false"] + function Some s -> pf "%a" Subst.pp s | None -> pf "false"] (** Equality Relations *) (** see also [invariant] *) type t = { sat: bool (** [false] only if constraints are inconsistent *) - ; rep: subst + ; rep: Subst.t (** functional set of oriented equations: map [a] to [a'], indicating that [a = a'] holds, and that [a'] is the 'rep(resentative)' of [a] *) } [@@deriving compare, equal, sexp] -(** apply a subst to a term *) -let apply s a = Map.find s a |> Option.value ~default:a - let classes r = let add key data cls = if Term.equal key data then cls else Map.add_multi cls ~key:data ~data:key in - Map.fold r.rep ~init:empty_map ~f:(fun ~key ~data cls -> + Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> match classify key with | Interpreted | Atomic -> add key data cls | Simplified | Uninterpreted -> - add (Term.map ~f:(apply r.rep) key) data cls ) + add (Term.map ~f:(Subst.apply r.rep) key) data cls ) (** Pretty-printing *) @@ -128,41 +219,20 @@ let pp fs {sat; rep} = let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in Format.fprintf fs "@[{@[sat= %b;@ rep= %a@]}@]" sat (pp_alist Term.pp pp_term_v) - (Map.to_alist rep) + (Subst.to_alist rep) let pp_diff fs (r, s) = - let pp_sdiff_map pp_elt_diff equal nam fs x y = - let sd = Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y) in - if not (List.is_empty sd) then - Format.fprintf fs "%s= [@[%a@]];@ " nam - (List.pp ";@ " pp_elt_diff) - sd - in - let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function - | k, `Left v -> - Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Right v -> - Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Unequal vv -> - Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv - in - let pp_sdiff_term_map = - let pp_sdiff_term fs (u, v) = - Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v - in - pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal - in let pp_sat fs = if not (Bool.equal r.sat s.sat) then Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat in - let pp_rep fs = pp_sdiff_term_map "rep" fs r.rep s.rep in + let pp_rep fs = Subst.pp_sdiff ~pre:"rep= " fs r.rep s.rep in Format.fprintf fs "@[{@[%t%t@]}@]" pp_sat pp_rep (** Invariant *) (** test membership in carrier *) -let in_car r e = Map.mem r.rep e +let in_car r e = Subst.mem r.rep e let rec iter_max_solvables e ~f = match classify e with @@ -172,7 +242,7 @@ let rec iter_max_solvables e ~f = let invariant r = Invariant.invariant [%here] r [%sexp_of: t] @@ fun () -> - Map.iteri r.rep ~f:(fun ~key:a ~data:_ -> + Subst.iteri r.rep ~f:(fun ~key:a ~data:_ -> (* no interpreted terms in carrier *) assert (Poly.(classify a <> Interpreted)) ; (* carrier is closed under subterms *) @@ -184,26 +254,23 @@ let invariant r = (** Core operations *) -let true_ = {sat= true; rep= empty_map} |> check invariant - -(** apply a subst to maximal non-interpreted subterms *) -let rec norm s a = - match classify a with - | Interpreted -> Term.map ~f:(norm s) a - | Simplified -> apply s (Term.map ~f:(norm s) a) - | Atomic | Uninterpreted -> apply s a +let true_ = {sat= true; rep= Subst.empty} |> check invariant (** terms are congruent if equal after normalizing subterms *) let congruent r a b = - Term.equal (Term.map ~f:(norm r.rep) a) (Term.map ~f:(norm r.rep) b) + Term.equal + (Term.map ~f:(Subst.norm r.rep) a) + (Term.map ~f:(Subst.norm r.rep) b) (** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *) let lookup r a = With_return.with_return @@ fun {return} -> (* congruent specialized to assume [a] canonized and [b] non-interpreted *) - let semi_congruent r a b = Term.equal a (Term.map ~f:(apply r.rep) b) in - Map.iteri r.rep ~f:(fun ~key ~data -> + let semi_congruent r a b = + Term.equal a (Term.map ~f:(Subst.apply r.rep) b) + in + Subst.iteri r.rep ~f:(fun ~key ~data -> if semi_congruent r a key then return data ) ; a @@ -213,35 +280,25 @@ let rec canon r a = match classify a with | Interpreted -> Term.map ~f:(canon r) a | Simplified | Uninterpreted -> lookup r (Term.map ~f:(canon r) a) - | Atomic -> apply r.rep a + | Atomic -> Subst.apply r.rep a (** add a term to the carrier *) let rec extend a r = match classify a with | Interpreted | Simplified -> Term.fold ~f:extend a ~init:r - | Uninterpreted -> - Map.find_or_add r.rep a - ~if_found:(fun _ -> r) - ~default:a - ~if_added:(fun rep -> Term.fold ~f:extend a ~init:{r with rep}) + | Uninterpreted -> ( + match Subst.extend a r.rep with + | Some rep -> Term.fold ~f:extend a ~init:{r with rep} + | None -> r ) | Atomic -> r let extend a r = extend a r |> check invariant -let compose r s = - let rep = Map.map ~f:(norm s) r.rep in - let rep = - Map.merge_skewed rep s ~combine:(fun ~key v1 v2 -> - if Term.equal v1 v2 then v1 - else fail "domains intersect: %a" Term.pp key () ) - in - {r with rep} - let merge a b r = [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r] ; ( match solve a b with - | Some s -> compose r s + | Some s -> {r with rep= Subst.compose r.rep s} | None -> {r with sat= false} ) |> [%Trace.retn fun {pf} r' -> @@ -252,8 +309,8 @@ let merge a b r = let find_missing r = With_return.with_return @@ fun {return} -> - Map.iteri r.rep ~f:(fun ~key:a ~data:a' -> - Map.iteri r.rep ~f:(fun ~key:b ~data:b' -> + Subst.iteri r.rep ~f:(fun ~key:a ~data:a' -> + Subst.iteri r.rep ~f:(fun ~key:b ~data:b' -> if Term.compare a b < 0 && (not (Term.equal a' b')) @@ -295,13 +352,13 @@ let and_eq a b r = invariant r'] let is_true {sat; rep} = - sat && Map.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') + sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') let is_false {sat} = not sat let entails_eq r d e = Term.equal (canon r d) (canon r e) let entails r s = - Map.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') + Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') let normalize = canon @@ -328,9 +385,9 @@ let and_ r s = else if not s.sat then s else let s, r = - if Map.length s.rep <= Map.length r.rep then (s, r) else (r, s) + if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) in - Map.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r) + Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r) let or_ r s = [%Trace.call fun {pf} -> pf "@[ %a@ @<2>∨ %a@]" pp r pp s] @@ -355,30 +412,18 @@ let or_ r s = |> [%Trace.retn fun {pf} -> pf "%a" pp] -(* assumes that f is injective and for any set of terms E, f[E] is disjoint - from E *) -let map_terms ({sat= _; rep} as r) ~f = +let rename r sub = [%Trace.call fun {pf} -> pf "%a" pp r] ; - let map m = - Map.fold m ~init:m ~f:(fun ~key ~data m -> - let key' = f key in - let data' = f data in - if Term.equal key' key then - if Term.equal data' data then m else Map.set m ~key ~data:data' - else Map.remove m key |> Map.add_exn ~key:key' ~data:data' ) - in - let rep' = map rep in - (if rep' == rep then r else {r with rep= rep'}) + let rep = Subst.map_entries ~f:(Term.rename sub) r.rep in + (if rep == r.rep then r else {r with rep}) |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; invariant r'] -let rename r sub = map_terms r ~f:(Term.rename sub) - let fold_terms r ~init ~f = - Map.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init + Subst.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init let fold_vars r ~init ~f = fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)