[sledge] Improve Equality invariant checking and debugging support

Summary:
Ensure all entry points check the representation invariant before
returning, and strengthen it to check the constraints on preference
between representative terms, and to check that the relation is
closed.

Reviewed By: jvillard

Differential Revision: D20612566

fbshipit-source-id: b345397c4
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent 3c0924cf01
commit 7e4673cbeb

@ -308,6 +308,7 @@ let apply_summary q ({xs; foot; post} as fs) =
let%test_module _ =
( module struct
let () = Trace.init ~margin:68 ()
let pp = Format.printf "@.%a@." Sh.pp
let wrt = Var.Set.empty
let main_, wrt = Var.fresh "main" ~wrt

@ -177,10 +177,10 @@ end
(** Theory Solver *)
(** orient equations s.t. Var < Memory < Extract < Concat < others, then
using height of aggregate nesting, and then using Term.compare *)
let orient e f =
let compare e f =
(** prefer representative terms that are minimal in the order s.t. Var <
Memory < Extract < Concat < others, then using height of aggregate
nesting, and then using Term.compare *)
let prefer e f =
let rank e =
match (e : Term.t) with
| Var _ -> 0
@ -202,8 +202,10 @@ let orient e f =
else
let o = compare (height e) (height f) in
if o <> 0 then o else Term.compare e f
in
match Ordering.of_int (compare e f) with
(** orient equations based on representative preference *)
let orient e f =
match Ordering.of_int (prefer e f) with
| Less -> Some (e, f)
| Equal -> None
| Greater -> Some (f, e)
@ -402,23 +404,44 @@ let pp_diff_clss =
Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls
pp_diff_cls
(** Invariant *)
(** Basic queries *)
(** test membership in carrier *)
let in_car r e = Subst.mem r.rep e
let invariant r =
(** terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term.equal
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** Invariant *)
let pre_invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
Subst.iteri r.rep ~f:(fun ~key:a ~data:_ ->
Subst.iteri r.rep ~f:(fun ~key:trm ~data:_ ->
(* no interpreted terms in carrier *)
assert (non_interpreted a) ;
assert (non_interpreted trm || fail "non-interp %a" Term.pp trm ()) ;
(* carrier is closed under subterms *)
iter_max_solvables a ~f:(fun b ->
iter_max_solvables trm ~f:(fun subtrm ->
assert (
in_car r subtrm
|| fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp
subtrm Term.pp trm pp r () ) ) )
let invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
pre_invariant r ;
assert (
in_car r b
|| fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp b
Term.pp a pp r () ) ) )
(not r.sat)
|| Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' ->
Subst.for_alli r.rep ~f:(fun ~key:b ~data:b' ->
Term.compare a b >= 0
|| congruent r a b ==> Term.equal a' b'
|| fail "not congruent %a@ %a@ in@ %a" Term.pp a Term.pp b pp
r () ) ) )
(** Core operations *)
@ -427,15 +450,9 @@ let true_ =
let false_ = {true_ with sat= false}
(** terms are congruent if equal after normalizing subterms *)
let congruent r a b =
Term.equal
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp a pp r]
[%Trace.call fun {pf} -> pf "%a" Term.pp a]
;
( with_return
@@ fun {return} ->
@ -452,7 +469,7 @@ let lookup r a =
(** rewrite a term into canonical form using rep and, for non-interpreted
terms, congruence composed with rep *)
let rec canon r a =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp a pp r]
[%Trace.call fun {pf} -> pf "%a" Term.pp a]
;
( match classify a with
| Atomic -> Subst.apply r.rep a
@ -478,7 +495,7 @@ let rec extend_ a r =
(** add a term to the carrier *)
let extend a r =
let rep = extend_ a r.rep in
if rep == r.rep then r else {r with rep} |> check invariant
if rep == r.rep then r else {r with rep} |> check pre_invariant
let merge us a b r =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r]
@ -490,7 +507,7 @@ let merge us a b r =
|>
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
invariant r']
pre_invariant r']
(** find an unproved equation between congruent terms *)
let find_missing r =
@ -521,7 +538,7 @@ let close us r =
pf "%a" pp_diff (r, r') ;
invariant r']
let and_eq us a b r =
let and_eq_ us a b r =
if not r.sat then r
else
let a' = canon r a in
@ -538,7 +555,13 @@ let is_true {sat; rep} =
sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a')
let is_false {sat} = not sat
let entails_eq r d e = Term.is_true (canon r (Term.eq d e))
let entails_eq r d e =
[%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp d Term.pp e pp r]
;
Term.is_true (canon r (Term.eq d e))
|>
[%Trace.retn fun {pf} -> pf "%b"]
let entails r s =
Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e')
@ -584,21 +607,29 @@ let apply_subst us s r =
let rep' = Subst.subst s rep in
List.fold cls ~init:r ~f:(fun r trm ->
let trm' = Subst.subst s trm in
and_eq us trm' rep' r ) )
and_eq_ us trm' rep' r ) )
|> extract_xs
|>
[%Trace.retn fun {pf} (xs, r') -> pf "%a%a" Var.Set.pp_xs xs pp r']
[%Trace.retn fun {pf} (xs, r') ->
pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ;
invariant r']
let and_ us r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2>∧ %a@]" pp r pp s]
;
( if not r.sat then r
else if not s.sat then s
else
let s, r =
if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
in
Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq us e e' r)
Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r)
)
|> extract_xs
|>
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let or_ us r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s]
@ -612,7 +643,7 @@ let or_ us r s =
~init:([rep], rs)
~f:(fun (reps, rs) exp ->
match List.find ~f:(entails_eq r exp) reps with
| Some rep -> (reps, and_eq us exp rep rs)
| Some rep -> (reps, and_eq_ us exp rep rs)
| None -> (exp :: reps, rs) )
|> snd )
in
@ -622,7 +653,9 @@ let or_ us r s =
rs )
|> extract_xs
|>
[%Trace.retn fun {pf} (_, r) -> pf "%a" pp r]
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let orN us rs =
match rs with
@ -630,21 +663,28 @@ let orN us rs =
| r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs
let rec and_term_ us e r =
let eq_false b r = and_eq us b Term.false_ r in
let eq_false b r = and_eq_ us b Term.false_ r in
match (e : Term.t) with
| Integer {data} -> if Z.is_false data then false_ else r
| Ap2 (And, a, b) -> and_term_ us a (and_term_ us b r)
| Ap2 (Eq, a, b) -> and_eq us a b r
| Ap2 (Eq, a, b) -> and_eq_ us a b r
| Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r
| Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r
| _ -> r
let and_term us e r = and_term_ us e r |> extract_xs
let and_term us e r =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e pp r]
;
and_term_ us e r |> extract_xs
|>
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let and_eq us a b r =
[%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp a Term.pp b pp r]
;
and_eq us a b r |> extract_xs
and_eq_ us a b r |> extract_xs
|>
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
@ -710,6 +750,8 @@ type 'a zom = Zero | One of 'a | Many
[fv kill us]; solve [p = q] for [kill]; extend subst mapping [kill]
to the solution *)
let solve_poly_eq us p' q' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Term.pp p' Term.pp q']
;
let diff = Term.sub p' q' in
let max_solvables_not_ito_us =
fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> function
@ -718,11 +760,14 @@ let solve_poly_eq us p' q' subst =
| One _ -> Many
| Zero -> One solvable_subterm )
in
match max_solvables_not_ito_us with
( match max_solvables_not_ito_us with
| One kill ->
let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in
Subst.compose1 ~key:kill ~data:keep subst
| Many | Zero -> None
| Many | Zero -> None )
|>
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let solve_memory_eq us e' f' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Term.pp e' Term.pp f']
@ -1001,7 +1046,9 @@ let pp_vss fs vss =
[fv u ¹ v] if possible and otherwise
[fv u v] *)
let solve_for_vars vss r =
[%Trace.call fun {pf} -> pf "%a@ @[%a@]" pp_vss vss pp_classes r]
[%Trace.call fun {pf} ->
pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ;
invariant r]
;
let us, vss =
match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss)
@ -1014,8 +1061,8 @@ let solve_for_vars vss r =
Subst.iteri subst ~f:(fun ~key ~data ->
assert (
entails_eq r key data
|| fail "@[%a = %a not entailed by@ %a@]" Term.pp key Term.pp data
pp_classes r () ) ;
|| fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Term.pp key
Term.pp data pp_classes r () ) ;
assert (
List.fold_until vss ~init:us
~f:(fun us xs ->

@ -52,6 +52,9 @@ let%test_module _ =
let and_eq a b r = and_eq wrt a b r |> snd
let and_ r s = and_ wrt r s |> snd
let or_ r s = or_ wrt r s |> snd
(* tests *)
let f1 = of_eqs [(!0, !1)]
let%test _ = is_false f1

@ -73,7 +73,7 @@ end) : S with type key = Key.t = struct
let pp pp_k pp_v fs m =
Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) ->
Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v ))
Format.fprintf fs "@[%a@ @<2>↦ %a@]" pp_k k pp_v v ))
(to_alist m)
let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) =

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*)
(* [@@@warning "-32"] *)
let%test_module _ =
( module struct
open Sh
@ -14,7 +12,9 @@ let%test_module _ =
let () = Trace.init ~margin:68 ()
(* let () =
* Trace.init ~margin:160 ~config:(Result.ok_exn (Trace.parse "+Sh")) () *)
* Trace.init ~margin:160 ~config:(Result.ok_exn (Trace.parse "+Sh")) ()
*
* [@@@warning "-32"] *)
let pp = Format.printf "@\n%a@." pp
let pp_raw = Format.printf "@\n%a@." pp_raw
@ -41,6 +41,9 @@ let%test_module _ =
let x = Term.var x_
let y = Term.var y_
let of_eqs l =
List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l
let%expect_test _ =
let p = exists ~$[x_] (extend_us ~$[x_] emp) in
let q = pure (x = !0) in
@ -125,9 +128,6 @@ let%test_module _ =
( ( 1 = %y_7 emp) ( emp) ( emp) ) |}]
let of_eqs l =
List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l
let%expect_test _ =
let q = exists ~$[x_] (of_eqs [(f x, x); (f y, y - !1)]) in
pp q ;

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*)
(* [@@@warning "-32"] *)
let%test_module _ =
( module struct
let () =
@ -18,7 +16,9 @@ let%test_module _ =
* Trace.init ~margin:160
* ~config:
* (Result.ok_exn (Trace.parse "+Solver.infer_frame+Solver.excise"))
* () *)
* ()
*
* [@@@warning "-32"] *)
let infer_frame p xs q =
Solver.infer_frame p (Var.Set.of_list xs) q

@ -213,6 +213,6 @@ let fail fmt =
let margin = Format.pp_get_margin fs () in
raisef ~margin
(fun msg ->
Format.fprintf fs "@\n@[<2>| %s@]@." msg ;
Format.fprintf fs "@\n%s@." msg ;
Failure msg )
fmt

Loading…
Cancel
Save