diff --git a/infer/src/bufferoverrun/itv.ml b/infer/src/bufferoverrun/itv.ml index b17bba3a5..5cea7d2d0 100644 --- a/infer/src/bufferoverrun/itv.ml +++ b/infer/src/bufferoverrun/itv.ml @@ -1005,6 +1005,13 @@ module MakePolynomial (S : NonNegativeSymbol) = struct let increasing_union ~f m1 m2 = union (fun _ v1 v2 -> Some (f v1 v2)) m1 m2 + let zip m1 m2 = merge (fun _ opt1 opt2 -> Some (opt1, opt2)) m1 m2 + + let fold_no_key m ~init ~f = + let f _k v acc = f acc v in + fold f m init + + let le ~le_elt m1 m2 = match merge @@ -1022,6 +1029,12 @@ module MakePolynomial (S : NonNegativeSymbol) = struct true | exception Exit -> false + + + let xcompare ~xcompare_elt ~lhs ~rhs = + (* TODO: avoid creating zipped map *) + zip lhs rhs + |> PartialOrder.container ~fold:fold_no_key ~xcompare_elt:(PartialOrder.of_opt ~xcompare_elt) end (** If x < y < z then @@ -1078,7 +1091,9 @@ module MakePolynomial (S : NonNegativeSymbol) = struct let is_one : t -> bool = fun {const; terms} -> NonNegativeInt.is_one const && M.is_empty terms - let is_symbolic : t -> bool = fun {terms} -> not (M.is_empty terms) + let is_constant : t -> bool = fun {terms} -> M.is_empty terms + + let is_symbolic : t -> bool = fun p -> not (is_constant p) let rec plus : t -> t -> t = fun p1 p2 -> @@ -1132,6 +1147,14 @@ module MakePolynomial (S : NonNegativeSymbol) = struct NonNegativeInt.( <= ) ~lhs:lhs.const ~rhs:rhs.const && M.le ~le_elt:( <= ) lhs.terms rhs.terms + let rec xcompare ~lhs ~rhs = + let cmp_const = + PartialOrder.of_compare ~compare:NonNegativeInt.compare ~lhs:lhs.const ~rhs:rhs.const + in + let cmp_terms = M.xcompare ~xcompare_elt:xcompare ~lhs:lhs.terms ~rhs:rhs.terms in + PartialOrder.join cmp_const cmp_terms + + (* Possible optimization for later: x join x^2 = x^2 instead of x + x^2 *) let rec join : t -> t -> t = fun p1 p2 -> @@ -1141,11 +1164,15 @@ module MakePolynomial (S : NonNegativeSymbol) = struct (* assumes symbols are not comparable *) (* TODO: improve this for comparable symbols *) - let min : t -> t -> t = + let min_default_left : t -> t -> t = fun p1 p2 -> - if ( <= ) ~lhs:p1 ~rhs:p2 then p1 - else (* either can't decide which one is smaller or p2 is smaller *) - p2 + match xcompare ~lhs:p1 ~rhs:p2 with + | `Equal | `LeftSmallerThanRight -> + p1 + | `RightSmallerThanLeft -> + p2 + | `NotComparable -> + if is_constant p1 then p1 else if is_constant p2 then p2 else p1 let widen : prev:t -> next:t -> num_iters:int -> t = @@ -1233,12 +1260,12 @@ module NonNegativePolynomial = struct let mult = top_lifted_increasing ~f:NonNegativeNonTopPolynomial.mult - let min p1 p2 = + let min_default_left p1 p2 = match (p1, p2) with | Top, x | x, Top -> x | NonTop p1, NonTop p2 -> - NonTop (NonNegativeNonTopPolynomial.min p1 p2) + NonTop (NonNegativeNonTopPolynomial.min_default_left p1 p2) let widen ~prev ~next ~num_iters:_ = if ( <= ) ~lhs:next ~rhs:prev then prev else Top diff --git a/infer/src/bufferoverrun/itv.mli b/infer/src/bufferoverrun/itv.mli index f6e8c2b82..4fa3256e6 100644 --- a/infer/src/bufferoverrun/itv.mli +++ b/infer/src/bufferoverrun/itv.mli @@ -72,7 +72,7 @@ module NonNegativePolynomial : sig val mult : astate -> astate -> astate - val min : astate -> astate -> astate + val min_default_left : astate -> astate -> astate val subst : astate -> Bound.t bottom_lifted SymbolMap.t -> astate end diff --git a/infer/src/checkers/cost.ml b/infer/src/checkers/cost.ml index b3f21744f..3cfe70af0 100644 --- a/infer/src/checkers/cost.ml +++ b/infer/src/checkers/cost.ml @@ -594,7 +594,7 @@ module MinTree = struct | Leaf (_, c) -> c | Min l -> - evaluate_operator BasicCost.min l + evaluate_operator BasicCost.min_default_left l | Plus l -> evaluate_operator BasicCost.plus l diff --git a/infer/src/istd/PartialOrder.ml b/infer/src/istd/PartialOrder.ml index a9f98c59a..3b8274c42 100644 --- a/infer/src/istd/PartialOrder.ml +++ b/infer/src/istd/PartialOrder.ml @@ -11,6 +11,30 @@ type total = [`LeftSmallerThanRight | `Equal | `RightSmallerThanLeft] type t = [total | `NotComparable] +let join t1 t2 = + match (t1, t2) with + | `Equal, `Equal -> + `Equal + | (`LeftSmallerThanRight | `Equal), (`LeftSmallerThanRight | `Equal) -> + `LeftSmallerThanRight + | (`RightSmallerThanLeft | `Equal), (`RightSmallerThanLeft | `Equal) -> + `RightSmallerThanLeft + | `LeftSmallerThanRight, `RightSmallerThanLeft + | `RightSmallerThanLeft, `LeftSmallerThanRight + | _, `NotComparable + | `NotComparable, _ -> + `NotComparable + + +type 'a xcompare = lhs:'a -> rhs:'a -> t + +type 'a xcompare_total = lhs:'a -> rhs:'a -> total + +let of_compare ~compare ~lhs ~rhs = + let r = compare lhs rhs in + if r < 0 then `LeftSmallerThanRight else if r > 0 then `RightSmallerThanLeft else `Equal + + let of_le ~le ~lhs ~rhs = let ller = le lhs rhs in let rlel = le rhs lhs in @@ -23,3 +47,23 @@ let of_le ~le ~lhs ~rhs = `RightSmallerThanLeft | false, false -> `NotComparable + + +let of_opt ~xcompare_elt ~lhs ~rhs = + match (lhs, rhs) with + | None, None -> + `Equal + | None, Some _ -> + `LeftSmallerThanRight + | Some _, None -> + `RightSmallerThanLeft + | Some lhs, Some rhs -> + xcompare_elt ~lhs ~rhs + + +let join_lazy t1 ~xcompare ~lhs ~rhs = + match t1 with `NotComparable -> `NotComparable | _ -> join t1 (xcompare ~lhs ~rhs) + + +let container ~(fold: ('t, 'a * 'a, t) Container.fold) cont ~xcompare_elt = + fold cont ~init:`Equal ~f:(fun acc (lhs, rhs) -> join_lazy acc ~xcompare:xcompare_elt ~lhs ~rhs) diff --git a/infer/src/istd/PartialOrder.mli b/infer/src/istd/PartialOrder.mli index bb27e24be..e12909522 100644 --- a/infer/src/istd/PartialOrder.mli +++ b/infer/src/istd/PartialOrder.mli @@ -11,4 +11,16 @@ type total = [`LeftSmallerThanRight | `Equal | `RightSmallerThanLeft] type t = [total | `NotComparable] -val of_le : le:('a -> 'a -> bool) -> lhs:'a -> rhs:'a -> t +val join : [< t] -> [< t] -> t + +type 'a xcompare = lhs:'a -> rhs:'a -> t + +type 'a xcompare_total = lhs:'a -> rhs:'a -> total + +val of_compare : compare:('a -> 'a -> int) -> 'a xcompare_total + +val of_le : le:('a -> 'a -> bool) -> 'a xcompare + +val of_opt : xcompare_elt:'a xcompare -> 'a option xcompare + +val container : fold:('t, 'a * 'a, t) Container.fold -> 't -> xcompare_elt:'a xcompare -> t