[sledge] Factor solution substitutions into Equality.Subst

Summary:
Equality relies on the result of solving an equation to be a "solution
substitution". In constrast to unconstrained Map's, solution
substitutions are idempotent and have constraints on the terms that
may appear in their domain (they must be "maximal solvables", that is,
variables or uninterpreted function applications, which would be
variables if explicit "variable abstraction" was done).

This diff factors out the manipulation of concrete Map's into a
Equality.Subst module, and uses these for the result of `solve`.

Reviewed By: ngorogiannis

Differential Revision: D19282637

fbshipit-source-id: 4fc825e59
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent bfdb379fe3
commit 9b1ff9c012

@ -7,18 +7,122 @@
(** Equality over uninterpreted functions and linear rational arithmetic *)
type 'a term_map = 'a Map.M(Term).t [@@deriving compare, equal, sexp]
(** Classification of Terms by Theory *)
let empty_map = Map.empty (module Term)
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[@@deriving compare]
type subst = Term.t term_map [@@deriving compare, equal, sexp]
let classify e =
match (e : Term.t) with
| Add _ | Mul _ -> Interpreted
| Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let pp_subst fs s =
(** Solution Substitutions *)
module Subst : sig
type t [@@deriving compare, equal, sexp]
val pp : t pp
val pp_sdiff : ?pre:string -> Format.formatter -> t -> t -> unit
val empty : t
val length : t -> int
val mem : t -> Term.t -> bool
val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a
val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit
val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool
val apply : t -> Term.t -> Term.t
val norm : t -> Term.t -> Term.t
val compose : t -> t -> t
val compose1 : key:Term.t -> data:Term.t -> t -> t
val extend : Term.t -> t -> t option
val map_entries : f:(Term.t -> Term.t) -> t -> t
val to_alist : t -> (Term.t * Term.t) list
end = struct
type t = Term.t Term.Map.t [@@deriving compare, equal, sexp]
let pp fs s =
Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) ->
Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v ))
(Map.to_alist s)
let pp_sdiff ?(pre = "") =
let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function
| k, `Left v ->
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Right v ->
Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Unequal vv ->
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv
in
let pp_sdiff_map pp_elt_diff equal fs x y =
let sd =
Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y)
in
if not (List.is_empty sd) then
Format.fprintf fs "%s[@[<hv>%a@]];@ " pre
(List.pp ";@ " pp_elt_diff)
sd
in
let pp_sdiff_term fs (u, v) =
Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v
in
pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal
let empty = Term.Map.empty
let length = Map.length
let mem = Map.mem
let fold = Map.fold
let iteri = Map.iteri
let for_alli = Map.for_alli
let to_alist = Map.to_alist ~key_order:`Increasing
(** look up a term in a substitution *)
let apply s a = Map.find s a |> Option.value ~default:a
(** apply a substitution to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term.map ~f:(norm s) a
| Simplified -> apply s (Term.map ~f:(norm s) a)
| Atomic | Uninterpreted -> apply s a
(** compose two substitutions *)
let compose r s =
let r' = Map.map ~f:(norm s) r in
Map.merge_skewed r' s ~combine:(fun ~key v1 v2 ->
if Term.equal v1 v2 then v1
else fail "domains intersect: %a" Term.pp key () )
(** compose a substitution with a mapping *)
let compose1 ~key ~data s =
if Term.equal key data then s
else compose s (Map.set Term.Map.empty ~key ~data)
(** add an identity entry if the term is not already present *)
let extend e s =
let exception Found in
match
Map.update s e ~f:(function
| Some _ -> Exn.raise_without_backtrace Found
| None -> e )
with
| exception Found -> None
| s -> Some s
(** map over a subst, applying [f] to both domain and range, requires that
[f] is injective and for any set of terms [E], [f\[E\]] is disjoint
from [E] *)
let map_entries ~f s =
Map.fold s ~init:s ~f:(fun ~key ~data s ->
let key' = f key in
let data' = f data in
if Term.equal key' key then
if Term.equal data' data then s else Map.set s ~key ~data:data'
else Map.remove s key |> Map.add_exn ~key:key' ~data:data' )
end
(** Theory Solver *)
let rec is_constant e =
@ -32,16 +136,6 @@ let rec is_constant e =
Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| Label _ | Float _ | Integer _ -> true
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[@@deriving compare]
let classify e =
match (e : Term.t) with
| Add _ | Mul _ -> Interpreted
| Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let solve e f =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e Term.pp f]
;
@ -53,13 +147,13 @@ let solve e f =
match (is_constant e, is_constant f) with
(* orient equation to discretionarily prefer term that is constant
or compares smaller as class representative *)
| true, false -> Some (Map.add_exn s ~key:f ~data:e)
| false, true -> Some (Map.add_exn s ~key:e ~data:f)
| true, false -> Some (Subst.compose1 ~key:f ~data:e s)
| false, true -> Some (Subst.compose1 ~key:e ~data:f s)
| _ ->
let key, data =
if Term.compare e f > 0 then (e, f) else (f, e)
in
Some (Map.add_exn s ~key ~data) )
Some (Subst.compose1 ~key ~data s) )
in
let concat_size args =
Vector.fold_until args ~init:Term.zero
@ -73,7 +167,7 @@ let solve e f =
| (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> (
let e_f = Term.sub e f in
match Term.solve_zero_eq e_f with
| Some (key, data) -> Some (Map.add_exn s ~key ~data)
| Some (key, data) -> Some (Subst.compose1 ~key ~data s)
| None -> solve_uninterp e_f Term.zero )
| ApN (Concat, ms), ApN (Concat, ns) -> (
match (concat_size ms, concat_size ns) with
@ -86,35 +180,32 @@ let solve e f =
| _ -> solve_uninterp e f )
| _ -> solve_uninterp e f
in
solve_ e f empty_map
solve_ e f Subst.empty
|>
[%Trace.retn fun {pf} ->
function Some s -> pf "%a" pp_subst s | None -> pf "false"]
function Some s -> pf "%a" Subst.pp s | None -> pf "false"]
(** Equality Relations *)
(** see also [invariant] *)
type t =
{ sat: bool (** [false] only if constraints are inconsistent *)
; rep: subst
; rep: Subst.t
(** functional set of oriented equations: map [a] to [a'],
indicating that [a = a'] holds, and that [a'] is the
'rep(resentative)' of [a] *) }
[@@deriving compare, equal, sexp]
(** apply a subst to a term *)
let apply s a = Map.find s a |> Option.value ~default:a
let classes r =
let add key data cls =
if Term.equal key data then cls
else Map.add_multi cls ~key:data ~data:key
in
Map.fold r.rep ~init:empty_map ~f:(fun ~key ~data cls ->
Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls ->
match classify key with
| Interpreted | Atomic -> add key data cls
| Simplified | Uninterpreted ->
add (Term.map ~f:(apply r.rep) key) data cls )
add (Term.map ~f:(Subst.apply r.rep) key) data cls )
(** Pretty-printing *)
@ -128,41 +219,20 @@ let pp fs {sat; rep} =
let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in
Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a@]}@]" sat
(pp_alist Term.pp pp_term_v)
(Map.to_alist rep)
(Subst.to_alist rep)
let pp_diff fs (r, s) =
let pp_sdiff_map pp_elt_diff equal nam fs x y =
let sd = Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y) in
if not (List.is_empty sd) then
Format.fprintf fs "%s= [@[<hv>%a@]];@ " nam
(List.pp ";@ " pp_elt_diff)
sd
in
let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function
| k, `Left v ->
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Right v ->
Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Unequal vv ->
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv
in
let pp_sdiff_term_map =
let pp_sdiff_term fs (u, v) =
Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v
in
pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal
in
let pp_sat fs =
if not (Bool.equal r.sat s.sat) then
Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat
in
let pp_rep fs = pp_sdiff_term_map "rep" fs r.rep s.rep in
let pp_rep fs = Subst.pp_sdiff ~pre:"rep= " fs r.rep s.rep in
Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep
(** Invariant *)
(** test membership in carrier *)
let in_car r e = Map.mem r.rep e
let in_car r e = Subst.mem r.rep e
let rec iter_max_solvables e ~f =
match classify e with
@ -172,7 +242,7 @@ let rec iter_max_solvables e ~f =
let invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
Map.iteri r.rep ~f:(fun ~key:a ~data:_ ->
Subst.iteri r.rep ~f:(fun ~key:a ~data:_ ->
(* no interpreted terms in carrier *)
assert (Poly.(classify a <> Interpreted)) ;
(* carrier is closed under subterms *)
@ -184,26 +254,23 @@ let invariant r =
(** Core operations *)
let true_ = {sat= true; rep= empty_map} |> check invariant
(** apply a subst to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term.map ~f:(norm s) a
| Simplified -> apply s (Term.map ~f:(norm s) a)
| Atomic | Uninterpreted -> apply s a
let true_ = {sat= true; rep= Subst.empty} |> check invariant
(** terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term.equal (Term.map ~f:(norm r.rep) a) (Term.map ~f:(norm r.rep) b)
Term.equal
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a =
With_return.with_return
@@ fun {return} ->
(* congruent specialized to assume [a] canonized and [b] non-interpreted *)
let semi_congruent r a b = Term.equal a (Term.map ~f:(apply r.rep) b) in
Map.iteri r.rep ~f:(fun ~key ~data ->
let semi_congruent r a b =
Term.equal a (Term.map ~f:(Subst.apply r.rep) b)
in
Subst.iteri r.rep ~f:(fun ~key ~data ->
if semi_congruent r a key then return data ) ;
a
@ -213,35 +280,25 @@ let rec canon r a =
match classify a with
| Interpreted -> Term.map ~f:(canon r) a
| Simplified | Uninterpreted -> lookup r (Term.map ~f:(canon r) a)
| Atomic -> apply r.rep a
| Atomic -> Subst.apply r.rep a
(** add a term to the carrier *)
let rec extend a r =
match classify a with
| Interpreted | Simplified -> Term.fold ~f:extend a ~init:r
| Uninterpreted ->
Map.find_or_add r.rep a
~if_found:(fun _ -> r)
~default:a
~if_added:(fun rep -> Term.fold ~f:extend a ~init:{r with rep})
| Uninterpreted -> (
match Subst.extend a r.rep with
| Some rep -> Term.fold ~f:extend a ~init:{r with rep}
| None -> r )
| Atomic -> r
let extend a r = extend a r |> check invariant
let compose r s =
let rep = Map.map ~f:(norm s) r.rep in
let rep =
Map.merge_skewed rep s ~combine:(fun ~key v1 v2 ->
if Term.equal v1 v2 then v1
else fail "domains intersect: %a" Term.pp key () )
in
{r with rep}
let merge a b r =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r]
;
( match solve a b with
| Some s -> compose r s
| Some s -> {r with rep= Subst.compose r.rep s}
| None -> {r with sat= false} )
|>
[%Trace.retn fun {pf} r' ->
@ -252,8 +309,8 @@ let merge a b r =
let find_missing r =
With_return.with_return
@@ fun {return} ->
Map.iteri r.rep ~f:(fun ~key:a ~data:a' ->
Map.iteri r.rep ~f:(fun ~key:b ~data:b' ->
Subst.iteri r.rep ~f:(fun ~key:a ~data:a' ->
Subst.iteri r.rep ~f:(fun ~key:b ~data:b' ->
if
Term.compare a b < 0
&& (not (Term.equal a' b'))
@ -295,13 +352,13 @@ let and_eq a b r =
invariant r']
let is_true {sat; rep} =
sat && Map.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a')
sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a')
let is_false {sat} = not sat
let entails_eq r d e = Term.equal (canon r d) (canon r e)
let entails r s =
Map.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e')
Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e')
let normalize = canon
@ -328,9 +385,9 @@ let and_ r s =
else if not s.sat then s
else
let s, r =
if Map.length s.rep <= Map.length r.rep then (s, r) else (r, s)
if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
in
Map.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r)
Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r)
let or_ r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s]
@ -355,30 +412,18 @@ let or_ r s =
|>
[%Trace.retn fun {pf} -> pf "%a" pp]
(* assumes that f is injective and for any set of terms E, f[E] is disjoint
from E *)
let map_terms ({sat= _; rep} as r) ~f =
let rename r sub =
[%Trace.call fun {pf} -> pf "%a" pp r]
;
let map m =
Map.fold m ~init:m ~f:(fun ~key ~data m ->
let key' = f key in
let data' = f data in
if Term.equal key' key then
if Term.equal data' data then m else Map.set m ~key ~data:data'
else Map.remove m key |> Map.add_exn ~key:key' ~data:data' )
in
let rep' = map rep in
(if rep' == rep then r else {r with rep= rep'})
let rep = Subst.map_entries ~f:(Term.rename sub) r.rep in
(if rep == r.rep then r else {r with rep})
|>
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
invariant r']
let rename r sub = map_terms r ~f:(Term.rename sub)
let fold_terms r ~init ~f =
Map.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init
Subst.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init
let fold_vars r ~init ~f =
fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)

Loading…
Cancel
Save