[sledge] Fix a fresh name clash when solving extract equations

Summary:
The first-order solver sometimes needs to generate fresh variables to
express the solution of equations. It needs to ensure that these
generated variables do not clash. Before this diff, there was a
confusion where new variables were fresh with respect to only the
current set of "universal" variables. This is wrong, and this diff
adds the full set of variables instead.

Reviewed By: jvillard

Differential Revision: D25196732

fbshipit-source-id: afc56834a
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 77c630b7f4
commit bb52f96ded

@ -234,10 +234,10 @@ let compose1 ?f ~var ~rep (us, xs, s) =
in in
Some (us, xs, s) Some (us, xs, s)
let fresh name (us, xs, s) = let fresh name (wrt, xs, s) =
let x, us = Var.fresh name ~wrt:us in let x, wrt = Var.fresh name ~wrt in
let xs = Var.Set.add x xs in let xs = Var.Set.add x xs in
(Trm.var x, (us, xs, s)) (Trm.var x, (wrt, xs, s))
let solve_poly ?f p q s = let solve_poly ?f p q s =
[%trace] [%trace]
@ -373,10 +373,10 @@ and solve_ ?f d e s =
| Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false"] | None -> pf "false"]
let solve ?f ~us ~xs d e = let solve ?f ~wrt ~xs d e =
[%Trace.call fun {pf} -> pf "%a@ %a" Trm.pp d Trm.pp e] [%Trace.call fun {pf} -> pf "%a@ %a" Trm.pp d Trm.pp e]
; ;
( solve_ ?f d e (us, xs, Subst.empty) ( solve_ ?f d e (wrt, xs, Subst.empty)
|>= fun (_, xs, s) -> |>= fun (_, xs, s) ->
let xs = Var.Set.inter xs (Subst.fv s) in let xs = Var.Set.inter xs (Subst.fv s) in
(xs, s) ) (xs, s) )
@ -568,10 +568,10 @@ let extend a r =
let rep = extend_ a r.rep in let rep = extend_ a r.rep in
if rep == r.rep then r else {r with rep} |> check pre_invariant if rep == r.rep then r else {r with rep} |> check pre_invariant
let merge us a b r = let merge ~wrt a b r =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Trm.pp a Trm.pp b pp r] [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Trm.pp a Trm.pp b pp r]
; ;
( match solve ~us ~xs:r.xs a b with ( match solve ~wrt ~xs:r.xs a b with
| Some (xs, s) -> | Some (xs, s) ->
{r with xs= Var.Set.union r.xs xs; rep= Subst.compose r.rep s} {r with xs= Var.Set.union r.xs xs; rep= Subst.compose r.rep s}
| None -> {r with sat= false} ) | None -> {r with sat= false} )
@ -596,23 +596,23 @@ let find_missing r =
in in
Option.return_if need_a'_eq_b' (a', b') ) ) Option.return_if need_a'_eq_b' (a', b') ) )
let rec close us r = let rec close ~wrt r =
if not r.sat then r if not r.sat then r
else else
match find_missing r with match find_missing r with
| Some (a', b') -> close us (merge us a' b' r) | Some (a', b') -> close ~wrt (merge ~wrt a' b' r)
| None -> r | None -> r
let close us r = let close ~wrt r =
[%Trace.call fun {pf} -> pf "%a" pp r] [%Trace.call fun {pf} -> pf "%a" pp r]
; ;
close us r close ~wrt r
|> |>
[%Trace.retn fun {pf} r' -> [%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
invariant r'] invariant r']
let and_eq_ us a b r = let and_eq_ ~wrt a b r =
if not r.sat then r if not r.sat then r
else else
let r0 = r in let r0 = r in
@ -622,11 +622,11 @@ let and_eq_ us a b r =
let r = extend b' r in let r = extend b' r in
if Trm.equal a' b' then r if Trm.equal a' b' then r
else else
let r = merge us a' b' r in let r = merge ~wrt a' b' r in
match (a, b) with match (a, b) with
| (Var _ as v), _ when not (in_car r0 v) -> r | (Var _ as v), _ when not (in_car r0 v) -> r
| _, (Var _ as v) when not (in_car r0 v) -> r | _, (Var _ as v) when not (in_car r0 v) -> r
| _ -> close us r | _ -> close ~wrt r
let extract_xs r = (r.xs, {r with xs= Var.Set.empty}) let extract_xs r = (r.xs, {r with xs= Var.Set.empty})
@ -678,21 +678,21 @@ let fold_uses_of r t s ~f =
Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s -> Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s ->
fold_ ~f trm (fold_ ~f rep s) ) fold_ ~f trm (fold_ ~f rep s) )
let apply_subst us s r = let apply_subst wrt s r =
[%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r]
; ;
Trm.Map.fold (classes r) empty ~f:(fun ~key:rep ~data:cls r -> Trm.Map.fold (classes r) empty ~f:(fun ~key:rep ~data:cls r ->
let rep' = Subst.subst_ s rep in let rep' = Subst.subst_ s rep in
List.fold cls r ~f:(fun trm r -> List.fold cls r ~f:(fun trm r ->
let trm' = Subst.subst_ s trm in let trm' = Subst.subst_ s trm in
and_eq_ us trm' rep' r ) ) and_eq_ ~wrt trm' rep' r ) )
|> extract_xs |> extract_xs
|> |>
[%Trace.retn fun {pf} (xs, r') -> [%Trace.retn fun {pf} (xs, r') ->
pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ; pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ;
invariant r'] invariant r']
let union us r s = let union wrt r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2>∧ %a@]" pp r pp s] [%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2>∧ %a@]" pp r pp s]
; ;
( if not r.sat then r ( if not r.sat then r
@ -701,14 +701,14 @@ let union us r s =
let s, r = let s, r =
if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
in in
Subst.fold s.rep r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r) ) Subst.fold s.rep r ~f:(fun ~key:e ~data:e' r -> and_eq_ ~wrt e e' r) )
|> extract_xs |> extract_xs
|> |>
[%Trace.retn fun {pf} (_, r') -> [%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
invariant r'] invariant r']
let inter us r s = let inter wrt r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s] [%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2> %a@]" pp r pp s]
; ;
( if not s.sat then r ( if not s.sat then r
@ -722,7 +722,7 @@ let inter us r s =
match match
List.find ~f:(fun rep -> implies r (Fml.eq exp rep)) reps List.find ~f:(fun rep -> implies r (Fml.eq exp rep)) reps
with with
| Some rep -> (reps, and_eq_ us exp rep rs) | Some rep -> (reps, and_eq_ ~wrt exp rep rs)
| None -> (exp :: reps, rs) ) | None -> (exp :: reps, rs) )
|> snd ) |> snd )
in in
@ -741,13 +741,13 @@ let interN us rs =
| [] -> (us, unsat) | [] -> (us, unsat)
| r :: rs -> List.fold ~f:(fun r (us, s) -> inter us s r) rs (us, r) | r :: rs -> List.fold ~f:(fun r (us, s) -> inter us s r) rs (us, r)
let rec add_ us b r = let rec add_ wrt b r =
match (b : Fml.t) with match (b : Fml.t) with
| Tt -> r | Tt -> r
| Not Tt -> unsat | Not Tt -> unsat
| And {pos; neg} -> Fml.fold_pos_neg ~f:(add_ us) ~pos ~neg r | And {pos; neg} -> Fml.fold_pos_neg ~f:(add_ wrt) ~pos ~neg r
| Eq (d, e) -> and_eq_ us d e r | Eq (d, e) -> and_eq_ ~wrt d e r
| Eq0 e -> and_eq_ us Trm.zero e r | Eq0 e -> and_eq_ ~wrt Trm.zero e r
| Pos _ | Not _ | Or _ | Iff _ | Cond _ | Lit _ -> r | Pos _ | Not _ | Or _ | Iff _ | Cond _ | Lit _ -> r
let add us b r = let add us b r =
@ -829,7 +829,7 @@ let solve_poly_eq us p' q' subst =
[%Trace.retn fun {pf} subst' -> [%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let solve_seq_eq us e' f' subst = let solve_seq_eq ~wrt us e' f' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Trm.pp e' Trm.pp f'] [%Trace.call fun {pf} -> pf "%a = %a" Trm.pp e' Trm.pp f']
; ;
let f x u = let f x u =
@ -842,7 +842,7 @@ let solve_seq_eq us e' f' subst =
| Some n -> (a, n) | Some n -> (a, n)
| None -> (Trm.sized ~siz:n ~seq:a, n) | None -> (Trm.sized ~siz:n ~seq:a, n)
in in
let+ _, xs, s = solve_concat ~f ms a n (us, Var.Set.empty, subst) in let+ _, xs, s = solve_concat ~f ms a n (wrt, Var.Set.empty, subst) in
assert (Var.Set.disjoint xs (Subst.fv s)) ; assert (Var.Set.disjoint xs (Subst.fv s)) ;
s s
in in
@ -858,14 +858,14 @@ let solve_seq_eq us e' f' subst =
[%Trace.retn fun {pf} subst' -> [%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let solve_interp_eq us e' (cls, subst) = let solve_interp_eq ~wrt us e' (cls, subst) =
[%Trace.call fun {pf} -> [%Trace.call fun {pf} ->
pf "trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' pp_cls cls pf "trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' pp_cls cls
Subst.pp subst] Subst.pp subst]
; ;
List.find_map cls ~f:(fun f -> List.find_map cls ~f:(fun f ->
let f' = Subst.norm subst f in let f' = Subst.norm subst f in
match solve_seq_eq us e' f' subst with match solve_seq_eq ~wrt us e' f' subst with
| Some subst -> Some subst | Some subst -> Some subst
| None -> solve_poly_eq us e' f' subst ) | None -> solve_poly_eq us e' f' subst )
|> |>
@ -877,7 +877,7 @@ let solve_interp_eq us e' (cls, subst) =
and can be expressed, after normalizing with [subst], as [x u] where and can be expressed, after normalizing with [subst], as [x u] where
[us xs fv x us] and [fv u us] or else [us xs fv x us] and [fv u us] or else
[fv u us xs] *) [fv u us xs] *)
let rec solve_interp_eqs us (cls, subst) = let rec solve_interp_eqs ~wrt us (cls, subst) =
[%Trace.call fun {pf} -> [%Trace.call fun {pf} ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
; ;
@ -887,13 +887,13 @@ let rec solve_interp_eqs us (cls, subst) =
| trm :: cls -> | trm :: cls ->
let trm' = Subst.norm subst trm in let trm' = Subst.norm subst trm in
if is_interpreted trm' then if is_interpreted trm' then
match solve_interp_eq us trm' (cls, subst) with match solve_interp_eq ~wrt us trm' (cls, subst) with
| Some subst -> solve_interp_eqs_ cls' (cls, subst) | Some subst -> solve_interp_eqs_ cls' (cls, subst)
| None -> solve_interp_eqs_ (trm' :: cls') (cls, subst) | None -> solve_interp_eqs_ (trm' :: cls') (cls, subst)
else solve_interp_eqs_ (trm' :: cls') (cls, subst) else solve_interp_eqs_ (trm' :: cls') (cls, subst)
in in
let cls', subst' = solve_interp_eqs_ [] (cls, subst) in let cls', subst' = solve_interp_eqs_ [] (cls, subst) in
( if subst' != subst then solve_interp_eqs us (cls', subst') ( if subst' != subst then solve_interp_eqs ~wrt us (cls', subst')
else (cls', subst') ) else (cls', subst') )
|> |>
[%Trace.retn fun {pf} (cls', subst') -> [%Trace.retn fun {pf} (cls', subst') ->
@ -982,7 +982,7 @@ let solve_uninterp_eqs us (cls, subst) =
[subst] which can be expressed, after normalizing with [subst], as [subst] which can be expressed, after normalizing with [subst], as
[x u] where [us xs fv x us] and [fv u us] or else [x u] where [us xs fv x us] and [fv u us] or else
[fv u us xs] *) [fv u us xs] *)
let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = let solve_class ~wrt us us_xs ~key:rep ~data:cls (classes, subst) =
let classes0 = classes in let classes0 = classes in
[%Trace.call fun {pf} -> [%Trace.call fun {pf} ->
pf "rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep pp_cls cls pf "rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep pp_cls cls
@ -993,7 +993,7 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
~f:(fun e -> Var.Set.subset (Trm.fv e) ~of_:us_xs) ~f:(fun e -> Var.Set.subset (Trm.fv e) ~of_:us_xs)
(rep :: cls) (rep :: cls)
in in
let cls, subst = solve_interp_eqs us (cls, subst) in let cls, subst = solve_interp_eqs ~wrt us (cls, subst) in
let cls, subst = solve_uninterp_eqs us (cls, subst) in let cls, subst = solve_uninterp_eqs us (cls, subst) in
let cls = List.rev_append cls_not_ito_us_xs cls in let cls = List.rev_append cls_not_ito_us_xs cls in
let cls = List.remove ~eq:Trm.equal (Subst.norm subst rep) cls in let cls = List.remove ~eq:Trm.equal (Subst.norm subst rep) cls in
@ -1069,13 +1069,13 @@ let solve_for_xs r us xs =
(** move equations from [classes] to [subst] which can be expressed, after (** move equations from [classes] to [subst] which can be expressed, after
normalizing with [subst], as [x u] where [us xs fv x us] normalizing with [subst], as [x u] where [us xs fv x us]
and [fv u us] or else [fv u us xs]. *) and [fv u us] or else [fv u us xs]. *)
let solve_classes r xs (classes, subst, us) = let solve_classes ~wrt r xs (classes, subst, us) =
[%Trace.call fun {pf} -> [%Trace.call fun {pf} ->
pf "us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs] pf "us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs]
; ;
let rec solve_classes_ (classes0, subst0, us_xs) = let rec solve_classes_ (classes0, subst0, us_xs) =
let classes, subst = let classes, subst =
Trm.Map.fold ~f:(solve_class us us_xs) classes0 (classes0, subst0) Trm.Map.fold ~f:(solve_class ~wrt us us_xs) classes0 (classes0, subst0)
in in
if subst != subst0 then solve_classes_ (classes, subst, us_xs) if subst != subst0 then solve_classes_ (classes, subst, us_xs)
else (classes, subst, us_xs) else (classes, subst, us_xs)
@ -1103,10 +1103,12 @@ let solve_for_vars vss r =
pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ; pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ;
invariant r] invariant r]
; ;
let wrt = Var.Set.union_list vss in
let us, vss = let us, vss =
match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss) match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss)
in in
List.fold ~f:(solve_classes r) vss (classes r, Subst.empty, us) |> snd3 List.fold ~f:(solve_classes ~wrt r) vss (classes r, Subst.empty, us)
|> snd3
|> |>
[%Trace.retn fun {pf} subst -> [%Trace.retn fun {pf} subst ->
pf "%a" Subst.pp subst ; pf "%a" Subst.pp subst ;

Loading…
Cancel
Save