[sledge] Factor solution substitutions into Equality.Subst

Summary:
Equality relies on the result of solving an equation to be a "solution
substitution". In constrast to unconstrained Map's, solution
substitutions are idempotent and have constraints on the terms that
may appear in their domain (they must be "maximal solvables", that is,
variables or uninterpreted function applications, which would be
variables if explicit "variable abstraction" was done).

This diff factors out the manipulation of concrete Map's into a
Equality.Subst module, and uses these for the result of `solve`.

Reviewed By: ngorogiannis

Differential Revision: D19282637

fbshipit-source-id: 4fc825e59
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent bfdb379fe3
commit 9b1ff9c012

@ -7,18 +7,122 @@
(** Equality over uninterpreted functions and linear rational arithmetic *) (** Equality over uninterpreted functions and linear rational arithmetic *)
type 'a term_map = 'a Map.M(Term).t [@@deriving compare, equal, sexp] (** Classification of Terms by Theory *)
let empty_map = Map.empty (module Term) type kind = Interpreted | Simplified | Atomic | Uninterpreted
[@@deriving compare]
type subst = Term.t term_map [@@deriving compare, equal, sexp] let classify e =
match (e : Term.t) with
| Add _ | Mul _ -> Interpreted
| Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let pp_subst fs s = (** Solution Substitutions *)
module Subst : sig
type t [@@deriving compare, equal, sexp]
val pp : t pp
val pp_sdiff : ?pre:string -> Format.formatter -> t -> t -> unit
val empty : t
val length : t -> int
val mem : t -> Term.t -> bool
val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a
val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit
val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool
val apply : t -> Term.t -> Term.t
val norm : t -> Term.t -> Term.t
val compose : t -> t -> t
val compose1 : key:Term.t -> data:Term.t -> t -> t
val extend : Term.t -> t -> t option
val map_entries : f:(Term.t -> Term.t) -> t -> t
val to_alist : t -> (Term.t * Term.t) list
end = struct
type t = Term.t Term.Map.t [@@deriving compare, equal, sexp]
let pp fs s =
Format.fprintf fs "@[<1>[%a]@]" Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) -> (List.pp ",@ " (fun fs (k, v) ->
Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v )) Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v ))
(Map.to_alist s) (Map.to_alist s)
let pp_sdiff ?(pre = "") =
let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function
| k, `Left v ->
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Right v ->
Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Unequal vv ->
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv
in
let pp_sdiff_map pp_elt_diff equal fs x y =
let sd =
Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y)
in
if not (List.is_empty sd) then
Format.fprintf fs "%s[@[<hv>%a@]];@ " pre
(List.pp ";@ " pp_elt_diff)
sd
in
let pp_sdiff_term fs (u, v) =
Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v
in
pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal
let empty = Term.Map.empty
let length = Map.length
let mem = Map.mem
let fold = Map.fold
let iteri = Map.iteri
let for_alli = Map.for_alli
let to_alist = Map.to_alist ~key_order:`Increasing
(** look up a term in a substitution *)
let apply s a = Map.find s a |> Option.value ~default:a
(** apply a substitution to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term.map ~f:(norm s) a
| Simplified -> apply s (Term.map ~f:(norm s) a)
| Atomic | Uninterpreted -> apply s a
(** compose two substitutions *)
let compose r s =
let r' = Map.map ~f:(norm s) r in
Map.merge_skewed r' s ~combine:(fun ~key v1 v2 ->
if Term.equal v1 v2 then v1
else fail "domains intersect: %a" Term.pp key () )
(** compose a substitution with a mapping *)
let compose1 ~key ~data s =
if Term.equal key data then s
else compose s (Map.set Term.Map.empty ~key ~data)
(** add an identity entry if the term is not already present *)
let extend e s =
let exception Found in
match
Map.update s e ~f:(function
| Some _ -> Exn.raise_without_backtrace Found
| None -> e )
with
| exception Found -> None
| s -> Some s
(** map over a subst, applying [f] to both domain and range, requires that
[f] is injective and for any set of terms [E], [f\[E\]] is disjoint
from [E] *)
let map_entries ~f s =
Map.fold s ~init:s ~f:(fun ~key ~data s ->
let key' = f key in
let data' = f data in
if Term.equal key' key then
if Term.equal data' data then s else Map.set s ~key ~data:data'
else Map.remove s key |> Map.add_exn ~key:key' ~data:data' )
end
(** Theory Solver *) (** Theory Solver *)
let rec is_constant e = let rec is_constant e =
@ -32,16 +136,6 @@ let rec is_constant e =
Qset.for_all ~f:(fun arg _ -> is_constant arg) args Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| Label _ | Float _ | Integer _ -> true | Label _ | Float _ | Integer _ -> true
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[@@deriving compare]
let classify e =
match (e : Term.t) with
| Add _ | Mul _ -> Interpreted
| Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let solve e f = let solve e f =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e Term.pp f] [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e Term.pp f]
; ;
@ -53,13 +147,13 @@ let solve e f =
match (is_constant e, is_constant f) with match (is_constant e, is_constant f) with
(* orient equation to discretionarily prefer term that is constant (* orient equation to discretionarily prefer term that is constant
or compares smaller as class representative *) or compares smaller as class representative *)
| true, false -> Some (Map.add_exn s ~key:f ~data:e) | true, false -> Some (Subst.compose1 ~key:f ~data:e s)
| false, true -> Some (Map.add_exn s ~key:e ~data:f) | false, true -> Some (Subst.compose1 ~key:e ~data:f s)
| _ -> | _ ->
let key, data = let key, data =
if Term.compare e f > 0 then (e, f) else (f, e) if Term.compare e f > 0 then (e, f) else (f, e)
in in
Some (Map.add_exn s ~key ~data) ) Some (Subst.compose1 ~key ~data s) )
in in
let concat_size args = let concat_size args =
Vector.fold_until args ~init:Term.zero Vector.fold_until args ~init:Term.zero
@ -73,7 +167,7 @@ let solve e f =
| (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> ( | (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> (
let e_f = Term.sub e f in let e_f = Term.sub e f in
match Term.solve_zero_eq e_f with match Term.solve_zero_eq e_f with
| Some (key, data) -> Some (Map.add_exn s ~key ~data) | Some (key, data) -> Some (Subst.compose1 ~key ~data s)
| None -> solve_uninterp e_f Term.zero ) | None -> solve_uninterp e_f Term.zero )
| ApN (Concat, ms), ApN (Concat, ns) -> ( | ApN (Concat, ms), ApN (Concat, ns) -> (
match (concat_size ms, concat_size ns) with match (concat_size ms, concat_size ns) with
@ -86,35 +180,32 @@ let solve e f =
| _ -> solve_uninterp e f ) | _ -> solve_uninterp e f )
| _ -> solve_uninterp e f | _ -> solve_uninterp e f
in in
solve_ e f empty_map solve_ e f Subst.empty
|> |>
[%Trace.retn fun {pf} -> [%Trace.retn fun {pf} ->
function Some s -> pf "%a" pp_subst s | None -> pf "false"] function Some s -> pf "%a" Subst.pp s | None -> pf "false"]
(** Equality Relations *) (** Equality Relations *)
(** see also [invariant] *) (** see also [invariant] *)
type t = type t =
{ sat: bool (** [false] only if constraints are inconsistent *) { sat: bool (** [false] only if constraints are inconsistent *)
; rep: subst ; rep: Subst.t
(** functional set of oriented equations: map [a] to [a'], (** functional set of oriented equations: map [a] to [a'],
indicating that [a = a'] holds, and that [a'] is the indicating that [a = a'] holds, and that [a'] is the
'rep(resentative)' of [a] *) } 'rep(resentative)' of [a] *) }
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
(** apply a subst to a term *)
let apply s a = Map.find s a |> Option.value ~default:a
let classes r = let classes r =
let add key data cls = let add key data cls =
if Term.equal key data then cls if Term.equal key data then cls
else Map.add_multi cls ~key:data ~data:key else Map.add_multi cls ~key:data ~data:key
in in
Map.fold r.rep ~init:empty_map ~f:(fun ~key ~data cls -> Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls ->
match classify key with match classify key with
| Interpreted | Atomic -> add key data cls | Interpreted | Atomic -> add key data cls
| Simplified | Uninterpreted -> | Simplified | Uninterpreted ->
add (Term.map ~f:(apply r.rep) key) data cls ) add (Term.map ~f:(Subst.apply r.rep) key) data cls )
(** Pretty-printing *) (** Pretty-printing *)
@ -128,41 +219,20 @@ let pp fs {sat; rep} =
let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in
Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a@]}@]" sat Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a@]}@]" sat
(pp_alist Term.pp pp_term_v) (pp_alist Term.pp pp_term_v)
(Map.to_alist rep) (Subst.to_alist rep)
let pp_diff fs (r, s) = let pp_diff fs (r, s) =
let pp_sdiff_map pp_elt_diff equal nam fs x y =
let sd = Sequence.to_list (Map.symmetric_diff ~data_equal:equal x y) in
if not (List.is_empty sd) then
Format.fprintf fs "%s= [@[<hv>%a@]];@ " nam
(List.pp ";@ " pp_elt_diff)
sd
in
let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function
| k, `Left v ->
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Right v ->
Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Unequal vv ->
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv
in
let pp_sdiff_term_map =
let pp_sdiff_term fs (u, v) =
Format.fprintf fs "-- %a ++ %a" Term.pp u Term.pp v
in
pp_sdiff_map (pp_sdiff_elt Term.pp Term.pp pp_sdiff_term) Term.equal
in
let pp_sat fs = let pp_sat fs =
if not (Bool.equal r.sat s.sat) then if not (Bool.equal r.sat s.sat) then
Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat
in in
let pp_rep fs = pp_sdiff_term_map "rep" fs r.rep s.rep in let pp_rep fs = Subst.pp_sdiff ~pre:"rep= " fs r.rep s.rep in
Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep
(** Invariant *) (** Invariant *)
(** test membership in carrier *) (** test membership in carrier *)
let in_car r e = Map.mem r.rep e let in_car r e = Subst.mem r.rep e
let rec iter_max_solvables e ~f = let rec iter_max_solvables e ~f =
match classify e with match classify e with
@ -172,7 +242,7 @@ let rec iter_max_solvables e ~f =
let invariant r = let invariant r =
Invariant.invariant [%here] r [%sexp_of: t] Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () -> @@ fun () ->
Map.iteri r.rep ~f:(fun ~key:a ~data:_ -> Subst.iteri r.rep ~f:(fun ~key:a ~data:_ ->
(* no interpreted terms in carrier *) (* no interpreted terms in carrier *)
assert (Poly.(classify a <> Interpreted)) ; assert (Poly.(classify a <> Interpreted)) ;
(* carrier is closed under subterms *) (* carrier is closed under subterms *)
@ -184,26 +254,23 @@ let invariant r =
(** Core operations *) (** Core operations *)
let true_ = {sat= true; rep= empty_map} |> check invariant let true_ = {sat= true; rep= Subst.empty} |> check invariant
(** apply a subst to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term.map ~f:(norm s) a
| Simplified -> apply s (Term.map ~f:(norm s) a)
| Atomic | Uninterpreted -> apply s a
(** terms are congruent if equal after normalizing subterms *) (** terms are congruent if equal after normalizing subterms *)
let congruent r a b = let congruent r a b =
Term.equal (Term.map ~f:(norm r.rep) a) (Term.map ~f:(norm r.rep) b) Term.equal
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *) (** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a = let lookup r a =
With_return.with_return With_return.with_return
@@ fun {return} -> @@ fun {return} ->
(* congruent specialized to assume [a] canonized and [b] non-interpreted *) (* congruent specialized to assume [a] canonized and [b] non-interpreted *)
let semi_congruent r a b = Term.equal a (Term.map ~f:(apply r.rep) b) in let semi_congruent r a b =
Map.iteri r.rep ~f:(fun ~key ~data -> Term.equal a (Term.map ~f:(Subst.apply r.rep) b)
in
Subst.iteri r.rep ~f:(fun ~key ~data ->
if semi_congruent r a key then return data ) ; if semi_congruent r a key then return data ) ;
a a
@ -213,35 +280,25 @@ let rec canon r a =
match classify a with match classify a with
| Interpreted -> Term.map ~f:(canon r) a | Interpreted -> Term.map ~f:(canon r) a
| Simplified | Uninterpreted -> lookup r (Term.map ~f:(canon r) a) | Simplified | Uninterpreted -> lookup r (Term.map ~f:(canon r) a)
| Atomic -> apply r.rep a | Atomic -> Subst.apply r.rep a
(** add a term to the carrier *) (** add a term to the carrier *)
let rec extend a r = let rec extend a r =
match classify a with match classify a with
| Interpreted | Simplified -> Term.fold ~f:extend a ~init:r | Interpreted | Simplified -> Term.fold ~f:extend a ~init:r
| Uninterpreted -> | Uninterpreted -> (
Map.find_or_add r.rep a match Subst.extend a r.rep with
~if_found:(fun _ -> r) | Some rep -> Term.fold ~f:extend a ~init:{r with rep}
~default:a | None -> r )
~if_added:(fun rep -> Term.fold ~f:extend a ~init:{r with rep})
| Atomic -> r | Atomic -> r
let extend a r = extend a r |> check invariant let extend a r = extend a r |> check invariant
let compose r s =
let rep = Map.map ~f:(norm s) r.rep in
let rep =
Map.merge_skewed rep s ~combine:(fun ~key v1 v2 ->
if Term.equal v1 v2 then v1
else fail "domains intersect: %a" Term.pp key () )
in
{r with rep}
let merge a b r = let merge a b r =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r] [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r]
; ;
( match solve a b with ( match solve a b with
| Some s -> compose r s | Some s -> {r with rep= Subst.compose r.rep s}
| None -> {r with sat= false} ) | None -> {r with sat= false} )
|> |>
[%Trace.retn fun {pf} r' -> [%Trace.retn fun {pf} r' ->
@ -252,8 +309,8 @@ let merge a b r =
let find_missing r = let find_missing r =
With_return.with_return With_return.with_return
@@ fun {return} -> @@ fun {return} ->
Map.iteri r.rep ~f:(fun ~key:a ~data:a' -> Subst.iteri r.rep ~f:(fun ~key:a ~data:a' ->
Map.iteri r.rep ~f:(fun ~key:b ~data:b' -> Subst.iteri r.rep ~f:(fun ~key:b ~data:b' ->
if if
Term.compare a b < 0 Term.compare a b < 0
&& (not (Term.equal a' b')) && (not (Term.equal a' b'))
@ -295,13 +352,13 @@ let and_eq a b r =
invariant r'] invariant r']
let is_true {sat; rep} = let is_true {sat; rep} =
sat && Map.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a')
let is_false {sat} = not sat let is_false {sat} = not sat
let entails_eq r d e = Term.equal (canon r d) (canon r e) let entails_eq r d e = Term.equal (canon r d) (canon r e)
let entails r s = let entails r s =
Map.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e')
let normalize = canon let normalize = canon
@ -328,9 +385,9 @@ let and_ r s =
else if not s.sat then s else if not s.sat then s
else else
let s, r = let s, r =
if Map.length s.rep <= Map.length r.rep then (s, r) else (r, s) if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
in in
Map.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r) Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq e e' r)
let or_ r s = let or_ r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s] [%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s]
@ -355,30 +412,18 @@ let or_ r s =
|> |>
[%Trace.retn fun {pf} -> pf "%a" pp] [%Trace.retn fun {pf} -> pf "%a" pp]
(* assumes that f is injective and for any set of terms E, f[E] is disjoint let rename r sub =
from E *)
let map_terms ({sat= _; rep} as r) ~f =
[%Trace.call fun {pf} -> pf "%a" pp r] [%Trace.call fun {pf} -> pf "%a" pp r]
; ;
let map m = let rep = Subst.map_entries ~f:(Term.rename sub) r.rep in
Map.fold m ~init:m ~f:(fun ~key ~data m -> (if rep == r.rep then r else {r with rep})
let key' = f key in
let data' = f data in
if Term.equal key' key then
if Term.equal data' data then m else Map.set m ~key ~data:data'
else Map.remove m key |> Map.add_exn ~key:key' ~data:data' )
in
let rep' = map rep in
(if rep' == rep then r else {r with rep= rep'})
|> |>
[%Trace.retn fun {pf} r' -> [%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
invariant r'] invariant r']
let rename r sub = map_terms r ~f:(Term.rename sub)
let fold_terms r ~init ~f = let fold_terms r ~init ~f =
Map.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init Subst.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init
let fold_vars r ~init ~f = let fold_vars r ~init ~f =
fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init) fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)

Loading…
Cancel
Save