Inferbo: rewrote subst

Summary:
The motivation is in a following diff: ensuring symbols do not cross procedure boundaries.
- This diff rewrites `Bound.subst` to be based on the substituted bound rather than folding on the symbol map.
- This way we are sure all symbols are substituted and no symbols from another procedure remains in the result.
- All cases from the previous version should still be here, I think I added a few constant approximations of minmax substituted with minmax (that would be return Top).

Side-effects (good):
- `mult_const` has also more constant approximations for minmax,
- substitution should be faster

Reviewed By: jvillard

Differential Revision: D8369993

fbshipit-source-id: 6ed8be8
master
Mehdi Bouaziz 7 years ago committed by Facebook Github Bot
parent 4624ff48d1
commit 3c240fc880

@ -63,6 +63,8 @@ module Symbol = struct
let is_unsigned : t -> bool = fun x -> x.unsigned let is_unsigned : t -> bool = fun x -> x.unsigned
end end
exception Symbol_not_found of Symbol.t
module SymbolMap = struct module SymbolMap = struct
include PrettyPrintable.MakePPMap (Symbol) include PrettyPrintable.MakePPMap (Symbol)
@ -244,16 +246,10 @@ module SymLinear = struct
let is_empty : t -> bool = fun x -> M.is_empty x let is_empty : t -> bool = fun x -> M.is_empty x
let remove : Symbol.t -> t -> t = fun s x -> M.remove s x
let singleton_one : Symbol.t -> t = fun s -> M.singleton s NonZeroInt.one let singleton_one : Symbol.t -> t = fun s -> M.singleton s NonZeroInt.one
let singleton_minus_one : Symbol.t -> t = fun s -> M.singleton s NonZeroInt.minus_one let singleton_minus_one : Symbol.t -> t = fun s -> M.singleton s NonZeroInt.minus_one
let find : Symbol.t -> t -> NonZeroInt.t = fun s x -> M.find s x
let use_symbol : Symbol.t -> t -> bool = fun s x -> M.mem s x
let is_le_zero : t -> bool = let is_le_zero : t -> bool =
fun x -> M.for_all (fun s v -> Symbol.is_unsigned s && NonZeroInt.is_negative v) x fun x -> M.for_all (fun s v -> Symbol.is_unsigned s && NonZeroInt.is_negative v) x
@ -310,23 +306,6 @@ module SymLinear = struct
M.union plus_coeff x y M.union plus_coeff x y
(** [se1] * [c] + [se2] *)
let mult_const_plus : t -> NonZeroInt.t -> t -> t =
fun se1 c se2 ->
let f _ (coeff1: NonZeroInt.t option) (coeff2: NonZeroInt.t option) =
match (coeff1, coeff2) with
| None, None ->
None
| None, (Some _ as some_v) ->
some_v
| Some v, None ->
Some NonZeroInt.(v * c)
| Some v1, Some v2 ->
NonZeroInt.(v1 * c |> plus v2)
in
M.merge f se1 se2
let mult_const : NonZeroInt.t -> t -> t = fun n x -> M.map (NonZeroInt.( * ) n) x let mult_const : NonZeroInt.t -> t -> t = fun n x -> M.map (NonZeroInt.( * ) n) x
let exact_div_const_exn : t -> NonZeroInt.t -> t = let exact_div_const_exn : t -> NonZeroInt.t -> t =
@ -344,6 +323,11 @@ module SymLinear = struct
None None
let fold m ~init ~f =
let f s coeff acc = f acc s coeff in
M.fold f m init
let get_one_symbol_opt : t -> Symbol.t option = one_symbol_of_coeff NonZeroInt.one let get_one_symbol_opt : t -> Symbol.t option = one_symbol_of_coeff NonZeroInt.one
let get_mone_symbol_opt : t -> Symbol.t option = one_symbol_of_coeff NonZeroInt.minus_one let get_mone_symbol_opt : t -> Symbol.t option = one_symbol_of_coeff NonZeroInt.minus_one
@ -385,6 +369,12 @@ module SymLinear = struct
let int_ub x = if is_le_zero x then Some 0 else None let int_ub x = if is_le_zero x then Some 0 else None
end end
module BoundEnd = struct
type t = LowerBound | UpperBound
let neg = function LowerBound -> UpperBound | UpperBound -> LowerBound
end
module Bound = struct module Bound = struct
type sign = Plus | Minus [@@deriving compare] type sign = Plus | Minus [@@deriving compare]
@ -445,6 +435,8 @@ module Bound = struct
F.fprintf fmt "%a(%d, %a)" MinMax.pp m d Symbol.pp x F.fprintf fmt "%a(%d, %a)" MinMax.pp m d Symbol.pp x
let of_bound_end = function BoundEnd.LowerBound -> MInf | BoundEnd.UpperBound -> PInf
let of_int : int -> t = fun n -> Linear (n, SymLinear.empty) let of_int : int -> t = fun n -> Linear (n, SymLinear.empty)
let minus_one = of_int (-1) let minus_one = of_int (-1)
@ -462,14 +454,6 @@ module Bound = struct
true true
let eq_symbol : Symbol.t -> t -> bool =
fun s -> function
| Linear (0, se) -> (
match SymLinear.get_one_symbol_opt se with None -> false | Some s' -> Symbol.equal s s' )
| _ ->
false
let lift_symlinear : (SymLinear.t -> 'a option) -> t -> 'a option = let lift_symlinear : (SymLinear.t -> 'a option) -> t -> 'a option =
fun f -> function Linear (0, se) -> f se | _ -> None fun f -> function Linear (0, se) -> f se | _ -> None
@ -504,103 +488,6 @@ module Bound = struct
else MinMax (c, sign, m, d, s) else MinMax (c, sign, m, d, s)
let use_symbol : Symbol.t -> t -> bool =
fun s -> function
| PInf | MInf ->
false
| Linear (_, se) ->
SymLinear.use_symbol s se
| MinMax (_, _, _, _, s') ->
Symbol.equal s s'
type subst_pos_t = SubstLowerBound | SubstUpperBound
(* [subst1] substitutes [s] in [x0] to [y0].
- If the precise result is expressible by the domain, the
function returns it.
- If the precise result is not expressible, but a compromized
value, with regard to [subst_pos], is expressible by the domain,
the function returns the compromized value.
- Otherwise, it returns the default values, -oo for lower bound and
+oo for upper bound. *)
let subst1
: subst_pos:subst_pos_t -> t bottom_lifted -> Symbol.t -> t bottom_lifted -> t bottom_lifted =
let get_default = function SubstLowerBound -> MInf | SubstUpperBound -> PInf in
let subst1_linears c1 se1 s c2 se2 =
let coeff = SymLinear.find s se1 in
let c' = c1 + ((coeff :> int) * c2) in
let se1 = SymLinear.remove s se1 in
let se' = SymLinear.mult_const_plus se2 coeff se1 in
Linear (c', se')
in
let subst1_non_bottom ~subst_pos x s y =
match (x, y) with
| Linear (c1, se1), Linear (c2, se2) ->
subst1_linears c1 se1 s c2 se2
| Linear (c1, se1), MinMax (c2, sign, min_max, d2, s2) when SymLinear.is_one_symbol se1 ->
assert (Symbol.equal (SymLinear.get_one_symbol se1) s) ;
MinMax (c1 + c2, sign, min_max, d2, s2)
| Linear (c1, se1), MinMax (c2, sign, min_max, d2, s2) when SymLinear.is_mone_symbol se1 ->
assert (Symbol.equal (SymLinear.get_mone_symbol se1) s) ;
MinMax (c1 - c2, Sign.neg sign, min_max, d2, s2)
| Linear (c1, se1), MinMax (c2, bop, min_max, d2, _) ->
let coeff = SymLinear.find s se1 in
let compromisable =
match (subst_pos, min_max) with
| SubstLowerBound, Max | SubstUpperBound, Min ->
NonZeroInt.is_positive coeff
| SubstUpperBound, Max | SubstLowerBound, Min ->
NonZeroInt.is_negative coeff
in
if compromisable then subst1_linears c1 se1 s (Sign.eval_int bop c2 d2) SymLinear.empty
else get_default subst_pos
| MinMax (_, Plus, Min, _, _), MInf ->
MInf
| MinMax (_, Minus, Min, _, _), MInf ->
PInf
| MinMax (_, Plus, Max, _, _), PInf ->
PInf
| MinMax (_, Minus, Max, _, _), PInf ->
MInf
| MinMax (c, sign, Min, d, _), PInf | MinMax (c, sign, Max, d, _), MInf ->
of_int (Sign.eval_int sign c d)
| MinMax (c1, sign, min_max, d1, _), Linear (c2, se) when SymLinear.is_zero se ->
of_int (Sign.eval_int sign c1 (MinMax.eval_int min_max d1 c2))
| MinMax (c, sign, m, d, _), _ when is_one_symbol y ->
mk_MinMax (c, sign, m, d, get_one_symbol y)
| MinMax (c, sign, m, d, _), _ when is_mone_symbol y ->
mk_MinMax (c, Sign.neg sign, MinMax.neg m, -d, get_mone_symbol y)
| MinMax (c1, sign1, min_max, d1, _), MinMax (c2, Plus, min_max', d2, s')
when MinMax.equal min_max min_max' ->
let c = Sign.eval_int sign1 c1 c2 in
let d = MinMax.eval_int min_max (d1 - c2) d2 in
mk_MinMax (c, sign1, min_max, d, s')
| MinMax (c1, sign1, min_max, d1, _), MinMax (c2, Minus, min_max', d2, s')
when MinMax.equal (MinMax.neg min_max) min_max' ->
let c = Sign.eval_int sign1 c1 c2 in
let d = MinMax.eval_int min_max' (c2 - d1) d2 in
mk_MinMax (c, Sign.neg sign1, min_max', d, s')
| _ ->
get_default subst_pos
in
fun ~subst_pos x0 s y0 ->
match (x0, y0) with
| Bottom, _ ->
x0
| NonBottom x, _ when eq_symbol s x ->
y0
| NonBottom x, _ when not (use_symbol s x) ->
x0
| NonBottom _, Bottom ->
NonBottom (get_default subst_pos)
| NonBottom x, NonBottom y ->
NonBottom (subst1_non_bottom ~subst_pos x s y)
let int_ub_of_minmax = function let int_ub_of_minmax = function
| MinMax (c, Plus, Min, d, _) -> | MinMax (c, Plus, Min, d, _) ->
Some (c + d) Some (c + d)
@ -627,6 +514,13 @@ module Bound = struct
assert false assert false
let int_of_minmax = function
| BoundEnd.LowerBound ->
int_lb_of_minmax
| BoundEnd.UpperBound ->
int_ub_of_minmax
let int_lb = function let int_lb = function
| MInf -> | MInf ->
None None
@ -859,26 +753,6 @@ module Bound = struct
fun x -> match x with Linear (c, y) when SymLinear.is_zero y -> Some c | _ -> None fun x -> match x with Linear (c, y) when SymLinear.is_zero y -> Some c | _ -> None
(* substitution symbols in ``x'' with respect to ``map'' *)
let subst : subst_pos:subst_pos_t -> t -> t bottom_lifted SymbolMap.t -> t bottom_lifted =
fun ~subst_pos x map ->
let subst_helper s y x =
let y' =
match y with
| Bottom ->
Bottom
| NonBottom r ->
NonBottom (if Symbol.is_unsigned s then ub ~default:r zero r else r)
in
subst1 ~subst_pos x s y'
in
SymbolMap.fold subst_helper map (NonBottom x)
let subst_lb x map = subst ~subst_pos:SubstLowerBound x map
let subst_ub x map = subst ~subst_pos:SubstUpperBound x map
let plus_common : f:(t -> t -> t) -> t -> t -> t = let plus_common : f:(t -> t -> t) -> t -> t -> t =
fun ~f x y -> fun ~f x y ->
match (x, y) with match (x, y) with
@ -922,8 +796,10 @@ module Bound = struct
PInf ) PInf )
let mult_const : default:t -> NonZeroInt.t -> t -> t = let plus = function BoundEnd.LowerBound -> plus_l | BoundEnd.UpperBound -> plus_u
fun ~default n x ->
let mult_const : BoundEnd.t -> NonZeroInt.t -> t -> t =
fun bound_end n x ->
match x with match x with
| MInf -> | MInf ->
if NonZeroInt.is_positive n then MInf else PInf if NonZeroInt.is_positive n then MInf else PInf
@ -931,13 +807,19 @@ module Bound = struct
if NonZeroInt.is_positive n then PInf else MInf if NonZeroInt.is_positive n then PInf else MInf
| Linear (c, x') -> | Linear (c, x') ->
Linear (c * (n :> int), SymLinear.mult_const n x') Linear (c * (n :> int), SymLinear.mult_const n x')
| _ -> | MinMax _ ->
default let int_bound =
let bound_end' =
if NonZeroInt.is_positive n then bound_end else BoundEnd.neg bound_end
in
int_of_minmax bound_end' x
in
match int_bound with Some i -> of_int (i * (n :> int)) | None -> of_bound_end bound_end
let mult_const_l = mult_const ~default:MInf let mult_const_l = mult_const BoundEnd.LowerBound
let mult_const_u = mult_const ~default:PInf let mult_const_u = mult_const BoundEnd.UpperBound
let neg : t -> t = function let neg : t -> t = function
| MInf -> | MInf ->
@ -989,6 +871,119 @@ module Bound = struct
let is_not_infty : t -> bool = function MInf | PInf -> false | _ -> true let is_not_infty : t -> bool = function MInf | PInf -> false | _ -> true
let lift1 : (t -> t) -> t bottom_lifted -> t bottom_lifted =
fun f x -> match x with Bottom -> Bottom | NonBottom x -> NonBottom (f x)
let lift2 : (t -> t -> t) -> t bottom_lifted -> t bottom_lifted -> t bottom_lifted =
fun f x y ->
match (x, y) with
| Bottom, _ | _, Bottom ->
Bottom
| NonBottom x, NonBottom y ->
NonBottom (f x y)
(** Substitutes ALL symbols in [x] with respect to [map]. Throws [Symbol_not_found] if a symbol in [x] can't be found in [map]. Under/over-Approximate as good as possible according to [subst_pos]. *)
let subst_exn : subst_pos:BoundEnd.t -> t -> t bottom_lifted SymbolMap.t -> t bottom_lifted =
fun ~subst_pos x map ->
let get_exn s =
match SymbolMap.find s map with
| NonBottom x when Symbol.is_unsigned s ->
NonBottom (ub ~default:x zero x)
| x ->
x
in
let get_mult_const s coeff =
try
if NonZeroInt.is_one coeff then get_exn s
else if NonZeroInt.is_minus_one coeff then get_exn s |> lift1 neg
else
match SymbolMap.find s map with
| Bottom ->
Bottom
| NonBottom x ->
let x = mult_const subst_pos coeff x in
if Symbol.is_unsigned s then NonBottom (ub ~default:x zero x) else NonBottom x
with Caml.Not_found ->
(* For unsigned symbols, we can over/under-approximate with zero depending on [subst_pos] and the sign of the coefficient. *)
match (Symbol.is_unsigned s, subst_pos, NonZeroInt.is_positive coeff) with
| true, BoundEnd.LowerBound, true | true, BoundEnd.UpperBound, false ->
NonBottom zero
| _ ->
raise (Symbol_not_found s)
in
match x with
| MInf | PInf ->
NonBottom x
| Linear (c, se) ->
SymLinear.fold se ~init:(NonBottom (of_int c)) ~f:(fun acc s coeff ->
lift2 (plus subst_pos) acc (get_mult_const s coeff) )
| MinMax (c, sign, min_max, d, s) ->
match get_exn s with
| Bottom ->
Bottom
| exception Caml.Not_found -> (
match int_of_minmax subst_pos x with
| Some i ->
NonBottom (of_int i)
| None ->
raise (Symbol_not_found s) )
| NonBottom x' ->
let res =
match (sign, min_max, x') with
| Plus, Min, MInf | Minus, Max, PInf ->
MInf
| Plus, Max, PInf | Minus, Min, MInf ->
PInf
| sign, Min, PInf | sign, Max, MInf ->
of_int (Sign.eval_int sign c d)
| _, _, Linear (c2, se)
-> (
if SymLinear.is_zero se then
of_int (Sign.eval_int sign c (MinMax.eval_int min_max d c2))
else if SymLinear.is_one_symbol se then
mk_MinMax
(Sign.eval_int sign c c2, sign, min_max, d - c2, SymLinear.get_one_symbol se)
else if SymLinear.is_mone_symbol se then
mk_MinMax
( Sign.eval_int sign c c2
, Sign.neg sign
, MinMax.neg min_max
, c2 - d
, SymLinear.get_mone_symbol se )
else
match int_of_minmax subst_pos x with
| Some i ->
of_int i
| None ->
of_bound_end subst_pos )
| _, _, MinMax (c2, sign2, min_max2, d2, s2) ->
match (min_max, sign2, min_max2) with
| Min, Plus, Min | Max, Plus, Max ->
let c' = Sign.eval_int sign c c2 in
let d' = MinMax.eval_int min_max (d - c2) d2 in
mk_MinMax (c', sign, min_max, d', s2)
| Min, Minus, Max | Max, Minus, Min ->
let c' = Sign.eval_int sign c c2 in
let d' = MinMax.eval_int min_max2 (c2 - d) d2 in
mk_MinMax (c', Sign.neg sign, min_max2, d', s2)
| _ ->
let bound_end =
match sign with Plus -> subst_pos | Minus -> BoundEnd.neg subst_pos
in
of_int
(Sign.eval_int sign c
(MinMax.eval_int min_max d
(int_of_minmax bound_end x' |> Option.value ~default:d)))
in
NonBottom res
let subst_lb_exn x map = subst_exn ~subst_pos:BoundEnd.LowerBound x map
let subst_ub_exn x map = subst_exn ~subst_pos:BoundEnd.UpperBound x map
end end
type ('c, 's) valclass = Constant of 'c | Symbolic of 's | ValTop type ('c, 's) valclass = Constant of 'c | Symbolic of 's | ValTop
@ -1023,8 +1018,8 @@ module NonNegativeBound = struct
Constant (NonNegativeInt.of_int_exn c) Constant (NonNegativeInt.of_int_exn c)
let subst b map = let subst_exn b map =
match Bound.subst_ub b map with match Bound.subst_ub_exn b map with
| Bottom -> | Bottom ->
Constant NonNegativeInt.zero Constant NonNegativeInt.zero
| NonBottom b -> | NonBottom b ->
@ -1038,7 +1033,8 @@ module type NonNegativeSymbol = sig
val int_ub : t -> NonNegativeInt.t option val int_ub : t -> NonNegativeInt.t option
val subst : t -> Bound.t bottom_lifted SymbolMap.t -> (NonNegativeInt.t, t) valclass val subst_exn : t -> Bound.t bottom_lifted SymbolMap.t -> (NonNegativeInt.t, t) valclass
(** may throw Symbol_not_found *)
val pp : F.formatter -> t -> unit val pp : F.formatter -> t -> unit
end end
@ -1254,7 +1250,7 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
let rec subst {const; terms} map = let rec subst {const; terms} map =
M.fold M.fold
(fun s p acc -> (fun s p acc ->
match S.subst s map with match S.subst_exn s map with
| Constant c -> ( | Constant c -> (
match PositiveInt.of_int (c :> int) with match PositiveInt.of_int (c :> int) with
| None -> | None ->
@ -1267,7 +1263,9 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
if is_zero p then acc else raise ReturnTop if is_zero p then acc else raise ReturnTop
| Symbolic s -> | Symbolic s ->
let p = subst p map in let p = subst p map in
mult_symb p s |> plus acc ) mult_symb p s |> plus acc
| exception Symbol_not_found _ ->
raise ReturnTop )
terms (of_non_negative_int const) terms (of_non_negative_int const)
in in
fun p map -> match subst p map with p -> NonTop p | exception ReturnTop -> Top fun p map -> match subst p map with p -> NonTop p | exception ReturnTop -> Top
@ -1379,11 +1377,14 @@ module ItvPure = struct
let subst : t -> Bound.t bottom_lifted SymbolMap.t -> t bottom_lifted = let subst : t -> Bound.t bottom_lifted SymbolMap.t -> t bottom_lifted =
fun (l, u) map -> fun (l, u) map ->
match (Bound.subst_lb l map, Bound.subst_ub u map) with match (Bound.subst_lb_exn l map, Bound.subst_ub_exn u map) with
| NonBottom l, NonBottom u -> | NonBottom l, NonBottom u ->
NonBottom (l, u) NonBottom (l, u)
| _ -> | _ ->
Bottom Bottom
| exception Symbol_not_found _ ->
(* For now, let's be VERY aggressive. Under-approximate unknown symbols with Bottom. *)
Bottom
let ( <= ) : lhs:t -> rhs:t -> bool = let ( <= ) : lhs:t -> rhs:t -> bool =

Loading…
Cancel
Save