diff --git a/sledge/src/fol/context.ml b/sledge/src/fol/context.ml index c591042a6..462d1f14f 100644 --- a/sledge/src/fol/context.ml +++ b/sledge/src/fol/context.ml @@ -234,10 +234,10 @@ let compose1 ?f ~var ~rep (us, xs, s) = in Some (us, xs, s) -let fresh name (us, xs, s) = - let x, us = Var.fresh name ~wrt:us in +let fresh name (wrt, xs, s) = + let x, wrt = Var.fresh name ~wrt in let xs = Var.Set.add x xs in - (Trm.var x, (us, xs, s)) + (Trm.var x, (wrt, xs, s)) let solve_poly ?f p q s = [%trace] @@ -373,10 +373,10 @@ and solve_ ?f d e s = | Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false"] -let solve ?f ~us ~xs d e = +let solve ?f ~wrt ~xs d e = [%Trace.call fun {pf} -> pf "%a@ %a" Trm.pp d Trm.pp e] ; - ( solve_ ?f d e (us, xs, Subst.empty) + ( solve_ ?f d e (wrt, xs, Subst.empty) |>= fun (_, xs, s) -> let xs = Var.Set.inter xs (Subst.fv s) in (xs, s) ) @@ -568,10 +568,10 @@ let extend a r = let rep = extend_ a r.rep in if rep == r.rep then r else {r with rep} |> check pre_invariant -let merge us a b r = +let merge ~wrt a b r = [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Trm.pp a Trm.pp b pp r] ; - ( match solve ~us ~xs:r.xs a b with + ( match solve ~wrt ~xs:r.xs a b with | Some (xs, s) -> {r with xs= Var.Set.union r.xs xs; rep= Subst.compose r.rep s} | None -> {r with sat= false} ) @@ -596,23 +596,23 @@ let find_missing r = in Option.return_if need_a'_eq_b' (a', b') ) ) -let rec close us r = +let rec close ~wrt r = if not r.sat then r else match find_missing r with - | Some (a', b') -> close us (merge us a' b' r) + | Some (a', b') -> close ~wrt (merge ~wrt a' b' r) | None -> r -let close us r = +let close ~wrt r = [%Trace.call fun {pf} -> pf "%a" pp r] ; - close us r + close ~wrt r |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; invariant r'] -let and_eq_ us a b r = +let and_eq_ ~wrt a b r = if not r.sat then r else let r0 = r in @@ -622,11 +622,11 @@ let and_eq_ us a b r = let r = extend b' r in if Trm.equal a' b' then r else - let r = merge us a' b' r in + let r = merge ~wrt a' b' r in match (a, b) with | (Var _ as v), _ when not (in_car r0 v) -> r | _, (Var _ as v) when not (in_car r0 v) -> r - | _ -> close us r + | _ -> close ~wrt r let extract_xs r = (r.xs, {r with xs= Var.Set.empty}) @@ -678,21 +678,21 @@ let fold_uses_of r t s ~f = Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s -> fold_ ~f trm (fold_ ~f rep s) ) -let apply_subst us s r = +let apply_subst wrt s r = [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] ; Trm.Map.fold (classes r) empty ~f:(fun ~key:rep ~data:cls r -> let rep' = Subst.subst_ s rep in List.fold cls r ~f:(fun trm r -> let trm' = Subst.subst_ s trm in - and_eq_ us trm' rep' r ) ) + and_eq_ ~wrt trm' rep' r ) ) |> extract_xs |> [%Trace.retn fun {pf} (xs, r') -> pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ; invariant r'] -let union us r s = +let union wrt r s = [%Trace.call fun {pf} -> pf "@[ %a@ @<2>∧ %a@]" pp r pp s] ; ( if not r.sat then r @@ -701,14 +701,14 @@ let union us r s = let s, r = if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) in - Subst.fold s.rep r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r) ) + Subst.fold s.rep r ~f:(fun ~key:e ~data:e' r -> and_eq_ ~wrt e e' r) ) |> extract_xs |> [%Trace.retn fun {pf} (_, r') -> pf "%a" pp_diff (r, r') ; invariant r'] -let inter us r s = +let inter wrt r s = [%Trace.call fun {pf} -> pf "@[ %a@ @<2>∨ %a@]" pp r pp s] ; ( if not s.sat then r @@ -722,7 +722,7 @@ let inter us r s = match List.find ~f:(fun rep -> implies r (Fml.eq exp rep)) reps with - | Some rep -> (reps, and_eq_ us exp rep rs) + | Some rep -> (reps, and_eq_ ~wrt exp rep rs) | None -> (exp :: reps, rs) ) |> snd ) in @@ -741,13 +741,13 @@ let interN us rs = | [] -> (us, unsat) | r :: rs -> List.fold ~f:(fun r (us, s) -> inter us s r) rs (us, r) -let rec add_ us b r = +let rec add_ wrt b r = match (b : Fml.t) with | Tt -> r | Not Tt -> unsat - | And {pos; neg} -> Fml.fold_pos_neg ~f:(add_ us) ~pos ~neg r - | Eq (d, e) -> and_eq_ us d e r - | Eq0 e -> and_eq_ us Trm.zero e r + | And {pos; neg} -> Fml.fold_pos_neg ~f:(add_ wrt) ~pos ~neg r + | Eq (d, e) -> and_eq_ ~wrt d e r + | Eq0 e -> and_eq_ ~wrt Trm.zero e r | Pos _ | Not _ | Or _ | Iff _ | Cond _ | Lit _ -> r let add us b r = @@ -829,7 +829,7 @@ let solve_poly_eq us p' q' subst = [%Trace.retn fun {pf} subst' -> pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] -let solve_seq_eq us e' f' subst = +let solve_seq_eq ~wrt us e' f' subst = [%Trace.call fun {pf} -> pf "%a = %a" Trm.pp e' Trm.pp f'] ; let f x u = @@ -842,7 +842,7 @@ let solve_seq_eq us e' f' subst = | Some n -> (a, n) | None -> (Trm.sized ~siz:n ~seq:a, n) in - let+ _, xs, s = solve_concat ~f ms a n (us, Var.Set.empty, subst) in + let+ _, xs, s = solve_concat ~f ms a n (wrt, Var.Set.empty, subst) in assert (Var.Set.disjoint xs (Subst.fv s)) ; s in @@ -858,14 +858,14 @@ let solve_seq_eq us e' f' subst = [%Trace.retn fun {pf} subst' -> pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] -let solve_interp_eq us e' (cls, subst) = +let solve_interp_eq ~wrt us e' (cls, subst) = [%Trace.call fun {pf} -> pf "trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' pp_cls cls Subst.pp subst] ; List.find_map cls ~f:(fun f -> let f' = Subst.norm subst f in - match solve_seq_eq us e' f' subst with + match solve_seq_eq ~wrt us e' f' subst with | Some subst -> Some subst | None -> solve_poly_eq us e' f' subst ) |> @@ -877,7 +877,7 @@ let solve_interp_eq us e' (cls, subst) = and can be expressed, after normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs] *) -let rec solve_interp_eqs us (cls, subst) = +let rec solve_interp_eqs ~wrt us (cls, subst) = [%Trace.call fun {pf} -> pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] ; @@ -887,13 +887,13 @@ let rec solve_interp_eqs us (cls, subst) = | trm :: cls -> let trm' = Subst.norm subst trm in if is_interpreted trm' then - match solve_interp_eq us trm' (cls, subst) with + match solve_interp_eq ~wrt us trm' (cls, subst) with | Some subst -> solve_interp_eqs_ cls' (cls, subst) | None -> solve_interp_eqs_ (trm' :: cls') (cls, subst) else solve_interp_eqs_ (trm' :: cls') (cls, subst) in let cls', subst' = solve_interp_eqs_ [] (cls, subst) in - ( if subst' != subst then solve_interp_eqs us (cls', subst') + ( if subst' != subst then solve_interp_eqs ~wrt us (cls', subst') else (cls', subst') ) |> [%Trace.retn fun {pf} (cls', subst') -> @@ -982,7 +982,7 @@ let solve_uninterp_eqs us (cls, subst) = [subst] which can be expressed, after normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs] *) -let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = +let solve_class ~wrt us us_xs ~key:rep ~data:cls (classes, subst) = let classes0 = classes in [%Trace.call fun {pf} -> pf "rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep pp_cls cls @@ -993,7 +993,7 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = ~f:(fun e -> Var.Set.subset (Trm.fv e) ~of_:us_xs) (rep :: cls) in - let cls, subst = solve_interp_eqs us (cls, subst) in + let cls, subst = solve_interp_eqs ~wrt us (cls, subst) in let cls, subst = solve_uninterp_eqs us (cls, subst) in let cls = List.rev_append cls_not_ito_us_xs cls in let cls = List.remove ~eq:Trm.equal (Subst.norm subst rep) cls in @@ -1069,13 +1069,13 @@ let solve_for_xs r us xs = (** move equations from [classes] to [subst] which can be expressed, after normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs]. *) -let solve_classes r xs (classes, subst, us) = +let solve_classes ~wrt r xs (classes, subst, us) = [%Trace.call fun {pf} -> pf "us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs] ; let rec solve_classes_ (classes0, subst0, us_xs) = let classes, subst = - Trm.Map.fold ~f:(solve_class us us_xs) classes0 (classes0, subst0) + Trm.Map.fold ~f:(solve_class ~wrt us us_xs) classes0 (classes0, subst0) in if subst != subst0 then solve_classes_ (classes, subst, us_xs) else (classes, subst, us_xs) @@ -1103,10 +1103,12 @@ let solve_for_vars vss r = pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ; invariant r] ; + let wrt = Var.Set.union_list vss in let us, vss = match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss) in - List.fold ~f:(solve_classes r) vss (classes r, Subst.empty, us) |> snd3 + List.fold ~f:(solve_classes ~wrt r) vss (classes r, Subst.empty, us) + |> snd3 |> [%Trace.retn fun {pf} subst -> pf "%a" Subst.pp subst ;