From 7e4673cbeb277807c7fd0ac576958e563f5ec86c Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Thu, 16 Apr 2020 03:38:25 -0700 Subject: [PATCH] [sledge] Improve Equality invariant checking and debugging support Summary: Ensure all entry points check the representation invariant before returning, and strengthen it to check the constraints on preference between representative terms, and to check that the relation is closed. Reviewed By: jvillard Differential Revision: D20612566 fbshipit-source-id: b345397c4 --- sledge/lib/domain_sh.ml | 1 + sledge/lib/equality.ml | 167 ++++++++++++++++++++------------ sledge/lib/equality_test.ml | 19 ++-- sledge/lib/import/map.ml | 2 +- sledge/lib/sh_test.ml | 16 +-- sledge/lib/solver_test.ml | 6 +- sledge/ppx_trace/trace/trace.ml | 2 +- 7 files changed, 132 insertions(+), 81 deletions(-) diff --git a/sledge/lib/domain_sh.ml b/sledge/lib/domain_sh.ml index 280c652dc..327b6ec61 100644 --- a/sledge/lib/domain_sh.ml +++ b/sledge/lib/domain_sh.ml @@ -308,6 +308,7 @@ let apply_summary q ({xs; foot; post} as fs) = let%test_module _ = ( module struct + let () = Trace.init ~margin:68 () let pp = Format.printf "@.%a@." Sh.pp let wrt = Var.Set.empty let main_, wrt = Var.fresh "main" ~wrt diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index f4965e29b..ec88f3214 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -177,33 +177,35 @@ end (** Theory Solver *) -(** orient equations s.t. Var < Memory < Extract < Concat < others, then - using height of aggregate nesting, and then using Term.compare *) -let orient e f = - let compare e f = - let rank e = - match (e : Term.t) with - | Var _ -> 0 - | Ap2 (Memory, _, _) -> 1 - | Ap3 (Extract, _, _, _) -> 2 - | ApN (Concat, _) -> 3 - | _ -> 4 - in - let rec height e = - match (e : Term.t) with - | Ap2 (Memory, _, x) -> 1 + height x - | Ap3 (Extract, x, _, _) -> 1 + height x - | ApN (Concat, xs) -> - 1 + IArray.fold ~init:0 ~f:(fun h x -> max h (height x)) xs - | _ -> 0 - in - let o = compare (rank e) (rank f) in - if o <> 0 then o - else - let o = compare (height e) (height f) in - if o <> 0 then o else Term.compare e f +(** prefer representative terms that are minimal in the order s.t. Var < + Memory < Extract < Concat < others, then using height of aggregate + nesting, and then using Term.compare *) +let prefer e f = + let rank e = + match (e : Term.t) with + | Var _ -> 0 + | Ap2 (Memory, _, _) -> 1 + | Ap3 (Extract, _, _, _) -> 2 + | ApN (Concat, _) -> 3 + | _ -> 4 in - match Ordering.of_int (compare e f) with + let rec height e = + match (e : Term.t) with + | Ap2 (Memory, _, x) -> 1 + height x + | Ap3 (Extract, x, _, _) -> 1 + height x + | ApN (Concat, xs) -> + 1 + IArray.fold ~init:0 ~f:(fun h x -> max h (height x)) xs + | _ -> 0 + in + let o = compare (rank e) (rank f) in + if o <> 0 then o + else + let o = compare (height e) (height f) in + if o <> 0 then o else Term.compare e f + +(** orient equations based on representative preference *) +let orient e f = + match Ordering.of_int (prefer e f) with | Less -> Some (e, f) | Equal -> None | Greater -> Some (f, e) @@ -402,23 +404,44 @@ let pp_diff_clss = Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls -(** Invariant *) +(** Basic queries *) (** test membership in carrier *) let in_car r e = Subst.mem r.rep e -let invariant r = +(** terms are congruent if equal after normalizing subterms *) +let congruent r a b = + Term.equal + (Term.map ~f:(Subst.norm r.rep) a) + (Term.map ~f:(Subst.norm r.rep) b) + +(** Invariant *) + +let pre_invariant r = Invariant.invariant [%here] r [%sexp_of: t] @@ fun () -> - Subst.iteri r.rep ~f:(fun ~key:a ~data:_ -> + Subst.iteri r.rep ~f:(fun ~key:trm ~data:_ -> (* no interpreted terms in carrier *) - assert (non_interpreted a) ; + assert (non_interpreted trm || fail "non-interp %a" Term.pp trm ()) ; (* carrier is closed under subterms *) - iter_max_solvables a ~f:(fun b -> + iter_max_solvables trm ~f:(fun subtrm -> assert ( - in_car r b - || fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp b - Term.pp a pp r () ) ) ) + in_car r subtrm + || fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp + subtrm Term.pp trm pp r () ) ) ) + +let invariant r = + Invariant.invariant [%here] r [%sexp_of: t] + @@ fun () -> + pre_invariant r ; + assert ( + (not r.sat) + || Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' -> + Subst.for_alli r.rep ~f:(fun ~key:b ~data:b' -> + Term.compare a b >= 0 + || congruent r a b ==> Term.equal a' b' + || fail "not congruent %a@ %a@ in@ %a" Term.pp a Term.pp b pp + r () ) ) ) (** Core operations *) @@ -427,15 +450,9 @@ let true_ = let false_ = {true_ with sat= false} -(** terms are congruent if equal after normalizing subterms *) -let congruent r a b = - Term.equal - (Term.map ~f:(Subst.norm r.rep) a) - (Term.map ~f:(Subst.norm r.rep) b) - (** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *) let lookup r a = - [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp a pp r] + [%Trace.call fun {pf} -> pf "%a" Term.pp a] ; ( with_return @@ fun {return} -> @@ -452,7 +469,7 @@ let lookup r a = (** rewrite a term into canonical form using rep and, for non-interpreted terms, congruence composed with rep *) let rec canon r a = - [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp a pp r] + [%Trace.call fun {pf} -> pf "%a" Term.pp a] ; ( match classify a with | Atomic -> Subst.apply r.rep a @@ -478,7 +495,7 @@ let rec extend_ a r = (** add a term to the carrier *) let extend a r = let rep = extend_ a r.rep in - if rep == r.rep then r else {r with rep} |> check invariant + if rep == r.rep then r else {r with rep} |> check pre_invariant let merge us a b r = [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r] @@ -490,7 +507,7 @@ let merge us a b r = |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; - invariant r'] + pre_invariant r'] (** find an unproved equation between congruent terms *) let find_missing r = @@ -521,7 +538,7 @@ let close us r = pf "%a" pp_diff (r, r') ; invariant r'] -let and_eq us a b r = +let and_eq_ us a b r = if not r.sat then r else let a' = canon r a in @@ -538,7 +555,13 @@ let is_true {sat; rep} = sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a') let is_false {sat} = not sat -let entails_eq r d e = Term.is_true (canon r (Term.eq d e)) + +let entails_eq r d e = + [%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp d Term.pp e pp r] + ; + Term.is_true (canon r (Term.eq d e)) + |> + [%Trace.retn fun {pf} -> pf "%b"] let entails r s = Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e') @@ -584,21 +607,29 @@ let apply_subst us s r = let rep' = Subst.subst s rep in List.fold cls ~init:r ~f:(fun r trm -> let trm' = Subst.subst s trm in - and_eq us trm' rep' r ) ) + and_eq_ us trm' rep' r ) ) |> extract_xs |> - [%Trace.retn fun {pf} (xs, r') -> pf "%a%a" Var.Set.pp_xs xs pp r'] + [%Trace.retn fun {pf} (xs, r') -> + pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ; + invariant r'] let and_ us r s = + [%Trace.call fun {pf} -> pf "@[ %a@ @<2>∧ %a@]" pp r pp s] + ; ( if not r.sat then r else if not s.sat then s else let s, r = if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) in - Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq us e e' r) + Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r) ) |> extract_xs + |> + [%Trace.retn fun {pf} (_, r') -> + pf "%a" pp_diff (r, r') ; + invariant r'] let or_ us r s = [%Trace.call fun {pf} -> pf "@[ %a@ @<2>∨ %a@]" pp r pp s] @@ -612,7 +643,7 @@ let or_ us r s = ~init:([rep], rs) ~f:(fun (reps, rs) exp -> match List.find ~f:(entails_eq r exp) reps with - | Some rep -> (reps, and_eq us exp rep rs) + | Some rep -> (reps, and_eq_ us exp rep rs) | None -> (exp :: reps, rs) ) |> snd ) in @@ -622,7 +653,9 @@ let or_ us r s = rs ) |> extract_xs |> - [%Trace.retn fun {pf} (_, r) -> pf "%a" pp r] + [%Trace.retn fun {pf} (_, r') -> + pf "%a" pp_diff (r, r') ; + invariant r'] let orN us rs = match rs with @@ -630,21 +663,28 @@ let orN us rs = | r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs let rec and_term_ us e r = - let eq_false b r = and_eq us b Term.false_ r in + let eq_false b r = and_eq_ us b Term.false_ r in match (e : Term.t) with | Integer {data} -> if Z.is_false data then false_ else r | Ap2 (And, a, b) -> and_term_ us a (and_term_ us b r) - | Ap2 (Eq, a, b) -> and_eq us a b r + | Ap2 (Eq, a, b) -> and_eq_ us a b r | Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r | Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r | _ -> r -let and_term us e r = and_term_ us e r |> extract_xs +let and_term us e r = + [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e pp r] + ; + and_term_ us e r |> extract_xs + |> + [%Trace.retn fun {pf} (_, r') -> + pf "%a" pp_diff (r, r') ; + invariant r'] let and_eq us a b r = [%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp a Term.pp b pp r] ; - and_eq us a b r |> extract_xs + and_eq_ us a b r |> extract_xs |> [%Trace.retn fun {pf} (_, r') -> pf "%a" pp_diff (r, r') ; @@ -710,6 +750,8 @@ type 'a zom = Zero | One of 'a | Many [fv kill ⊈ us]; solve [p = q] for [kill]; extend subst mapping [kill] to the solution *) let solve_poly_eq us p' q' subst = + [%Trace.call fun {pf} -> pf "%a = %a" Term.pp p' Term.pp q'] + ; let diff = Term.sub p' q' in let max_solvables_not_ito_us = fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> function @@ -718,11 +760,14 @@ let solve_poly_eq us p' q' subst = | One _ -> Many | Zero -> One solvable_subterm ) in - match max_solvables_not_ito_us with + ( match max_solvables_not_ito_us with | One kill -> let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in Subst.compose1 ~key:kill ~data:keep subst - | Many | Zero -> None + | Many | Zero -> None ) + |> + [%Trace.retn fun {pf} subst' -> + pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] let solve_memory_eq us e' f' subst = [%Trace.call fun {pf} -> pf "%a = %a" Term.pp e' Term.pp f'] @@ -1001,7 +1046,9 @@ let pp_vss fs vss = [fv u ⊆ ⋃ⱼ₌₁ⁱ⁻¹ vⱼ] if possible and otherwise [fv u ⊆ ⋃ⱼ₌₁ⁱ vⱼ] *) let solve_for_vars vss r = - [%Trace.call fun {pf} -> pf "%a@ @[%a@]" pp_vss vss pp_classes r] + [%Trace.call fun {pf} -> + pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ; + invariant r] ; let us, vss = match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss) @@ -1014,8 +1061,8 @@ let solve_for_vars vss r = Subst.iteri subst ~f:(fun ~key ~data -> assert ( entails_eq r key data - || fail "@[%a = %a not entailed by@ %a@]" Term.pp key Term.pp data - pp_classes r () ) ; + || fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Term.pp key + Term.pp data pp_classes r () ) ; assert ( List.fold_until vss ~init:us ~f:(fun us xs -> diff --git a/sledge/lib/equality_test.ml b/sledge/lib/equality_test.ml index 0c6d25599..28960ad14 100644 --- a/sledge/lib/equality_test.ml +++ b/sledge/lib/equality_test.ml @@ -52,6 +52,9 @@ let%test_module _ = let and_eq a b r = and_eq wrt a b r |> snd let and_ r s = and_ wrt r s |> snd let or_ r s = or_ wrt r s |> snd + + (* tests *) + let f1 = of_eqs [(!0, !1)] let%test _ = is_false f1 @@ -100,7 +103,7 @@ let%test_module _ = [%expect {| %x_5 = %y_6 - + {sat= true; rep= [[%y_6 ↦ %x_5]]} |}] let%test _ = entails_eq r1 x y @@ -137,9 +140,9 @@ let%test_module _ = [%expect {| {sat= true; rep= [[%y_6 ↦ %w_4]; [%z_7 ↦ %w_4]]} - + {sat= true; rep= [[%y_6 ↦ %x_5]; [%z_7 ↦ %x_5]]} - + {sat= true; rep= [[%z_7 ↦ %y_6]]} |}] let%test _ = @@ -215,7 +218,7 @@ let%test_module _ = [%expect {| %v_3 = %x_5 ∧ %w_4 = %y_6 = %z_7 - + {sat= true; rep= [[%x_5 ↦ %v_3]; [%y_6 ↦ %w_4]; [%z_7 ↦ %w_4]]} {sat= true; @@ -263,7 +266,7 @@ let%test_module _ = [%expect {| (13 × %z_7) = %x_5 ∧ 14 = %y_6 - + {sat= true; rep= [[%x_5 ↦ (13 × %z_7)]; [%y_6 ↦ 14]]} |}] let%test _ = entails_eq r8 y !14 @@ -276,7 +279,7 @@ let%test_module _ = [%expect {| (%z_7 + -16) = %x_5 - + {sat= true; rep= [[%x_5 ↦ (%z_7 + -16)]]} |}] let%test _ = difference r9 z (x + !8) |> Poly.equal (Some (Z.of_int 8)) @@ -293,9 +296,9 @@ let%test_module _ = [%expect {| (%z_7 + -16) = %x_5 - + {sat= true; rep= [[%x_5 ↦ (%z_7 + -16)]]} - + (-1 × %x_5 + %z_7 + -8) 8 diff --git a/sledge/lib/import/map.ml b/sledge/lib/import/map.ml index bead4fd89..57f093710 100644 --- a/sledge/lib/import/map.ml +++ b/sledge/lib/import/map.ml @@ -73,7 +73,7 @@ end) : S with type key = Key.t = struct let pp pp_k pp_v fs m = Format.fprintf fs "@[<1>[%a]@]" (List.pp ",@ " (fun fs (k, v) -> - Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) + Format.fprintf fs "@[%a@ @<2>↦ %a@]" pp_k k pp_v v )) (to_alist m) let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = diff --git a/sledge/lib/sh_test.ml b/sledge/lib/sh_test.ml index fe68382d3..303cb3eac 100644 --- a/sledge/lib/sh_test.ml +++ b/sledge/lib/sh_test.ml @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. *) -(* [@@@warning "-32"] *) - let%test_module _ = ( module struct open Sh @@ -14,7 +12,9 @@ let%test_module _ = let () = Trace.init ~margin:68 () (* let () = - * Trace.init ~margin:160 ~config:(Result.ok_exn (Trace.parse "+Sh")) () *) + * Trace.init ~margin:160 ~config:(Result.ok_exn (Trace.parse "+Sh")) () + * + * [@@@warning "-32"] *) let pp = Format.printf "@\n%a@." pp let pp_raw = Format.printf "@\n%a@." pp_raw @@ -41,6 +41,9 @@ let%test_module _ = let x = Term.var x_ let y = Term.var y_ + let of_eqs l = + List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l + let%expect_test _ = let p = exists ~$[x_] (extend_us ~$[x_] emp) in let q = pure (x = !0) in @@ -50,9 +53,9 @@ let%test_module _ = [%expect {| ∃ %x_6 . emp - + 0 = %x_6 ∧ emp - + 0 = %x_6 ∧ emp |}] let%expect_test _ = @@ -125,9 +128,6 @@ let%test_module _ = ( ( 1 = %y_7 ∧ emp) ∨ ( emp) ∨ ( emp) ) |}] - let of_eqs l = - List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l - let%expect_test _ = let q = exists ~$[x_] (of_eqs [(f x, x); (f y, y - !1)]) in pp q ; diff --git a/sledge/lib/solver_test.ml b/sledge/lib/solver_test.ml index aa4c72902..696bac247 100644 --- a/sledge/lib/solver_test.ml +++ b/sledge/lib/solver_test.ml @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. *) -(* [@@@warning "-32"] *) - let%test_module _ = ( module struct let () = @@ -18,7 +16,9 @@ let%test_module _ = * Trace.init ~margin:160 * ~config: * (Result.ok_exn (Trace.parse "+Solver.infer_frame+Solver.excise")) - * () *) + * () + * + * [@@@warning "-32"] *) let infer_frame p xs q = Solver.infer_frame p (Var.Set.of_list xs) q diff --git a/sledge/ppx_trace/trace/trace.ml b/sledge/ppx_trace/trace/trace.ml index 1120a78d9..c887e44de 100644 --- a/sledge/ppx_trace/trace/trace.ml +++ b/sledge/ppx_trace/trace/trace.ml @@ -213,6 +213,6 @@ let fail fmt = let margin = Format.pp_get_margin fs () in raisef ~margin (fun msg -> - Format.fprintf fs "@\n@[<2>| %s@]@." msg ; + Format.fprintf fs "@\n%s@." msg ; Failure msg ) fmt