[sledge] Improve Equality invariant checking and debugging support

Summary:
Ensure all entry points check the representation invariant before
returning, and strengthen it to check the constraints on preference
between representative terms, and to check that the relation is
closed.

Reviewed By: jvillard

Differential Revision: D20612566

fbshipit-source-id: b345397c4
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent 3c0924cf01
commit 7e4673cbeb

@ -308,6 +308,7 @@ let apply_summary q ({xs; foot; post} as fs) =
let%test_module _ = let%test_module _ =
( module struct ( module struct
let () = Trace.init ~margin:68 ()
let pp = Format.printf "@.%a@." Sh.pp let pp = Format.printf "@.%a@." Sh.pp
let wrt = Var.Set.empty let wrt = Var.Set.empty
let main_, wrt = Var.fresh "main" ~wrt let main_, wrt = Var.fresh "main" ~wrt

@ -177,33 +177,35 @@ end
(** Theory Solver *) (** Theory Solver *)
(** orient equations s.t. Var < Memory < Extract < Concat < others, then (** prefer representative terms that are minimal in the order s.t. Var <
using height of aggregate nesting, and then using Term.compare *) Memory < Extract < Concat < others, then using height of aggregate
let orient e f = nesting, and then using Term.compare *)
let compare e f = let prefer e f =
let rank e = let rank e =
match (e : Term.t) with match (e : Term.t) with
| Var _ -> 0 | Var _ -> 0
| Ap2 (Memory, _, _) -> 1 | Ap2 (Memory, _, _) -> 1
| Ap3 (Extract, _, _, _) -> 2 | Ap3 (Extract, _, _, _) -> 2
| ApN (Concat, _) -> 3 | ApN (Concat, _) -> 3
| _ -> 4 | _ -> 4
in
let rec height e =
match (e : Term.t) with
| Ap2 (Memory, _, x) -> 1 + height x
| Ap3 (Extract, x, _, _) -> 1 + height x
| ApN (Concat, xs) ->
1 + IArray.fold ~init:0 ~f:(fun h x -> max h (height x)) xs
| _ -> 0
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (height e) (height f) in
if o <> 0 then o else Term.compare e f
in in
match Ordering.of_int (compare e f) with let rec height e =
match (e : Term.t) with
| Ap2 (Memory, _, x) -> 1 + height x
| Ap3 (Extract, x, _, _) -> 1 + height x
| ApN (Concat, xs) ->
1 + IArray.fold ~init:0 ~f:(fun h x -> max h (height x)) xs
| _ -> 0
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (height e) (height f) in
if o <> 0 then o else Term.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Ordering.of_int (prefer e f) with
| Less -> Some (e, f) | Less -> Some (e, f)
| Equal -> None | Equal -> None
| Greater -> Some (f, e) | Greater -> Some (f, e)
@ -402,23 +404,44 @@ let pp_diff_clss =
Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls
pp_diff_cls pp_diff_cls
(** Invariant *) (** Basic queries *)
(** test membership in carrier *) (** test membership in carrier *)
let in_car r e = Subst.mem r.rep e let in_car r e = Subst.mem r.rep e
let invariant r = (** terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term.equal
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** Invariant *)
let pre_invariant r =
Invariant.invariant [%here] r [%sexp_of: t] Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () -> @@ fun () ->
Subst.iteri r.rep ~f:(fun ~key:a ~data:_ -> Subst.iteri r.rep ~f:(fun ~key:trm ~data:_ ->
(* no interpreted terms in carrier *) (* no interpreted terms in carrier *)
assert (non_interpreted a) ; assert (non_interpreted trm || fail "non-interp %a" Term.pp trm ()) ;
(* carrier is closed under subterms *) (* carrier is closed under subterms *)
iter_max_solvables a ~f:(fun b -> iter_max_solvables trm ~f:(fun subtrm ->
assert ( assert (
in_car r b in_car r subtrm
|| fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp b || fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp
Term.pp a pp r () ) ) ) subtrm Term.pp trm pp r () ) ) )
let invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
pre_invariant r ;
assert (
(not r.sat)
|| Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' ->
Subst.for_alli r.rep ~f:(fun ~key:b ~data:b' ->
Term.compare a b >= 0
|| congruent r a b ==> Term.equal a' b'
|| fail "not congruent %a@ %a@ in@ %a" Term.pp a Term.pp b pp
r () ) ) )
(** Core operations *) (** Core operations *)
@ -427,15 +450,9 @@ let true_ =
let false_ = {true_ with sat= false} let false_ = {true_ with sat= false}
(** terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term.equal
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *) (** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a = let lookup r a =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp a pp r] [%Trace.call fun {pf} -> pf "%a" Term.pp a]
; ;
( with_return ( with_return
@@ fun {return} -> @@ fun {return} ->
@ -452,7 +469,7 @@ let lookup r a =
(** rewrite a term into canonical form using rep and, for non-interpreted (** rewrite a term into canonical form using rep and, for non-interpreted
terms, congruence composed with rep *) terms, congruence composed with rep *)
let rec canon r a = let rec canon r a =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp a pp r] [%Trace.call fun {pf} -> pf "%a" Term.pp a]
; ;
( match classify a with ( match classify a with
| Atomic -> Subst.apply r.rep a | Atomic -> Subst.apply r.rep a
@ -478,7 +495,7 @@ let rec extend_ a r =
(** add a term to the carrier *) (** add a term to the carrier *)
let extend a r = let extend a r =
let rep = extend_ a r.rep in let rep = extend_ a r.rep in
if rep == r.rep then r else {r with rep} |> check invariant if rep == r.rep then r else {r with rep} |> check pre_invariant
let merge us a b r = let merge us a b r =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r] [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r]
@ -490,7 +507,7 @@ let merge us a b r =
|> |>
[%Trace.retn fun {pf} r' -> [%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
invariant r'] pre_invariant r']
(** find an unproved equation between congruent terms *) (** find an unproved equation between congruent terms *)
let find_missing r = let find_missing r =
@ -521,7 +538,7 @@ let close us r =
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
invariant r'] invariant r']
let and_eq us a b r = let and_eq_ us a b r =
if not r.sat then r if not r.sat then r
else else
let a' = canon r a in let a' = canon r a in
@ -538,7 +555,13 @@ let is_true {sat; rep} =
sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a')
let is_false {sat} = not sat let is_false {sat} = not sat
let entails_eq r d e = Term.is_true (canon r (Term.eq d e))
let entails_eq r d e =
[%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp d Term.pp e pp r]
;
Term.is_true (canon r (Term.eq d e))
|>
[%Trace.retn fun {pf} -> pf "%b"]
let entails r s = let entails r s =
Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e')
@ -584,21 +607,29 @@ let apply_subst us s r =
let rep' = Subst.subst s rep in let rep' = Subst.subst s rep in
List.fold cls ~init:r ~f:(fun r trm -> List.fold cls ~init:r ~f:(fun r trm ->
let trm' = Subst.subst s trm in let trm' = Subst.subst s trm in
and_eq us trm' rep' r ) ) and_eq_ us trm' rep' r ) )
|> extract_xs |> extract_xs
|> |>
[%Trace.retn fun {pf} (xs, r') -> pf "%a%a" Var.Set.pp_xs xs pp r'] [%Trace.retn fun {pf} (xs, r') ->
pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ;
invariant r']
let and_ us r s = let and_ us r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2>∧ %a@]" pp r pp s]
;
( if not r.sat then r ( if not r.sat then r
else if not s.sat then s else if not s.sat then s
else else
let s, r = let s, r =
if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
in in
Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq us e e' r) Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r)
) )
|> extract_xs |> extract_xs
|>
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let or_ us r s = let or_ us r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s] [%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s]
@ -612,7 +643,7 @@ let or_ us r s =
~init:([rep], rs) ~init:([rep], rs)
~f:(fun (reps, rs) exp -> ~f:(fun (reps, rs) exp ->
match List.find ~f:(entails_eq r exp) reps with match List.find ~f:(entails_eq r exp) reps with
| Some rep -> (reps, and_eq us exp rep rs) | Some rep -> (reps, and_eq_ us exp rep rs)
| None -> (exp :: reps, rs) ) | None -> (exp :: reps, rs) )
|> snd ) |> snd )
in in
@ -622,7 +653,9 @@ let or_ us r s =
rs ) rs )
|> extract_xs |> extract_xs
|> |>
[%Trace.retn fun {pf} (_, r) -> pf "%a" pp r] [%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let orN us rs = let orN us rs =
match rs with match rs with
@ -630,21 +663,28 @@ let orN us rs =
| r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs | r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs
let rec and_term_ us e r = let rec and_term_ us e r =
let eq_false b r = and_eq us b Term.false_ r in let eq_false b r = and_eq_ us b Term.false_ r in
match (e : Term.t) with match (e : Term.t) with
| Integer {data} -> if Z.is_false data then false_ else r | Integer {data} -> if Z.is_false data then false_ else r
| Ap2 (And, a, b) -> and_term_ us a (and_term_ us b r) | Ap2 (And, a, b) -> and_term_ us a (and_term_ us b r)
| Ap2 (Eq, a, b) -> and_eq us a b r | Ap2 (Eq, a, b) -> and_eq_ us a b r
| Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r | Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r
| Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r | Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r
| _ -> r | _ -> r
let and_term us e r = and_term_ us e r |> extract_xs let and_term us e r =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e pp r]
;
and_term_ us e r |> extract_xs
|>
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let and_eq us a b r = let and_eq us a b r =
[%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp a Term.pp b pp r] [%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp a Term.pp b pp r]
; ;
and_eq us a b r |> extract_xs and_eq_ us a b r |> extract_xs
|> |>
[%Trace.retn fun {pf} (_, r') -> [%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
@ -710,6 +750,8 @@ type 'a zom = Zero | One of 'a | Many
[fv kill us]; solve [p = q] for [kill]; extend subst mapping [kill] [fv kill us]; solve [p = q] for [kill]; extend subst mapping [kill]
to the solution *) to the solution *)
let solve_poly_eq us p' q' subst = let solve_poly_eq us p' q' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Term.pp p' Term.pp q']
;
let diff = Term.sub p' q' in let diff = Term.sub p' q' in
let max_solvables_not_ito_us = let max_solvables_not_ito_us =
fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> function fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> function
@ -718,11 +760,14 @@ let solve_poly_eq us p' q' subst =
| One _ -> Many | One _ -> Many
| Zero -> One solvable_subterm ) | Zero -> One solvable_subterm )
in in
match max_solvables_not_ito_us with ( match max_solvables_not_ito_us with
| One kill -> | One kill ->
let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in
Subst.compose1 ~key:kill ~data:keep subst Subst.compose1 ~key:kill ~data:keep subst
| Many | Zero -> None | Many | Zero -> None )
|>
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let solve_memory_eq us e' f' subst = let solve_memory_eq us e' f' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Term.pp e' Term.pp f'] [%Trace.call fun {pf} -> pf "%a = %a" Term.pp e' Term.pp f']
@ -1001,7 +1046,9 @@ let pp_vss fs vss =
[fv u ¹ v] if possible and otherwise [fv u ¹ v] if possible and otherwise
[fv u v] *) [fv u v] *)
let solve_for_vars vss r = let solve_for_vars vss r =
[%Trace.call fun {pf} -> pf "%a@ @[%a@]" pp_vss vss pp_classes r] [%Trace.call fun {pf} ->
pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ;
invariant r]
; ;
let us, vss = let us, vss =
match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss) match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss)
@ -1014,8 +1061,8 @@ let solve_for_vars vss r =
Subst.iteri subst ~f:(fun ~key ~data -> Subst.iteri subst ~f:(fun ~key ~data ->
assert ( assert (
entails_eq r key data entails_eq r key data
|| fail "@[%a = %a not entailed by@ %a@]" Term.pp key Term.pp data || fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Term.pp key
pp_classes r () ) ; Term.pp data pp_classes r () ) ;
assert ( assert (
List.fold_until vss ~init:us List.fold_until vss ~init:us
~f:(fun us xs -> ~f:(fun us xs ->

@ -52,6 +52,9 @@ let%test_module _ =
let and_eq a b r = and_eq wrt a b r |> snd let and_eq a b r = and_eq wrt a b r |> snd
let and_ r s = and_ wrt r s |> snd let and_ r s = and_ wrt r s |> snd
let or_ r s = or_ wrt r s |> snd let or_ r s = or_ wrt r s |> snd
(* tests *)
let f1 = of_eqs [(!0, !1)] let f1 = of_eqs [(!0, !1)]
let%test _ = is_false f1 let%test _ = is_false f1

@ -73,7 +73,7 @@ end) : S with type key = Key.t = struct
let pp pp_k pp_v fs m = let pp pp_k pp_v fs m =
Format.fprintf fs "@[<1>[%a]@]" Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) -> (List.pp ",@ " (fun fs (k, v) ->
Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) Format.fprintf fs "@[%a@ @<2>↦ %a@]" pp_k k pp_v v ))
(to_alist m) (to_alist m)
let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) =

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree. * LICENSE file in the root directory of this source tree.
*) *)
(* [@@@warning "-32"] *)
let%test_module _ = let%test_module _ =
( module struct ( module struct
open Sh open Sh
@ -14,7 +12,9 @@ let%test_module _ =
let () = Trace.init ~margin:68 () let () = Trace.init ~margin:68 ()
(* let () = (* let () =
* Trace.init ~margin:160 ~config:(Result.ok_exn (Trace.parse "+Sh")) () *) * Trace.init ~margin:160 ~config:(Result.ok_exn (Trace.parse "+Sh")) ()
*
* [@@@warning "-32"] *)
let pp = Format.printf "@\n%a@." pp let pp = Format.printf "@\n%a@." pp
let pp_raw = Format.printf "@\n%a@." pp_raw let pp_raw = Format.printf "@\n%a@." pp_raw
@ -41,6 +41,9 @@ let%test_module _ =
let x = Term.var x_ let x = Term.var x_
let y = Term.var y_ let y = Term.var y_
let of_eqs l =
List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l
let%expect_test _ = let%expect_test _ =
let p = exists ~$[x_] (extend_us ~$[x_] emp) in let p = exists ~$[x_] (extend_us ~$[x_] emp) in
let q = pure (x = !0) in let q = pure (x = !0) in
@ -125,9 +128,6 @@ let%test_module _ =
( ( 1 = %y_7 emp) ( emp) ( emp) ) |}] ( ( 1 = %y_7 emp) ( emp) ( emp) ) |}]
let of_eqs l =
List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l
let%expect_test _ = let%expect_test _ =
let q = exists ~$[x_] (of_eqs [(f x, x); (f y, y - !1)]) in let q = exists ~$[x_] (of_eqs [(f x, x); (f y, y - !1)]) in
pp q ; pp q ;

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree. * LICENSE file in the root directory of this source tree.
*) *)
(* [@@@warning "-32"] *)
let%test_module _ = let%test_module _ =
( module struct ( module struct
let () = let () =
@ -18,7 +16,9 @@ let%test_module _ =
* Trace.init ~margin:160 * Trace.init ~margin:160
* ~config: * ~config:
* (Result.ok_exn (Trace.parse "+Solver.infer_frame+Solver.excise")) * (Result.ok_exn (Trace.parse "+Solver.infer_frame+Solver.excise"))
* () *) * ()
*
* [@@@warning "-32"] *)
let infer_frame p xs q = let infer_frame p xs q =
Solver.infer_frame p (Var.Set.of_list xs) q Solver.infer_frame p (Var.Set.of_list xs) q

@ -213,6 +213,6 @@ let fail fmt =
let margin = Format.pp_get_margin fs () in let margin = Format.pp_get_margin fs () in
raisef ~margin raisef ~margin
(fun msg -> (fun msg ->
Format.fprintf fs "@\n@[<2>| %s@]@." msg ; Format.fprintf fs "@\n%s@." msg ;
Failure msg ) Failure msg )
fmt fmt

Loading…
Cancel
Save