Cost: Polynomial.min_default_left

Reviewed By: ddino

Differential Revision: D8348293

fbshipit-source-id: 1a351f1
master
Mehdi Bouaziz 7 years ago committed by Facebook Github Bot
parent e379132412
commit dc49cb6124

@ -1005,6 +1005,13 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
let increasing_union ~f m1 m2 = union (fun _ v1 v2 -> Some (f v1 v2)) m1 m2 let increasing_union ~f m1 m2 = union (fun _ v1 v2 -> Some (f v1 v2)) m1 m2
let zip m1 m2 = merge (fun _ opt1 opt2 -> Some (opt1, opt2)) m1 m2
let fold_no_key m ~init ~f =
let f _k v acc = f acc v in
fold f m init
let le ~le_elt m1 m2 = let le ~le_elt m1 m2 =
match match
merge merge
@ -1022,6 +1029,12 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
true true
| exception Exit -> | exception Exit ->
false false
let xcompare ~xcompare_elt ~lhs ~rhs =
(* TODO: avoid creating zipped map *)
zip lhs rhs
|> PartialOrder.container ~fold:fold_no_key ~xcompare_elt:(PartialOrder.of_opt ~xcompare_elt)
end end
(** If x < y < z then (** If x < y < z then
@ -1078,7 +1091,9 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
let is_one : t -> bool = fun {const; terms} -> NonNegativeInt.is_one const && M.is_empty terms let is_one : t -> bool = fun {const; terms} -> NonNegativeInt.is_one const && M.is_empty terms
let is_symbolic : t -> bool = fun {terms} -> not (M.is_empty terms) let is_constant : t -> bool = fun {terms} -> M.is_empty terms
let is_symbolic : t -> bool = fun p -> not (is_constant p)
let rec plus : t -> t -> t = let rec plus : t -> t -> t =
fun p1 p2 -> fun p1 p2 ->
@ -1132,6 +1147,14 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
NonNegativeInt.( <= ) ~lhs:lhs.const ~rhs:rhs.const && M.le ~le_elt:( <= ) lhs.terms rhs.terms NonNegativeInt.( <= ) ~lhs:lhs.const ~rhs:rhs.const && M.le ~le_elt:( <= ) lhs.terms rhs.terms
let rec xcompare ~lhs ~rhs =
let cmp_const =
PartialOrder.of_compare ~compare:NonNegativeInt.compare ~lhs:lhs.const ~rhs:rhs.const
in
let cmp_terms = M.xcompare ~xcompare_elt:xcompare ~lhs:lhs.terms ~rhs:rhs.terms in
PartialOrder.join cmp_const cmp_terms
(* Possible optimization for later: x join x^2 = x^2 instead of x + x^2 *) (* Possible optimization for later: x join x^2 = x^2 instead of x + x^2 *)
let rec join : t -> t -> t = let rec join : t -> t -> t =
fun p1 p2 -> fun p1 p2 ->
@ -1141,11 +1164,15 @@ module MakePolynomial (S : NonNegativeSymbol) = struct
(* assumes symbols are not comparable *) (* assumes symbols are not comparable *)
(* TODO: improve this for comparable symbols *) (* TODO: improve this for comparable symbols *)
let min : t -> t -> t = let min_default_left : t -> t -> t =
fun p1 p2 -> fun p1 p2 ->
if ( <= ) ~lhs:p1 ~rhs:p2 then p1 match xcompare ~lhs:p1 ~rhs:p2 with
else (* either can't decide which one is smaller or p2 is smaller *) | `Equal | `LeftSmallerThanRight ->
p1
| `RightSmallerThanLeft ->
p2 p2
| `NotComparable ->
if is_constant p1 then p1 else if is_constant p2 then p2 else p1
let widen : prev:t -> next:t -> num_iters:int -> t = let widen : prev:t -> next:t -> num_iters:int -> t =
@ -1233,12 +1260,12 @@ module NonNegativePolynomial = struct
let mult = top_lifted_increasing ~f:NonNegativeNonTopPolynomial.mult let mult = top_lifted_increasing ~f:NonNegativeNonTopPolynomial.mult
let min p1 p2 = let min_default_left p1 p2 =
match (p1, p2) with match (p1, p2) with
| Top, x | x, Top -> | Top, x | x, Top ->
x x
| NonTop p1, NonTop p2 -> | NonTop p1, NonTop p2 ->
NonTop (NonNegativeNonTopPolynomial.min p1 p2) NonTop (NonNegativeNonTopPolynomial.min_default_left p1 p2)
let widen ~prev ~next ~num_iters:_ = if ( <= ) ~lhs:next ~rhs:prev then prev else Top let widen ~prev ~next ~num_iters:_ = if ( <= ) ~lhs:next ~rhs:prev then prev else Top

@ -72,7 +72,7 @@ module NonNegativePolynomial : sig
val mult : astate -> astate -> astate val mult : astate -> astate -> astate
val min : astate -> astate -> astate val min_default_left : astate -> astate -> astate
val subst : astate -> Bound.t bottom_lifted SymbolMap.t -> astate val subst : astate -> Bound.t bottom_lifted SymbolMap.t -> astate
end end

@ -594,7 +594,7 @@ module MinTree = struct
| Leaf (_, c) -> | Leaf (_, c) ->
c c
| Min l -> | Min l ->
evaluate_operator BasicCost.min l evaluate_operator BasicCost.min_default_left l
| Plus l -> | Plus l ->
evaluate_operator BasicCost.plus l evaluate_operator BasicCost.plus l

@ -11,6 +11,30 @@ type total = [`LeftSmallerThanRight | `Equal | `RightSmallerThanLeft]
type t = [total | `NotComparable] type t = [total | `NotComparable]
let join t1 t2 =
match (t1, t2) with
| `Equal, `Equal ->
`Equal
| (`LeftSmallerThanRight | `Equal), (`LeftSmallerThanRight | `Equal) ->
`LeftSmallerThanRight
| (`RightSmallerThanLeft | `Equal), (`RightSmallerThanLeft | `Equal) ->
`RightSmallerThanLeft
| `LeftSmallerThanRight, `RightSmallerThanLeft
| `RightSmallerThanLeft, `LeftSmallerThanRight
| _, `NotComparable
| `NotComparable, _ ->
`NotComparable
type 'a xcompare = lhs:'a -> rhs:'a -> t
type 'a xcompare_total = lhs:'a -> rhs:'a -> total
let of_compare ~compare ~lhs ~rhs =
let r = compare lhs rhs in
if r < 0 then `LeftSmallerThanRight else if r > 0 then `RightSmallerThanLeft else `Equal
let of_le ~le ~lhs ~rhs = let of_le ~le ~lhs ~rhs =
let ller = le lhs rhs in let ller = le lhs rhs in
let rlel = le rhs lhs in let rlel = le rhs lhs in
@ -23,3 +47,23 @@ let of_le ~le ~lhs ~rhs =
`RightSmallerThanLeft `RightSmallerThanLeft
| false, false -> | false, false ->
`NotComparable `NotComparable
let of_opt ~xcompare_elt ~lhs ~rhs =
match (lhs, rhs) with
| None, None ->
`Equal
| None, Some _ ->
`LeftSmallerThanRight
| Some _, None ->
`RightSmallerThanLeft
| Some lhs, Some rhs ->
xcompare_elt ~lhs ~rhs
let join_lazy t1 ~xcompare ~lhs ~rhs =
match t1 with `NotComparable -> `NotComparable | _ -> join t1 (xcompare ~lhs ~rhs)
let container ~(fold: ('t, 'a * 'a, t) Container.fold) cont ~xcompare_elt =
fold cont ~init:`Equal ~f:(fun acc (lhs, rhs) -> join_lazy acc ~xcompare:xcompare_elt ~lhs ~rhs)

@ -11,4 +11,16 @@ type total = [`LeftSmallerThanRight | `Equal | `RightSmallerThanLeft]
type t = [total | `NotComparable] type t = [total | `NotComparable]
val of_le : le:('a -> 'a -> bool) -> lhs:'a -> rhs:'a -> t val join : [< t] -> [< t] -> t
type 'a xcompare = lhs:'a -> rhs:'a -> t
type 'a xcompare_total = lhs:'a -> rhs:'a -> total
val of_compare : compare:('a -> 'a -> int) -> 'a xcompare_total
val of_le : le:('a -> 'a -> bool) -> 'a xcompare
val of_opt : xcompare_elt:'a xcompare -> 'a option xcompare
val container : fold:('t, 'a * 'a, t) Container.fold -> 't -> xcompare_elt:'a xcompare -> t

Loading…
Cancel
Save