[sledge] Change type of fold functions for improved composition

Summary:
Change the type of `fold` functions to enable them to compose
better. The guiding reasoning behind using types such as:
```
val fold : 'a t -> 's -> f:('a -> 's -> 's) -> 's
```
is:

1. The function argument should be labeled. This is so that it can be
   reordered relative to the others, since it is often a multi-line
   `fun` expression.

2. The function argument should come last. This enables its
   arguments (which are often polymorphic) to benefit from type-based
   disambiguation information determined by the types of the other
   arguments at the call sites.

3. The function argument's type should produce an
   accumulator-transformer when partially-applied. That is,
   `f x : 's -> 's`. This composes well with other functions designed
   to produce transformers/endofunctions when partially applied, and
   in particular improves the common case of composing folds into
   "state-passing style" code.

4. The fold function itself should produce an accumulator-transformer
   when partially applied. So `'a t -> 's -> f:_ -> 's` rather than
   `'s -> 'a t -> f:_ -> 's` or  `'a t -> init:'s -> f:_ -> 's` etc.

Reviewed By: jvillard

Differential Revision: D24306063

fbshipit-source-id: 13bd8bbee
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent ec4cb61db3
commit 920c553902

@ -31,7 +31,7 @@ let bindings (itv : t) =
let sexp_of_t (itv : t) = let sexp_of_t (itv : t) =
let sexps = let sexps =
Array.fold (bindings itv) ~init:[] ~f:(fun acc (v, {inf; sup}) -> Array.fold_right (bindings itv) [] ~f:(fun (v, {inf; sup}) acc ->
Sexp.List Sexp.List
[ Sexp.Atom (Var.to_string v) [ Sexp.Atom (Var.to_string v)
; Sexp.Atom (Scalar.to_string inf) ; Sexp.Atom (Scalar.to_string inf)
@ -166,40 +166,39 @@ let exec_assume q e =
| _ -> Some q | _ -> Some q
(** existentially quantify killed register [r] out of state [q] *) (** existentially quantify killed register [r] out of state [q] *)
let exec_kill q r = let exec_kill r q =
let apron_v = apron_var_of_reg r in let apron_v = apron_var_of_reg r in
if Environment.mem_var (Abstract1.env q) apron_v then if Environment.mem_var (Abstract1.env q) apron_v then
Abstract1.forget_array (Lazy.force man) q [|apron_v|] false Abstract1.forget_array (Lazy.force man) q [|apron_v|] false
else q else q
(** perform a series [move_vec] of reg:=exp moves at state [q] *) (** perform a series [move_vec] of reg:=exp moves at state [q] *)
let exec_move q move_vec = let exec_move move_vec q =
let defs, uses = let defs, uses =
IArray.fold move_vec ~init:(Llair.Reg.Set.empty, Llair.Reg.Set.empty) IArray.fold move_vec (Llair.Reg.Set.empty, Llair.Reg.Set.empty)
~f:(fun (defs, uses) (r, e) -> ~f:(fun (r, e) (defs, uses) ->
( Llair.Reg.Set.add r defs ( Llair.Reg.Set.add r defs
, Llair.Exp.fold_regs e ~init:uses ~f:(Fun.flip Llair.Reg.Set.add) , Llair.Exp.fold_regs ~f:Llair.Reg.Set.add e uses ) )
) )
in in
assert (Llair.Reg.Set.disjoint defs uses) ; assert (Llair.Reg.Set.disjoint defs uses) ;
IArray.fold move_vec ~init:q ~f:(fun a (r, e) -> assign r e a) IArray.fold ~f:(fun (r, e) q -> assign r e q) move_vec q
let exec_inst q i = let exec_inst i q =
match (i : Llair.inst) with match (i : Llair.inst) with
| Move {reg_exps; loc= _} -> Some (exec_move q reg_exps) | Move {reg_exps; loc= _} -> Some (exec_move reg_exps q)
| Store {ptr; exp; len= _; loc= _} -> ( | Store {ptr; exp; len= _; loc= _} -> (
match Llair.Reg.of_exp ptr with match Llair.Reg.of_exp ptr with
| Some reg -> Some (assign reg exp q) | Some reg -> Some (assign reg exp q)
| None -> Some q ) | None -> Some q )
| Load {reg; ptr; len= _; loc= _} -> Some (assign reg ptr q) | Load {reg; ptr; len= _; loc= _} -> Some (assign reg ptr q)
| Nondet {reg= Some reg; msg= _; loc= _} -> Some (exec_kill q reg) | Nondet {reg= Some reg; msg= _; loc= _} -> Some (exec_kill reg q)
| Nondet {reg= None; msg= _; loc= _} | Nondet {reg= None; msg= _; loc= _}
|Alloc _ | Memset _ | Memcpy _ | Memmov _ | Free _ -> |Alloc _ | Memset _ | Memcpy _ | Memmov _ | Free _ ->
Some q Some q
| Abort _ -> None | Abort _ -> None
(** Treat any intrinsic function as havoc on the return register [aret] *) (** Treat any intrinsic function as havoc on the return register [aret] *)
let exec_intrinsic ~skip_throw:_ pre aret i _ = let exec_intrinsic ~skip_throw:_ aret i _ pre =
let name = Llair.Reg.name i in let name = Llair.Reg.name i in
if if
List.exists List.exists
@ -224,7 +223,9 @@ let exec_intrinsic ~skip_throw:_ pre aret i _ =
; "__cxa_allocate_exception" ; "__cxa_allocate_exception"
; "_ZN5folly13usingJEMallocEv" ] ; "_ZN5folly13usingJEMallocEv" ]
~f:(String.equal name) ~f:(String.equal name)
then Option.map ~f:(Option.return << exec_kill pre) aret then
let+ aret = aret in
Some (exec_kill aret pre)
else None else None
type from_call = {areturn: Llair.Reg.t option; caller_q: t} type from_call = {areturn: Llair.Reg.t option; caller_q: t}
@ -235,7 +236,7 @@ let recursion_beyond_bound = `prune
(** existentially quantify locals *) (** existentially quantify locals *)
let post locals _ (q : t) = let post locals _ (q : t) =
let locals = let locals =
Llair.Reg.Set.fold locals ~init:[] ~f:(fun r a -> Llair.Reg.Set.fold locals [] ~f:(fun r a ->
let v = apron_var_of_reg r in let v = apron_var_of_reg r in
if Environment.mem_var q.env v then v :: a else a ) if Environment.mem_var q.env v then v :: a else a )
|> Array.of_list |> Array.of_list
@ -264,7 +265,7 @@ let retn _ freturn {areturn; caller_q} callee_q =
Abstract1.rename_array man result Abstract1.rename_array man result
[|apron_var_of_reg fret|] [|apron_var_of_reg fret|]
[|apron_var_of_reg aret|] [|apron_var_of_reg aret|]
| Some aret, None -> exec_kill caller_q aret | Some aret, None -> exec_kill aret caller_q
| None, _ -> caller_q | None, _ -> caller_q
(** map actuals to formals (via temporary registers), stash constraints on (** map actuals to formals (via temporary registers), stash constraints on
@ -280,11 +281,9 @@ let call ~summaries ~globals:_ ~actuals ~areturn ~formals ~freturn:_
Llair.Reg.program (Llair.Reg.typ r) ("__tmp__" ^ Llair.Reg.name r) Llair.Reg.program (Llair.Reg.typ r) ("__tmp__" ^ Llair.Reg.name r)
in in
let args = List.combine_exn formals actuals in let args = List.combine_exn formals actuals in
let q' = let q' = List.fold ~f:(fun (f, a) q -> assign (mangle f) a q) args q in
List.fold args ~init:q ~f:(fun q (f, a) -> assign (mangle f) a q)
in
let callee_env = let callee_env =
List.fold formals ~init:([], []) ~f:(fun (is, fs) f -> List.fold formals ([], []) ~f:(fun f (is, fs) ->
match apron_typ_of_llair_typ (Llair.Reg.typ f) with match apron_typ_of_llair_typ (Llair.Reg.typ f) with
| None -> (is, fs) | None -> (is, fs)
| Some Texpr1.Int -> (apron_var_of_reg (mangle f) :: is, fs) | Some Texpr1.Int -> (apron_var_of_reg (mangle f) :: is, fs)

@ -865,10 +865,10 @@ let xlate_jump :
match xlate_jump_ reg_exps (Llvm.instr_begin dst) with match xlate_jump_ reg_exps (Llvm.instr_begin dst) with
| [] -> ([], jmp, blocks) | [] -> ([], jmp, blocks)
| rev_reg_pre_exps -> | rev_reg_pre_exps ->
let rev_pre, rev_reg_exps = let rev_reg_exps, rev_pre =
List.fold_map rev_reg_pre_exps ~init:[] List.fold_map rev_reg_pre_exps []
~f:(fun rev_pre (reg, (pre, exp)) -> ~f:(fun (reg, (pre, exp)) rev_pre ->
(List.rev_append pre rev_pre, (reg, exp)) ) ((reg, exp), List.rev_append pre rev_pre) )
in in
let mov = let mov =
Inst.move ~reg_exps:(IArray.of_list_rev rev_reg_exps) ~loc Inst.move ~reg_exps:(IArray.of_list_rev rev_reg_exps) ~loc

@ -95,8 +95,8 @@ let used_globals pgm preanalyze : Domain_used_globals.r =
(Llair.Reg.Map.map summary_table ~f:Llair.Reg.Set.union_list) (Llair.Reg.Map.map summary_table ~f:Llair.Reg.Set.union_list)
else else
Declared Declared
(IArray.fold pgm.globals ~init:Llair.Reg.Set.empty ~f:(fun acc g -> (IArray.fold pgm.globals Llair.Reg.Set.empty ~f:(fun g ->
Llair.Reg.Set.add g.reg acc )) Llair.Reg.Set.add g.reg ))
let analyze = let analyze =
let%map_open bound = let%map_open bound =

@ -45,9 +45,9 @@ let assert_term term =
let top = top () in let top = top () in
top.asserts <- term :: top.asserts top.asserts <- term :: top.asserts
let rec x_let init nes = let rec x_let env nes =
List.fold nes ~init ~f:(fun n (name, term) -> List.fold nes env ~f:(fun (name, term) ->
VarEnv.add_exn ~key:name ~data:(x_trm init term) n ) VarEnv.add_exn ~key:name ~data:(x_trm env term) )
and x_trm : var_env -> Smt.Ast.term -> Term.t = and x_trm : var_env -> Smt.Ast.term -> Term.t =
fun n term -> fun n term ->
@ -60,13 +60,13 @@ and x_trm : var_env -> Smt.Ast.term -> Term.t =
try Term.rational (Q.of_float (Float.of_string_exn s)) try Term.rational (Q.of_float (Float.of_string_exn s))
with _ -> fail "not a rational: %a" Smt.Ast.pp_term term () ) ) ) with _ -> fail "not a rational: %a" Smt.Ast.pp_term term () ) ) )
| Arith (Add, e :: es) -> | Arith (Add, e :: es) ->
List.fold ~f:(fun s e -> Term.add s (x_trm n e)) ~init:(x_trm n e) es List.fold ~f:(fun e -> Term.add (x_trm n e)) es (x_trm n e)
| Arith (Minus, e :: es) -> | Arith (Minus, e :: es) ->
List.fold ~f:(fun s e -> Term.sub s (x_trm n e)) ~init:(x_trm n e) es List.fold ~f:(fun e -> Term.sub (x_trm n e)) es (x_trm n e)
| Arith (Mult, es) -> ( | Arith (Mult, es) -> (
match List.map ~f:(x_trm n) es with match List.map ~f:(x_trm n) es with
| e :: es -> | e :: es ->
List.fold es ~init:e ~f:(fun p e -> List.fold es e ~f:(fun e p ->
match Term.get_const e with match Term.get_const e with
| Some q -> Term.mulq q p | Some q -> Term.mulq q p
| None -> ( | None -> (
@ -77,7 +77,7 @@ and x_trm : var_env -> Smt.Ast.term -> Term.t =
| Arith (Div, es) -> ( | Arith (Div, es) -> (
match List.map ~f:(x_trm n) es with match List.map ~f:(x_trm n) es with
| e :: es -> | e :: es ->
List.fold es ~init:e ~f:(fun p e -> List.fold es e ~f:(fun e p ->
match Term.get_const e with match Term.get_const e with
| Some q -> Term.mulq (Q.inv q) p | Some q -> Term.mulq (Q.inv q) p
| None -> fail "nonlinear: %a" Smt.Ast.pp_term term () ) | None -> fail "nonlinear: %a" Smt.Ast.pp_term term () )

@ -34,7 +34,7 @@ let build_info =
|> List.sort ~cmp:[%compare: string * string] |> List.sort ~cmp:[%compare: string * string]
in in
let max_length = let max_length =
List.fold_left libs ~init:0 ~f:(fun n (name, _) -> List.fold_left libs 0 ~f:(fun n (name, _) ->
max n (String.length name) ) max n (String.length name) )
in in
String.concat ~sep:"\n" String.concat ~sep:"\n"

@ -137,6 +137,7 @@ module Either : sig
val right : 'a -> ('b, 'a) t val right : 'a -> ('b, 'a) t
end end
module Pair = Containers.Pair
module List = List module List = List
module Array = Array module Array = Array
module IArray = IArray module IArray = IArray

@ -127,6 +127,7 @@ module Either = struct
let right v = Right v let right v = Right v
end end
module Pair = Containers.Pair
module FHeap = Fheap [@@warning "-49"] module FHeap = Fheap [@@warning "-49"]
module HashQueue = Core_kernel.Hash_queue module HashQueue = Core_kernel.Hash_queue

@ -74,29 +74,31 @@ let iter xs ~f = iter ~f xs
let iteri xs ~f = iteri ~f xs let iteri xs ~f = iteri ~f xs
let exists xs ~f = exists ~f xs let exists xs ~f = exists ~f xs
let for_all xs ~f = for_all ~f xs let for_all xs ~f = for_all ~f xs
let fold xs ~init ~f = fold ~f ~init xs let fold xs init ~f = fold ~f:(fun s x -> f x s) ~init xs
let fold_right xs ~init ~f = fold_right ~f ~init xs let fold_right xs init ~f = fold_right ~f ~init xs
let fold_map xs ~init ~f = fold_map ~f ~init xs
let fold_map_until xs ~init ~f ~finish = let fold_map xs init ~f =
Pair.swap (fold_map ~f:(fun s x -> Pair.swap (f x s)) ~init xs)
let fold_map_until xs s ~f ~finish =
let l = length xs in let l = length xs in
if l = 0 then finish (init, [||]) if l = 0 then finish ([||], s)
else else
match f init xs.(0) with match f xs.(0) s with
| `Stop r -> r | `Stop r -> r
| `Continue (s, y) -> | `Continue (y, s) ->
let ys = make l y in let ys = make l y in
let rec fold_map_until_ s i = let rec fold_map_until_ s i =
if i = l then finish (s, ys) if i = l then finish (ys, s)
else else
match f s xs.(i) with match f xs.(i) s with
| `Stop r -> r | `Stop r -> r
| `Continue (s, y) -> | `Continue (y, s) ->
ys.(i) <- y ; ys.(i) <- y ;
fold_map_until_ s (i + 1) fold_map_until_ s (i + 1)
in in
fold_map_until_ s 1 fold_map_until_ s 1
let for_all2_exn xs ys ~f = for_all2 ~f xs ys let for_all2_exn xs ys ~f = for_all2 ~f xs ys
let to_list_rev_map xs ~f = fold ~f:(fun ys x -> f x :: ys) ~init:[] xs let to_list_rev_map xs ~f = fold ~f:(fun x ys -> f x :: ys) xs []
let pp sep pp_elt fs a = List.pp sep pp_elt fs (to_list a) let pp sep pp_elt fs a = List.pp sep pp_elt fs (to_list a)

@ -29,15 +29,15 @@ val iteri : 'a t -> f:(int -> 'a -> unit) -> unit
val exists : 'a t -> f:('a -> bool) -> bool val exists : 'a t -> f:('a -> bool) -> bool
val for_all : 'a t -> f:('a -> bool) -> bool val for_all : 'a t -> f:('a -> bool) -> bool
val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool
val fold : 'a array -> init:'s -> f:('s -> 'a -> 's) -> 's val fold : 'a t -> 's -> f:('a -> 's -> 's) -> 's
val fold_right : 'a t -> init:'s -> f:('a -> 's -> 's) -> 's val fold_right : 'a t -> 's -> f:('a -> 's -> 's) -> 's
val fold_map : 'a t -> init:'s -> f:('s -> 'a -> 's * 'b) -> 's * 'b t val fold_map : 'a t -> 's -> f:('a -> 's -> 'b * 's) -> 'b t * 's
val fold_map_until : val fold_map_until :
'a t 'a t
-> init:'s -> 's
-> f:('s -> 'a -> [`Continue of 's * 'b | `Stop of 'c]) -> f:('a -> 's -> [`Continue of 'b * 's | `Stop of 'c])
-> finish:('s * 'b t -> 'c) -> finish:('b t * 's -> 'c)
-> 'c -> 'c
val to_list_rev_map : 'a array -> f:('a -> 'b) -> 'b list val to_list_rev_map : 'a array -> f:('a -> 'b) -> 'b list

@ -37,7 +37,5 @@ module Make (Key : HashedType) = struct
Option.get_exn !found Option.get_exn !found
let iteri tbl ~f = iter (fun key data -> f ~key ~data) tbl let iteri tbl ~f = iter (fun key data -> f ~key ~data) tbl
let fold tbl s ~f = fold (fun key data acc -> f ~key ~data acc) tbl s
let fold tbl ~init ~f =
fold (fun key data acc -> f ~key ~data acc) tbl init
end end

@ -19,5 +19,5 @@ module type S = sig
val find : 'a t -> key -> 'a option val find : 'a t -> key -> 'a option
val find_or_add : 'a t -> key -> default:(unit -> 'a) -> 'a val find_or_add : 'a t -> key -> default:(unit -> 'a) -> 'a
val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit
val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's val fold : 'a t -> 's -> f:(key:key -> data:'a -> 's -> 's) -> 's
end end

@ -57,15 +57,13 @@ val iteri : 'a t -> f:(int -> 'a -> unit) -> unit
val exists : 'a t -> f:('a -> bool) -> bool val exists : 'a t -> f:('a -> bool) -> bool
val for_all : 'a t -> f:('a -> bool) -> bool val for_all : 'a t -> f:('a -> bool) -> bool
val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool
val fold : 'a t -> init:'s -> f:('s -> 'a -> 's) -> 's val fold : 'a t -> 's -> f:('a -> 's -> 's) -> 's
val fold_right : 'a t -> init:'s -> f:('a -> 's -> 's) -> 's val fold_right : 'a t -> 's -> f:('a -> 's -> 's) -> 's
val fold_map : 'a t -> 's -> f:('a -> 's -> 'b * 's) -> 'b t * 's
val fold_map :
'a t -> init:'accum -> f:('accum -> 'a -> 'accum * 'b) -> 'accum * 'b t
val fold_map_until : val fold_map_until :
'a t 'a t
-> init:'s -> 's
-> f:('s -> 'a -> [`Continue of 's * 'b | `Stop of 'c]) -> f:('a -> 's -> [`Continue of 'b * 's | `Stop of 'c])
-> finish:('s * 'b t -> 'c) -> finish:('b t * 's -> 'c)
-> 'c -> 'c

@ -37,6 +37,7 @@ let pop seq =
let find_map seq ~f = find_map ~f seq let find_map seq ~f = find_map ~f seq
let find seq ~f = find (CCOpt.if_ f) seq let find seq ~f = find (CCOpt.if_ f) seq
let find_exn seq ~f = CCOpt.get_exn (find ~f seq) let find_exn seq ~f = CCOpt.get_exn (find ~f seq)
let fold seq init ~f = fold ~f:(fun s x -> f x s) ~init seq
let contains_dup (type elt) seq ~cmp = let contains_dup (type elt) seq ~cmp =
let module S = CCSet.Make (struct let module S = CCSet.Make (struct
@ -46,41 +47,41 @@ let contains_dup (type elt) seq ~cmp =
end) in end) in
let exception Found_dup in let exception Found_dup in
try try
fold ~init:S.empty seq ~f:(fun elts x -> fold seq S.empty ~f:(fun x elts ->
let elts' = S.add x elts in let elts' = S.add x elts in
if elts' == elts then raise_notrace Found_dup else elts' ) if elts' == elts then raise_notrace Found_dup else elts' )
|> ignore ; |> ignore ;
false false
with Found_dup -> true with Found_dup -> true
let fold_opt seq ~init ~f = let fold_opt seq s ~f =
let state = ref init in let state = ref s in
let exception Stop in let exception Stop in
try try
seq (fun x -> seq (fun x ->
match f !state x with match f x !state with
| Some s -> state := s | Some s -> state := s
| None -> raise_notrace Stop ) ; | None -> raise_notrace Stop ) ;
Some !state Some !state
with Stop -> None with Stop -> None
let fold_until (type res) seq ~init ~f ~finish = let fold_until (type res) seq s ~f ~finish =
let state = ref init in let state = ref s in
let exception Stop of res in let exception Stop of res in
try try
seq (fun x -> seq (fun x ->
match f !state x with match f x !state with
| `Continue s -> state := s | `Continue s -> state := s
| `Stop r -> raise_notrace (Stop r) ) ; | `Stop r -> raise_notrace (Stop r) ) ;
finish !state finish !state
with Stop r -> r with Stop r -> r
let fold_result (type s e) seq ~init ~f = let fold_result (type s e) seq s ~f =
let state = ref init in let state = ref s in
let exception Stop of (s, e) result in let exception Stop of (s, e) result in
try try
seq (fun x -> seq (fun x ->
match f !state x with match f x !state with
| Ok s -> state := s | Ok s -> state := s
| Error _ as e -> raise_notrace (Stop e) ) ; | Error _ as e -> raise_notrace (Stop e) ) ;
Ok !state Ok !state

@ -66,18 +66,19 @@ val find : 'a t -> f:('a -> bool) -> 'a option
val find_exn : 'a t -> f:('a -> bool) -> 'a val find_exn : 'a t -> f:('a -> bool) -> 'a
val find_map : 'a iter -> f:('a -> 'b option) -> 'b option val find_map : 'a iter -> f:('a -> 'b option) -> 'b option
val contains_dup : 'a iter -> cmp:('a -> 'a -> int) -> bool val contains_dup : 'a iter -> cmp:('a -> 'a -> int) -> bool
val fold : 'a t -> 's -> f:('a -> 's -> 's) -> 's
val fold_opt : 'a t -> init:'s -> f:('s -> 'a -> 's option) -> 's option val fold_opt : 'a t -> 's -> f:('a -> 's -> 's option) -> 's option
(** [fold_option t ~init ~f] is a short-circuiting version of [fold] that (** [fold_opt t s ~f] is a short-circuiting version of [fold] that runs in
runs in the [Option] monad. If [f] returns [None], [None] is returned the [Option] monad. If [f] returns [None], [None] is returned without
without any additional invocations of [f]. *) any additional invocations of [f]. *)
val fold_until : val fold_until :
'a t 'a t
-> init:'s -> 's
-> f:('s -> 'a -> [`Continue of 's | `Stop of 'b]) -> f:('a -> 's -> [`Continue of 's | `Stop of 'b])
-> finish:('s -> 'b) -> finish:('s -> 'b)
-> 'b -> 'b
val fold_result : val fold_result :
'a t -> init:'s -> f:('s -> 'a -> ('s, 'e) result) -> ('s, 'e) result 'a t -> 's -> f:('a -> 's -> ('s, 'e) result) -> ('s, 'e) result

@ -51,12 +51,19 @@ let group_by seq ~hash ~eq = group_by ~hash ~eq seq
let join_by ~eq ~hash k1 k2 ~merge = join_by ~eq ~hash k1 k2 ~merge let join_by ~eq ~hash k1 k2 ~merge = join_by ~eq ~hash k1 k2 ~merge
let join_all_by ~eq ~hash k1 k2 ~merge = join_all_by ~eq ~hash k1 k2 ~merge let join_all_by ~eq ~hash k1 k2 ~merge = join_all_by ~eq ~hash k1 k2 ~merge
let group_join_by ~eq ~hash = group_join_by ~eq ~hash let group_join_by ~eq ~hash = group_join_by ~eq ~hash
let fold xs ~init ~f = fold_left ~f ~init xs let fold xs init ~f = fold_left ~f:(fun s x -> f x s) ~init xs
let fold_left xs init ~f = fold_left ~f ~init xs
let fold_right xs init ~f = fold_right ~f ~init xs
let reduce xs ~f = let reduce xs ~f =
match xs with [] -> None | x :: xs -> Some (fold xs ~init:x ~f) match xs with [] -> None | x :: xs -> Some (fold ~f xs x)
let fold_map xs init ~f =
Pair.swap (fold_map ~f:(fun s x -> Pair.swap (f x s)) ~init xs)
let fold2_exn xs ys init ~f =
fold_left2 ~f:(fun s x y -> f x y s) ~init xs ys
let fold2_exn xs ys ~init ~f = fold_left2 ~f ~init xs ys
let group_succ ~eq xs = group_succ ~eq:(fun y x -> eq x y) xs let group_succ ~eq xs = group_succ ~eq:(fun y x -> eq x y) xs
let symmetric_diff ~cmp xs ys = let symmetric_diff ~cmp xs ys =

@ -86,9 +86,12 @@ val group_join_by :
-> 'b t -> 'b t
-> ('a * 'b list) t -> ('a * 'b list) t
val fold : 'a list -> init:'s -> f:('s -> 'a -> 's) -> 's val fold : 'a list -> 's -> f:('a -> 's -> 's) -> 's
val fold_left : 'a list -> 's -> f:('s -> 'a -> 's) -> 's
val fold_right : 'a list -> 's -> f:('a -> 's -> 's) -> 's
val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option
val fold2_exn : 'a t -> 'b t -> init:'s -> f:('s -> 'a -> 'b -> 's) -> 's val fold_map : 'a t -> 's -> f:('a -> 's -> 'b * 's) -> 'b t * 's
val fold2_exn : 'a t -> 'b t -> 's -> f:('a -> 'b -> 's -> 's) -> 's
val symmetric_diff : val symmetric_diff :
cmp:('a -> 'a -> int) -> 'a t -> 'a t -> ('a, 'a) Either.t t cmp:('a -> 'a -> int) -> 'a t -> 'a t -> ('a, 'a) Either.t t

@ -158,7 +158,7 @@ end) : S with type key = Key.t = struct
let iteri m ~f = M.iter (fun key data -> f ~key ~data) m let iteri m ~f = M.iter (fun key data -> f ~key ~data) m
let existsi m ~f = M.exists (fun key data -> f ~key ~data) m let existsi m ~f = M.exists (fun key data -> f ~key ~data) m
let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m
let fold m ~init ~f = M.fold (fun key data acc -> f ~key ~data acc) m init let fold m s ~f = M.fold (fun key data acc -> f ~key ~data acc) m s
let keys = M.keys let keys = M.keys
let values = M.values let values = M.values
let to_iter = M.to_iter let to_iter = M.to_iter

@ -105,7 +105,7 @@ module type S = sig
val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit
val existsi : 'a t -> f:(key:key -> data:'a -> bool) -> bool val existsi : 'a t -> f:(key:key -> data:'a -> bool) -> bool
val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool
val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's val fold : 'a t -> 's -> f:(key:key -> data:'a -> 's -> 's) -> 's
(** {1 Convert} *) (** {1 Convert} *)

@ -26,9 +26,8 @@ struct
let hash_fold_t hash_fold_elt s m = let hash_fold_t hash_fold_elt s m =
let hash_fold_mul s i = Hash.fold_int s (Mul.hash i) in let hash_fold_mul s i = Hash.fold_int s (Mul.hash i) in
M.fold m let init = Hash.fold_int s (M.length m) in
~init:(Hash.fold_int s (M.length m)) M.fold m init ~f:(fun ~key ~data state ->
~f:(fun ~key ~data state ->
hash_fold_mul (hash_fold_elt state key) data ) hash_fold_mul (hash_fold_elt state key) data )
let sexp_of_t s = let sexp_of_t s =
@ -39,10 +38,10 @@ struct
let t_of_sexp elt_of_sexp sexp = let t_of_sexp elt_of_sexp sexp =
List.fold_left List.fold_left
~f:(fun m (key, data) -> M.add_exn ~key ~data m) ~f:(fun m (key, data) -> M.add_exn ~key ~data m)
~init:M.empty
(List.t_of_sexp (List.t_of_sexp
(Sexplib.Conv.pair_of_sexp elt_of_sexp Mul.t_of_sexp) (Sexplib.Conv.pair_of_sexp elt_of_sexp Mul.t_of_sexp)
sexp) sexp)
M.empty
let pp sep pp_elt fs s = let pp sep pp_elt fs s =
List.pp sep pp_elt fs (Iter.to_list (M.to_iter s)) List.pp sep pp_elt fs (Iter.to_list (M.to_iter s))
@ -71,7 +70,7 @@ struct
let map m ~f = let map m ~f =
let m' = empty in let m' = empty in
let m, m' = let m, m' =
M.fold m ~init:(m, m') ~f:(fun ~key:x ~data:i (m, m') -> M.fold m (m, m') ~f:(fun ~key:x ~data:i (m, m') ->
let x', i' = f x i in let x', i' = f x i in
if x' == x then if x' == x then
if Mul.equal i' i then (m, m') else (M.add ~key:x ~data:i' m, m') if Mul.equal i' i then (m, m') else (M.add ~key:x ~data:i' m, m')
@ -85,7 +84,7 @@ struct
let flat_map m ~f = let flat_map m ~f =
let m' = empty in let m' = empty in
let m, m' = let m, m' =
M.fold m ~init:(m, m') ~f:(fun ~key:x ~data:i (m, m') -> M.fold m (m, m') ~f:(fun ~key:x ~data:i (m, m') ->
let d = f x i in let d = f x i in
match M.only_binding d with match M.only_binding d with
| Some (x', i') -> | Some (x', i') ->
@ -112,5 +111,5 @@ struct
let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m
let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m
let for_all m ~f = M.for_alli ~f:(fun ~key ~data -> f key data) m let for_all m ~f = M.for_alli ~f:(fun ~key ~data -> f key data) m
let fold m ~init ~f = M.fold ~f:(fun ~key ~data -> f key data) m ~init let fold m s ~f = M.fold ~f:(fun ~key ~data -> f key data) m s
end end

@ -115,6 +115,6 @@ module type S = sig
val for_all : t -> f:(elt -> mul -> bool) -> bool val for_all : t -> f:(elt -> mul -> bool) -> bool
(** Test whether all elements satisfy a predicate. *) (** Test whether all elements satisfy a predicate. *)
val fold : t -> init:'s -> f:(elt -> mul -> 's -> 's) -> 's val fold : t -> 's -> f:(elt -> mul -> 's -> 's) -> 's
(** Fold over the elements in ascending order. *) (** Fold over the elements in ascending order. *)
end end

@ -20,4 +20,4 @@ let bind xo ~f = bind xo f
let iter xo ~f = iter f xo let iter xo ~f = iter f xo
let exists xo ~f = exists f xo let exists xo ~f = exists f xo
let for_all xo ~f = for_all f xo let for_all xo ~f = for_all f xo
let fold xo ~init ~f = fold f init xo let fold xo s ~f = fold (fun x s -> f s x) s xo

@ -17,4 +17,4 @@ val bind : 'a t -> f:('a -> 'b t) -> 'b t
val iter : 'a t -> f:('a -> unit) -> unit val iter : 'a t -> f:('a -> unit) -> unit
val exists : 'a t -> f:('a -> bool) -> bool val exists : 'a t -> f:('a -> bool) -> bool
val for_all : 'a t -> f:('a -> bool) -> bool val for_all : 'a t -> f:('a -> bool) -> bool
val fold : 'a t -> init:'s -> f:('s -> 'a -> 's) -> 's val fold : 'a t -> 's -> f:('a -> 's -> 's) -> 's

@ -31,13 +31,13 @@ end) : S with type elt = Elt.t = struct
let of_option xo = Option.map_or ~f:S.singleton xo ~default:empty let of_option xo = Option.map_or ~f:S.singleton xo ~default:empty
let of_list = S.of_list let of_list = S.of_list
let add x s = S.add x s let add x s = S.add x s
let add_option xo s = Option.fold ~f:(Fun.flip add) ~init:s xo let add_option = Option.fold ~f:add
let add_list xs s = S.add_list s xs let add_list xs s = S.add_list s xs
let diff = S.diff let diff = S.diff
let inter = S.inter let inter = S.inter
let union = S.union let union = S.union
let diff_inter s t = (diff s t, inter s t) let diff_inter s t = (diff s t, inter s t)
let union_list ss = List.fold ~f:union ~init:empty ss let union_list ss = List.fold ~f:union ss empty
let is_empty = S.is_empty let is_empty = S.is_empty
let cardinal = S.cardinal let cardinal = S.cardinal
let mem s x = S.mem x s let mem s x = S.mem x s
@ -78,7 +78,7 @@ end) : S with type elt = Elt.t = struct
let iter s ~f = S.iter f s let iter s ~f = S.iter f s
let exists s ~f = S.exists f s let exists s ~f = S.exists f s
let for_all s ~f = S.for_all f s let for_all s ~f = S.for_all f s
let fold s ~init ~f = S.fold f s init let fold s z ~f = S.fold f s z
let to_iter = S.to_iter let to_iter = S.to_iter
let pp ?pre ?suf ?(sep = (",@ " : (unit, unit) fmt)) pp_elt fs x = let pp ?pre ?suf ?(sep = (",@ " : (unit, unit) fmt)) pp_elt fs x =

@ -55,7 +55,7 @@ module type S = sig
val iter : t -> f:(elt -> unit) -> unit val iter : t -> f:(elt -> unit) -> unit
val exists : t -> f:(elt -> bool) -> bool val exists : t -> f:(elt -> bool) -> bool
val for_all : t -> f:(elt -> bool) -> bool val for_all : t -> f:(elt -> bool) -> bool
val fold : t -> init:'s -> f:(elt -> 's -> 's) -> 's val fold : t -> 's -> f:(elt -> 's -> 's) -> 's
(** {1 Convert} *) (** {1 Convert} *)

@ -41,12 +41,12 @@ let times_of_raw {Report.etime; utime; stime; cutime; cstime} =
let etime = etime in let etime = etime in
{etime; utime; stime} {etime; utime; stime}
let add_time base_times row ptimes = let add_time base_times ptimes row =
let tustimes = times_of_raw ptimes in let tustimes = times_of_raw ptimes in
let times = tustimes :: row.times in let times = tustimes :: row.times in
let times_deltas = let times_deltas =
Option.fold base_times ~init:row.times_deltas Option.fold base_times row.times_deltas
~f:(fun times_deltas {etime= btt; utime= but; stime= bst} -> ~f:(fun {etime= btt; utime= but; stime= bst} times_deltas ->
let {etime= tt; utime= ut; stime= st} = tustimes in let {etime= tt; utime= ut; stime= st} = tustimes in
{etime= tt -. btt; utime= ut -. but; stime= st -. bst} {etime= tt -. btt; utime= ut -. but; stime= st -. bst}
:: times_deltas ) :: times_deltas )
@ -56,12 +56,12 @@ let add_time base_times row ptimes =
let add_times base_times times row = let add_times base_times times row =
if List.is_empty times then if List.is_empty times then
{row with times_deltas= Option.to_list base_times} {row with times_deltas= Option.to_list base_times}
else List.fold ~f:(add_time base_times) ~init:row times else List.fold ~f:(add_time base_times) times row
let add_gc base_gcs row gc = let add_gc base_gcs gc row =
let gcs = gc :: row.gcs in let gcs = gc :: row.gcs in
let gcs_deltas = let gcs_deltas =
Option.fold base_gcs ~init:row.gcs_deltas ~f:(fun gcs_deltas bgc -> Option.fold base_gcs row.gcs_deltas ~f:(fun bgc gcs_deltas ->
Report. Report.
{ allocated= gc.allocated -. bgc.allocated { allocated= gc.allocated -. bgc.allocated
; promoted= gc.promoted -. bgc.promoted ; promoted= gc.promoted -. bgc.promoted
@ -72,9 +72,9 @@ let add_gc base_gcs row gc =
let add_gcs base_gcs gcs row = let add_gcs base_gcs gcs row =
if List.is_empty gcs then {row with gcs_deltas= Option.to_list base_gcs} if List.is_empty gcs then {row with gcs_deltas= Option.to_list base_gcs}
else List.fold ~f:(add_gc base_gcs) ~init:row gcs else List.fold ~f:(add_gc base_gcs) gcs row
let add_status base_status row status = let add_status base_status status row =
if List.mem ~eq:Report.equal_status status row.status then row if List.mem ~eq:Report.equal_status status row.status then row
else else
match base_status with match base_status with
@ -88,13 +88,13 @@ let add_status base_status row status =
| _ -> {row with status= status :: row.status} | _ -> {row with status= status :: row.status}
let add_statuses base_status statuses row = let add_statuses base_status statuses row =
List.fold ~f:(add_status base_status) ~init:row statuses List.fold ~f:(add_status base_status) statuses row
let ave_floats flts = let ave_floats flts =
assert (not (Iter.is_empty flts)) ; assert (not (Iter.is_empty flts)) ;
let min, max, sum, num = let min, max, sum, num =
Iter.fold flts ~init:(Float.infinity, Float.neg_infinity, 0., 0) Iter.fold flts (Float.infinity, Float.neg_infinity, 0., 0)
~f:(fun (min, max, sum, num) flt -> ~f:(fun flt (min, max, sum, num) ->
(Float.min min flt, Float.max max flt, sum +. flt, num + 1) ) (Float.min min flt, Float.max max flt, sum +. flt, num + 1) )
in in
if num >= 5 then (sum -. min -. max) /. Float.of_int (num - 2) if num >= 5 then (sum -. min -. max) /. Float.of_int (num - 2)
@ -108,15 +108,12 @@ let combine name b_result c_result =
if List.is_empty times then None if List.is_empty times then None
else else
let etimes, utimes, stimes, cutimes, cstimes = let etimes, utimes, stimes, cutimes, cstimes =
List.fold times let init =
~init: (Iter.empty, Iter.empty, Iter.empty, Iter.empty, Iter.empty)
( Iter.empty in
, Iter.empty List.fold times init
, Iter.empty ~f:(fun {Report.etime; utime; stime; cutime; cstime}
, Iter.empty (etimes, utimes, stimes, cutimes, cstimes)
, Iter.empty )
~f:(fun (etimes, utimes, stimes, cutimes, cstimes)
{Report.etime; utime; stime; cutime; cstime}
-> ->
( Iter.cons etime etimes ( Iter.cons etime etimes
, Iter.cons utime utimes , Iter.cons utime utimes
@ -136,9 +133,9 @@ let combine name b_result c_result =
if List.is_empty gcs then None if List.is_empty gcs then None
else else
let allocs, promos, peaks = let allocs, promos, peaks =
List.fold gcs ~init:(Iter.empty, Iter.empty, Iter.empty) List.fold gcs (Iter.empty, Iter.empty, Iter.empty)
~f:(fun (allocs, promos, peaks) ~f:(fun {Report.allocated; promoted; peak_size}
{Report.allocated; promoted; peak_size} (allocs, promos, peaks)
-> ->
( Iter.cons allocated allocs ( Iter.cons allocated allocs
, Iter.cons promoted promos , Iter.cons promoted promos
@ -198,20 +195,20 @@ let ranges rows =
; max_peak= 0. ; max_peak= 0.
; pct_peak= 0. } ; pct_peak= 0. }
in in
Iter.fold rows ~init ~f:(fun acc {times; times_deltas; gcs; gcs_deltas} -> Iter.fold rows init ~f:(fun {times; times_deltas; gcs; gcs_deltas} acc ->
Option.fold times_deltas ~init:acc ~f:(fun acc deltas -> Option.fold times_deltas acc ~f:(fun deltas acc ->
let max_time = Float.max acc.max_time (Float.abs deltas.etime) in let max_time = Float.max acc.max_time (Float.abs deltas.etime) in
let pct_time = let pct_time =
Option.fold times ~init:acc.pct_time ~f:(fun pct_time times -> Option.fold times acc.pct_time ~f:(fun times pct_time ->
let pct = 100. *. deltas.etime /. times.etime in let pct = 100. *. deltas.etime /. times.etime in
Float.max pct_time (Float.abs pct) ) Float.max pct_time (Float.abs pct) )
in in
{acc with max_time; pct_time} ) {acc with max_time; pct_time} )
|> fun init -> |> fun acc ->
Option.fold gcs_deltas ~init ~f:(fun acc deltas -> Option.fold gcs_deltas acc ~f:(fun deltas acc ->
let max_alloc = Float.max acc.max_alloc deltas.Report.allocated in let max_alloc = Float.max acc.max_alloc deltas.Report.allocated in
let pct_alloc = let pct_alloc =
Option.fold gcs ~init:acc.pct_alloc ~f:(fun pct_alloc gcs -> Option.fold gcs acc.pct_alloc ~f:(fun gcs pct_alloc ->
let pct = let pct =
100. *. deltas.Report.allocated /. gcs.Report.allocated 100. *. deltas.Report.allocated /. gcs.Report.allocated
in in
@ -219,7 +216,7 @@ let ranges rows =
in in
let max_promo = Float.max acc.max_promo deltas.Report.promoted in let max_promo = Float.max acc.max_promo deltas.Report.promoted in
let pct_promo = let pct_promo =
Option.fold gcs ~init:acc.pct_promo ~f:(fun pct_promo gcs -> Option.fold gcs acc.pct_promo ~f:(fun gcs pct_promo ->
let pct = let pct =
100. *. deltas.Report.promoted /. gcs.Report.promoted 100. *. deltas.Report.promoted /. gcs.Report.promoted
in in
@ -227,7 +224,7 @@ let ranges rows =
in in
let max_peak = Float.max acc.max_peak deltas.Report.peak_size in let max_peak = Float.max acc.max_peak deltas.Report.peak_size in
let pct_peak = let pct_peak =
Option.fold gcs ~init:acc.pct_peak ~f:(fun pct_peak gcs -> Option.fold gcs acc.pct_peak ~f:(fun gcs pct_peak ->
let pct = let pct =
100. *. deltas.Report.peak_size /. gcs.Report.peak_size 100. *. deltas.Report.peak_size /. gcs.Report.peak_size
in in
@ -424,8 +421,8 @@ let average row =
if List.is_empty times then None if List.is_empty times then None
else else
let etimes, utimes, stimes = let etimes, utimes, stimes =
List.fold times ~init:(Iter.empty, Iter.empty, Iter.empty) List.fold times (Iter.empty, Iter.empty, Iter.empty)
~f:(fun (etimes, utimes, stimes) {etime; utime; stime} -> ~f:(fun {etime; utime; stime} (etimes, utimes, stimes) ->
( Iter.cons etime etimes ( Iter.cons etime etimes
, Iter.cons utime utimes , Iter.cons utime utimes
, Iter.cons stime stimes ) ) , Iter.cons stime stimes ) )
@ -441,9 +438,9 @@ let average row =
if List.is_empty gcs then None if List.is_empty gcs then None
else else
let alloc, promo, peak = let alloc, promo, peak =
List.fold gcs ~init:(Iter.empty, Iter.empty, Iter.empty) List.fold gcs (Iter.empty, Iter.empty, Iter.empty)
~f:(fun (alloc, promo, peak) ~f:(fun {Report.allocated; promoted; peak_size}
{Report.allocated; promoted; peak_size} (alloc, promo, peak)
-> ->
( Iter.cons allocated alloc ( Iter.cons allocated alloc
, Iter.cons promoted promo , Iter.cons promoted promo
@ -470,7 +467,7 @@ let add_total rows =
; status_deltas= None } ; status_deltas= None }
in in
let total = let total =
Iter.fold rows ~init ~f:(fun total row -> Iter.fold rows init ~f:(fun total row ->
let times = let times =
match (total.times, row.times) with match (total.times, row.times) with
| Some total_times, Some row_times -> | Some total_times, Some row_times ->
@ -545,9 +542,7 @@ let input_rows ?baseline current =
let names = let names =
let keys = Tbl.keys c_tbl in let keys = Tbl.keys c_tbl in
let keys = let keys =
Option.fold Option.fold ~f:(fun t -> Iter.append (Tbl.keys t)) b_tbl keys
~f:(fun i t -> Iter.append (Tbl.keys t) i)
~init:keys b_tbl
in in
Iter.sort_uniq ~cmp:String.compare keys Iter.sort_uniq ~cmp:String.compare keys
in in

@ -242,10 +242,10 @@ module Representation (Trm : INDETERMINATE) = struct
let map poly ~f = let map poly ~f =
let p, p' = (poly, Sum.empty) in let p, p' = (poly, Sum.empty) in
let p, p' = let p, p' =
Sum.fold poly ~init:(p, p') ~f:(fun mono coeff (p, p') -> Sum.fold poly (p, p') ~f:(fun mono coeff (p, p') ->
let m, cm' = (mono, CM.one) in let m, cm' = (mono, CM.one) in
let m, cm' = let m, cm' =
Prod.fold mono ~init:(m, cm') ~f:(fun trm power (m, cm') -> Prod.fold mono (m, cm') ~f:(fun trm power (m, cm') ->
let trm' = f trm in let trm' = f trm in
if trm == trm' then (m, cm') if trm == trm' then (m, cm')
else else

@ -68,8 +68,8 @@ module type S = sig
type product type product
val fold_factors : product -> init:'s -> f:(trm -> int -> 's -> 's) -> 's val fold_factors : product -> 's -> f:(trm -> int -> 's -> 's) -> 's
val fold_monomials : t -> init:'s -> f:(product -> Q.t -> 's -> 's) -> 's val fold_monomials : t -> 's -> f:(product -> Q.t -> 's -> 's) -> 's
end end
(** Indeterminate terms, treated as atomic / variables except when they can (** Indeterminate terms, treated as atomic / variables except when they can

@ -30,7 +30,7 @@ module Make (Dom : Domain_intf.Dom) = struct
val pop_throw : val pop_throw :
t t
-> init:'a -> 'a
-> unwind: -> unwind:
( Llair.Reg.t list ( Llair.Reg.t list
-> Llair.Reg.Set.t -> Llair.Reg.Set.t
@ -130,7 +130,7 @@ module Make (Dom : Domain_intf.Dom) = struct
| Return {from_call; dst; stk} -> Some (from_call, dst, stk) | Return {from_call; dst; stk} -> Some (from_call, dst, stk)
| Empty -> None | Empty -> None
let pop_throw stk ~init ~unwind = let pop_throw stk state ~unwind =
let rec pop_throw_ state = function let rec pop_throw_ state = function
| Return {formals; locals; from_call; stk} -> | Return {formals; locals; from_call; stk} ->
pop_throw_ (unwind formals locals from_call state) stk pop_throw_ (unwind formals locals from_call state) stk
@ -139,7 +139,7 @@ module Make (Dom : Domain_intf.Dom) = struct
| Empty -> None | Empty -> None
| Throw _ as stk -> violates invariant stk | Throw _ as stk -> violates invariant stk
in in
pop_throw_ init stk pop_throw_ state stk
end end
module Work : sig module Work : sig
@ -240,7 +240,7 @@ module Make (Dom : Domain_intf.Dom) = struct
| Some (q :: qs, ws) -> | Some (q :: qs, ws) ->
let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in
let skipped, (qs, depths) = let skipped, (qs, depths) =
List.fold qs ~init:([], q) ~f:(fun (skipped, joined) curr -> List.fold qs ([], q) ~f:(fun curr (skipped, joined) ->
match join curr joined with match join curr joined with
| Some joined, depths -> (skipped, (joined, depths)) | Some joined, depths -> (skipped, (joined, depths))
| None, _ -> (curr :: skipped, joined) ) | None, _ -> (curr :: skipped, joined) )
@ -275,7 +275,7 @@ module Make (Dom : Domain_intf.Dom) = struct
let domain_call = let domain_call =
Dom.call ~globals ~actuals ~areturn ~formals ~freturn ~locals Dom.call ~globals ~actuals ~areturn ~formals ~freturn ~locals
in in
List.fold ~init:Work.skip dnf_states ~f:(fun acc state -> List.fold dnf_states Work.skip ~f:(fun state acc ->
match match
if not opts.function_summaries then None if not opts.function_summaries then None
else else
@ -337,7 +337,7 @@ module Make (Dom : Domain_intf.Dom) = struct
let exit_state = let exit_state =
match (freturn, exp) with match (freturn, exp) with
| Some freturn, Some return_val -> | Some freturn, Some return_val ->
Dom.exec_move pre_state (IArray.of_ (freturn, return_val)) Dom.exec_move (IArray.of_ (freturn, return_val)) pre_state
| None, None -> pre_state | None, None -> pre_state
| _ -> violates Llair.Func.invariant block.parent | _ -> violates Llair.Func.invariant block.parent
in in
@ -365,11 +365,11 @@ module Make (Dom : Domain_intf.Dom) = struct
Dom.retn formals (Some func.fthrow) from_call Dom.retn formals (Some func.fthrow) from_call
(Dom.post scope from_call state) (Dom.post scope from_call state)
in in
( match Stack.pop_throw stk ~unwind ~init:pre_state with ( match Stack.pop_throw stk ~unwind pre_state with
| Some (from_call, retn_site, stk, unwind_state) -> | Some (from_call, retn_site, stk, unwind_state) ->
let fthrow = func.fthrow in let fthrow = func.fthrow in
let exit_state = let exit_state =
Dom.exec_move unwind_state (IArray.of_ (fthrow, exc)) Dom.exec_move (IArray.of_ (fthrow, exc)) unwind_state
in in
let post_state = Dom.post func.locals from_call exit_state in let post_state = Dom.post func.locals from_call exit_state in
let retn_state = let retn_state =
@ -389,7 +389,7 @@ module Make (Dom : Domain_intf.Dom) = struct
-> Work.x = -> Work.x =
fun stk state block areturn return -> fun stk state block areturn return ->
Report.unknown_call block.term ; Report.unknown_call block.term ;
let state = Option.fold ~f:Dom.exec_kill ~init:state areturn in let state = Option.fold ~f:Dom.exec_kill areturn state in
exec_jump stk state block return exec_jump stk state block return
let exec_term : let exec_term :
@ -405,22 +405,21 @@ module Make (Dom : Domain_intf.Dom) = struct
Report.step () ; Report.step () ;
match block.term with match block.term with
| Switch {key; tbl; els} -> | Switch {key; tbl; els} ->
IArray.fold tbl IArray.fold
~f:(fun x (case, jump) -> ~f:(fun (case, jump) x ->
match Dom.exec_assume state (Llair.Exp.eq key case) with match Dom.exec_assume state (Llair.Exp.eq key case) with
| Some state -> exec_jump stk state block jump |> Work.seq x | Some state -> exec_jump stk state block jump |> Work.seq x
| None -> x ) | None -> x )
~init: tbl
( match ( match
Dom.exec_assume state Dom.exec_assume state
(IArray.fold tbl ~init:Llair.Exp.true_ (IArray.fold tbl Llair.Exp.true_ ~f:(fun (case, _) b ->
~f:(fun b (case, _) -> Llair.Exp.and_ (Llair.Exp.dq key case) b ))
Llair.Exp.and_ (Llair.Exp.dq key case) b )) with
with | Some state -> exec_jump stk state block els
| Some state -> exec_jump stk state block els | None -> Work.skip )
| None -> Work.skip )
| Iswitch {ptr; tbl} -> | Iswitch {ptr; tbl} ->
IArray.fold tbl ~init:Work.skip ~f:(fun x (jump : Llair.jump) -> IArray.fold tbl Work.skip ~f:(fun (jump : Llair.jump) x ->
match match
Dom.exec_assume state Dom.exec_assume state
(Llair.Exp.eq ptr (Llair.Exp.eq ptr
@ -438,10 +437,10 @@ module Make (Dom : Domain_intf.Dom) = struct
match callees with match callees with
| [] -> exec_skip_func stk state block areturn return | [] -> exec_skip_func stk state block areturn return
| callees -> | callees ->
List.fold callees ~init:Work.skip ~f:(fun x callee -> List.fold callees Work.skip ~f:(fun callee x ->
( match ( match
Dom.exec_intrinsic ~skip_throw:opts.skip_throw state Dom.exec_intrinsic ~skip_throw:opts.skip_throw areturn
areturn callee.name.reg actuals callee.name.reg actuals state
with with
| Some None -> | Some None ->
Report.invalid_access_term Report.invalid_access_term
@ -463,13 +462,13 @@ module Make (Dom : Domain_intf.Dom) = struct
else exec_throw stk state block exc else exec_throw stk state block exc
| Unreachable -> Work.skip | Unreachable -> Work.skip
let exec_inst : Dom.t -> Llair.inst -> (Dom.t, Dom.t * Llair.inst) result let exec_inst : Llair.inst -> Dom.t -> (Dom.t, Dom.t * Llair.inst) result
= =
fun state inst -> fun inst state ->
[%Trace.info [%Trace.info
"@[<2>exec inst@\n@[%a@]@\n%a@]" Dom.pp state Llair.Inst.pp inst] ; "@[<2>exec inst@\n@[%a@]@\n%a@]" Dom.pp state Llair.Inst.pp inst] ;
Report.step () ; Report.step () ;
Dom.exec_inst state inst Dom.exec_inst inst state
|> function |> function
| Some state -> Result.Ok state | None -> Result.Error (state, inst) | Some state -> Result.Ok state | None -> Result.Error (state, inst)
@ -483,7 +482,7 @@ module Make (Dom : Domain_intf.Dom) = struct
fun opts pgm stk state block -> fun opts pgm stk state block ->
[%Trace.info "exec block %%%s" block.lbl] ; [%Trace.info "exec block %%%s" block.lbl] ;
match match
Iter.fold_result ~f:exec_inst ~init:state (IArray.to_iter block.cmnd) Iter.fold_result ~f:exec_inst (IArray.to_iter block.cmnd) state
with with
| Ok state -> exec_term opts pgm stk state block | Ok state -> exec_term opts pgm stk state block
| Error (state, inst) -> | Error (state, inst) ->
@ -517,7 +516,6 @@ module Make (Dom : Domain_intf.Dom) = struct
let compute_summaries opts pgm : Dom.summary list Llair.Reg.Map.t = let compute_summaries opts pgm : Dom.summary list Llair.Reg.Map.t =
assert opts.function_summaries ; assert opts.function_summaries ;
exec_pgm opts pgm ; exec_pgm opts pgm ;
RegTbl.fold summary_table ~init:Llair.Reg.Map.empty RegTbl.fold summary_table Llair.Reg.Map.empty ~f:(fun ~key ~data map ->
~f:(fun ~key ~data map ->
match data with [] -> map | _ -> Llair.Reg.Map.add ~key ~data map ) match data with [] -> map | _ -> Llair.Reg.Map.add ~key ~data map )
end end

@ -16,16 +16,16 @@ module type Dom = sig
val is_false : t -> bool val is_false : t -> bool
val dnf : t -> t list val dnf : t -> t list
val exec_assume : t -> Llair.Exp.t -> t option val exec_assume : t -> Llair.Exp.t -> t option
val exec_kill : t -> Llair.Reg.t -> t val exec_kill : Llair.Reg.t -> t -> t
val exec_move : t -> (Llair.Reg.t * Llair.Exp.t) iarray -> t val exec_move : (Llair.Reg.t * Llair.Exp.t) iarray -> t -> t
val exec_inst : t -> Llair.inst -> t option val exec_inst : Llair.inst -> t -> t option
val exec_intrinsic : val exec_intrinsic :
skip_throw:bool skip_throw:bool
-> t
-> Llair.Reg.t option -> Llair.Reg.t option
-> Llair.Reg.t -> Llair.Reg.t
-> Llair.Exp.t list -> Llair.Exp.t list
-> t
-> t option option -> t option option
type from_call [@@deriving sexp_of] type from_call [@@deriving sexp_of]

@ -45,21 +45,21 @@ module Make (State_domain : State_domain_sig) = struct
let+ next = State_domain.exec_assume current cnd in let+ next = State_domain.exec_assume current cnd in
(entry, next) (entry, next)
let exec_kill (entry, current) reg = let exec_kill reg (entry, current) =
(entry, State_domain.exec_kill current reg) (entry, State_domain.exec_kill reg current)
let exec_move (entry, current) reg_exps = let exec_move reg_exps (entry, current) =
(entry, State_domain.exec_move current reg_exps) (entry, State_domain.exec_move reg_exps current)
let exec_inst (entry, current) inst = let exec_inst inst (entry, current) =
let+ next = State_domain.exec_inst current inst in let+ next = State_domain.exec_inst inst current in
(entry, next) (entry, next)
let exec_intrinsic ~skip_throw (entry, current) areturn intrinsic actuals let exec_intrinsic ~skip_throw areturn intrinsic actuals (entry, current)
= =
let+ next_opt = let+ next_opt =
State_domain.exec_intrinsic ~skip_throw current areturn intrinsic State_domain.exec_intrinsic ~skip_throw areturn intrinsic actuals
actuals current
in in
let+ next = next_opt in let+ next = next_opt in
(entry, next) (entry, next)

@ -20,13 +20,14 @@ let simplify_states = ref true
let simplify q = if !simplify_states then Sh.simplify q else q let simplify q = if !simplify_states then Sh.simplify q else q
let init globals = let init globals =
IArray.fold globals ~init:Sh.emp ~f:(fun q -> function IArray.fold globals Sh.emp ~f:(fun global q ->
| {Llair.Global.reg; init= Some (seq, siz)} -> match global with
let loc = Term.var (X.reg reg) in | {Llair.Global.reg; init= Some (seq, siz)} ->
let len = Term.integer (Z.of_int siz) in let loc = Term.var (X.reg reg) in
let seq = X.term seq in let len = Term.integer (Z.of_int siz) in
Sh.star q (Sh.seg {loc; bas= loc; len; siz= len; seq}) let seq = X.term seq in
| _ -> q ) Sh.star q (Sh.seg {loc; bas= loc; len; siz= len; seq})
| _ -> q )
let join p q = let join p q =
[%Trace.call fun {pf} -> pf "%a@ %a" pp p pp q] [%Trace.call fun {pf} -> pf "%a@ %a" pp p pp q]
@ -38,13 +39,13 @@ let join p q =
let is_false = Sh.is_false let is_false = Sh.is_false
let dnf = Sh.dnf let dnf = Sh.dnf
let exec_assume q b = Exec.assume q (X.formula b) |> Option.map ~f:simplify let exec_assume q b = Exec.assume q (X.formula b) |> Option.map ~f:simplify
let exec_kill q r = Exec.kill q (X.reg r) |> simplify let exec_kill r q = Exec.kill q (X.reg r) |> simplify
let exec_move q res = let exec_move res q =
Exec.move q (IArray.map res ~f:(fun (r, e) -> (X.reg r, X.term e))) Exec.move q (IArray.map res ~f:(fun (r, e) -> (X.reg r, X.term e)))
|> simplify |> simplify
let exec_inst pre inst = let exec_inst inst pre =
( match (inst : Llair.inst) with ( match (inst : Llair.inst) with
| Move {reg_exps; _} -> | Move {reg_exps; _} ->
Some Some
@ -67,7 +68,7 @@ let exec_inst pre inst =
| Abort _ -> Exec.abort pre ) | Abort _ -> Exec.abort pre )
|> Option.map ~f:simplify |> Option.map ~f:simplify
let exec_intrinsic ~skip_throw q r i es = let exec_intrinsic ~skip_throw r i es q =
Exec.intrinsic ~skip_throw q (Option.map ~f:X.reg r) (X.reg i) Exec.intrinsic ~skip_throw q (Option.map ~f:X.reg r) (X.reg i)
(List.map ~f:X.term es) (List.map ~f:X.term es)
|> Option.map ~f:(Option.map ~f:simplify) |> Option.map ~f:(Option.map ~f:simplify)
@ -94,10 +95,10 @@ let garbage_collect (q : t) ~wrt =
if Var.Set.equal previous current then current if Var.Set.equal previous current then current
else else
let new_set = let new_set =
List.fold ~init:current q.heap ~f:(fun current seg -> List.fold q.heap current ~f:(fun seg current ->
if term_eq_class_has_only_vars_in current q.ctx seg.loc then if term_eq_class_has_only_vars_in current q.ctx seg.loc then
List.fold (Context.class_of q.ctx seg.seq) ~init:current List.fold (Context.class_of q.ctx seg.seq) current
~f:(fun c e -> Var.Set.union c (Term.fv e)) ~f:(fun e c -> Var.Set.union c (Term.fv e))
else current ) else current )
in in
all_reachable_vars current new_set q all_reachable_vars current new_set q
@ -109,11 +110,11 @@ let garbage_collect (q : t) ~wrt =
[%Trace.retn fun {pf} -> pf "%a" pp] [%Trace.retn fun {pf} -> pf "%a" pp]
let and_eqs sub formals actuals q = let and_eqs sub formals actuals q =
let and_eq q formal actual = let and_eq formal actual q =
let actual' = Term.rename sub actual in let actual' = Term.rename sub actual in
Sh.and_ (Formula.eq (Term.var formal) actual') q Sh.and_ (Formula.eq (Term.var formal) actual') q
in in
List.fold2_exn ~f:and_eq formals actuals ~init:q List.fold2_exn ~f:and_eq formals actuals q
let localize_entry globals actuals formals freturn locals shadow pre entry = let localize_entry globals actuals formals freturn locals shadow pre entry =
(* Add the formals here to do garbage collection and then get rid of them *) (* Add the formals here to do garbage collection and then get rid of them *)
@ -257,7 +258,7 @@ let create_summary ~locals ~formals ~entry ~current:(post : Sh.t) =
let foot = Sh.exists locals entry in let foot = Sh.exists locals entry in
let foot, subst = Sh.freshen ~wrt:(Var.Set.union foot.us post.us) foot in let foot, subst = Sh.freshen ~wrt:(Var.Set.union foot.us post.us) foot in
let restore_formals q = let restore_formals q =
Var.Set.fold formals ~init:q ~f:(fun var q -> Var.Set.fold formals q ~f:(fun var q ->
let var = Term.var var in let var = Term.var var in
let renamed_var = Term.rename subst var in let renamed_var = Term.rename subst var in
Sh.and_ (Formula.eq renamed_var var) q ) Sh.and_ (Formula.eq renamed_var var) q )

@ -15,9 +15,9 @@ let init _ = ()
let join () () = Some () let join () () = Some ()
let is_false _ = false let is_false _ = false
let exec_assume () _ = Some () let exec_assume () _ = Some ()
let exec_kill () _ = () let exec_kill _ () = ()
let exec_move () _ = () let exec_move _ () = ()
let exec_inst () _ = Some () let exec_inst _ () = Some ()
let exec_intrinsic ~skip_throw:_ _ _ _ _ : t option option = None let exec_intrinsic ~skip_throw:_ _ _ _ _ : t option option = None
type from_call = unit [@@deriving compare, equal, sexp] type from_call = unit [@@deriving compare, equal, sexp]

@ -24,30 +24,25 @@ let post _ _ state = state
let retn _ _ from_call post = Llair.Reg.Set.union from_call post let retn _ _ from_call post = Llair.Reg.Set.union from_call post
let dnf t = [t] let dnf t = [t]
let add_if_global gs v = let add_if_global v gs =
if Llair.Reg.is_global v then Llair.Reg.Set.add v gs else gs if Llair.Reg.is_global v then Llair.Reg.Set.add v gs else gs
let used_globals ?(init = empty) exp = let used_globals exp s = Llair.Exp.fold_regs ~f:add_if_global exp s
Llair.Exp.fold_regs exp ~init ~f:add_if_global let exec_assume st exp = Some (used_globals exp st)
let exec_kill _ st = st
let exec_assume st exp = Some (used_globals ~init:st exp) let exec_move reg_exps st =
let exec_kill st _ = st IArray.fold ~f:(fun (_, rhs) -> used_globals rhs) reg_exps st
let exec_move st reg_exps = let exec_inst inst st =
IArray.fold reg_exps ~init:st ~f:(fun st (_, rhs) ->
used_globals ~init:st rhs )
let exec_inst st inst =
[%Trace.call fun {pf} -> pf "pre:{%a} %a" pp st Llair.Inst.pp inst] [%Trace.call fun {pf} -> pf "pre:{%a} %a" pp st Llair.Inst.pp inst]
; ;
Some Some (Llair.Inst.fold_exps ~f:used_globals inst st)
(Llair.Inst.fold_exps inst ~init:st ~f:(fun acc e ->
used_globals ~init:acc e ))
|> |>
[%Trace.retn fun {pf} -> [%Trace.retn fun {pf} ->
Option.iter ~f:(fun uses -> pf "post:{%a}" pp uses)] Option.iter ~f:(fun uses -> pf "post:{%a}" pp uses)]
let exec_intrinsic ~skip_throw:_ st _ intrinsic actuals = let exec_intrinsic ~skip_throw:_ _ intrinsic actuals st =
let name = Llair.Reg.name intrinsic in let name = Llair.Reg.name intrinsic in
if if
List.exists List.exists
@ -72,9 +67,7 @@ let exec_intrinsic ~skip_throw:_ st _ intrinsic actuals =
; "__cxa_allocate_exception" ; "__cxa_allocate_exception"
; "_ZN5folly13usingJEMallocEv" ] ; "_ZN5folly13usingJEMallocEv" ]
~f:(String.equal name) ~f:(String.equal name)
then then List.fold ~f:used_globals actuals st |> fun res -> Some (Some res)
List.fold actuals ~init:st ~f:(fun s a -> used_globals ~init:s a)
|> fun res -> Some (Some res)
else None else None
type from_call = t [@@deriving sexp] type from_call = t [@@deriving sexp]
@ -82,10 +75,10 @@ type from_call = t [@@deriving sexp]
(* Set abstract state to bottom (i.e. empty set) at function entry *) (* Set abstract state to bottom (i.e. empty set) at function entry *)
let call ~summaries:_ ~globals:_ ~actuals ~areturn:_ ~formals:_ ~freturn:_ let call ~summaries:_ ~globals:_ ~actuals ~areturn:_ ~formals:_ ~freturn:_
~locals:_ st = ~locals:_ st =
(empty, List.fold actuals ~init:st ~f:(fun s a -> used_globals ~init:s a)) (empty, List.fold ~f:used_globals actuals st)
let resolve_callee lookup ptr st = let resolve_callee lookup ptr st =
let st = used_globals ~init:st ptr in let st = used_globals ptr st in
match Llair.Reg.of_exp ptr with match Llair.Reg.of_exp ptr with
| Some callee -> (lookup (Llair.Reg.name callee), st) | Some callee -> (lookup (Llair.Reg.name callee), st)
| None -> ([], st) | None -> ([], st)

@ -102,13 +102,13 @@ open Fresh.Import
let move_spec reg_exps = let move_spec reg_exps =
let foot = Sh.emp in let foot = Sh.emp in
let ws, rs = let ws, rs =
IArray.fold reg_exps ~init:(Var.Set.empty, Var.Set.empty) IArray.fold reg_exps (Var.Set.empty, Var.Set.empty)
~f:(fun (ws, rs) (reg, exp) -> ~f:(fun (reg, exp) (ws, rs) ->
(Var.Set.add reg ws, Var.Set.union rs (Term.fv exp)) ) (Var.Set.add reg ws, Var.Set.union rs (Term.fv exp)) )
in in
let+ sub, ms = Fresh.assign ~ws ~rs in let+ sub, ms = Fresh.assign ~ws ~rs in
let post = let post =
IArray.fold reg_exps ~init:Sh.emp ~f:(fun post (reg, exp) -> IArray.fold reg_exps Sh.emp ~f:(fun (reg, exp) post ->
Sh.and_ (Formula.eq (Term.var reg) (Term.rename sub exp)) post ) Sh.and_ (Formula.eq (Term.var reg) (Term.rename sub exp)) post )
in in
{foot; sub; ms; post} {foot; sub; ms; post}

@ -350,51 +350,46 @@ let pp = ppx (fun _ -> None)
(** fold_vars *) (** fold_vars *)
let fold_pos_neg ~pos ~neg ~init ~f = let fold_pos_neg ~pos ~neg s ~f =
let f_not p s = f s (_Not p) in let f_not p s = f (_Not p) s in
Fmls.fold ~init:(Fmls.fold ~init ~f:(Fun.flip f) pos) ~f:f_not neg Fmls.fold ~f:f_not neg (Fmls.fold ~f pos s)
let rec fold_vars_t e ~init ~f = let rec fold_vars_t e s ~f =
match e with match e with
| Z _ | Q _ | Ancestor _ -> init | Z _ | Q _ | Ancestor _ -> s
| Var _ as v -> f init (Var.of_ v) | Var _ as v -> f (Var.of_ v) s
| Splat x | Select {rcd= x} -> fold_vars_t ~f x ~init | Splat x | Select {rcd= x} -> fold_vars_t ~f x s
| Sized {seq= x; siz= y} | Update {rcd= x; elt= y} -> | Sized {seq= x; siz= y} | Update {rcd= x; elt= y} ->
fold_vars_t ~f x ~init:(fold_vars_t ~f y ~init) fold_vars_t ~f x (fold_vars_t ~f y s)
| Extract {seq= x; off= y; len= z} -> | Extract {seq= x; off= y; len= z} ->
fold_vars_t ~f x fold_vars_t ~f x (fold_vars_t ~f y (fold_vars_t ~f z s))
~init:(fold_vars_t ~f y ~init:(fold_vars_t ~f z ~init))
| Concat xs | Record xs | Apply (_, xs) -> | Concat xs | Record xs | Apply (_, xs) ->
Array.fold ~f:(fun init -> fold_vars_t ~f ~init) xs ~init Array.fold ~f:(fold_vars_t ~f) xs s
| Arith a -> | Arith a -> Iter.fold ~f:(fold_vars_t ~f) (Arith.iter a) s
Iter.fold
~f:(fun s x -> fold_vars_t ~f x ~init:s)
~init (Arith.iter a)
let rec fold_vars_f ~init p ~f = let rec fold_vars_f p s ~f =
match (p : fml) with match (p : fml) with
| Tt -> init | Tt -> s
| Eq (x, y) -> fold_vars_t ~f x ~init:(fold_vars_t ~f y ~init) | Eq (x, y) -> fold_vars_t ~f x (fold_vars_t ~f y s)
| Eq0 x | Pos x -> fold_vars_t ~f x ~init | Eq0 x | Pos x -> fold_vars_t ~f x s
| Not x -> fold_vars_f ~f x ~init | Not x -> fold_vars_f ~f x s
| And {pos; neg} | Or {pos; neg} -> | And {pos; neg} | Or {pos; neg} ->
fold_pos_neg ~f:(fun init -> fold_vars_f ~f ~init) ~pos ~neg ~init fold_pos_neg ~f:(fold_vars_f ~f) ~pos ~neg s
| Iff (x, y) -> fold_vars_f ~f x ~init:(fold_vars_f ~f y ~init) | Iff (x, y) -> fold_vars_f ~f x (fold_vars_f ~f y s)
| Cond {cnd; pos; neg} -> | Cond {cnd; pos; neg} ->
fold_vars_f ~f cnd fold_vars_f ~f cnd (fold_vars_f ~f pos (fold_vars_f ~f neg s))
~init:(fold_vars_f ~f pos ~init:(fold_vars_f ~f neg ~init)) | Lit (_, xs) -> Array.fold ~f:(fold_vars_t ~f) xs s
| Lit (_, xs) -> Array.fold ~f:(fun init -> fold_vars_t ~f ~init) xs ~init
let rec fold_vars_c ~init ~f = function let rec fold_vars_c c s ~f =
match c with
| `Ite (cnd, thn, els) -> | `Ite (cnd, thn, els) ->
fold_vars_f ~f cnd fold_vars_f ~f cnd (fold_vars_c ~f thn (fold_vars_c ~f els s))
~init:(fold_vars_c ~f thn ~init:(fold_vars_c ~f els ~init)) | `Trm t -> fold_vars_t ~f t s
| `Trm t -> fold_vars_t ~f t ~init
let fold_vars ~init e ~f = let fold_vars e s ~f =
match e with match e with
| `Fml p -> fold_vars_f ~f ~init p | `Fml p -> fold_vars_f ~f p s
| #cnd as c -> fold_vars_c ~f ~init c | #cnd as c -> fold_vars_c ~f c s
(** map *) (** map *)
@ -704,21 +699,21 @@ module Term = struct
let map_vars = map_vars let map_vars = map_vars
let fold_map_vars e ~init ~f = let fold_map_vars e s0 ~f =
let s = ref init in let s = ref s0 in
let f x = let f x =
let s', x' = f !s x in let x', s' = f x !s in
s := s' ; s := s' ;
x' x'
in in
let e' = map_vars ~f e in let e' = map_vars ~f e in
(!s, e') (e', !s)
let rename s e = map_vars ~f:(Var.Subst.apply s) e let rename s e = map_vars ~f:(Var.Subst.apply s) e
(** Query *) (** Query *)
let fv e = fold_vars e ~f:(Fun.flip Var.Set.add) ~init:Var.Set.empty let fv e = fold_vars ~f:Var.Set.add e Var.Set.empty
end end
(* (*
@ -763,9 +758,9 @@ module Formula = struct
(* connectives *) (* connectives *)
let and_ = and_ let and_ = and_
let andN = function [] -> tt | b :: bs -> List.fold ~init:b ~f:and_ bs let andN = function [] -> tt | b :: bs -> List.fold ~f:and_ bs b
let or_ = or_ let or_ = or_
let orN = function [] -> ff | b :: bs -> List.fold ~init:b ~f:or_ bs let orN = function [] -> ff | b :: bs -> List.fold ~f:or_ bs b
let iff = _Iff let iff = _Iff
let xor p q = _Not (_Iff p q) let xor p q = _Not (_Iff p q)
let cond ~cnd ~pos ~neg = _Cond cnd pos neg let cond ~cnd ~pos ~neg = _Cond cnd pos neg
@ -773,7 +768,7 @@ module Formula = struct
(** Query *) (** Query *)
let fv e = fold_vars_f e ~f:(Fun.flip Var.Set.add) ~init:Var.Set.empty let fv e = fold_vars_f ~f:Var.Set.add e Var.Set.empty
(** Traverse *) (** Traverse *)
@ -808,15 +803,15 @@ module Formula = struct
| Cond {cnd; pos; neg} -> map3 (map_terms ~f) b _Cond cnd pos neg | Cond {cnd; pos; neg} -> map3 (map_terms ~f) b _Cond cnd pos neg
| Lit (p, xs) -> lift_mapN f b (_Lit p) xs | Lit (p, xs) -> lift_mapN f b (_Lit p) xs
let fold_map_vars ~init e ~f = let fold_map_vars e s0 ~f =
let s = ref init in let s = ref s0 in
let f x = let f x =
let s', x' = f !s x in let x', s' = f x !s in
s := s' ; s := s' ;
x' x'
in in
let e' = map_vars ~f e in let e' = map_vars ~f e in
(!s, e') (e', !s)
let rename s e = map_vars ~f:(Var.Subst.apply s) e let rename s e = map_vars ~f:(Var.Subst.apply s) e
@ -828,25 +823,26 @@ module Formula = struct
-> 'formula -> 'formula
-> 'disjunction = -> 'disjunction =
fun ~meet1 ~join1 ~top ~bot fml -> fun ~meet1 ~join1 ~top ~bot fml ->
let rec add_conjunct (cjn, splits) fml = let rec add_conjunct fml (cjn, splits) =
match fml with match fml with
| Tt | Eq _ | Eq0 _ | Pos _ | Iff _ | Lit _ | Not _ -> | Tt | Eq _ | Eq0 _ | Pos _ | Iff _ | Lit _ | Not _ ->
(meet1 fml cjn, splits) (meet1 fml cjn, splits)
| And {pos; neg} -> | And {pos; neg} ->
fold_pos_neg ~f:add_conjunct ~init:(cjn, splits) ~pos ~neg fold_pos_neg ~f:add_conjunct ~pos ~neg (cjn, splits)
| Or {pos; neg} -> (cjn, (pos, neg) :: splits) | Or {pos; neg} -> (cjn, (pos, neg) :: splits)
| Cond {cnd; pos; neg} -> | Cond {cnd; pos; neg} ->
add_conjunct (cjn, splits) add_conjunct
(or_ (and_ cnd pos) (and_ (not_ cnd) neg)) (or_ (and_ cnd pos) (and_ (not_ cnd) neg))
(cjn, splits)
in in
let rec add_disjunct (cjn, splits) djn fml = let rec add_disjunct (cjn, splits) fml djn =
let cjn, splits = add_conjunct (cjn, splits) fml in let cjn, splits = add_conjunct fml (cjn, splits) in
match splits with match splits with
| (pos, neg) :: splits -> | (pos, neg) :: splits ->
fold_pos_neg ~f:(add_disjunct (cjn, splits)) ~init:djn ~pos ~neg fold_pos_neg ~f:(add_disjunct (cjn, splits)) ~pos ~neg djn
| [] -> join1 cjn djn | [] -> join1 cjn djn
in in
add_disjunct (top, []) bot fml add_disjunct (top, []) fml bot
end end
(* (*
@ -858,14 +854,14 @@ let v_to_ses : var -> Ses.Var.t =
let vs_to_ses : Var.Set.t -> Ses.Var.Set.t = let vs_to_ses : Var.Set.t -> Ses.Var.Set.t =
fun vs -> fun vs ->
Var.Set.fold vs ~init:Ses.Var.Set.empty ~f:(fun v vs -> Var.Set.fold vs Ses.Var.Set.empty ~f:(fun v ->
Ses.Var.Set.add (v_to_ses v) vs ) Ses.Var.Set.add (v_to_ses v) )
let rec arith_to_ses poly = let rec arith_to_ses poly =
Arith.fold_monomials poly ~init:Ses.Term.zero ~f:(fun mono coeff e -> Arith.fold_monomials poly Ses.Term.zero ~f:(fun mono coeff e ->
Ses.Term.add e Ses.Term.add e
(Ses.Term.mulq coeff (Ses.Term.mulq coeff
(Arith.fold_factors mono ~init:Ses.Term.one ~f:(fun trm pow f -> (Arith.fold_factors mono Ses.Term.one ~f:(fun trm pow f ->
let rec exp b i = let rec exp b i =
assert (i > 0) ; assert (i > 0) ;
if i = 1 then b else Ses.Term.mul b (exp f (i - 1)) if i = 1 then b else Ses.Term.mul b (exp f (i - 1))
@ -911,12 +907,12 @@ let rec f_to_ses : fml -> Ses.Term.t = function
| Not p -> Ses.Term.not_ (f_to_ses p) | Not p -> Ses.Term.not_ (f_to_ses p)
| And {pos; neg} -> | And {pos; neg} ->
fold_pos_neg fold_pos_neg
~f:(fun p f -> Ses.Term.and_ p (f_to_ses f)) ~f:(fun f p -> Ses.Term.and_ p (f_to_ses f))
~init:Ses.Term.true_ ~pos ~neg ~pos ~neg Ses.Term.true_
| Or {pos; neg} -> | Or {pos; neg} ->
fold_pos_neg fold_pos_neg
~f:(fun p f -> Ses.Term.or_ p (f_to_ses f)) ~f:(fun f p -> Ses.Term.or_ p (f_to_ses f))
~init:Ses.Term.false_ ~pos ~neg ~pos ~neg Ses.Term.false_
| Iff (p, q) -> Ses.Term.eq (f_to_ses p) (f_to_ses q) | Iff (p, q) -> Ses.Term.eq (f_to_ses p) (f_to_ses q)
| Cond {cnd; pos; neg} -> | Cond {cnd; pos; neg} ->
Ses.Term.conditional ~cnd:(f_to_ses cnd) ~thn:(f_to_ses pos) Ses.Term.conditional ~cnd:(f_to_ses cnd) ~thn:(f_to_ses pos)
@ -941,8 +937,7 @@ let v_of_ses : Ses.Var.t -> var =
let vs_of_ses : Ses.Var.Set.t -> Var.Set.t = let vs_of_ses : Ses.Var.Set.t -> Var.Set.t =
fun vs -> fun vs ->
Ses.Var.Set.fold vs ~init:Var.Set.empty ~f:(fun v vs -> Ses.Var.Set.fold ~f:(fun v -> Var.Set.add (v_of_ses v)) vs Var.Set.empty
Var.Set.add (v_of_ses v) vs )
let uap1 f = ap1t (fun x -> _Apply f [|x|]) let uap1 f = ap1t (fun x -> _Apply f [|x|])
let uap2 f = ap2t (fun x y -> _Apply f [|x; y|]) let uap2 f = ap2t (fun x y -> _Apply f [|x; y|])
@ -960,7 +955,7 @@ and ap2_f mk_f mk_t a b = ap2 mk_f (fun x y -> `Fml (mk_t x y)) a b
and apN mk_f mk_t mk_unit es = and apN mk_f mk_t mk_unit es =
match match
Ses.Term.Set.fold ~init:(None, None) es ~f:(fun e (fs, ts) -> Ses.Term.Set.fold es (None, None) ~f:(fun e (fs, ts) ->
match of_ses e with match of_ses e with
| `Fml f -> | `Fml f ->
(Some (match fs with None -> f | Some g -> mk_f f g), ts) (Some (match fs with None -> f | Some g -> mk_f f g), ts)
@ -1001,8 +996,7 @@ and of_ses : Ses.Term.t -> exp =
| `Trm (Q r) -> rational (Q.mul q r) | `Trm (Q r) -> rational (Q.mul q r)
| t -> mulq q t | t -> mulq q t
in in
Ses.Term.Qset.fold sum ~init:(mul e q) ~f:(fun e q s -> Ses.Term.Qset.fold ~f:(fun e q -> add (mul e q)) sum (mul e q) )
add (mul e q) s ) )
| Mul prod -> ( | Mul prod -> (
match Ses.Term.Qset.pop_min_elt prod with match Ses.Term.Qset.pop_min_elt prod with
| None -> one | None -> one
@ -1018,8 +1012,7 @@ and of_ses : Ses.Term.t -> exp =
else if sn > 0 then expn (of_ses e) n else if sn > 0 then expn (of_ses e) n
else div one (expn (of_ses e) (Z.neg n)) else div one (expn (of_ses e) (Z.neg n))
in in
Ses.Term.Qset.fold prod ~init:(exp e q) ~f:(fun e q s -> Ses.Term.Qset.fold ~f:(fun e q -> mul (exp e q)) prod (exp e q) )
mul (exp e q) s ) )
| Ap2 (Div, d, e) -> div (of_ses d) (of_ses e) | Ap2 (Div, d, e) -> div (of_ses d) (of_ses e)
| Ap2 (Rem, d, e) -> uap_ttt Rem d e | Ap2 (Rem, d, e) -> uap_ttt Rem d e
| And es -> apN and_ (uap2 BitAnd) tt es | And es -> apN and_ (uap2 BitAnd) tt es
@ -1073,10 +1066,10 @@ module Context = struct
(* Query *) (* Query *)
let fold_vars ~init x ~f = let fold_vars x s ~f =
Ses.Equality.fold_vars x ~init ~f:(fun s v -> f s (v_of_ses v)) Ses.Equality.fold_vars ~f:(fun v -> f (v_of_ses v)) x s
let fv e = fold_vars e ~f:(Fun.flip Var.Set.add) ~init:Var.Set.empty let fv e = fold_vars ~f:Var.Set.add e Var.Set.empty
let is_empty x = Ses.Equality.is_true x let is_empty x = Ses.Equality.is_true x
let is_unsat x = Ses.Equality.is_false x let is_unsat x = Ses.Equality.is_false x
let implies x b = Ses.Equality.implies x (f_to_ses b) let implies x b = Ses.Equality.implies x (f_to_ses b)
@ -1091,7 +1084,7 @@ module Context = struct
let class_of x e = List.map ~f:of_ses (Ses.Equality.class_of x (to_ses e)) let class_of x e = List.map ~f:of_ses (Ses.Equality.class_of x (to_ses e))
let classes x = let classes x =
Ses.Term.Map.fold (Ses.Equality.classes x) ~init:Term.Map.empty Ses.Term.Map.fold (Ses.Equality.classes x) Term.Map.empty
~f:(fun ~key:rep ~data:cls clss -> ~f:(fun ~key:rep ~data:cls clss ->
let rep' = of_ses rep in let rep' = of_ses rep in
let cls' = List.map ~f:of_ses cls in let cls' = List.map ~f:of_ses cls in
@ -1178,8 +1171,8 @@ module Context = struct
let pp = Ses.Equality.Subst.pp let pp = Ses.Equality.Subst.pp
let is_empty = Ses.Equality.Subst.is_empty let is_empty = Ses.Equality.Subst.is_empty
let fold s ~init ~f = let fold s z ~f =
Ses.Equality.Subst.fold s ~init ~f:(fun ~key ~data -> Ses.Equality.Subst.fold s z ~f:(fun ~key ~data ->
f ~key:(of_ses key) ~data:(of_ses data) ) f ~key:(of_ses key) ~data:(of_ses data) )
let subst s = ses_map (Ses.Equality.Subst.subst s) let subst s = ses_map (Ses.Equality.Subst.subst s)

@ -76,15 +76,12 @@ module rec Term : sig
(** Traverse *) (** Traverse *)
val fold_vars : init:'a -> t -> f:('a -> Var.t -> 'a) -> 'a val fold_vars : t -> 's -> f:(Var.t -> 's -> 's) -> 's
(** Transform *) (** Transform *)
val map_vars : f:(Var.t -> Var.t) -> t -> t val map_vars : f:(Var.t -> Var.t) -> t -> t
val fold_map_vars : t -> 's -> f:(Var.t -> 's -> Var.t * 's) -> t * 's
val fold_map_vars :
t -> init:'a -> f:('a -> Var.t -> 'a * Var.t) -> 'a * t
val rename : Var.Subst.t -> t -> t val rename : Var.Subst.t -> t -> t
end end
@ -137,16 +134,13 @@ and Formula : sig
(** Traverse *) (** Traverse *)
val fold_vars : init:'a -> t -> f:('a -> Var.t -> 'a) -> 'a val fold_vars : t -> 's -> f:(Var.t -> 's -> 's) -> 's
(** Transform *) (** Transform *)
val map_terms : f:(Term.t -> Term.t) -> t -> t val map_terms : f:(Term.t -> Term.t) -> t -> t
val map_vars : f:(Var.t -> Var.t) -> t -> t val map_vars : f:(Var.t -> Var.t) -> t -> t
val fold_map_vars : t -> 's -> f:(Var.t -> 's -> Var.t * 's) -> t * 's
val fold_map_vars :
init:'a -> t -> f:('a -> Var.t -> 'a * Var.t) -> 'a * t
val rename : Var.Subst.t -> t -> t val rename : Var.Subst.t -> t -> t
end end
@ -209,7 +203,7 @@ module Context : sig
(** Equivalence class of [e]: all the terms [f] in the context such that (** Equivalence class of [e]: all the terms [f] in the context such that
[e = f] is implied by the assumptions. *) [e = f] is implied by the assumptions. *)
val fold_vars : init:'a -> t -> f:('a -> Var.t -> 'a) -> 'a val fold_vars : t -> 's -> f:(Var.t -> 's -> 's) -> 's
(** Enumerate the variables occurring in the terms of the context. *) (** Enumerate the variables occurring in the terms of the context. *)
val fv : t -> Var.Set.t val fv : t -> Var.Set.t
@ -221,9 +215,7 @@ module Context : sig
val pp : t pp val pp : t pp
val is_empty : t -> bool val is_empty : t -> bool
val fold : t -> 's -> f:(key:Term.t -> data:Term.t -> 's -> 's) -> 's
val fold :
t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a
val subst : t -> Term.t -> Term.t val subst : t -> Term.t -> Term.t
(** Apply a substitution recursively to subterms. *) (** Apply a substitution recursively to subterms. *)

@ -410,24 +410,19 @@ let rec_record i typ = RecRecord (i, typ)
(** Traverse *) (** Traverse *)
let fold_exps e ~init ~f = let rec fold_exps e z ~f =
let rec fold_exps_ e z = f e
let z = ( match e with
match e with | Ap1 (_, _, x) -> fold_exps ~f x z
| Ap1 (_, _, x) -> fold_exps_ x z | Ap2 (_, _, x, y) -> fold_exps ~f y (fold_exps ~f x z)
| Ap2 (_, _, x, y) -> fold_exps_ y (fold_exps_ x z) | Ap3 (_, _, w, x, y) ->
| Ap3 (_, _, w, x, y) -> fold_exps_ w (fold_exps_ y (fold_exps_ x z)) fold_exps ~f w (fold_exps ~f y (fold_exps ~f x z))
| ApN (_, _, xs) -> | ApN (_, _, xs) -> IArray.fold xs z ~f:(fold_exps ~f)
IArray.fold xs ~init:z ~f:(fun z elt -> fold_exps_ elt z) | _ -> z )
| _ -> z
in let fold_regs e z ~f =
f z e fold_exps e z ~f:(fun x z ->
in match x with Reg _ -> f (x :> Reg.t) z | _ -> z )
fold_exps_ e init
let fold_regs e ~init ~f =
fold_exps e ~init ~f:(fun z x ->
match x with Reg _ -> f z (x :> Reg.t) | _ -> z )
(** Query *) (** Query *)

@ -183,7 +183,7 @@ val rec_record : int -> Typ.t -> t
(** Traverse *) (** Traverse *)
val fold_regs : t -> init:'a -> f:('a -> Reg.t -> 'a) -> 'a val fold_regs : t -> 's -> f:(Reg.t -> 's -> 's) -> 's
(** Query *) (** Query *)

@ -260,9 +260,7 @@ module Inst = struct
let union_locals inst vs = let union_locals inst vs =
match inst with match inst with
| Move {reg_exps; _} -> | Move {reg_exps; _} ->
IArray.fold IArray.fold ~f:(fun (reg, _) vs -> Reg.Set.add reg vs) reg_exps vs
~f:(fun vs (reg, _) -> Reg.Set.add reg vs)
~init:vs reg_exps
| Load {reg; _} | Alloc {reg; _} | Nondet {reg= Some reg; _} -> | Load {reg; _} | Alloc {reg; _} | Nondet {reg= Some reg; _} ->
Reg.Set.add reg vs Reg.Set.add reg vs
| Store _ | Memcpy _ | Memmov _ | Memset _ | Free _ | Store _ | Memcpy _ | Memmov _ | Memset _ | Free _
@ -272,19 +270,19 @@ module Inst = struct
let locals inst = union_locals inst Reg.Set.empty let locals inst = union_locals inst Reg.Set.empty
let fold_exps inst ~init ~f = let fold_exps inst s ~f =
match inst with match inst with
| Move {reg_exps; loc= _} -> | Move {reg_exps; loc= _} ->
IArray.fold reg_exps ~init ~f:(fun acc (_reg, exp) -> f acc exp) IArray.fold ~f:(fun (_reg, exp) -> f exp) reg_exps s
| Load {reg= _; ptr; len; loc= _} -> f (f init ptr) len | Load {reg= _; ptr; len; loc= _} -> f len (f ptr s)
| Store {ptr; exp; len; loc= _} -> f (f (f init ptr) exp) len | Store {ptr; exp; len; loc= _} -> f len (f exp (f ptr s))
| Memset {dst; byt; len; loc= _} -> f (f (f init dst) byt) len | Memset {dst; byt; len; loc= _} -> f len (f byt (f dst s))
| Memcpy {dst; src; len; loc= _} | Memmov {dst; src; len; loc= _} -> | Memcpy {dst; src; len; loc= _} | Memmov {dst; src; len; loc= _} ->
f (f (f init dst) src) len f len (f src (f dst s))
| Alloc {reg= _; num; len= _; loc= _} -> f init num | Alloc {reg= _; num; len= _; loc= _} -> f num s
| Free {ptr; loc= _} -> f init ptr | Free {ptr; loc= _} -> f ptr s
| Nondet {reg= _; msg= _; loc= _} -> init | Nondet {reg= _; msg= _; loc= _} -> s
| Abort {loc= _} -> init | Abort {loc= _} -> s
end end
(** Jumps *) (** Jumps *)
@ -416,34 +414,27 @@ module Func = struct
| {entry= {cmnd; term= Unreachable; _}; _} -> IArray.is_empty cmnd | {entry= {cmnd; term= Unreachable; _}; _} -> IArray.is_empty cmnd
| _ -> false | _ -> false
let fold_cfg ~init ~f func = let fold_cfg ~f func s =
let seen = BlockS.create 0 in let seen = BlockS.create 0 in
let rec fold_cfg_ s blk = let rec fold_cfg_ blk s =
if not (BlockS.add seen blk) then s if not (BlockS.add seen blk) then s
else else
let s = let s =
let f s j = fold_cfg_ s j.dst in let f j s = fold_cfg_ j.dst s in
match blk.term with match blk.term with
| Switch {tbl; els; _} -> | Switch {tbl; els; _} ->
let s = IArray.fold ~f:(fun s (_, j) -> f s j) ~init:s tbl in let s = IArray.fold ~f:(fun (_, j) -> f j) tbl s in
f s els f els s
| Iswitch {tbl; _} -> IArray.fold ~f ~init:s tbl | Iswitch {tbl; _} -> IArray.fold ~f tbl s
| Call {return; throw; _} -> | Call {return; throw; _} -> Option.fold ~f throw (f return s)
let s = f s return in
Option.fold ~f ~init:s throw
| Return _ | Throw _ | Unreachable -> s | Return _ | Throw _ | Unreachable -> s
in in
f s blk f blk s
in in
fold_cfg_ init func.entry fold_cfg_ func.entry s
let fold_term func ~init ~f = let iter_term func ~f = fold_cfg ~f:(fun blk () -> f blk.term) func ()
fold_cfg func ~init ~f:(fun s blk -> f s blk.term) let entry_cfg func = fold_cfg ~f:(fun blk cfg -> blk :: cfg) func []
let iter_term func ~f =
fold_cfg func ~init:() ~f:(fun () blk -> f blk.term)
let entry_cfg func = fold_cfg ~init:[] ~f:(fun cfg blk -> blk :: cfg) func
let pp fs func = let pp fs func =
let {name; formals; freturn; entry; _} = func in let {name; formals; freturn; entry; _} = func in
@ -489,13 +480,12 @@ module Func = struct
let mk ~(name : Global.t) ~formals ~freturn ~fthrow ~entry ~cfg = let mk ~(name : Global.t) ~formals ~freturn ~fthrow ~entry ~cfg =
let locals = let locals =
let locals_cmnd locals cmnd = let locals_cmnd locals cmnd =
IArray.fold_right ~f:Inst.union_locals cmnd ~init:locals IArray.fold_right ~f:Inst.union_locals cmnd locals
in in
let locals_block locals block = let locals_block block locals =
locals_cmnd (Term.union_locals block.term locals) block.cmnd locals_cmnd (Term.union_locals block.term locals) block.cmnd
in in
let init = locals_block Reg.Set.empty entry in IArray.fold ~f:locals_block cfg (locals_block entry Reg.Set.empty)
IArray.fold ~f:locals_block cfg ~init
in in
let func = {name; formals; freturn; fthrow; locals; entry} in let func = {name; formals; freturn; fthrow; locals; entry} in
let resolve_parent_and_jumps block = let resolve_parent_and_jumps block =
@ -535,13 +525,14 @@ let set_derived_metadata functions =
String.Map.iter functions ~f:(fun func -> String.Map.iter functions ~f:(fun func ->
FuncQ.enqueue_back_exn roots func.name.reg func ) ; FuncQ.enqueue_back_exn roots func.name.reg func ) ;
String.Map.iter functions ~f:(fun func -> String.Map.iter functions ~f:(fun func ->
Func.fold_term func ~init:() ~f:(fun () -> function Func.iter_term func ~f:(fun term ->
| Call {callee; _} -> ( match term with
match Reg.of_exp callee with | Call {callee; _} -> (
| Some callee -> match Reg.of_exp callee with
FuncQ.remove roots callee |> (ignore : [> ] -> unit) | Some callee ->
| None -> () ) FuncQ.remove roots callee |> (ignore : [> ] -> unit)
| _ -> () ) ) ; | None -> () )
| _ -> () ) ) ;
roots roots
in in
let topsort functions roots = let topsort functions roots =
@ -588,7 +579,7 @@ let set_derived_metadata functions =
index := !index - 1 ) index := !index - 1 )
in in
let functions = let functions =
List.fold functions ~init:String.Map.empty ~f:(fun m func -> List.fold functions String.Map.empty ~f:(fun func m ->
String.Map.add_exn ~key:(Reg.name func.name.reg) ~data:func m ) String.Map.add_exn ~key:(Reg.name func.name.reg) ~data:func m )
in in
let roots = compute_roots functions in let roots = compute_roots functions in

@ -121,7 +121,7 @@ module Inst : sig
val abort : loc:Loc.t -> inst val abort : loc:Loc.t -> inst
val loc : inst -> Loc.t val loc : inst -> Loc.t
val locals : inst -> Reg.Set.t val locals : inst -> Reg.Set.t
val fold_exps : inst -> init:'a -> f:('a -> Exp.t -> 'a) -> 'a val fold_exps : inst -> 's -> f:(Exp.t -> 's -> 's) -> 's
end end
module Jump : sig module Jump : sig

@ -17,9 +17,7 @@ let reg r =
Var.program ~name ~global Var.program ~name ~global
let regs rs = let regs rs =
Llair.Reg.Set.fold Llair.Reg.Set.fold ~f:(fun r -> Var.Set.add (reg r)) rs Var.Set.empty
~f:(fun r -> Var.Set.add (reg r))
rs ~init:Var.Set.empty
let uap0 f = T.apply f [||] let uap0 f = T.apply f [||]
let uap1 f a = T.apply f [|a|] let uap1 f a = T.apply f [|a|]

@ -25,9 +25,9 @@ let interpreted e = equal_kind (classify e) Interpreted
let non_interpreted e = not (interpreted e) let non_interpreted e = not (interpreted e)
let uninterpreted e = equal_kind (classify e) Uninterpreted let uninterpreted e = equal_kind (classify e) Uninterpreted
let rec fold_max_solvables e ~init ~f = let rec fold_max_solvables e s ~f =
if non_interpreted e then f e init if non_interpreted e then f e s
else Term.fold e ~init ~f:(fun d s -> fold_max_solvables ~f d ~init:s) else Term.fold ~f:(fold_max_solvables ~f) e s
(** Solution Substitutions *) (** Solution Substitutions *)
module Subst : sig module Subst : sig
@ -40,7 +40,7 @@ module Subst : sig
val length : t -> int val length : t -> int
val mem : Term.t -> t -> bool val mem : Term.t -> t -> bool
val find : Term.t -> t -> Term.t option val find : Term.t -> t -> Term.t option
val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a val fold : t -> 's -> f:(key:Term.t -> data:Term.t -> 's -> 's) -> 's
val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit
val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool
val apply : t -> Term.t -> Term.t val apply : t -> Term.t -> Term.t
@ -116,13 +116,13 @@ end = struct
(** remove entries for vars *) (** remove entries for vars *)
let remove xs s = let remove xs s =
Var.Set.fold ~f:(fun x s -> Term.Map.remove (Term.var x) s) xs ~init:s Var.Set.fold ~f:(fun x -> Term.Map.remove (Term.var x)) xs s
(** map over a subst, applying [f] to both domain and range, requires that (** map over a subst, applying [f] to both domain and range, requires that
[f] is injective and for any set of terms [E], [f\[E\]] is disjoint [f] is injective and for any set of terms [E], [f\[E\]] is disjoint
from [E] *) from [E] *)
let map_entries ~f s = let map_entries ~f s =
Term.Map.fold s ~init:s ~f:(fun ~key ~data s -> Term.Map.fold s s ~f:(fun ~key ~data s ->
let key' = f key in let key' = f key in
let data' = f data in let data' = f data in
if Term.equal key' key then if Term.equal key' key then
@ -159,7 +159,7 @@ end = struct
valid, so loop until no change. *) valid, so loop until no change. *)
let rec partition_valid_ t ks s = let rec partition_valid_ t ks s =
let t', ks', s' = let t', ks', s' =
Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> Term.Map.fold s (t, ks, s) ~f:(fun ~key ~data (t, ks, s) ->
if is_valid_eq ks key data then (t, ks, s) if is_valid_eq ks key data then (t, ks, s)
else else
let t = Term.Map.add ~key ~data t let t = Term.Map.add ~key ~data t
@ -244,8 +244,8 @@ let rec solve_extract ?f a o l b s =
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ … (* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *) where n |α| and m = |β| *)
and solve_concat ?f a0V b m s = and solve_concat ?f a0V b m s =
Iter.fold_until (IArray.to_iter a0V) ~init:(s, Term.zero) Iter.fold_until (IArray.to_iter a0V) (s, Term.zero)
~f:(fun (s, oI) aJ -> ~f:(fun aJ (s, oI) ->
let nJ = Term.seq_size_exn aJ in let nJ = Term.seq_size_exn aJ in
let oJ = Term.add oI nJ in let oJ = Term.add oI nJ in
match solve_ ?f aJ (Term.extract ~seq:b ~off:oI ~len:nJ) s with match solve_ ?f aJ (Term.extract ~seq:b ~off:oI ~len:nJ) s with
@ -344,7 +344,7 @@ let classes r =
if Term.equal key data then cls if Term.equal key data then cls
else Term.Map.add_multi ~key:data ~data:key cls else Term.Map.add_multi ~key:data ~data:key cls
in in
Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> Subst.fold r.rep Term.Map.empty ~f:(fun ~key ~data cls ->
match classify key with match classify key with
| Interpreted | Atomic -> add key data cls | Interpreted | Atomic -> add key data cls
| Uninterpreted -> add (Term.map ~f:(Subst.apply r.rep) key) data cls ) | Uninterpreted -> add (Term.map ~f:(Subst.apply r.rep) key) data cls )
@ -477,12 +477,12 @@ let rec extend_ a r =
match (a : Term.t) with match (a : Term.t) with
| Integer _ | Rational _ -> r | Integer _ | Rational _ -> r
| _ -> ( | _ -> (
if interpreted a then Term.fold ~f:extend_ a ~init:r if interpreted a then Term.fold ~f:extend_ a r
else else
(* add uninterpreted terms *) (* add uninterpreted terms *)
match Subst.extend a r with match Subst.extend a r with
(* and their subterms if newly added *) (* and their subterms if newly added *)
| Some r -> Term.fold ~f:extend_ a ~init:r | Some r -> Term.fold ~f:extend_ a r
| None -> r ) | None -> r )
(** add a term to the carrier *) (** add a term to the carrier *)
@ -572,26 +572,22 @@ let class_of r e =
let e' = normalize r e in let e' = normalize r e in
e' :: Term.Map.find_multi e' (classes r) e' :: Term.Map.find_multi e' (classes r)
let fold_uses_of r t ~init ~f = let fold_uses_of r t s ~f =
let rec fold_ e ~init:s ~f = let rec fold_ e s ~f =
let s = let s =
Term.fold e ~init:s ~f:(fun sub s -> Term.fold e s ~f:(fun sub s -> if Term.equal t sub then f s e else s)
if Term.equal t sub then f s e else s )
in in
if interpreted e then if interpreted e then Term.fold ~f:(fold_ ~f) e s else s
Term.fold e ~init:s ~f:(fun d s -> fold_ ~f d ~init:s)
else s
in in
Subst.fold r.rep ~init ~f:(fun ~key:trm ~data:rep s -> Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s ->
let f trm s = fold_ trm ~init:s ~f in fold_ ~f trm (fold_ ~f rep s) )
f trm (f rep s) )
let apply_subst us s r = let apply_subst us s r =
[%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r]
; ;
Term.Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> Term.Map.fold (classes r) true_ ~f:(fun ~key:rep ~data:cls r ->
let rep' = Subst.subst s rep in let rep' = Subst.subst s rep in
List.fold cls ~init:r ~f:(fun r trm -> List.fold cls r ~f:(fun trm r ->
let trm' = Subst.subst s trm in let trm' = Subst.subst s trm in
and_eq_ us trm' rep' r ) ) and_eq_ us trm' rep' r ) )
|> extract_xs |> extract_xs
@ -609,8 +605,7 @@ let and_ us r s =
let s, r = let s, r =
if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
in in
Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r) Subst.fold s.rep r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r) )
)
|> extract_xs |> extract_xs
|> |>
[%Trace.retn fun {pf} (_, r') -> [%Trace.retn fun {pf} (_, r') ->
@ -624,10 +619,10 @@ let or_ us r s =
else if not r.sat then s else if not r.sat then s
else else
let merge_mems rs r s = let merge_mems rs r s =
Term.Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs -> Term.Map.fold (classes s) rs ~f:(fun ~key:rep ~data:cls rs ->
List.fold cls List.fold cls
~init:([rep], rs) ([rep], rs)
~f:(fun (reps, rs) exp -> ~f:(fun exp (reps, rs) ->
match match
List.find ~f:(fun rep -> implies r (Term.eq exp rep)) reps List.find ~f:(fun rep -> implies r (Term.eq exp rep)) reps
with with
@ -648,13 +643,13 @@ let or_ us r s =
let orN us rs = let orN us rs =
match rs with match rs with
| [] -> (us, false_) | [] -> (us, false_)
| r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs | r :: rs -> List.fold ~f:(fun r (us, s) -> or_ us s r) rs (us, r)
let rec and_term_ us e r = let rec and_term_ us e r =
let eq_false b r = and_eq_ us b Term.false_ r in let eq_false b r = and_eq_ us b Term.false_ r in
match (e : Term.t) with match (e : Term.t) with
| Integer {data} -> if Z.is_false data then false_ else r | Integer {data} -> if Z.is_false data then false_ else r
| And cs -> Term.Set.fold ~f:(and_term_ us) cs ~init:r | And cs -> Term.Set.fold ~f:(and_term_ us) cs r
| Ap2 (Eq, a, b) -> and_eq_ us a b r | Ap2 (Eq, a, b) -> and_eq_ us a b r
| Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r | Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r
| Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r | Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r
@ -688,11 +683,10 @@ let rename r sub =
pf "%a" pp_diff (r, r') ; pf "%a" pp_diff (r, r') ;
invariant r'] invariant r']
let fold_terms r ~init ~f = let fold_terms r z ~f =
Subst.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init Subst.fold ~f:(fun ~key ~data z -> f key (f data z)) r.rep z
let fold_vars r ~init ~f = let fold_vars r z ~f = fold_terms ~f:(Term.fold_vars ~f) r z
fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)
(** Existential Witnessing and Elimination *) (** Existential Witnessing and Elimination *)
@ -721,7 +715,7 @@ let solve_poly_eq us p' q' subst =
; ;
let diff = Term.sub p' q' in let diff = Term.sub p' q' in
let max_solvables_not_ito_us = let max_solvables_not_ito_us =
fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> function fold_max_solvables diff Zero ~f:(fun solvable_subterm -> function
| Many -> Many | Many -> Many
| zom when Var.Set.subset (Term.fv solvable_subterm) ~of_:us -> zom | zom when Var.Set.subset (Term.fv solvable_subterm) ~of_:us -> zom
| One _ -> Many | One _ -> Many
@ -833,8 +827,8 @@ let solve_uninterp_eqs us (cls, subst) =
[%compare: kind * Term.t] (classify e, e) (classify f, f) [%compare: kind * Term.t] (classify e, e) (classify f, f)
in in
let {rep_us; cls_us; rep_xs; cls_xs} = let {rep_us; cls_us; rep_xs; cls_xs} =
List.fold cls ~init:{rep_us= None; cls_us= []; rep_xs= None; cls_xs= []} List.fold cls {rep_us= None; cls_us= []; rep_xs= None; cls_xs= []}
~f:(fun ({rep_us; cls_us; rep_xs; cls_xs} as s) trm -> ~f:(fun trm ({rep_us; cls_us; rep_xs; cls_xs} as s) ->
if Var.Set.subset (Term.fv trm) ~of_:us then if Var.Set.subset (Term.fv trm) ~of_:us then
match rep_us with match rep_us with
| Some rep when compare rep trm <= 0 -> | Some rep when compare rep trm <= 0 ->
@ -867,7 +861,7 @@ let solve_uninterp_eqs us (cls, subst) =
| None -> (cls, cls_xs) | None -> (cls, cls_xs)
in in
let subst = let subst =
List.fold cls_xs ~init:subst ~f:(fun subst trm_xs -> List.fold cls_xs subst ~f:(fun trm_xs subst ->
Subst.compose1 ~key:trm_xs ~data:rep_us subst ) Subst.compose1 ~key:trm_xs ~data:rep_us subst )
in in
(cls, subst) (cls, subst)
@ -876,7 +870,7 @@ let solve_uninterp_eqs us (cls, subst) =
| Some rep_xs -> | Some rep_xs ->
let cls = rep_xs :: cls_us in let cls = rep_xs :: cls_us in
let subst = let subst =
List.fold cls_xs ~init:subst ~f:(fun subst trm_xs -> List.fold cls_xs subst ~f:(fun trm_xs subst ->
Subst.compose1 ~key:trm_xs ~data:rep_xs subst ) Subst.compose1 ~key:trm_xs ~data:rep_xs subst )
in in
(cls, subst) (cls, subst)
@ -920,9 +914,9 @@ let solve_concat_extracts_eq r x =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp x pp r] [%Trace.call fun {pf} -> pf "%a@ %a" Term.pp x pp r]
; ;
let uses = let uses =
fold_uses_of r x ~init:[] ~f:(fun uses -> function fold_uses_of r x [] ~f:(fun uses -> function
| Ap2 (Sized, _, _) as m -> | Ap2 (Sized, _, _) as m ->
fold_uses_of r m ~init:uses ~f:(fun uses -> function fold_uses_of r m uses ~f:(fun uses -> function
| Ap3 (Extract, _, _, _) as e -> e :: uses | _ -> uses ) | Ap3 (Extract, _, _, _) as e -> e :: uses | _ -> uses )
| _ -> uses ) | _ -> uses )
in in
@ -933,8 +927,8 @@ let solve_concat_extracts_eq r x =
| _ -> false ) | _ -> false )
in in
let rec find_extracts full_rev_extracts rev_prefix off = let rec find_extracts full_rev_extracts rev_prefix off =
List.fold (find_extracts_at_off off) ~init:full_rev_extracts List.fold (find_extracts_at_off off) full_rev_extracts
~f:(fun full_rev_extracts e -> ~f:(fun e full_rev_extracts ->
match e with match e with
| Ap3 (Extract, Ap2 (Sized, n, _), o, l) -> | Ap3 (Extract, Ap2 (Sized, n, _), o, l) ->
let o_l = Term.add o l in let o_l = Term.add o l in
@ -951,10 +945,9 @@ let solve_concat_extracts_eq r x =
let solve_concat_extracts r us x (classes, subst, us_xs) = let solve_concat_extracts r us x (classes, subst, us_xs) =
match match
List.filter_map (solve_concat_extracts_eq r x) ~f:(fun rev_extracts -> List.filter_map (solve_concat_extracts_eq r x) ~f:(fun rev_extracts ->
Iter.fold_opt (Iter.of_list rev_extracts) ~init:[] Iter.fold_opt (Iter.of_list rev_extracts) [] ~f:(fun e suffix ->
~f:(fun suffix e ->
let+ rep_ito_us = let+ rep_ito_us =
List.fold (cls_of r e) ~init:None ~f:(fun rep_ito_us trm -> List.fold (cls_of r e) None ~f:(fun trm rep_ito_us ->
match rep_ito_us with match rep_ito_us with
| Some rep when Term.compare rep trm <= 0 -> rep_ito_us | Some rep when Term.compare rep trm <= 0 -> rep_ito_us
| _ when Var.Set.subset (Term.fv trm) ~of_:us -> Some trm | _ when Var.Set.subset (Term.fv trm) ~of_:us -> Some trm
@ -970,9 +963,8 @@ let solve_concat_extracts r us x (classes, subst, us_xs) =
(classes, subst, us_xs) (classes, subst, us_xs)
| None -> (classes, subst, us_xs) | None -> (classes, subst, us_xs)
let solve_for_xs r us xs (classes, subst, us_xs) = let solve_for_xs r us xs =
Var.Set.fold xs ~init:(classes, subst, us_xs) Var.Set.fold xs ~f:(fun x (classes, subst, us_xs) ->
~f:(fun x (classes, subst, us_xs) ->
let x = Term.var x in let x = Term.var x in
if Subst.mem x subst then (classes, subst, us_xs) if Subst.mem x subst then (classes, subst, us_xs)
else solve_concat_extracts r us x (classes, subst, us_xs) ) else solve_concat_extracts r us x (classes, subst, us_xs) )
@ -980,14 +972,13 @@ let solve_for_xs r us xs (classes, subst, us_xs) =
(** move equations from [classes] to [subst] which can be expressed, after (** move equations from [classes] to [subst] which can be expressed, after
normalizing with [subst], as [x u] where [us xs fv x us] normalizing with [subst], as [x u] where [us xs fv x us]
and [fv u us] or else [fv u us xs]. *) and [fv u us] or else [fv u us xs]. *)
let solve_classes r (classes, subst, us) xs = let solve_classes r xs (classes, subst, us) =
[%Trace.call fun {pf} -> [%Trace.call fun {pf} ->
pf "us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs] pf "us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs]
; ;
let rec solve_classes_ (classes0, subst0, us_xs) = let rec solve_classes_ (classes0, subst0, us_xs) =
let classes, subst = let classes, subst =
Term.Map.fold ~f:(solve_class us us_xs) classes0 Term.Map.fold ~f:(solve_class us us_xs) classes0 (classes0, subst0)
~init:(classes0, subst0)
in in
if subst != subst0 then solve_classes_ (classes, subst, us_xs) if subst != subst0 then solve_classes_ (classes, subst, us_xs)
else (classes, subst, us_xs) else (classes, subst, us_xs)
@ -1018,8 +1009,7 @@ let solve_for_vars vss r =
let us, vss = let us, vss =
match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss) match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss)
in in
List.fold ~f:(solve_classes r) ~init:(classes r, Subst.empty, us) vss List.fold ~f:(solve_classes r) vss (classes r, Subst.empty, us) |> snd3
|> snd3
|> |>
[%Trace.retn fun {pf} subst -> [%Trace.retn fun {pf} subst ->
pf "%a" Subst.pp subst ; pf "%a" Subst.pp subst ;
@ -1029,8 +1019,8 @@ let solve_for_vars vss r =
|| fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Term.pp key || fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Term.pp key
Term.pp data pp_classes r () ) ; Term.pp data pp_classes r () ) ;
assert ( assert (
Iter.fold_until (Iter.of_list vss) ~init:us Iter.fold_until (Iter.of_list vss) us
~f:(fun us xs -> ~f:(fun xs us ->
let us_xs = Var.Set.union us xs in let us_xs = Var.Set.union us xs in
let ks = Term.fv key in let ks = Term.fv key in
let ds = Term.fv data in let ds = Term.fv data in

@ -65,7 +65,7 @@ val normalize : t -> Term.t -> Term.t
relation, where [e'] and its subterms are expressed in terms of the relation, where [e'] and its subterms are expressed in terms of the
relation's canonical representatives of each equivalence class. *) relation's canonical representatives of each equivalence class. *)
val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a val fold_vars : t -> 's -> f:(Var.t -> 's -> 's) -> 's
(** Solution Substitutions *) (** Solution Substitutions *)
module Subst : sig module Subst : sig
@ -73,7 +73,7 @@ module Subst : sig
val pp : t pp val pp : t pp
val is_empty : t -> bool val is_empty : t -> bool
val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a val fold : t -> 's -> f:(key:Term.t -> data:Term.t -> 's -> 's) -> 's
val subst : t -> Term.t -> Term.t val subst : t -> Term.t -> Term.t
(** Apply a substitution recursively to subterms. *) (** Apply a substitution recursively to subterms. *)

@ -373,9 +373,7 @@ module Sum = struct
| _ -> Qset.add term coeff sum | _ -> Qset.add term coeff sum
let of_ ?(coeff = Q.one) term = add coeff term empty let of_ ?(coeff = Q.one) term = add coeff term empty
let map sum ~f = Qset.fold ~f:(fun e c sum -> add c (f e) sum) sum empty
let map sum ~f =
Qset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum)
let mul_const const sum = let mul_const const sum =
assert (not (Q.equal Q.zero const)) ; assert (not (Q.equal Q.zero const)) ;
@ -436,7 +434,7 @@ let rec simp_add_ es poly =
(* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *) (* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *)
| _ -> Sum.to_term (Sum.add coeff term (Sum.of_ poly)) | _ -> Sum.to_term (Sum.add coeff term (Sum.of_ poly))
in in
Qset.fold ~f es ~init:poly Qset.fold ~f es poly
and simp_mul2 e f = and simp_mul2 e f =
match (e, f) with match (e, f) with
@ -526,7 +524,7 @@ let simp_mul es =
if Q.equal Q.zero pwr then term if Q.equal Q.zero pwr then term
else mul_pwr bas Q.(pwr - one) (simp_mul2 bas term) else mul_pwr bas Q.(pwr - one) (simp_mul2 bas term)
in in
Qset.fold es ~init:one ~f:(fun bas pwr term -> Qset.fold es one ~f:(fun bas pwr term ->
if Q.sign pwr >= 0 then mul_pwr bas pwr term if Q.sign pwr >= 0 then mul_pwr bas pwr term
else simp_div term (mul_pwr bas (Q.neg pwr) one) ) else simp_div term (mul_pwr bas (Q.neg pwr) one) )
@ -569,7 +567,7 @@ let rec simp_and2 x y =
let add s = function And cs -> Set.union s cs | c -> Set.add c s in let add s = function And cs -> Set.union s cs | c -> Set.add c s in
And (add (add Set.empty x) y) And (add (add Set.empty x) y)
let simp_and xs = Set.fold xs ~init:true_ ~f:simp_and2 let simp_and xs = Set.fold ~f:simp_and2 xs true_
let rec simp_or2 x y = let rec simp_or2 x y =
match (x, y) with match (x, y) with
@ -590,15 +588,16 @@ let rec simp_or2 x y =
let add s = function Or cs -> Set.union s cs | c -> Set.add c s in let add s = function Or cs -> Set.union s cs | c -> Set.add c s in
Or (add (add Set.empty x) y) Or (add (add Set.empty x) y)
let simp_or xs = Set.fold xs ~init:false_ ~f:simp_or2 let simp_or xs = Set.fold ~f:simp_or2 xs false_
(* sequence sizes *) (* sequence sizes *)
let rec seq_size_exn = function let rec seq_size_exn = function
| Ap2 (Sized, n, _) | Ap3 (Extract, _, _, n) -> n | Ap2 (Sized, n, _) | Ap3 (Extract, _, _, n) -> n
| ApN (Concat, a0U) -> | ApN (Concat, a0U) ->
IArray.fold a0U ~init:zero ~f:(fun a0I aJ -> IArray.fold
simp_add2 a0I (seq_size_exn aJ) ) ~f:(fun aJ a0I -> simp_add2 a0I (seq_size_exn aJ))
a0U zero
| _ -> invalid_arg "seq_size_exn" | _ -> invalid_arg "seq_size_exn"
let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None
@ -666,11 +665,11 @@ let rec simp_extract seq off len =
| ApN (Concat, na1N) -> ( | ApN (Concat, na1N) -> (
match len with match len with
| Integer {data= l} -> | Integer {data= l} ->
IArray.fold_map_until na1N ~init:(l, off) IArray.fold_map_until na1N (l, off)
~f:(fun (l, oI) naI -> ~f:(fun naI (l, oI) ->
let nI = seq_size_exn naI in let nI = seq_size_exn naI in
if Z.equal Z.zero l then if Z.equal Z.zero l then
`Continue ((l, oI), simp_extract naI oI zero) `Continue (simp_extract naI oI zero, (l, oI))
else else
let oI_nI = simp_sub oI nI in let oI_nI = simp_sub oI nI in
match oI_nI with match oI_nI with
@ -678,9 +677,9 @@ let rec simp_extract seq off len =
let oJ = if Z.sign data <= 0 then zero else oI_nI in let oJ = if Z.sign data <= 0 then zero else oI_nI in
let lI = Z.(max zero (min l (neg data))) in let lI = Z.(max zero (min l (neg data))) in
let l = Z.(l - lI) in let l = Z.(l - lI) in
`Continue ((l, oJ), simp_extract naI oI (integer lI)) `Continue (simp_extract naI oI (integer lI), (l, oJ))
| _ -> `Stop (Ap3 (Extract, seq, off, len)) ) | _ -> `Stop (Ap3 (Extract, seq, off, len)) )
~finish:(fun (_, e1N) -> simp_concat e1N) ~finish:(fun (e1N, _) -> simp_concat e1N)
| _ -> Ap3 (Extract, seq, off, len) ) | _ -> Ap3 (Extract, seq, off, len) )
(* α[o,l) *) (* α[o,l) *)
| _ -> Ap3 (Extract, seq, off, len) ) | _ -> Ap3 (Extract, seq, off, len) )
@ -697,7 +696,7 @@ and simp_concat xs =
in in
let concat_sub_Concat xs = let concat_sub_Concat xs =
IArray.concat IArray.concat
(IArray.fold_right xs ~init:[] ~f:(fun x s -> (IArray.fold_right xs [] ~f:(fun x s ->
match x with match x with
| ApN (Concat, ys) -> ys :: s | ApN (Concat, ys) -> ys :: s
| x -> IArray.of_array [|x|] :: s )) | x -> IArray.of_array [|x|] :: s ))
@ -1018,31 +1017,30 @@ let map e ~f =
| NegLit (sym, xs) -> mapN (simp_neglit sym) ~f xs | NegLit (sym, xs) -> mapN (simp_neglit sym) ~f xs
| Var _ | Integer _ | Rational _ | RecRecord _ -> e | Var _ | Integer _ | Rational _ | RecRecord _ -> e
let fold_map e ~init ~f = let fold_map e s0 ~f =
let s = ref init in let s = ref s0 in
let f x = let f x =
let s', x' = f !s x in let x', s' = f x !s in
s := s' ; s := s' ;
x' x'
in in
let e' = map e ~f in let e' = map e ~f in
(!s, e') (e', !s)
let rec map_rec_pre e ~f = let rec map_rec_pre e ~f =
match f e with Some e' -> e' | None -> map ~f:(map_rec_pre ~f) e match f e with Some e' -> e' | None -> map ~f:(map_rec_pre ~f) e
let rec fold_map_rec_pre e ~init:s ~f = let rec fold_map_rec_pre e s ~f =
match f s e with match f e s with
| Some (s, e') -> (s, e') | Some (e', s) -> (e', s)
| None -> fold_map ~f:(fun s e -> fold_map_rec_pre ~f ~init:s e) ~init:s e | None -> fold_map ~f:(fold_map_rec_pre ~f) e s
let disjuncts e = let disjuncts e =
let rec disjuncts_ e = let rec disjuncts_ e =
match e with match e with
| Or es -> | Or es ->
let e0, e1N = Set.pop_exn es in let e0, e1N = Set.pop_exn es in
Set.fold e1N ~init:(disjuncts_ e0) ~f:(fun e cs -> Set.fold ~f:(fun e -> Set.union (disjuncts_ e)) e1N (disjuncts_ e0)
Set.union cs (disjuncts_ e) )
| Ap3 (Conditional, cnd, thn, els) -> | Ap3 (Conditional, cnd, thn, els) ->
Set.add Set.add
(and_ (orN (disjuncts_ (not_ cnd))) (orN (disjuncts_ els))) (and_ (orN (disjuncts_ (not_ cnd))) (orN (disjuncts_ els)))
@ -1095,15 +1093,15 @@ let for_all e ~f =
| Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args | Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args
| Var _ | Integer _ | Rational _ | RecRecord _ -> true | Var _ | Integer _ | Rational _ | RecRecord _ -> true
let fold e ~init:s ~f = let fold e s ~f =
match e with match e with
| Ap1 (_, x) -> f x s | Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s) | Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s)) | Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) | Apply (_, xs) | PosLit (_, xs) | NegLit (_, xs) -> | ApN (_, xs) | Apply (_, xs) | PosLit (_, xs) | NegLit (_, xs) ->
IArray.fold ~f:(fun s x -> f x s) xs ~init:s IArray.fold ~f xs s
| And args | Or args -> Set.fold ~f args ~init:s | And args | Or args -> Set.fold ~f args s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args s
| Var _ | Integer _ | Rational _ | RecRecord _ -> s | Var _ | Integer _ | Rational _ | RecRecord _ -> s
let rec iter_terms e ~f = let rec iter_terms e ~f =
@ -1124,21 +1122,19 @@ let rec iter_terms e ~f =
| Var _ | Integer _ | Rational _ | RecRecord _ -> () ) ; | Var _ | Integer _ | Rational _ | RecRecord _ -> () ) ;
f e f e
let rec fold_terms e ~init:s ~f = let rec fold_terms e s ~f =
let fold_terms f e s = fold_terms e ~init:s ~f in f e
let s = ( match e with
match e with | Ap1 (_, x) -> fold_terms ~f x s
| Ap1 (_, x) -> fold_terms f x s | Ap2 (_, x, y) -> fold_terms ~f y (fold_terms ~f x s)
| Ap2 (_, x, y) -> fold_terms f y (fold_terms f x s) | Ap3 (_, x, y, z) ->
| Ap3 (_, x, y, z) -> fold_terms f z (fold_terms f y (fold_terms f x s)) fold_terms ~f z (fold_terms ~f y (fold_terms ~f x s))
| ApN (_, xs) | Apply (_, xs) | PosLit (_, xs) | NegLit (_, xs) -> | ApN (_, xs) | Apply (_, xs) | PosLit (_, xs) | NegLit (_, xs) ->
IArray.fold ~f:(fun s x -> fold_terms f x s) xs ~init:s IArray.fold ~f:(fold_terms ~f) xs s
| And args | Or args -> Set.fold args ~init:s ~f:(fold_terms f) | And args | Or args -> Set.fold ~f:(fold_terms ~f) args s
| Add args | Mul args -> | Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms f arg s) Qset.fold ~f:(fun arg _ -> fold_terms ~f arg) args s
| Var _ | Integer _ | Rational _ | RecRecord _ -> s | Var _ | Integer _ | Rational _ | RecRecord _ -> s )
in
f s e
let iter_vars e ~f = let iter_vars e ~f =
iter_terms ~f:(fun e -> Option.iter ~f (Var.of_term e)) e iter_terms ~f:(fun e -> Option.iter ~f (Var.of_term e)) e
@ -1146,12 +1142,12 @@ let iter_vars e ~f =
let exists_vars e ~f = let exists_vars e ~f =
Iter.exists ~f (Iter.from_labelled_iter (iter_vars e)) Iter.exists ~f (Iter.from_labelled_iter (iter_vars e))
let fold_vars e ~init ~f = let fold_vars e s ~f =
fold_terms ~f:(fun s e -> Option.fold ~f ~init:s (Var.of_term e)) ~init e fold_terms ~f:(fun e -> Option.fold ~f (Var.of_term e)) e s
(** Query *) (** Query *)
let fv e = fold_vars e ~f:(Fun.flip Var.Set.add) ~init:Var.Set.empty let fv e = fold_vars ~f:Var.Set.add e Var.Set.empty
let is_true = function Integer {data} -> Z.is_true data | _ -> false let is_true = function Integer {data} -> Z.is_true data | _ -> false
let is_false = function Integer {data} -> Z.is_false data | _ -> false let is_false = function Integer {data} -> Z.is_false data | _ -> false
@ -1166,11 +1162,9 @@ let rec height = function
| Ap2 (_, a, b) -> 1 + max (height a) (height b) | Ap2 (_, a, b) -> 1 + max (height a) (height b)
| Ap3 (_, a, b, c) -> 1 + max (height a) (max (height b) (height c)) | Ap3 (_, a, b, c) -> 1 + max (height a) (max (height b) (height c))
| ApN (_, v) | Apply (_, v) | PosLit (_, v) | NegLit (_, v) -> | ApN (_, v) | Apply (_, v) | PosLit (_, v) | NegLit (_, v) ->
1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height a)) 1 + IArray.fold ~f:(fun a m -> max m (height a)) v 0
| And bs | Or bs -> | And bs | Or bs -> 1 + Set.fold ~f:(fun a m -> max m (height a)) bs 0
1 + Set.fold bs ~init:0 ~f:(fun a m -> max m (height a)) | Add qs | Mul qs -> 1 + Qset.fold ~f:(fun a _ m -> max m (height a)) qs 0
| Add qs | Mul qs ->
1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height a))
| Integer _ | Rational _ | RecRecord _ -> 0 | Integer _ | Rational _ | RecRecord _ -> 0
(** Solve *) (** Solve *)

@ -199,11 +199,8 @@ val map_rec_pre : t -> f:(t -> t option) -> t
to the subterms of [x], followed by rebuilding the term structure on the to the subterms of [x], followed by rebuilding the term structure on the
transformed subterms. *) transformed subterms. *)
val fold_map : t -> init:'a -> f:('a -> t -> 'a * t) -> 'a * t val fold_map : t -> 's -> f:(t -> 's -> t * 's) -> t * 's
val fold_map_rec_pre : t -> 's -> f:(t -> 's -> (t * 's) option) -> t * 's
val fold_map_rec_pre :
t -> init:'a -> f:('a -> t -> ('a * t) option) -> 'a * t
val disjuncts : t -> t list val disjuncts : t -> t list
val rename : (Var.t -> Var.t) -> t -> t val rename : (Var.t -> Var.t) -> t -> t
@ -211,9 +208,9 @@ val rename : (Var.t -> Var.t) -> t -> t
val iter : t -> f:(t -> unit) -> unit val iter : t -> f:(t -> unit) -> unit
val exists : t -> f:(t -> bool) -> bool val exists : t -> f:(t -> bool) -> bool
val fold : t -> init:'a -> f:(t -> 'a -> 'a) -> 'a val fold : t -> 'a -> f:(t -> 'a -> 'a) -> 'a
val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a val fold_vars : t -> 'a -> f:(Var.t -> 'a -> 'a) -> 'a
val fold_terms : t -> init:'a -> f:('a -> t -> 'a) -> 'a val fold_terms : t -> 'a -> f:(t -> 'a -> 'a) -> 'a
(** Query *) (** Query *)

@ -66,7 +66,7 @@ module Make (T : REPR) = struct
let invariant s = let invariant s =
let@ () = Invariant.invariant [%here] s [%sexp_of: t] in let@ () = Invariant.invariant [%here] s [%sexp_of: t] in
let domain, range = let domain, range =
Map.fold s ~init:(Set.empty, Set.empty) Map.fold s (Set.empty, Set.empty)
~f:(fun ~key ~data (domain, range) -> ~f:(fun ~key ~data (domain, range) ->
(* substs are injective *) (* substs are injective *)
assert (not (Set.mem range data)) ; assert (not (Set.mem range data)) ;
@ -84,8 +84,7 @@ module Make (T : REPR) = struct
else else
let wrt = Set.union wrt vs in let wrt = Set.union wrt vs in
let sub, rng, wrt = let sub, rng, wrt =
Set.fold dom ~init:(empty, Set.empty, wrt) Set.fold dom (empty, Set.empty, wrt) ~f:(fun x (sub, rng, wrt) ->
~f:(fun x (sub, rng, wrt) ->
let x', wrt = fresh (name x) ~wrt in let x', wrt = fresh (name x) ~wrt in
let sub = Map.add_exn ~key:x ~data:x' sub in let sub = Map.add_exn ~key:x ~data:x' sub in
let rng = Set.add x' rng in let rng = Set.add x' rng in
@ -94,24 +93,21 @@ module Make (T : REPR) = struct
({sub; dom; rng}, wrt) ) ({sub; dom; rng}, wrt) )
|> check (fun ({sub; _}, _) -> invariant sub) |> check (fun ({sub; _}, _) -> invariant sub)
let fold sub ~init ~f = let fold sub z ~f = Map.fold ~f:(fun ~key ~data -> f key data) sub z
Map.fold sub ~init ~f:(fun ~key ~data s -> f key data s)
let domain sub = let domain sub =
Map.fold sub ~init:Set.empty ~f:(fun ~key ~data:_ domain -> Map.fold ~f:(fun ~key ~data:_ -> Set.add key) sub Set.empty
Set.add key domain )
let range sub = let range sub =
Map.fold sub ~init:Set.empty ~f:(fun ~key:_ ~data range -> Map.fold ~f:(fun ~key:_ ~data -> Set.add data) sub Set.empty
Set.add data range )
let invert sub = let invert sub =
Map.fold sub ~init:empty ~f:(fun ~key ~data sub' -> Map.fold sub empty ~f:(fun ~key ~data sub' ->
Map.add_exn ~key:data ~data:key sub' ) Map.add_exn ~key:data ~data:key sub' )
|> check invariant |> check invariant
let restrict sub vs = let restrict sub vs =
Map.fold sub ~init:{sub; dom= Set.empty; rng= Set.empty} Map.fold sub {sub; dom= Set.empty; rng= Set.empty}
~f:(fun ~key ~data z -> ~f:(fun ~key ~data z ->
if Set.mem vs key then if Set.mem vs key then
{z with dom= Set.add key z.dom; rng= Set.add data z.rng} {z with dom= Set.add key z.dom; rng= Set.add data z.rng}

@ -64,7 +64,7 @@ module type VAR = sig
val is_empty : t -> bool val is_empty : t -> bool
val domain : t -> Set.t val domain : t -> Set.t
val range : t -> Set.t val range : t -> Set.t
val fold : t -> init:'a -> f:(var -> var -> 'a -> 'a) -> 'a val fold : t -> 's -> f:(var -> var -> 's -> 's) -> 's
val apply : t -> var -> var val apply : t -> var -> var
end end
end end

@ -68,24 +68,21 @@ let map ~f_sjn ~f_ctx ~f_trm ~f_fml ({us; xs= _; ctx; pure; heap; djns} as q)
then q then q
else {q with ctx; pure; heap; djns} else {q with ctx; pure; heap; djns}
let fold_terms_seg {loc; bas; len; siz; seq} ~init ~f = let fold_terms_seg {loc; bas; len; siz; seq} s ~f =
let f b s = f s b in f loc (f bas (f len (f siz (f seq s))))
f loc (f bas (f len (f siz (f seq init))))
let fold_vars_seg seg ~init ~f = let fold_vars_seg seg s ~f = fold_terms_seg ~f:(Term.fold_vars ~f) seg s
fold_terms_seg seg ~init ~f:(fun init -> Term.fold_vars ~f ~init)
let fold_vars_stem ?ignore_ctx ?ignore_pure let fold_vars_stem ?ignore_ctx ?ignore_pure
{us= _; xs= _; ctx; pure; heap; djns= _} ~init ~f = {us= _; xs= _; ctx; pure; heap; djns= _} s ~f =
let unless flag f init = if Option.is_some flag then init else f ~init in let unless flag f s = if Option.is_some flag then s else f s in
List.fold ~f:(fun init -> fold_vars_seg ~f ~init) heap ~init List.fold ~f:(fold_vars_seg ~f) heap s
|> unless ignore_pure (Formula.fold_vars ~f pure) |> unless ignore_pure (Formula.fold_vars ~f pure)
|> unless ignore_ctx (Context.fold_vars ~f ctx) |> unless ignore_ctx (Context.fold_vars ~f ctx)
let fold_vars ?ignore_ctx ?ignore_pure fold_vars q ~init ~f = let fold_vars ?ignore_ctx ?ignore_pure fold_vars q s ~f =
fold_vars_stem ?ignore_ctx ?ignore_pure ~init ~f q fold_vars_stem ?ignore_ctx ?ignore_pure ~f q s
|> fun init -> |> List.fold ~f:(List.fold ~f:fold_vars) q.djns
List.fold ~init q.djns ~f:(fun init -> List.fold ~init ~f:fold_vars)
(** Pretty-printing *) (** Pretty-printing *)
@ -98,13 +95,13 @@ let rec var_strength_ xs m q =
in in
let xs = Var.Set.union xs q.xs in let xs = Var.Set.union xs q.xs in
let m_stem = let m_stem =
fold_vars_stem ~ignore_ctx:() q ~init:m ~f:(fun m var -> fold_vars_stem ~ignore_ctx:() q m ~f:(fun var m ->
if not (Var.Set.mem xs var) then if not (Var.Set.mem xs var) then
Var.Map.add ~key:var ~data:`Universal m Var.Map.add ~key:var ~data:`Universal m
else add var m ) else add var m )
in in
let m = let m =
List.fold ~init:m_stem q.djns ~f:(fun m djn -> List.fold q.djns m_stem ~f:(fun djn m ->
let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in
List.reduce ms ~f:(fun m1 m2 -> List.reduce ms ~f:(fun m1 m2 ->
Var.Map.union m1 m2 ~f:(fun _ s1 s2 -> Var.Map.union m1 m2 ~f:(fun _ s1 s2 ->
@ -118,8 +115,8 @@ let rec var_strength_ xs m q =
let var_strength ?(xs = Var.Set.empty) q = let var_strength ?(xs = Var.Set.empty) q =
let m = let m =
Var.Set.fold xs ~init:Var.Map.empty ~f:(fun x m -> Var.Set.fold xs Var.Map.empty ~f:(fun x ->
Var.Map.add ~key:x ~data:`Existential m ) Var.Map.add ~key:x ~data:`Existential )
in in
var_strength_ xs m q var_strength_ xs m q
@ -146,7 +143,7 @@ let pp_block x fs segs =
match Term.d_int len with match Term.d_int len with
| Some data -> ( | Some data -> (
match match
List.fold segs ~init:(Some Z.zero) ~f:(fun len seg -> List.fold segs (Some Z.zero) ~f:(fun seg len ->
match (len, Term.d_int seg.siz) with match (len, Term.d_int seg.siz) with
| Some len, Some data -> Some (Z.add len data) | Some len, Some data -> Some (Z.add len data)
| _ -> None ) | _ -> None )
@ -258,17 +255,15 @@ let pp_djn fs d =
let pp_raw fs q = let pp_raw fs q =
pp_ ?var_strength:None Var.Set.empty Var.Set.empty Context.empty fs q pp_ ?var_strength:None Var.Set.empty Var.Set.empty Context.empty fs q
let fv_seg seg = let fv_seg seg = fold_vars_seg ~f:Var.Set.add seg Var.Set.empty
fold_vars_seg seg ~f:(Fun.flip Var.Set.add) ~init:Var.Set.empty
let fv ?ignore_ctx ?ignore_pure q = let fv ?ignore_ctx ?ignore_pure q =
let rec fv_union init q = let rec fv_union q s =
Var.Set.diff Var.Set.diff
(fold_vars ?ignore_ctx ?ignore_pure fv_union q ~init (fold_vars ?ignore_ctx ?ignore_pure fv_union ~f:Var.Set.add q s)
~f:(Fun.flip Var.Set.add))
q.xs q.xs
in in
fv_union Var.Set.empty q fv_union q Var.Set.empty
let invariant_pure p = assert (not Formula.(equal ff p)) let invariant_pure p = assert (not Formula.(equal ff p))
let invariant_seg _ = () let invariant_seg _ = ()
@ -482,7 +477,7 @@ let star q1 q2 =
let starN = function let starN = function
| [] -> emp | [] -> emp
| [q] -> q | [q] -> q
| q :: qs -> List.fold ~f:star ~init:q qs | q :: qs -> List.fold ~f:star qs q
let or_ q1 q2 = let or_ q1 q2 =
[%Trace.call fun {pf} -> pf "(%a)@ (%a)" pp_raw q1 pp_raw q2] [%Trace.call fun {pf} -> pf "(%a)@ (%a)" pp_raw q1 pp_raw q2]
@ -514,13 +509,13 @@ let or_ q1 q2 =
let orN = function let orN = function
| [] -> false_ Var.Set.empty | [] -> false_ Var.Set.empty
| [q] -> q | [q] -> q
| q :: qs -> List.fold ~f:or_ ~init:q qs | q :: qs -> List.fold ~f:or_ qs q
let pure (p : Formula.t) = let pure (p : Formula.t) =
[%Trace.call fun {pf} -> pf "%a" Formula.pp p] [%Trace.call fun {pf} -> pf "%a" Formula.pp p]
; ;
Iter.fold (Context.dnf p) ~init:(false_ Var.Set.empty) Iter.fold (Context.dnf p) (false_ Var.Set.empty)
~f:(fun q (xs, pure, ctx) -> ~f:(fun (xs, pure, ctx) q ->
let us = Formula.fv pure in let us = Formula.fv pure in
if Context.is_unsat ctx then extend_us us q if Context.is_unsat ctx then extend_us us q
else or_ q (exists_fresh xs {emp with us; ctx; pure}) ) else or_ q (exists_fresh xs {emp with us; ctx; pure}) )
@ -536,7 +531,7 @@ let and_subst subst q =
; ;
Context.Subst.fold Context.Subst.fold
~f:(fun ~key ~data -> and_ (Formula.eq key data)) ~f:(fun ~key ~data -> and_ (Formula.eq key data))
subst ~init:q subst q
|> |>
[%Trace.retn fun {pf} q -> [%Trace.retn fun {pf} q ->
pf "%a" pp q ; pf "%a" pp q ;
@ -546,7 +541,7 @@ let subst sub q =
[%Trace.call fun {pf} -> pf "@[%a@]@ %a" Var.Subst.pp sub pp q] [%Trace.call fun {pf} -> pf "@[%a@]@ %a" Var.Subst.pp sub pp q]
; ;
let dom, eqs = let dom, eqs =
Var.Subst.fold sub ~init:(Var.Set.empty, Formula.tt) Var.Subst.fold sub (Var.Set.empty, Formula.tt)
~f:(fun var trm (dom, eqs) -> ~f:(fun var trm (dom, eqs) ->
( Var.Set.add var dom ( Var.Set.add var dom
, Formula.and_ (Formula.eq (Term.var var) (Term.var trm)) eqs ) ) , Formula.and_ (Formula.eq (Term.var var) (Term.var trm)) eqs ) )
@ -580,11 +575,9 @@ let rec is_empty q =
let rec pure_approx q = let rec pure_approx q =
Formula.andN Formula.andN
( [q.pure] ( [q.pure]
|> fun init -> |> List.fold q.heap ~f:(fun seg p -> Formula.dq0 seg.loc :: p)
List.fold ~init q.heap ~f:(fun p seg -> Formula.dq0 seg.loc :: p) |> List.fold q.djns ~f:(fun djn p ->
|> fun init -> Formula.orN (List.map djn ~f:pure_approx) :: p ) )
List.fold ~init q.djns ~f:(fun p djn ->
Formula.orN (List.map djn ~f:pure_approx) :: p ) )
let pure_approx q = let pure_approx q =
[%Trace.call fun {pf} -> pf "%a" pp q] [%Trace.call fun {pf} -> pf "%a" pp q]
@ -608,7 +601,7 @@ let fold_dnf ~conj ~disj sjn (xs, conjuncts) disjuncts =
and split_case pending_splits (xs, conjuncts) disjuncts = and split_case pending_splits (xs, conjuncts) disjuncts =
match Iter.pop pending_splits with match Iter.pop pending_splits with
| Some (split, pending_splits) -> | Some (split, pending_splits) ->
List.fold split ~init:disjuncts ~f:(fun disjuncts sjn -> List.fold split disjuncts ~f:(fun sjn disjuncts ->
add_disjunct pending_splits sjn (xs, conjuncts) disjuncts ) add_disjunct pending_splits sjn (xs, conjuncts) disjuncts )
| None -> disj (xs, conjuncts) disjuncts | None -> disj (xs, conjuncts) disjuncts
in in
@ -658,15 +651,15 @@ let rec freshen_nested_xs q =
(* trim xs to those that appear in the stem and sink the rest *) (* trim xs to those that appear in the stem and sink the rest *)
let fv_stem = fv {q with xs= Var.Set.empty; djns= []} in let fv_stem = fv {q with xs= Var.Set.empty; djns= []} in
let xs_sink, xs = Var.Set.diff_inter q.xs fv_stem in let xs_sink, xs = Var.Set.diff_inter q.xs fv_stem in
let xs_below, djns = let djns, xs_below =
List.fold_map ~init:Var.Set.empty q.djns ~f:(fun xs_below djn -> List.fold_map q.djns Var.Set.empty ~f:(fun djn xs_below ->
List.fold_map ~init:xs_below djn ~f:(fun xs_below dj -> List.fold_map djn xs_below ~f:(fun dj xs_below ->
(* quantify xs not in stem and freshen disjunct *) (* quantify xs not in stem and freshen disjunct *)
let dj' = let dj' =
freshen_nested_xs (exists (Var.Set.inter xs_sink dj.us) dj) freshen_nested_xs (exists (Var.Set.inter xs_sink dj.us) dj)
in in
let xs_below' = Var.Set.union xs_below dj'.xs in let xs_below' = Var.Set.union xs_below dj'.xs in
(xs_below', dj') ) ) (dj', xs_below') ) )
in in
(* rename xs to miss all xs in subformulas *) (* rename xs to miss all xs in subformulas *)
freshen_xs {q with xs; djns} ~wrt:(Var.Set.union q.us xs_below) freshen_xs {q with xs; djns} ~wrt:(Var.Set.union q.us xs_below)
@ -688,7 +681,7 @@ let rec propagate_context_ ancestor_vs ancestor_ctx q =
let ancestor_stem = and_ctx_ ancestor_ctx stem in let ancestor_stem = and_ctx_ ancestor_ctx stem in
let ancestor_ctx = ancestor_stem.ctx in let ancestor_ctx = ancestor_stem.ctx in
exists xs exists xs
(List.fold djns ~init:ancestor_stem ~f:(fun q' djn -> (List.fold djns ancestor_stem ~f:(fun djn q' ->
let dj_ctxs, djn = let dj_ctxs, djn =
List.rev_map_split djn ~f:(fun dj -> List.rev_map_split djn ~f:(fun dj ->
let dj = propagate_context_ ancestor_vs ancestor_ctx dj in let dj = propagate_context_ ancestor_vs ancestor_ctx dj in

@ -93,13 +93,11 @@ end = struct
let us = Option.value us ~default:Var.Set.empty in let us = Option.value us ~default:Var.Set.empty in
let us = let us =
Option.fold Option.fold
~f:(fun us sub -> Var.Set.union (Var.Set.diff sub.Sh.us xs) us) ~f:(fun sub -> Var.Set.union (Var.Set.diff sub.Sh.us xs))
sub ~init:us sub us
in in
let union_us q_opt us' = let union_us q_opt us' =
Option.fold Option.fold ~f:(fun q -> Var.Set.union q.Sh.us) q_opt us'
~f:(fun us' q -> Var.Set.union q.Sh.us us')
q_opt ~init:us'
in in
union_us com (union_us min us) union_us com (union_us min us)
in in
@ -655,8 +653,8 @@ let excise_dnf : Sh.t -> Var.Set.t -> Sh.t -> Sh.t option =
let dnf_subtrahend = Sh.dnf subtrahend in let dnf_subtrahend = Sh.dnf subtrahend in
Iter.fold_opt Iter.fold_opt
(Iter.of_list dnf_minuend) (Iter.of_list dnf_minuend)
~init:(Sh.false_ (Var.Set.union minuend.us xs)) (Sh.false_ (Var.Set.union minuend.us xs))
~f:(fun remainders minuend -> ~f:(fun minuend remainders ->
[%trace] [%trace]
~call:(fun {pf} -> pf "@[<2>minuend@ %a@]" Sh.pp minuend) ~call:(fun {pf} -> pf "@[<2>minuend@ %a@]" Sh.pp minuend)
~retn:(fun {pf} -> pf "%a" (Option.pp "%a" Sh.pp)) ~retn:(fun {pf} -> pf "%a" (Option.pp "%a" Sh.pp))

@ -42,9 +42,7 @@ let%test_module _ =
let z = Term.var z_ let z = Term.var z_
let of_eqs l = let of_eqs l =
List.fold ~init:(wrt, true_) List.fold ~f:(fun (a, b) (us, r) -> and_eq us a b r) l (wrt, true_)
~f:(fun (us, r) (a, b) -> and_eq us a b r)
l
|> snd |> snd
let implies_eq r a b = implies r (Term.eq a b) let implies_eq r a b = implies r (Term.eq a b)

@ -47,9 +47,9 @@ let%test_module _ =
let g = Term.mul let g = Term.mul
let of_eqs l = let of_eqs l =
List.fold ~init:(wrt, empty) List.fold
~f:(fun (us, r) (a, b) -> add us (Formula.eq a b) r) ~f:(fun (a, b) (us, r) -> add us (Formula.eq a b) r)
l l (wrt, empty)
|> snd |> snd
let add_eq a b r = add wrt (Formula.eq a b) r |> snd let add_eq a b r = add wrt (Formula.eq a b) r |> snd

@ -52,7 +52,7 @@ let%test_module _ =
(Array.map ~f:(fun (siz, seq) -> Term.sized ~siz ~seq) ms)) (Array.map ~f:(fun (siz, seq) -> Term.sized ~siz ~seq) ms))
let of_eqs l = let of_eqs l =
List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Formula.eq a b) q) l List.fold ~f:(fun (a, b) q -> and_ (Formula.eq a b) q) l emp
let%expect_test _ = let%expect_test _ =
pp pp
@ -143,8 +143,8 @@ let%test_module _ =
( ( 0 = _ emp) ( ( 0 = _ emp)
( ( ( 1 = _ = %y_7 emp) ( 2 = _ emp) )) ( ( ( 1 = _ = %y_7 emp) ( 2 = _ emp) ))
) )
( ( 1 = %y_7 emp) ( emp) ( emp) ) |}] ( ( emp) ( 1 = %y_7 emp) ( emp) ) |}]
let%expect_test _ = let%expect_test _ =
let q = exists ~$[x_] (of_eqs [(f x, x); (f y, y - !1)]) in let q = exists ~$[x_] (of_eqs [(f x, x); (f y, y - !1)]) in

Loading…
Cancel
Save