[pulse] quantifier elimination using var_eqs

Summary:
First stab at quantifier elimination done poorly but fast :)

Easy one: when we know "x = y", and we want to keep x but not y, then
replace y by x everywhere.

Reviewed By: skcho

Differential Revision: D25432207

fbshipit-source-id: 81b142b96
master
Jules Villard 4 years ago committed by Facebook GitHub Bot
parent d3a83beab4
commit b5bd85c967

@ -8,7 +8,7 @@
(modes byte) (modes byte)
(flags (flags
(:standard -w +60)) (:standard -w +60))
(libraries javalib ANSITerminal async atdgen base base64 cmdliner core (libraries javalib ANSITerminal async atdgen base base64 cmdliner core iter
mtime.clock.os ocamlgraph oUnit parmap re sawja sqlite3 str unix xmlm mtime.clock.os ocamlgraph oUnit parmap re sawja sqlite3 str unix xmlm
yojson zarith zip CStubs) yojson zarith zip CStubs)
(modules All_infer_in_one_file) (modules All_infer_in_one_file)

@ -9,15 +9,17 @@ open! IStd
module F = Format module F = Format
module type Element = sig module type Element = sig
type t [@@deriving compare] type t [@@deriving compare, equal]
val is_simpler_than : t -> t -> bool val is_simpler_than : t -> t -> bool
end end
module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) = struct module Make
module Map = Caml.Map.Make (X) (X : Element)
(XSet : Caml.Set.S with type elt = X.t)
let equal_x = [%compare.equal: X.t] (XMap : Caml.Map.S with type key = X.t) =
struct
module XSet = Iter.Set.Adapt (XSet)
(** the union-find backing data structure: maps elements to their representatives *) (** the union-find backing data structure: maps elements to their representatives *)
module UF : sig module UF : sig
@ -41,14 +43,14 @@ module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) = struct
end = struct end = struct
type repr = X.t type repr = X.t
type t = X.t Map.t type t = X.t XMap.t
let empty = Map.empty let empty = XMap.empty
let find_opt reprs x = let find_opt reprs x =
let rec find_opt_aux candidate_repr = let rec find_opt_aux candidate_repr =
(* [x] is in the relation and now we are climbing up to the final representative *) (* [x] is in the relation and now we are climbing up to the final representative *)
match Map.find_opt candidate_repr reprs with match XMap.find_opt candidate_repr reprs with
| None -> | None ->
(* [candidate_repr] is the representative *) (* [candidate_repr] is the representative *)
candidate_repr candidate_repr
@ -56,16 +58,16 @@ module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) = struct
(* keep climbing *) (* keep climbing *)
find_opt_aux candidate_repr' find_opt_aux candidate_repr'
in in
Map.find_opt x reprs |> Option.map ~f:find_opt_aux XMap.find_opt x reprs |> Option.map ~f:find_opt_aux
let find reprs x = find_opt reprs x |> Option.value ~default:x let find reprs x = find_opt reprs x |> Option.value ~default:x
let merge reprs x ~into:y = (* TODO: implement path compression *) Map.add x y reprs let merge reprs x ~into:y = (* TODO: implement path compression *) XMap.add x y reprs
let add_disjoint_class repr xs reprs = XSet.fold (fun x reprs -> Map.add x repr reprs) xs reprs let add_disjoint_class repr xs reprs = XSet.fold (fun x reprs -> XMap.add x repr reprs) xs reprs
module Map = Map module Map = XMap
end end
type repr = UF.repr type repr = UF.repr
@ -89,7 +91,7 @@ module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) = struct
let union uf x1 x2 = let union uf x1 x2 =
let repr1 = find uf x1 in let repr1 = find uf x1 in
let repr2 = find uf x2 in let repr2 = find uf x2 in
if equal_x (repr1 :> X.t) (repr2 :> X.t) then if X.equal (repr1 :> X.t) (repr2 :> X.t) then
(* avoid creating loops in the relation *) (* avoid creating loops in the relation *)
(uf, None) (uf, None)
else else
@ -124,6 +126,70 @@ module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) = struct
F.fprintf fmt "@[<hv>%a@]" pp_aux uf F.fprintf fmt "@[<hv>%a@]" pp_aux uf
let of_classes classes =
let reprs =
UF.Map.fold (fun repr xs reprs -> UF.add_disjoint_class repr xs reprs) classes UF.empty
in
{reprs; classes}
let apply_subst subst uf =
let in_subst x = XMap.mem x subst in
(* any variable that doesn't have a better representative according to the substitution should
be kept *)
let should_keep x = not (in_subst x) in
let classes_keep =
fold_congruences uf ~init:UF.Map.empty ~f:(fun classes_keep (repr, clazz) ->
let repr_in_range_opt =
if should_keep (repr :> X.t) then Some repr
else
(* [repr] is not a good representative for the class as the substitution prefers it
another element. Try to find a better representative. *)
XSet.to_seq clazz |> Iter.find_pred should_keep
|> (* HACK: trick [Repr] into casting [repr'] to a [repr], bypassing the private type
*)
Option.map ~f:(fun x_repr -> UF.find UF.empty x_repr)
in
match repr_in_range_opt with
| Some repr_in_range ->
let class_keep =
XSet.filter
(fun x -> (not (X.equal x (repr_in_range :> X.t))) && should_keep x)
clazz
in
if XSet.is_empty class_keep then classes_keep
else UF.Map.add repr_in_range class_keep classes_keep
| None ->
(* none of the elements in the class should be kept; note that this cannot happen if
[subst = reorient ~keep uf] *)
classes_keep )
in
of_classes classes_keep
let reorient ~keep uf =
let should_keep x = XSet.mem x keep in
fold_congruences uf ~init:XMap.empty ~f:(fun subst (repr, clazz) ->
(* map every variable in [repr::clazz] to either [repr] if [repr ∈ keep], or to the smallest
representative of [clazz] that's in [keep], if any *)
if should_keep (repr :> X.t) then
XSet.fold
(fun x subst -> if should_keep x then subst else XMap.add x (repr :> X.t) subst)
clazz subst
else
match XSet.to_seq clazz |> Iter.find_pred should_keep with
| None ->
(* no good representative: just substitute as in the original [uf] relation so that we
can get rid of non-representative variables *)
XSet.fold (fun x subst -> XMap.add x (repr :> X.t) subst) clazz subst
| Some repr' ->
let subst = XMap.add (repr :> X.t) repr' subst in
XSet.fold
(fun x subst ->
if X.equal x repr' || should_keep x then subst else XMap.add x repr' subst )
clazz subst )
let filter_not_in_closed_set ~keep uf = let filter_not_in_closed_set ~keep uf =
let classes = let classes =
UF.Map.filter UF.Map.filter
@ -136,8 +202,5 @@ module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) = struct
in in
(* rebuild [reprs] directly from [classes]: does path compression and garbage collection on the (* rebuild [reprs] directly from [classes]: does path compression and garbage collection on the
old [reprs] *) old [reprs] *)
let reprs = of_classes classes
UF.Map.fold (fun repr xs reprs -> UF.add_disjoint_class repr xs reprs) classes UF.empty
in
{reprs; classes}
end end

@ -11,13 +11,16 @@ module F = Format
(** A union-find data structure. *) (** A union-find data structure. *)
module type Element = sig module type Element = sig
type t [@@deriving compare] type t [@@deriving compare, equal]
val is_simpler_than : t -> t -> bool val is_simpler_than : t -> t -> bool
(** will be used to choose a "simpler" representative for a given equivalence class when possible *) (** will be used to choose a "simpler" representative for a given equivalence class when possible *)
end end
module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) : sig module Make
(X : Element)
(XSet : Caml.Set.S with type elt = X.t)
(XMap : Caml.Map.S with type key = X.t) : sig
type t type t
val pp : val pp :
@ -39,6 +42,16 @@ module Make (X : Element) (XSet : Caml.Set.S with type elt = X.t) : sig
(** fold over the equivalence classes of the relation, singling out the representative for each (** fold over the equivalence classes of the relation, singling out the representative for each
class *) class *)
val reorient : keep:XSet.t -> t -> X.t XMap.t
(** the relation [x -> x'] derived from the equality relation that relates all [x], [x'] such that
[xkeep], [x'keep], and [x=x'], as well as [y -> y'] when no element in the equivalence
class of [y] belongs to [keep] and [y'] is the representative of the class *)
val apply_subst : _ XMap.t -> t -> t
(** [apply_subst subst uf] eliminate all variables in the domain of [subst] from [uf], keeping the
smallest representative not in the domain of [subst] for each class. Classes without any such
elements are kept intact. *)
val filter_not_in_closed_set : keep:XSet.t -> t -> t val filter_not_in_closed_set : keep:XSet.t -> t -> t
(** only keep items in [keep], assuming that [keep] is closed under the relation, i.e. that if an (** only keep items in [keep], assuming that [keep] is closed under the relation, i.e. that if an
item [x] is in [keep] then so are all the [y] such that [x=y] according to the relation *) item [x] is in [keep] then so are all the [y] such that [x=y] according to the relation *)

@ -8,7 +8,7 @@
(public_name infer.IStdlib) (public_name infer.IStdlib)
(flags (flags
(:standard -open Core)) (:standard -open Core))
(libraries ANSITerminal core str yojson) (libraries ANSITerminal core iter str yojson)
(preprocess (preprocess
(pps ppx_compare))) (pps ppx_compare)))

@ -928,6 +928,8 @@ module Atom = struct
fold_map_terms a ~init ~f:(fun acc t -> Term.fold_subst_variables t ~init:acc ~f) fold_map_terms a ~init ~f:(fun acc t -> Term.fold_subst_variables t ~init:acc ~f)
let subst_variables l ~f = fold_subst_variables l ~init:() ~f:(fun () v -> ((), f v)) |> snd
let has_var_notin vars atom = let has_var_notin vars atom =
let t1, t2 = get_terms atom in let t1, t2 = get_terms atom in
Term.has_var_notin vars t1 || Term.has_var_notin vars t2 Term.has_var_notin vars t1 || Term.has_var_notin vars t2
@ -958,11 +960,12 @@ let sat_of_eval_result (eval_result : Atom.eval_result) =
module VarUF = module VarUF =
UnionFind.Make UnionFind.Make
(struct (struct
type t = Var.t [@@deriving compare] type t = Var.t [@@deriving compare, equal]
let is_simpler_than (v1 : Var.t) (v2 : Var.t) = (v1 :> int) < (v2 :> int) let is_simpler_than (v1 : Var.t) (v2 : Var.t) = (v1 :> int) < (v2 :> int)
end) end)
(Var.Set) (Var.Set)
(Var.Map)
type new_eq = EqZero of Var.t | Equal of Var.t * Var.t type new_eq = EqZero of Var.t | Equal of Var.t * Var.t
@ -1407,6 +1410,60 @@ let and_fold_subst_variables phi0 ~up_to_f:phi_foreign ~init ~f:f_var =
(acc, {known; pruned; both}, new_eqs) (acc, {known; pruned; both}, new_eqs)
module QuantifierElimination : sig
val eliminate_vars : keep:Var.Set.t -> t -> t SatUnsat.t
(** [eliminate_vars ~keep φ] substitutes every variable [x] in [φ] with [x'] whenever [x'] is a
distinguished representative of the equivalence class of [x] in [φ] such that [x' keep] *)
end = struct
exception Contradiction
let subst_f subst x = match Var.Map.find_opt x subst with Some y -> y | None -> x
let targetted_subst_var subst_var x = VarSubst (subst_f subst_var x)
let subst_var_linear_eqs subst linear_eqs =
Var.Map.fold
(fun x l new_map ->
let x' = subst_f subst x in
let l' = LinArith.subst_variables ~f:(targetted_subst_var subst) l in
match LinArith.solve_eq (LinArith.of_var x') l' with
| Unsat ->
L.d_printfln "Contradiction found: %a=%a became %a=%a with is Unsat" Var.pp x
(LinArith.pp Var.pp) l Var.pp x' (LinArith.pp Var.pp) l' ;
raise Contradiction
| Sat None ->
new_map
| Sat (Some (x'', l'')) ->
Var.Map.add x'' l'' new_map )
linear_eqs Var.Map.empty
let subst_var_atoms subst atoms =
Atom.Set.fold
(fun atom atoms ->
let atom' = Atom.subst_variables ~f:(targetted_subst_var subst) atom in
Atom.Set.add atom' atoms )
atoms Atom.Set.empty
let subst_var_formula subst {Formula.var_eqs; linear_eqs; atoms} =
{ Formula.var_eqs= VarUF.apply_subst subst var_eqs
; linear_eqs= subst_var_linear_eqs subst linear_eqs
; atoms= subst_var_atoms subst atoms }
let subst_var subst phi =
{ known= subst_var_formula subst phi.known
; pruned= subst_var_atoms subst phi.pruned
; both= subst_var_formula subst phi.both }
let eliminate_vars ~keep phi =
let subst = VarUF.reorient ~keep phi.both.var_eqs in
try Sat (subst_var subst phi) with Contradiction -> Unsat
end
module DeadVariables = struct
(** Intermediate step of [simplify]: build an (undirected) graph between variables where an edge (** Intermediate step of [simplify]: build an (undirected) graph between variables where an edge
between two variables means that they appear together in an atom, a linear equation, or an between two variables means that they appear together in an atom, a linear equation, or an
equivalence class. *) equivalence class. *)
@ -1485,15 +1542,12 @@ let get_reachable_from graph vs =
Caml.Hashtbl.to_seq_keys reachable |> Var.Set.of_seq Caml.Hashtbl.to_seq_keys reachable |> Var.Set.of_seq
let simplify ~keep phi = (** Get rid of atoms when they contain only variables that do not appear in atoms mentioning
let open SatUnsat.Import in variables in [keep], or variables appearing in atoms together with variables in [keep], and so
let+ phi, new_eqs = normalize phi in on. In other words, the variables to keep are all the ones transitively reachable from
L.d_printfln_escaped "Simplifying %a wrt {%a}" pp phi Var.Set.pp keep ; variables in [keep] in the graph connecting two variables whenever they appear together in a
(* Get rid of atoms when they contain only variables that do not appear in atoms mentioning same atom of the formula. *)
variables in [keep], or variables appearing in atoms together with variables in [keep], and let eliminate ~keep phi =
so on. In other words, the variables to keep are all the ones transitively reachable from
variables in [keep] in the graph connecting two variables whenever they appear together in
a same atom of the formula. *)
(* We only consider [phi.both] when building the relation. Considering [phi.known] and (* We only consider [phi.both] when building the relation. Considering [phi.known] and
[phi.pruned] as well could lead to us keeping more variables around, but that's not necessarily [phi.pruned] as well could lead to us keeping more variables around, but that's not necessarily
a good idea. Ignoring them means we err on the side of reporting potentially slightly more a good idea. Ignoring them means we err on the side of reporting potentially slightly more
@ -1517,7 +1571,18 @@ let simplify ~keep phi =
let known = simplify_phi phi.known in let known = simplify_phi phi.known in
let both = simplify_phi phi.both in let both = simplify_phi phi.both in
let pruned = Atom.Set.filter filter_atom phi.pruned in let pruned = Atom.Set.filter filter_atom phi.pruned in
({known; pruned; both}, new_eqs) {known; pruned; both}
end
let simplify ~keep phi =
let open SatUnsat.Import in
let* phi, new_eqs = normalize phi in
L.d_printfln_escaped "Simplifying %a wrt {%a}" pp phi Var.Set.pp keep ;
(* get rid of as many variables as possible *)
let+ phi = QuantifierElimination.eliminate_vars ~keep phi in
(* TODO: doing [QuantifierElimination.eliminate_vars; DeadVariables.eliminate] a few times may
eliminate even more variables *)
(DeadVariables.eliminate ~keep phi, new_eqs)
let is_known_zero phi v = let is_known_zero phi v =

@ -231,6 +231,13 @@ let%test_module "normalization" =
let%test_module "variable elimination" = let%test_module "variable elimination" =
( module struct ( module struct
let%expect_test _ =
simplify ~keep:[x_var; y_var] (x = y) ;
[%expect
{|
known=x=y && true (no linear) && true (no atoms), pruned=true (no atoms),
both=x=y && true (no linear) && true (no atoms)|}]
let%expect_test _ = let%expect_test _ =
simplify ~keep:[x_var] (x = i 0 && y = i 1 && z = i 2 && w = i 3) ; simplify ~keep:[x_var] (x = i 0 && y = i 1 && z = i 2 && w = i 3) ;
[%expect [%expect
@ -242,8 +249,8 @@ let%test_module "variable elimination" =
simplify ~keep:[x_var] (x = y + i 1 && x = i 0) ; simplify ~keep:[x_var] (x = y + i 1 && x = i 0) ;
[%expect [%expect
{| {|
known=x=v6 && x = 0 && true (no atoms), pruned=true (no atoms), known=true (no var=var) && x = 0 && true (no atoms), pruned=true (no atoms),
both=x=v6 && x = 0 && true (no atoms)|}] both=true (no var=var) && x = 0 && true (no atoms)|}]
let%expect_test _ = let%expect_test _ =
simplify ~keep:[y_var] (x = y + i 1 && x = i 0) ; simplify ~keep:[y_var] (x = y + i 1 && x = i 0) ;
@ -257,20 +264,20 @@ let%test_module "variable elimination" =
simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ; simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ;
[%expect [%expect
{| {|
known=x=v6 z=w=v7 && x = y -1 z = -1 && true (no atoms), pruned=true (no atoms), known=true (no var=var) && x = y -1 z = -1 && true (no atoms), pruned=true (no atoms),
both=x=v6 z=w=v7 && x = y -1 z = -1 && true (no atoms)|}] both=true (no var=var) && x = y -1 z = -1 && true (no atoms)|}]
let%expect_test _ = let%expect_test _ =
simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ; simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ;
[%expect [%expect
{| {|
known=x=v6 v=v9 known=true (no var=var)
&& &&
x = -v + v7 +1 y = -v7 z = -v + 2·v7 +1 w = v -1 x = -v + v7 +1 y = -v7 z = -v + 2·v7 +1 w = v -1
&& &&
true (no atoms), true (no atoms),
pruned=true (no atoms), pruned=true (no atoms),
both=x=v6 v=v9 both=true (no var=var)
&& &&
x = -v + v7 +1 y = -v7 z = -v + 2·v7 +1 w = v -1 x = -v + v7 +1 y = -v7 z = -v + 2·v7 +1 w = v -1
&& &&
@ -280,8 +287,8 @@ let%test_module "variable elimination" =
simplify ~keep:[x_var; y_var] (x = y + i 4 && x = w && y = z) ; simplify ~keep:[x_var; y_var] (x = y + i 4 && x = w && y = z) ;
[%expect [%expect
{| {|
known=x=w=v6 y=z && x = y +4 && true (no atoms), pruned=true (no atoms), known=true (no var=var) && x = y +4 && true (no atoms), pruned=true (no atoms),
both=x=w=v6 y=z && x = y +4 && true (no atoms)|}] both=true (no var=var) && x = y +4 && true (no atoms)|}]
end ) end )
let%test_module "non-linear simplifications" = let%test_module "non-linear simplifications" =

Loading…
Cancel
Save