From 9b1ff9c012c3a71a013e54da287a572ea8351fba Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Wed, 15 Jan 2020 13:17:30 -0800 Subject: [PATCH] [sledge] Factor solution substitutions into Equality.Subst Summary: Equality relies on the result of solving an equation to be a "solution substitution". In constrast to unconstrained Map's, solution substitutions are idempotent and have constraints on the terms that may appear in their domain (they must be "maximal solvables", that is, variables or uninterpreted function applications, which would be variables if explicit "variable abstraction" was done). This diff factors out the manipulation of concrete Map's into a Equality.Subst module, and uses these for the result of `solve`. Reviewed By: ngorogiannis Differential Revision: D19282637 fbshipit-source-id: 4fc825e59 --- sledge/src/symbheap/equality.ml | 253 +++++++++++++++++++------------- 1 file changed, 149 insertions(+), 104 deletions(-) diff --git a/sledge/src/symbheap/equality.ml b/sledge/src/symbheap/equality.ml index 37a50f6a0..718cb7ffc 100644 --- a/sledge/src/symbheap/equality.ml +++ b/sledge/src/symbheap/equality.ml @@ -7,17 +7,121 @@ (** Equality over uninterpreted functions and linear rational arithmetic *) -type 'a term_map = 'a Map.M(Term).t [@@deriving compare, equal, sexp] +(** Classification of Terms by Theory *) -let empty_map = Map.empty (module Term) +type kind = Interpreted | Simplified | Atomic | Uninterpreted +[@@deriving compare] + +let classify e = + match (e : Term.t) with + | Add _ | Mul _ -> Interpreted + | Ap2 ((Eq | Dq), _, _) -> Simplified + | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted + | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic + +(** Solution Substitutions *) +module Subst : sig + type t [@@deriving compare, equal, sexp] + + val pp : t pp + val pp_sdiff : ?pre:string -> Format.formatter -> t -> t -> unit + val empty : t + val length : t -> int + val mem : t -> Term.t -> bool + val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a + val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit + val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool + val apply : t -> Term.t -> Term.t + val norm : t -> Term.t -> Term.t + val compose : t -> t -> t + val compose1 : key:Term.t -> data:Term.t -> t -> t + val extend : Term.t -> t -> t option + val map_entries : f:(Term.t -> Term.t) -> t -> t + val to_alist : t -> (Term.t * Term.t) list +end = struct + type t = Term.t Term.Map.t [@@deriving compare, equal, sexp] + + let pp fs s = + Format.fprintf fs "@[<1>[%a]@]" + (List.pp ",@ " (fun fs (k, v) -> + Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v )) + (Map.to_alist s) + + let pp_sdiff ?(pre = "") = + let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function + | k, `Left v -> + Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Right v -> + Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Unequal vv -> + Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv + in + let pp_sdiff_map pp_elt_diff equal fs x y = + let sd = + Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y) + in + if not (List.is_empty sd) then + Format.fprintf fs "%s[@[%a@]];@ " pre + (List.pp ";@ " pp_elt_diff) + sd + in + let pp_sdiff_term fs (u, v) = + Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v + in + pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal -type subst = Term.t term_map [@@deriving compare, equal, sexp] + let empty = Term.Map.empty + let length = Map.length + let mem = Map.mem + let fold = Map.fold + let iteri = Map.iteri + let for_alli = Map.for_alli + let to_alist = Map.to_alist ~key_order:`Increasing + + (** look up a term in a substitution *) + let apply s a = Map.find s a |> Option.value ~default:a + + (** apply a substitution to maximal non-interpreted subterms *) + let rec norm s a = + match classify a with + | Interpreted -> Term.map ~f:(norm s) a + | Simplified -> apply s (Term.map ~f:(norm s) a) + | Atomic | Uninterpreted -> apply s a + + (** compose two substitutions *) + let compose r s = + let r' = Map.map ~f:(norm s) r in + Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> + if Term.equal v1 v2 then v1 + else fail "domains intersect: %a" Term.pp key () ) -let pp_subst fs s = - Format.fprintf fs "@[<1>[%a]@]" - (List.pp ",@ " (fun fs (k, v) -> - Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v )) - (Map.to_alist s) + (** compose a substitution with a mapping *) + let compose1 ~key ~data s = + if Term.equal key data then s + else compose s (Map.set Term.Map.empty ~key ~data) + + (** add an identity entry if the term is not already present *) + let extend e s = + let exception Found in + match + Map.update s e ~f:(function + | Some _ -> Exn.raise_without_backtrace Found + | None -> e ) + with + | exception Found -> None + | s -> Some s + + (** map over a subst, applying [f] to both domain and range, requires that + [f] is injective and for any set of terms [E], [f\[E\]] is disjoint + from [E] *) + let map_entries ~f s = + Map.fold s ~init:s ~f:(fun ~key ~data s -> + let key' = f key in + let data' = f data in + if Term.equal key' key then + if Term.equal data' data then s else Map.set s ~key ~data:data' + else Map.remove s key |> Map.add_exn ~key:key' ~data:data' ) +end (** Theory Solver *) @@ -32,16 +136,6 @@ let rec is_constant e = Qset.for_all ~f:(fun arg _ -> is_constant arg) args | Label _ | Float _ | Integer _ -> true -type kind = Interpreted | Simplified | Atomic | Uninterpreted -[@@deriving compare] - -let classify e = - match (e : Term.t) with - | Add _ | Mul _ -> Interpreted - | Ap2 ((Eq | Dq), _, _) -> Simplified - | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted - | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic - let solve e f = [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e Term.pp f] ; @@ -53,13 +147,13 @@ let solve e f = match (is_constant e, is_constant f) with (* orient equation to discretionarily prefer term that is constant or compares smaller as class representative *) - | true, false -> Some (Map.add_exn s ~key:f ~data:e) - | false, true -> Some (Map.add_exn s ~key:e ~data:f) + | true, false -> Some (Subst.compose1 ~key:f ~data:e s) + | false, true -> Some (Subst.compose1 ~key:e ~data:f s) | _ -> let key, data = if Term.compare e f > 0 then (e, f) else (f, e) in - Some (Map.add_exn s ~key ~data) ) + Some (Subst.compose1 ~key ~data s) ) in let concat_size args = Vector.fold_until args ~init:Term.zero @@ -73,7 +167,7 @@ let solve e f = | (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> ( let e_f = Term.sub e f in match Term.solve_zero_eq e_f with - | Some (key, data) -> Some (Map.add_exn s ~key ~data) + | Some (key, data) -> Some (Subst.compose1 ~key ~data s) | None -> solve_uninterp e_f Term.zero ) | ApN (Concat, ms), ApN (Concat, ns) -> ( match (concat_size ms, concat_size ns) with @@ -86,35 +180,32 @@ let solve e f = | _ -> solve_uninterp e f ) | _ -> solve_uninterp e f in - solve_ e f empty_map + solve_ e f Subst.empty |> [%Trace.retn fun {pf} -> - function Some s -> pf "%a" pp_subst s | None -> pf "false"] + function Some s -> pf "%a" Subst.pp s | None -> pf "false"] (** Equality Relations *) (** see also [invariant] *) type t = { sat: bool (** [false] only if constraints are inconsistent *) - ; rep: subst + ; rep: Subst.t (** functional set of oriented equations: map [a] to [a'], indicating that [a = a'] holds, and that [a'] is the 'rep(resentative)' of [a] *) } [@@deriving compare, equal, sexp] -(** apply a subst to a term *) -let apply s a = Map.find s a |> Option.value ~default:a - let classes r = let add key data cls = if Term.equal key data then cls else Map.add_multi cls ~key:data ~data:key in - Map.fold r.rep ~init:empty_map ~f:(fun ~key ~data cls -> + Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> match classify key with | Interpreted | Atomic -> add key data cls | Simplified | Uninterpreted -> - add (Term.map ~f:(apply r.rep) key) data cls ) + add (Term.map ~f:(Subst.apply r.rep) key) data cls ) (** Pretty-printing *) @@ -128,41 +219,20 @@ let pp fs {sat; rep} = let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in Format.fprintf fs "@[{@[sat= %b;@ rep= %a@]}@]" sat (pp_alist Term.pp pp_term_v) - (Map.to_alist rep) + (Subst.to_alist rep) let pp_diff fs (r, s) = - let pp_sdiff_map pp_elt_diff equal nam fs x y = - let sd = Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y) in - if not (List.is_empty sd) then - Format.fprintf fs "%s= [@[%a@]];@ " nam - (List.pp ";@ " pp_elt_diff) - sd - in - let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function - | k, `Left v -> - Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Right v -> - Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Unequal vv -> - Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv - in - let pp_sdiff_term_map = - let pp_sdiff_term fs (u, v) = - Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v - in - pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal - in let pp_sat fs = if not (Bool.equal r.sat s.sat) then Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat in - let pp_rep fs = pp_sdiff_term_map "rep" fs r.rep s.rep in + let pp_rep fs = Subst.pp_sdiff ~pre:"rep= " fs r.rep s.rep in Format.fprintf fs "@[{@[%t%t@]}@]" pp_sat pp_rep (** Invariant *) (** test membership in carrier *) -let in_car r e = Map.mem r.rep e +let in_car r e = Subst.mem r.rep e let rec iter_max_solvables e ~f = match classify e with @@ -172,7 +242,7 @@ let rec iter_max_solvables e ~f = let invariant r = Invariant.invariant [%here] r [%sexp_of: t] @@ fun () -> - Map.iteri r.rep ~f:(fun ~key:a ~data:_ -> + Subst.iteri r.rep ~f:(fun ~key:a ~data:_ -> (* no interpreted terms in carrier *) assert (Poly.(classify a <> Interpreted)) ; (* carrier is closed under subterms *) @@ -184,26 +254,23 @@ let invariant r = (** Core operations *) -let true_ = {sat= true; rep= empty_map} |> check invariant - -(** apply a subst to maximal non-interpreted subterms *) -let rec norm s a = - match classify a with - | Interpreted -> Term.map ~f:(norm s) a - | Simplified -> apply s (Term.map ~f:(norm s) a) - | Atomic | Uninterpreted -> apply s a +let true_ = {sat= true; rep= Subst.empty} |> check invariant (** terms are congruent if equal after normalizing subterms *) let congruent r a b = - Term.equal (Term.map ~f:(norm r.rep) a) (Term.map ~f:(norm r.rep) b) + Term.equal + (Term.map ~f:(Subst.norm r.rep) a) + (Term.map ~f:(Subst.norm r.rep) b) (** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *) let lookup r a = With_return.with_return @@ fun {return} -> (* congruent specialized to assume [a] canonized and [b] non-interpreted *) - let semi_congruent r a b = Term.equal a (Term.map ~f:(apply r.rep) b) in - Map.iteri r.rep ~f:(fun ~key ~data -> + let semi_congruent r a b = + Term.equal a (Term.map ~f:(Subst.apply r.rep) b) + in + Subst.iteri r.rep ~f:(fun ~key ~data -> if semi_congruent r a key then return data ) ; a @@ -213,35 +280,25 @@ let rec canon r a = match classify a with | Interpreted -> Term.map ~f:(canon r) a | Simplified | Uninterpreted -> lookup r (Term.map ~f:(canon r) a) - | Atomic -> apply r.rep a + | Atomic -> Subst.apply r.rep a (** add a term to the carrier *) let rec extend a r = match classify a with | Interpreted | Simplified -> Term.fold ~f:extend a ~init:r - | Uninterpreted -> - Map.find_or_add r.rep a - ~if_found:(fun _ -> r) - ~default:a - ~if_added:(fun rep -> Term.fold ~f:extend a ~init:{r with rep}) + | Uninterpreted -> ( + match Subst.extend a r.rep with + | Some rep -> Term.fold ~f:extend a ~init:{r with rep} + | None -> r ) | Atomic -> r let extend a r = extend a r |> check invariant -let compose r s = - let rep = Map.map ~f:(norm s) r.rep in - let rep = - Map.merge_skewed rep s ~combine:(fun ~key v1 v2 -> - if Term.equal v1 v2 then v1 - else fail "domains intersect: %a" Term.pp key () ) - in - {r with rep} - let merge a b r = [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r] ; ( match solve a b with - | Some s -> compose r s + | Some s -> {r with rep= Subst.compose r.rep s} | None -> {r with sat= false} ) |> [%Trace.retn fun {pf} r' -> @@ -252,8 +309,8 @@ let merge a b r = let find_missing r = With_return.with_return @@ fun {return} -> - Map.iteri r.rep ~f:(fun ~key:a ~data:a' -> - Map.iteri r.rep ~f:(fun ~key:b ~data:b' -> + Subst.iteri r.rep ~f:(fun ~key:a ~data:a' -> + Subst.iteri r.rep ~f:(fun ~key:b ~data:b' -> if Term.compare a b < 0 && (not (Term.equal a' b')) @@ -295,13 +352,13 @@ let and_eq a b r = invariant r'] let is_true {sat; rep} = - sat && Map.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') + sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') let is_false {sat} = not sat let entails_eq r d e = Term.equal (canon r d) (canon r e) let entails r s = - Map.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') + Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') let normalize = canon @@ -328,9 +385,9 @@ let and_ r s = else if not s.sat then s else let s, r = - if Map.length s.rep <= Map.length r.rep then (s, r) else (r, s) + if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) in - Map.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r) + Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r) let or_ r s = [%Trace.call fun {pf} -> pf "@[ %a@ @<2>∨ %a@]" pp r pp s] @@ -355,30 +412,18 @@ let or_ r s = |> [%Trace.retn fun {pf} -> pf "%a" pp] -(* assumes that f is injective and for any set of terms E, f[E] is disjoint - from E *) -let map_terms ({sat= _; rep} as r) ~f = +let rename r sub = [%Trace.call fun {pf} -> pf "%a" pp r] ; - let map m = - Map.fold m ~init:m ~f:(fun ~key ~data m -> - let key' = f key in - let data' = f data in - if Term.equal key' key then - if Term.equal data' data then m else Map.set m ~key ~data:data' - else Map.remove m key |> Map.add_exn ~key:key' ~data:data' ) - in - let rep' = map rep in - (if rep' == rep then r else {r with rep= rep'}) + let rep = Subst.map_entries ~f:(Term.rename sub) r.rep in + (if rep == r.rep then r else {r with rep}) |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; invariant r'] -let rename r sub = map_terms r ~f:(Term.rename sub) - let fold_terms r ~init ~f = - Map.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init + Subst.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init let fold_vars r ~init ~f = fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)