diff --git a/sledge/src/llair/term.ml b/sledge/src/llair/term.ml index bcebdbce0..67bb4f433 100644 --- a/sledge/src/llair/term.ml +++ b/sledge/src/llair/term.ml @@ -1021,25 +1021,7 @@ let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty let is_true = function Integer {data} -> Z.is_true data | _ -> false let is_false = function Integer {data} -> Z.is_false data | _ -> false -let rec is_constant e = - match e with - | Var _ | Nondet _ -> false - | Ap1 (_, x) -> is_constant x - | Ap2 (_, x, y) -> is_constant x && is_constant y - | Ap3 (_, x, y, z) -> is_constant x && is_constant y && is_constant z - | ApN (_, xs) | RecN (_, xs) -> Vector.for_all ~f:is_constant xs - | Add args | Mul args -> - Qset.for_all ~f:(fun arg _ -> is_constant arg) args - | Label _ | Float _ | Integer _ -> true - -type kind = Interpreted | Simplified | Atomic | Uninterpreted -[@@deriving compare] - -let classify = function - | Add _ | Mul _ -> Interpreted - | Ap2 ((Eq | Dq), _, _) -> Simplified - | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted - | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic +(** Solve *) let solve_zero_eq = function | Add args -> @@ -1049,49 +1031,3 @@ let solve_zero_eq = function let r = div n d in Some (c, r) | _ -> None - -let solve e f = - [%Trace.call fun {pf} -> pf "%a@ %a" pp e pp f] - ; - let rec solve_ e f s = - let solve_uninterp e f = - match (e, f) with - | Integer {data= m}, Integer {data= n} when not (Z.equal m n) -> None - | _ -> ( - match (is_constant e, is_constant f) with - (* orient equation to discretionarily prefer term that is constant - or compares smaller as class representative *) - | true, false -> Some (Map.add_exn s ~key:f ~data:e) - | false, true -> Some (Map.add_exn s ~key:e ~data:f) - | _ -> - let key, data = if compare e f > 0 then (e, f) else (f, e) in - Some (Map.add_exn s ~key ~data) ) - in - let concat_size args = - Vector.fold_until args ~init:zero - ~f:(fun sum -> function - | Ap2 (Memory, siz, _) -> Continue (add siz sum) | _ -> Stop None - ) - ~finish:(fun _ -> None) - in - match (e, f) with - | (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> ( - let e_f = sub e f in - match solve_zero_eq e_f with - | Some (key, data) -> Some (Map.add_exn s ~key ~data) - | None -> solve_uninterp e_f zero ) - | ApN (Concat, ms), ApN (Concat, ns) -> ( - match (concat_size ms, concat_size ns) with - | Some p, Some q -> solve_uninterp e f >>= solve_ p q - | _ -> solve_uninterp e f ) - | Ap2 (Memory, m, _), ApN (Concat, ns) - |ApN (Concat, ns), Ap2 (Memory, m, _) -> ( - match concat_size ns with - | Some p -> solve_uninterp e f >>= solve_ p m - | _ -> solve_uninterp e f ) - | _ -> solve_uninterp e f - in - solve_ e f Map.empty - |> - [%Trace.retn fun {pf} -> - function Some s -> pf "%a" Var.Subst.pp s | None -> pf "false"] diff --git a/sledge/src/llair/term.mli b/sledge/src/llair/term.mli index d4125c52b..589d4e39b 100644 --- a/sledge/src/llair/term.mli +++ b/sledge/src/llair/term.mli @@ -244,9 +244,9 @@ val fold_terms : t -> init:'a -> f:('a -> t -> 'a) -> 'a val fv : t -> Var.Set.t val is_true : t -> bool val is_false : t -> bool -val is_constant : t -> bool -type kind = Interpreted | Simplified | Atomic | Uninterpreted +(** Solve *) -val classify : t -> kind -val solve : t -> t -> t Map.t option +val solve_zero_eq : t -> (t * t) option +(** [solve_zero_eq d] is [Some (e, f)] if [d = 0] can be equivalently + expressed as [e = f] for some monomial subterm [e] of [d]. *) diff --git a/sledge/src/symbheap/equality.ml b/sledge/src/symbheap/equality.ml index 57a39388b..37a50f6a0 100644 --- a/sledge/src/symbheap/equality.ml +++ b/sledge/src/symbheap/equality.ml @@ -13,6 +13,86 @@ let empty_map = Map.empty (module Term) type subst = Term.t term_map [@@deriving compare, equal, sexp] +let pp_subst fs s = + Format.fprintf fs "@[<1>[%a]@]" + (List.pp ",@ " (fun fs (k, v) -> + Format.fprintf fs "@[%a ↦ %a@]" Term.pp k Term.pp v )) + (Map.to_alist s) + +(** Theory Solver *) + +let rec is_constant e = + match (e : Term.t) with + | Var _ | Nondet _ -> false + | Ap1 (_, x) -> is_constant x + | Ap2 (_, x, y) -> is_constant x && is_constant y + | Ap3 (_, x, y, z) -> is_constant x && is_constant y && is_constant z + | ApN (_, xs) | RecN (_, xs) -> Vector.for_all ~f:is_constant xs + | Add args | Mul args -> + Qset.for_all ~f:(fun arg _ -> is_constant arg) args + | Label _ | Float _ | Integer _ -> true + +type kind = Interpreted | Simplified | Atomic | Uninterpreted +[@@deriving compare] + +let classify e = + match (e : Term.t) with + | Add _ | Mul _ -> Interpreted + | Ap2 ((Eq | Dq), _, _) -> Simplified + | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted + | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic + +let solve e f = + [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e Term.pp f] + ; + let rec solve_ e f s = + let solve_uninterp e f = + match ((e : Term.t), (f : Term.t)) with + | Integer {data= m}, Integer {data= n} when not (Z.equal m n) -> None + | _ -> ( + match (is_constant e, is_constant f) with + (* orient equation to discretionarily prefer term that is constant + or compares smaller as class representative *) + | true, false -> Some (Map.add_exn s ~key:f ~data:e) + | false, true -> Some (Map.add_exn s ~key:e ~data:f) + | _ -> + let key, data = + if Term.compare e f > 0 then (e, f) else (f, e) + in + Some (Map.add_exn s ~key ~data) ) + in + let concat_size args = + Vector.fold_until args ~init:Term.zero + ~f:(fun sum m -> + match (m : Term.t) with + | Ap2 (Memory, siz, _) -> Continue (Term.add siz sum) + | _ -> Stop None ) + ~finish:(fun _ -> None) + in + match ((e : Term.t), (f : Term.t)) with + | (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> ( + let e_f = Term.sub e f in + match Term.solve_zero_eq e_f with + | Some (key, data) -> Some (Map.add_exn s ~key ~data) + | None -> solve_uninterp e_f Term.zero ) + | ApN (Concat, ms), ApN (Concat, ns) -> ( + match (concat_size ms, concat_size ns) with + | Some p, Some q -> solve_uninterp e f >>= solve_ p q + | _ -> solve_uninterp e f ) + | Ap2 (Memory, m, _), ApN (Concat, ns) + |ApN (Concat, ns), Ap2 (Memory, m, _) -> ( + match concat_size ns with + | Some p -> solve_uninterp e f >>= solve_ p m + | _ -> solve_uninterp e f ) + | _ -> solve_uninterp e f + in + solve_ e f empty_map + |> + [%Trace.retn fun {pf} -> + function Some s -> pf "%a" pp_subst s | None -> pf "false"] + +(** Equality Relations *) + (** see also [invariant] *) type t = { sat: bool (** [false] only if constraints are inconsistent *) @@ -31,7 +111,7 @@ let classes r = else Map.add_multi cls ~key:data ~data:key in Map.fold r.rep ~init:empty_map ~f:(fun ~key ~data cls -> - match Term.classify key with + match classify key with | Interpreted | Atomic -> add key data cls | Simplified | Uninterpreted -> add (Term.map ~f:(apply r.rep) key) data cls ) @@ -85,7 +165,7 @@ let pp_diff fs (r, s) = let in_car r e = Map.mem r.rep e let rec iter_max_solvables e ~f = - match Term.classify e with + match classify e with | Interpreted -> Term.iter ~f:(iter_max_solvables ~f) e | _ -> f e @@ -94,7 +174,7 @@ let invariant r = @@ fun () -> Map.iteri r.rep ~f:(fun ~key:a ~data:_ -> (* no interpreted terms in carrier *) - assert (Poly.(Term.classify a <> Interpreted)) ; + assert (Poly.(classify a <> Interpreted)) ; (* carrier is closed under subterms *) iter_max_solvables a ~f:(fun b -> assert ( @@ -108,7 +188,7 @@ let true_ = {sat= true; rep= empty_map} |> check invariant (** apply a subst to maximal non-interpreted subterms *) let rec norm s a = - match Term.classify a with + match classify a with | Interpreted -> Term.map ~f:(norm s) a | Simplified -> apply s (Term.map ~f:(norm s) a) | Atomic | Uninterpreted -> apply s a @@ -130,14 +210,14 @@ let lookup r a = (** rewrite a term into canonical form using rep and, for uninterpreted terms, congruence composed with rep *) let rec canon r a = - match Term.classify a with + match classify a with | Interpreted -> Term.map ~f:(canon r) a | Simplified | Uninterpreted -> lookup r (Term.map ~f:(canon r) a) | Atomic -> apply r.rep a (** add a term to the carrier *) let rec extend a r = - match Term.classify a with + match classify a with | Interpreted | Simplified -> Term.fold ~f:extend a ~init:r | Uninterpreted -> Map.find_or_add r.rep a @@ -160,7 +240,7 @@ let compose r s = let merge a b r = [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r] ; - ( match Term.solve a b with + ( match solve a b with | Some s -> compose r s | None -> {r with sat= false} ) |>