[pulse] add a Linear variant to terms

Summary:
More scaffolding, nothing creates `Linear _` terms yet. Some changes to
variables substitution to allow substituting variables for linear terms
(as well as constants and other variables).

Reviewed By: skcho

Differential Revision: D23241461

fbshipit-source-id: fc870255e
master
Jules Villard 4 years ago committed by Facebook GitHub Bot
parent 45894a7dd9
commit 69995cebb6

@ -41,7 +41,9 @@ end
(** Linear Arithmetic*)
module LinArith : sig
(** linear combination of variables, eg [2·x + 3/4·y + 12] *)
type t
type t [@@deriving compare]
type subst_target = QSubst of Q.t | VarSubst of Var.t | LinSubst of t
val pp : (F.formatter -> Var.t -> unit) -> F.formatter -> t -> unit
@ -77,14 +79,16 @@ module LinArith : sig
val subst : Var.t -> Var.t -> t -> t
val subst_vars : f:(Var.t -> t) -> t -> t
val get_variables : t -> Var.t Seq.t
val fold_map_variables : t -> init:'a -> f:('a -> Var.t -> 'a * Var.t) -> 'a * t
val fold_map_variables : t -> init:'a -> f:('a -> Var.t -> 'a * subst_target) -> 'a * t
val map_variables : t -> f:(Var.t -> subst_target) -> t
end = struct
(** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *)
type t = Q.t Var.Map.t * Q.t
type t = Q.t Var.Map.t * Q.t [@@deriving compare]
type subst_target = QSubst of Q.t | VarSubst of Var.t | LinSubst of t
let pp pp_var fmt (vs, c) =
if Var.Map.is_empty vs then Q.pp_print fmt c
@ -181,30 +185,30 @@ end = struct
(vs', c)
let subst_vars ~f (vs, c) = Var.Map.fold (fun v q l -> mult q (f v) |> add l) vs (Var.Map.empty, c)
let of_subst_target = function QSubst q -> of_q q | VarSubst v -> of_var v | LinSubst l -> l
let fold_map_variables (vs_foreign, c) ~init ~f =
let acc_f, vs =
let acc_f, l =
Var.Map.fold
(fun v_foreign q0 (acc_f, vs) ->
let acc_f, v = f acc_f v_foreign in
let vs =
match Var.Map.find_opt v vs with
| None ->
Var.Map.add v q0 vs
| Some q ->
let q' = Q.add q q0 in
if Q.is_zero q' then Var.Map.remove v vs else Var.Map.add v q vs
in
(acc_f, vs) )
vs_foreign (init, Var.Map.empty)
(fun v_foreign q0 (acc_f, l) ->
let acc_f, op = f acc_f v_foreign in
(acc_f, add (mult q0 (of_subst_target op)) l) )
vs_foreign
(init, (Var.Map.empty, c))
in
(acc_f, (vs, c))
(acc_f, l)
let map_variables l ~f = fold_map_variables l ~init:() ~f:(fun () v -> ((), f v)) |> snd
let get_variables (vs, _) = Var.Map.to_seq vs |> Seq.map fst
end
type subst_target = LinArith.subst_target =
| QSubst of Q.t
| VarSubst of Var.t
| LinSubst of LinArith.t
(** Expressive term structure to be able to express all of SIL, but the main smarts of the formulas
are for the equality between variables and linear arithmetic subsets. Terms (and atoms, below)
are kept as a last-resort for when outside that fragment. *)
@ -212,6 +216,7 @@ module Term = struct
type t =
| Const of Q.t
| Var of Var.t
| Linear of LinArith.t
| Add of t * t
| Minus of t
| LessThan of t * t
@ -242,6 +247,8 @@ module Term = struct
(* negative and/or a fraction *) true
| Var _ ->
false
| Linear _ ->
false
| Minus _
| BitNot _
| Not _
@ -272,6 +279,8 @@ module Term = struct
pp_var fmt v
| Const c ->
Q.pp_print fmt c
| Linear l ->
F.fprintf fmt "[%a]" (LinArith.pp pp_var) l
| Minus t ->
F.fprintf fmt "-%a" (pp_paren pp_var ~needs_paren) t
| BitNot t ->
@ -314,13 +323,20 @@ module Term = struct
F.fprintf fmt "%a≠%a" (pp_paren pp_var ~needs_paren) t1 (pp_paren pp_var ~needs_paren) t2
let of_intlit i = Const (Q.of_bigint (IntLit.to_big_int i))
let of_q q = Const q
let of_operand = function
| AbstractValueOperand v ->
Var v
| LiteralOperand i ->
IntLit.to_big_int i |> Q.of_bigint |> of_q
let of_operand = function AbstractValueOperand v -> Var v | LiteralOperand i -> of_intlit i
let one = Const Q.one
let of_subst_target = function QSubst q -> of_q q | VarSubst v -> Var v | LinSubst l -> Linear l
let zero = Const Q.zero
let one = of_q Q.one
let zero = of_q Q.zero
let of_bool b = if b then one else zero
@ -375,7 +391,7 @@ module Term = struct
(** Fold [f] on the strict sub-terms of [t], if any. Preserve physical equality if [f] does. *)
let fold_map_direct_subterms t ~init ~f =
match t with
| Var _ | Const _ ->
| Var _ | Const _ | Linear _ ->
(init, t)
| Minus t_not | BitNot t_not | Not t_not ->
let acc, t_not' = f init t_not in
@ -391,6 +407,7 @@ module Term = struct
Not t_not'
| Var _
| Const _
| Linear _
| Add _
| Mult _
| Div _
@ -460,7 +477,7 @@ module Term = struct
Equal (t1', t2')
| NotEqual _ ->
NotEqual (t1', t2')
| Var _ | Const _ | Minus _ | BitNot _ | Not _ ->
| Var _ | Const _ | Linear _ | Minus _ | BitNot _ | Not _ ->
assert false
in
(acc, t')
@ -473,15 +490,18 @@ module Term = struct
let rec fold_map_variables t ~init ~f =
match t with
| Var v ->
let acc, t_v = f init v in
let t' = match t_v with Var v' when Var.equal v v' -> t | _ -> t_v in
let acc, op = f init v in
let t' = match op with VarSubst v' when Var.equal v v' -> t | _ -> of_subst_target op in
(acc, t')
| Linear l ->
let acc, l' = LinArith.fold_map_variables l ~init ~f in
(acc, Linear l')
| _ ->
fold_map_direct_subterms t ~init ~f:(fun acc t' -> fold_map_variables t' ~init:acc ~f)
let fold_variables t ~init ~f =
fold_map_variables t ~init ~f:(fun acc v -> (f acc v, Var v)) |> fst
fold_map_variables t ~init ~f:(fun acc v -> (f acc v, VarSubst v)) |> fst
let iter_variables t ~f = fold_variables t ~init:() ~f:(fun () v -> f v)
@ -513,6 +533,8 @@ module Term = struct
match t0 with
| Const _ | Var _ ->
t0
| Linear l ->
LinArith.get_as_const l |> Option.value_map ~default:t0 ~f:(fun c -> Const c)
| Minus t' ->
q_map t' Q.(mul minus_one)
| Add (t1, t2) ->
@ -646,6 +668,8 @@ module Term = struct
Some (LinArith.of_var v)
| Const c ->
Some (LinArith.of_q c)
| Linear l ->
Some l
| Minus t ->
let+ l = to_lin_arith t in
LinArith.minus l
@ -923,13 +947,13 @@ end = struct
(** substitute vars in [l] *once* with their linear form to discover more simplification
opportunities *)
let apply phi l =
LinArith.subst_vars l ~f:(fun v ->
LinArith.map_variables l ~f:(fun v ->
let repr = (get_repr phi v :> Var.t) in
match Var.Map.find_opt repr phi.linear_eqs with
| None ->
LinArith.of_var repr
VarSubst repr
| Some l' ->
l' )
LinSubst l' )
let rec solve_eq ~fuel t1 t2 phi =
@ -1028,7 +1052,7 @@ end = struct
let open Option.Monad_infix in
Var.Map.find_opt v_canon phi.linear_eqs >>= LinArith.get_as_const
in
match q_opt with None -> Var v_canon | Some q -> Const q )
match q_opt with None -> VarSubst v_canon | Some q -> QSubst q )
in
let atom' = Atom.map_terms atom ~f:(fun t -> normalize_term phi t) in
match Atom.eval atom' with
@ -1205,7 +1229,11 @@ let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f =
IContainer.fold_of_pervasives_map_fold Var.Map.fold phi_foreign.linear_eqs ~init:acc
~f:(fun (acc_f, phi) (v_foreign, l_foreign) ->
let acc_f, v = f acc_f v_foreign in
let acc_f, l = LinArith.fold_map_variables l_foreign ~init:acc_f ~f in
let acc_f, l =
LinArith.fold_map_variables l_foreign ~init:acc_f ~f:(fun acc v ->
let acc', v' = f acc v in
(acc', VarSubst v') )
in
let phi = Normalizer.and_var_linarith v l phi |> sat_value_exn in
(acc_f, phi) )
in
@ -1215,7 +1243,7 @@ let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f =
let acc_f, atom =
Atom.fold_map_variables atom_foreign ~init:acc_f ~f:(fun acc_f v ->
let acc_f, v' = f acc_f v in
(acc_f, Term.Var v') )
(acc_f, VarSubst v') )
in
let phi = and_atom atom phi in
(acc_f, phi) )

Loading…
Cancel
Save