From 3003a8e6469aa41e9a65ac793d7c8bc6977bb392 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Wed, 9 Oct 2019 08:36:49 -0700 Subject: [PATCH] [sledge] NFC minor cleanups Reviewed By: jvillard Differential Revision: D17665255 fbshipit-source-id: 0f18e5777 --- sledge/src/control.ml | 2 +- sledge/src/domain/used_globals.ml | 3 +- sledge/src/llair/exp.ml | 82 +++--- sledge/src/llair/exp.mli | 24 +- sledge/src/llair/frontend.ml | 4 +- sledge/src/llair/term.ml | 438 +++++++++++++++--------------- sledge/src/llair/term.mli | 99 ++++--- 7 files changed, 348 insertions(+), 304 deletions(-) diff --git a/sledge/src/control.ml b/sledge/src/control.ml index 45b4c8d59..a5075bf4e 100644 --- a/sledge/src/control.ml +++ b/sledge/src/control.ml @@ -248,7 +248,7 @@ module Make (Dom : Domain_sig.Dom) = struct | None -> [%Trace.info "queue empty"] ; () end - let used_globals : exec_opts -> Reg.reg -> Reg.Set.t = + let used_globals : exec_opts -> Reg.t -> Reg.Set.t = fun opts fn -> [%Trace.call fun {pf} -> pf "%a" Reg.pp fn] ; diff --git a/sledge/src/domain/used_globals.ml b/sledge/src/domain/used_globals.ml index 479733b09..99e2f9e47 100644 --- a/sledge/src/domain/used_globals.ml +++ b/sledge/src/domain/used_globals.ml @@ -14,8 +14,7 @@ let report_fmt_thunk = Fn.flip pp let empty = Reg.Set.empty let init globals = - [%Trace.info - "pgm globals: {%a}" (Vector.pp ", " Llair_.Global.pp) globals] ; + [%Trace.info "pgm globals: {%a}" (Vector.pp ", " Global.pp) globals] ; empty let join l r = Some (Set.union l r) diff --git a/sledge/src/llair/exp.ml b/sledge/src/llair/exp.ml index 2666af1f7..e3d228d68 100644 --- a/sledge/src/llair/exp.ml +++ b/sledge/src/llair/exp.ml @@ -186,10 +186,6 @@ and pp_record fs elts = elts] [@@warning "-9"] -type exp = t - -let pp_exp = pp - (** Invariant *) let valid_idx idx elts = 0 <= idx && idx < Vector.length elts @@ -288,6 +284,10 @@ and typ_of exp = let typ = typ_of +type exp = t + +let pp_exp = pp + (** Registers are the expressions constructed by [Reg] *) module Reg = struct include T @@ -371,31 +371,14 @@ module Reg = struct |> check invariant end -(** Access *) - -let fold_exps e ~init ~f = - let fold_exps_ fold_exps_ e z = - let z = - match e.desc with - | Ap1 (_, _, x) -> fold_exps_ x z - | Ap2 (_, _, x, y) -> fold_exps_ y (fold_exps_ x z) - | Ap3 (_, _, w, x, y) -> fold_exps_ w (fold_exps_ y (fold_exps_ x z)) - | ApN (_, _, xs) -> - Vector.fold xs ~init:z ~f:(fun z elt -> fold_exps_ elt z) - | _ -> z - in - f z e - in - fix fold_exps_ (fun _ z -> z) e init - -let fold_regs e ~init ~f = - fold_exps e ~init ~f:(fun z x -> - match x.desc with Reg _ -> f z (x :> Reg.t) | _ -> z ) - (** Construct *) +(* registers *) + let reg x = x +(* constants *) + let nondet typ msg = {desc= Nondet {msg; typ}; term= Term.nondet msg} |> check invariant @@ -412,6 +395,8 @@ let bool b = integer Typ.bool (Z.of_bool b) let float typ data = {desc= Float {data; typ}; term= Term.float data} |> check invariant +(* type conversions *) + let convert ?(unsigned = false) ~dst ~src exp = ( if (not unsigned) && Typ.equal dst src then exp else @@ -419,9 +404,7 @@ let convert ?(unsigned = false) ~dst ~src exp = ; term= Term.convert ~unsigned ~dst ~src exp.term } ) |> check invariant -let select typ rcd idx = - {desc= Ap1 (Select idx, typ, rcd); term= Term.select ~rcd:rcd.term ~idx} - |> check invariant +(* comparisons *) let unsigned typ op x y = let bits = Option.value_exn (Typ.prim_bit_size_of typ) in @@ -477,6 +460,8 @@ let uno typ x y = {desc= Ap2 (Uno, typ, x, y); term= Term.uno x.term y.term} |> check invariant +(* arithmetic *) + let add typ x y = {desc= Ap2 (Add, typ, x, y); term= Term.add x.term y.term} |> check invariant @@ -505,6 +490,8 @@ let urem typ x y = {desc= Ap2 (Urem, typ, x, y); term= unsigned typ Term.rem x.term y.term} |> check invariant +(* boolean / bitwise *) + let and_ typ x y = {desc= Ap2 (And, typ, x, y); term= Term.and_ x.term y.term} |> check invariant @@ -513,6 +500,8 @@ let or_ typ x y = {desc= Ap2 (Or, typ, x, y); term= Term.or_ x.term y.term} |> check invariant +(* bitwise *) + let xor typ x y = {desc= Ap2 (Xor, typ, x, y); term= Term.xor x.term y.term} |> check invariant @@ -529,21 +518,29 @@ let ashr typ x y = {desc= Ap2 (Ashr, typ, x, y); term= Term.ashr x.term y.term} |> check invariant -let update typ ~rcd idx ~elt = - { desc= Ap2 (Update idx, typ, rcd, elt) - ; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term } - |> check invariant +(* if-then-else *) let conditional typ ~cnd ~thn ~els = { desc= Ap3 (Conditional, typ, cnd, thn, els) ; term= Term.conditional ~cnd:cnd.term ~thn:thn.term ~els:els.term } |> check invariant +(* records (struct / array values) *) + let record typ elts = { desc= ApN (Record, typ, elts) ; term= Term.record (Vector.map ~f:(fun elt -> elt.term) elts) } |> check invariant +let select typ rcd idx = + {desc= Ap1 (Select idx, typ, rcd); term= Term.select ~rcd:rcd.term ~idx} + |> check invariant + +let update typ ~rcd idx ~elt = + { desc= Ap2 (Update idx, typ, rcd, elt) + ; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term } + |> check invariant + let struct_rec key = let memo_id = Hashtbl.create key in let rec_app = (Staged.unstage (Term.rec_app key)) Term.Record in @@ -569,6 +566,27 @@ let struct_rec key = forcing the recursive thunks also updates this value. *) {desc= ApN (Struct_rec, typ, elts); term= rec_app ~id Vector.empty} +(** Traverse *) + +let fold_exps e ~init ~f = + let fold_exps_ fold_exps_ e z = + let z = + match e.desc with + | Ap1 (_, _, x) -> fold_exps_ x z + | Ap2 (_, _, x, y) -> fold_exps_ y (fold_exps_ x z) + | Ap3 (_, _, w, x, y) -> fold_exps_ w (fold_exps_ y (fold_exps_ x z)) + | ApN (_, _, xs) -> + Vector.fold xs ~init:z ~f:(fun z elt -> fold_exps_ elt z) + | _ -> z + in + f z e + in + fix fold_exps_ (fun _ z -> z) e init + +let fold_regs e ~init ~f = + fold_exps e ~init ~f:(fun z x -> + match x.desc with Reg _ -> f z (x :> Reg.t) | _ -> z ) + (** Query *) let is_true e = diff --git a/sledge/src/llair/exp.mli b/sledge/src/llair/exp.mli index c2a760a74..99c35ba29 100644 --- a/sledge/src/llair/exp.mli +++ b/sledge/src/llair/exp.mli @@ -125,13 +125,21 @@ end (** Construct *) +(* registers *) val reg : Reg.t -> t + +(* constants *) val nondet : Typ.t -> string -> t val label : parent:string -> name:string -> t val null : t val bool : bool -> t val integer : Typ.t -> Z.t -> t val float : Typ.t -> string -> t + +(* type conversions *) +val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t + +(* comparisons *) val eq : Typ.t -> t -> t -> t val dq : Typ.t -> t -> t -> t val gt : Typ.t -> t -> t -> t @@ -144,6 +152,8 @@ val ult : Typ.t -> t -> t -> t val ule : Typ.t -> t -> t -> t val ord : Typ.t -> t -> t -> t val uno : Typ.t -> t -> t -> t + +(* arithmetic *) val add : Typ.t -> t -> t -> t val sub : Typ.t -> t -> t -> t val mul : Typ.t -> t -> t -> t @@ -151,13 +161,21 @@ val div : Typ.t -> t -> t -> t val rem : Typ.t -> t -> t -> t val udiv : Typ.t -> t -> t -> t val urem : Typ.t -> t -> t -> t + +(* boolean / bitwise *) val and_ : Typ.t -> t -> t -> t val or_ : Typ.t -> t -> t -> t + +(* bitwise *) val xor : Typ.t -> t -> t -> t val shl : Typ.t -> t -> t -> t val lshr : Typ.t -> t -> t -> t val ashr : Typ.t -> t -> t -> t + +(* if-then-else *) val conditional : Typ.t -> cnd:t -> thn:t -> els:t -> t + +(* records (struct / array values) *) val record : Typ.t -> t vector -> t val select : Typ.t -> t -> int -> t val update : Typ.t -> rcd:t -> int -> elt:t -> t @@ -173,15 +191,13 @@ val struct_rec : one point on each cycle. Failure to obey these requirements will lead to stack overflow. *) -val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t - -(** Access *) +(** Traverse *) val fold_regs : t -> init:'a -> f:('a -> Reg.t -> 'a) -> 'a (** Query *) val term : t -> Term.t +val typ : t -> Typ.t val is_true : t -> bool val is_false : t -> bool -val typ : t -> Typ.t diff --git a/sledge/src/llair/frontend.ml b/sledge/src/llair/frontend.ml index b7c731389..e4b8d4b38 100644 --- a/sledge/src/llair/frontend.ml +++ b/sledge/src/llair/frontend.ml @@ -714,7 +714,7 @@ let exception_typs = the PHIs of [dst] translated to a move. *) let xlate_jump : x - -> ?reg_exps:(Reg.reg * Exp.t) list + -> ?reg_exps:(Reg.t * Exp.t) list -> Llvm.llvalue -> Llvm.llbasicblock -> Loc.t @@ -859,7 +859,7 @@ let xlate_instr : | ConstantExpr -> ( match Llvm.constexpr_opcode maybe_llfunc with | BitCast -> Llvm.operand maybe_llfunc 0 - | IntToPtr -> todo "maybe handle calls with inttoptr" () + | IntToPtr -> todo "calls with inttoptr" () | _ -> fail "Unknown value in a call instruction %a" pp_llvalue maybe_llfunc () ) diff --git a/sledge/src/llair/term.ml b/sledge/src/llair/term.ml index 109f6544c..1d25009c1 100644 --- a/sledge/src/llair/term.ml +++ b/sledge/src/llair/term.ml @@ -7,6 +7,8 @@ (** Terms *) +[@@@warning "+9"] + module Z = struct include Z @@ -19,63 +21,48 @@ module rec T : sig type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp] type op1 = - (* conversion *) | Extract of {unsigned: bool; bits: int} | Convert of {unsigned: bool; dst: Typ.t; src: Typ.t} - (* array/struct *) | Select of int [@@deriving compare, equal, hash, sexp] type op2 = - (* memory *) - | Splat - | Memory - (* comparison *) | Eq | Dq | Lt | Le | Ord | Uno - (* arithmetic *) | Div | Rem - (* boolean / bitwise *) | And | Or | Xor | Shl | Lshr | Ashr - (* array/struct *) + | Splat + | Memory | Update of int [@@deriving compare, equal, hash, sexp] - type op3 = (* if-then-else *) - | Conditional - [@@deriving compare, equal, hash, sexp] - + type op3 = Conditional [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type recN = Record [@@deriving compare, equal, hash, sexp] type t = - (* nary arithmetic *) | Add of qset | Mul of qset - (* nullary *) | Var of {id: int; name: string} - | Nondet of {msg: string} - | Label of {parent: string; name: string} - (* application *) | Ap1 of op1 * t | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t vector - (* recursive application *) | RecN of recN * t vector (** NOTE: cyclic *) - (* numeric constants *) | Integer of {data: Z.t} | Float of {data: string} + | Nondet of {msg: string} + | Label of {parent: string; name: string} [@@deriving compare, equal, hash, sexp] (* Note: solve (and invariant) requires Qset.min_elt to return a @@ -100,8 +87,6 @@ and T0 : sig [@@deriving compare, equal, hash, sexp] type op2 = - | Splat - | Memory | Eq | Dq | Lt @@ -116,6 +101,8 @@ and T0 : sig | Shl | Lshr | Ashr + | Splat + | Memory | Update of int [@@deriving compare, equal, hash, sexp] @@ -127,8 +114,6 @@ and T0 : sig | Add of qset | Mul of qset | Var of {id: int; name: string} - | Nondet of {msg: string} - | Label of {parent: string; name: string} | Ap1 of op1 * t | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t @@ -136,6 +121,8 @@ and T0 : sig | RecN of recN * t vector | Integer of {data: Z.t} | Float of {data: string} + | Nondet of {msg: string} + | Label of {parent: string; name: string} [@@deriving compare, equal, hash, sexp] end = struct type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp] @@ -147,8 +134,6 @@ end = struct [@@deriving compare, equal, hash, sexp] type op2 = - | Splat - | Memory | Eq | Dq | Lt @@ -163,6 +148,8 @@ end = struct | Shl | Lshr | Ashr + | Splat + | Memory | Update of int [@@deriving compare, equal, hash, sexp] @@ -174,8 +161,6 @@ end = struct | Add of qset | Mul of qset | Var of {id: int; name: string} - | Nondet of {msg: string} - | Label of {parent: string; name: string} | Ap1 of op1 * t | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t @@ -183,6 +168,8 @@ end = struct | RecN of recN * t vector | Integer of {data: Z.t} | Float of {data: string} + | Nondet of {msg: string} + | Label of {parent: string; name: string} [@@deriving compare, equal, hash, sexp] end @@ -226,13 +213,18 @@ let rec pp ?is_x fs term = Trace.pp_styled (get_var_style var) "%%%s" fs name | Var {name; id} as var -> Trace.pp_styled (get_var_style var) "%%%s_%d" fs name id - | Nondet {msg} -> pf "nondet \"%s\"" msg - | Label {name} -> pf "%s" name - | Ap2 (Splat, byt, siz) -> pf "%a^%a" pp byt pp siz - | Ap2 (Memory, siz, arr) -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp arr - | ApN (Concat, args) -> pf "%a" (Vector.pp "@,^" pp) args | Integer {data} -> Trace.pp_styled `Magenta "%a" fs Z.pp data | Float {data} -> pf "%s" data + | Nondet {msg} -> pf "nondet \"%s\"" msg + | Label {name} -> pf "%s" name + | Ap1 (Extract {unsigned; bits}, arg) -> + pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg + | Ap1 (Convert {dst; unsigned= true; src= Integer {bits}}, arg) -> + pf "((%a)(u%i)@ %a)" Typ.pp dst bits pp arg + | Ap1 (Convert {unsigned= true; dst= Integer {bits}; src}, arg) -> + pf "((u%i)(%a)@ %a)" bits Typ.pp src pp arg + | Ap1 (Convert {dst; src}, arg) -> + pf "((%a)(%a)@ %a)" Typ.pp dst Typ.pp src pp arg | Ap2 (Eq, x, y) -> pf "(%a@ = %a)" pp x pp y | Ap2 (Dq, x, y) -> pf "(%a@ @<2>≠ %a)" pp x pp y | Ap2 (Lt, x, y) -> pf "(%a@ < %a)" pp x pp y @@ -266,21 +258,17 @@ let rec pp ?is_x fs term = | Ap2 (Ashr, x, y) -> pf "(%a@ ashr %a)" pp x pp y | Ap3 (Conditional, cnd, thn, els) -> pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els + | Ap2 (Splat, byt, siz) -> pf "%a^%a" pp byt pp siz + | Ap2 (Memory, siz, arr) -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp arr + | ApN (Concat, args) -> pf "%a" (Vector.pp "@,^" pp) args + | ApN (Record, elts) -> pf "{%a}" pp_record elts + | RecN (Record, elts) -> pf "{|%a|}" (Vector.pp ",@ " pp) elts | Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx | Ap2 (Update idx, rcd, elt) -> pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt - | ApN (Record, elts) -> pf "{%a}" pp_record elts - | RecN (Record, elts) -> pf "{|%a|}" (Vector.pp ",@ " pp) elts - | Ap1 (Extract {unsigned; bits}, arg) -> - pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg - | Ap1 (Convert {dst; unsigned= true; src= Integer {bits}}, arg) -> - pf "((%a)(u%i)@ %a)" Typ.pp dst bits pp arg - | Ap1 (Convert {unsigned= true; dst= Integer {bits}; src}, arg) -> - pf "((u%i)(%a)@ %a)" bits Typ.pp src pp arg - | Ap1 (Convert {dst; src}, arg) -> - pf "((%a)(%a)@ %a)" Typ.pp dst Typ.pp src pp arg in fix_flip pp_ (fun _ _ -> ()) fs term + [@@warning "-9"] and pp_record fs elts = [%Trace.fprintf @@ -356,28 +344,15 @@ let invariant e = Invariant.invariant [%here] e [%sexp_of: t] @@ fun () -> match e with - | Var _ | Nondet _ | Label _ | Integer _ | Float _ -> () - | Ap1 (Extract _, _) -> () - | Ap1 (Convert {dst; src}, _) -> assert (Typ.convertible src dst) | Add _ -> assert_polynomial e |> Fn.id | Mul _ -> assert_monomial e |> Fn.id - | Ap2 - ( ( Eq | Dq | Lt | Le | Ord | Uno | Div | Rem | And | Or | Xor | Shl - | Lshr | Ashr ) - , _ - , _ ) -> - () - | ApN (Concat, args) -> assert (Vector.length args <> 1) - | Ap2 (Splat, _, siz) -> ( - match siz with - | Integer {data} -> assert (not (Z.equal Z.zero data)) - | _ -> () ) - | Ap2 (Memory, _, _) -> () - | Ap1 (Select _, _) -> () - | Ap3 (Conditional, _, _, _) -> () - | Ap2 (Update _, _, _) -> () + | Ap2 (Splat, _, Integer {data}) -> assert (not (Z.equal Z.zero data)) + | ApN (Concat, mems) -> assert (Vector.length mems <> 1) | ApN (Record, elts) | RecN (Record, elts) -> assert (not (Vector.is_empty elts)) + | Ap1 (Convert {dst; src}, _) -> assert (Typ.convertible src dst) + | _ -> () + [@@warning "-9"] (** Variables are the terms constructed by [Var] *) module Var = struct @@ -415,8 +390,8 @@ module Var = struct Invariant.invariant [%here] x [%sexp_of: t] @@ fun () -> match x with Var _ -> invariant x | _ -> assert false - let id = function Var {id} -> id | x -> violates invariant x - let name = function Var {name} -> name | x -> violates invariant x + let id = function Var v -> v.id | x -> violates invariant x + let name = function Var v -> v.name | x -> violates invariant x let of_term = function | Var _ as v -> Some (v |> check invariant) @@ -490,34 +465,14 @@ module Var = struct end end -let fold_terms e ~init ~f = - let fold_terms_ fold_terms_ e s = - let s = - match e with - | Ap1 (_, x) -> fold_terms_ x s - | Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s) - | Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) - | ApN (_, xs) | RecN (_, xs) -> - Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s - | Add args | Mul args -> - Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) - | _ -> s - in - f s e - in - fix fold_terms_ (fun _ s -> s) e init +(** Construct *) -let fold_vars e ~init ~f = - fold_terms e ~init ~f:(fun z -> function - | Var _ as v -> f z (v :> Var.t) | _ -> z ) +(* variables *) -let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty +let var x = x -(** Construct *) +(* constants *) -let var x = x -let nondet msg = Nondet {msg} |> check invariant -let label ~parent ~name = Label {parent; name} |> check invariant let integer data = Integer {data} |> check invariant let null = integer Z.zero let zero = integer Z.zero @@ -527,6 +482,10 @@ let bool b = integer (Z.of_bool b) let true_ = bool true let false_ = bool false let float data = Float {data} |> check invariant +let nondet msg = Nondet {msg} |> check invariant +let label ~parent ~name = Label {parent; name} |> check invariant + +(* type conversions *) let simp_extract ~unsigned bits arg = match arg with @@ -541,47 +500,7 @@ let simp_convert ~unsigned dst src arg = integer (Z.extract ~unsigned (min m n) data) | _ -> Ap1 (Convert {unsigned; dst; src}, arg) -let simp_record elts = ApN (Record, elts) -let simp_select idx rcd = Ap1 (Select idx, rcd) -let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt) - -let simp_concat xs = - if Vector.length xs = 1 then Vector.get xs 0 - else - let args = - if - Vector.for_all xs ~f:(function - | ApN (Concat, _) -> false - | _ -> true ) - then xs - else - Vector.concat - (Vector.fold_right xs ~init:[] ~f:(fun x s -> - match x with - | ApN (Concat, args) -> args :: s - | x -> Vector.of_array [|x|] :: s )) - in - ApN (Concat, args) - -let simp_splat byt siz = - match siz with - | Integer {data} when Z.equal Z.zero data -> simp_concat Vector.empty - | _ -> Ap2 (Splat, byt, siz) - -let simp_memory siz arr = Ap2 (Memory, siz, arr) - -let simp_lt x y = - match (x, y) with - | Integer {data= i}, Integer {data= j} -> bool (Z.lt i j) - | _ -> Ap2 (Lt, x, y) - -let simp_le x y = - match (x, y) with - | Integer {data= i}, Integer {data= j} -> bool (Z.leq i j) - | _ -> Ap2 (Le, x, y) - -let simp_ord x y = Ap2 (Ord, x, y) -let simp_uno x y = Ap2 (Uno, x, y) +(* arithmetic *) let sum_mul_const const sum = assert (not (Q.equal Q.zero const)) ; @@ -731,6 +650,8 @@ let simp_sub x y = (* x - y ==> x + (-1 * y) *) | _ -> simp_add2 x (simp_negate y) +(* if-then-else *) + let simp_cond cnd thn els = match cnd with (* ¬(true ? t : e) ==> t *) @@ -739,6 +660,17 @@ let simp_cond cnd thn els = | Integer {data} when Z.is_false data -> els | _ -> Ap3 (Conditional, cnd, thn, els) +(* boolean / bitwise *) + +let rec is_boolean = function + | Ap1 ((Extract {bits= 1; _} | Convert {dst= Integer {bits= 1; _}; _}), _) + |Ap2 ((Eq | Dq | Lt | Le), _, _) -> + true + | Ap2 ((Div | Rem | And | Or | Xor | Shl | Lshr | Ashr), x, y) + |Ap3 (Conditional, _, x, y) -> + is_boolean x || is_boolean y + | _ -> false + let rec simp_and x y = match (x, y) with (* i && j *) @@ -773,42 +705,22 @@ let rec simp_or x y = | _ when equal x y -> x | _ -> Ap2 (Or, x, y) -let rec is_boolean = function - | Ap1 ((Extract {bits= 1} | Convert {dst= Integer {bits= 1}}), _) - |Ap2 ((Eq | Dq | Lt | Le), _, _) -> - true - | Ap2 ((Div | Rem | And | Or | Xor | Shl | Lshr | Ashr), x, y) - |Ap3 (Conditional, _, x, y) -> - is_boolean x || is_boolean y - | _ -> false +(* comparison *) -let rec simp_not term = - match term with - (* ¬(x = y) ==> x ≠ y *) - | Ap2 (Eq, x, y) -> simp_dq x y - (* ¬(x ≠ y) ==> x = y *) - | Ap2 (Dq, x, y) -> simp_eq x y - (* ¬(x < y) ==> y <= x *) - | Ap2 (Lt, x, y) -> simp_le y x - (* ¬(x <= y) ==> y < x *) - | Ap2 (Le, x, y) -> simp_lt y x - (* ¬(x ≠ nan ∧ y ≠ nan) ==> x = nan ∨ y = nan *) - | Ap2 (Ord, x, y) -> simp_uno x y - (* ¬(x = nan ∨ y = nan) ==> x ≠ nan ∧ y ≠ nan *) - | Ap2 (Uno, x, y) -> simp_ord x y - (* ¬(a ∧ b) ==> ¬a ∨ ¬b *) - | Ap2 (And, x, y) -> simp_or (simp_not x) (simp_not y) - (* ¬(a ∨ b) ==> ¬a ∧ ¬b *) - | Ap2 (Or, x, y) -> simp_and (simp_not x) (simp_not y) - (* ¬(c ? t : e) ==> c ? ¬t : ¬e *) - | Ap3 (Conditional, cnd, thn, els) -> - simp_cond cnd (simp_not thn) (simp_not els) - (* ¬i ==> -i-1 *) - | Integer {data} -> integer (Z.lognot data) - (* ¬e ==> true xor e *) - | e -> Ap2 (Xor, true_, e) +let simp_lt x y = + match (x, y) with + | Integer {data= i}, Integer {data= j} -> bool (Z.lt i j) + | _ -> Ap2 (Lt, x, y) -and simp_eq x y = +let simp_le x y = + match (x, y) with + | Integer {data= i}, Integer {data= j} -> bool (Z.leq i j) + | _ -> Ap2 (Le, x, y) + +let simp_ord x y = Ap2 (Ord, x, y) +let simp_uno x y = Ap2 (Uno, x, y) + +let rec simp_eq x y = match (x, y) with (* i = j *) | Integer {data= i}, Integer {data= j} -> bool (Z.equal i j) @@ -835,6 +747,35 @@ and simp_dq x y = | Ap2 (Eq, x, y) -> Ap2 (Dq, x, y) | b -> simp_not b ) +(* negation-normal form *) +and simp_not term = + match term with + (* ¬(x = y) ==> x ≠ y *) + | Ap2 (Eq, x, y) -> simp_dq x y + (* ¬(x ≠ y) ==> x = y *) + | Ap2 (Dq, x, y) -> simp_eq x y + (* ¬(x < y) ==> y <= x *) + | Ap2 (Lt, x, y) -> simp_le y x + (* ¬(x <= y) ==> y < x *) + | Ap2 (Le, x, y) -> simp_lt y x + (* ¬(x ≠ nan ∧ y ≠ nan) ==> x = nan ∨ y = nan *) + | Ap2 (Ord, x, y) -> simp_uno x y + (* ¬(x = nan ∨ y = nan) ==> x ≠ nan ∧ y ≠ nan *) + | Ap2 (Uno, x, y) -> simp_ord x y + (* ¬(a ∧ b) ==> ¬a ∨ ¬b *) + | Ap2 (And, x, y) -> simp_or (simp_not x) (simp_not y) + (* ¬(a ∨ b) ==> ¬a ∧ ¬b *) + | Ap2 (Or, x, y) -> simp_and (simp_not x) (simp_not y) + (* ¬(c ? t : e) ==> c ? ¬t : ¬e *) + | Ap3 (Conditional, cnd, thn, els) -> + simp_cond cnd (simp_not thn) (simp_not els) + (* ¬i ==> -i-1 *) + | Integer {data} -> integer (Z.lognot data) + (* ¬e ==> true xor e *) + | e -> Ap2 (Xor, true_, e) + +(* bitwise *) + let simp_xor x y = match (x, y) with (* i xor j *) @@ -871,26 +812,62 @@ let simp_ashr x y = | e, Integer {data} when Z.equal Z.zero data -> e | _ -> Ap2 (Ashr, x, y) -(** Access *) +(* memory *) -let iter e ~f = - match e with - | Ap1 (_, x) -> f x - | Ap2 (_, x, y) -> f x ; f y - | Ap3 (_, x, y, z) -> f x ; f y ; f z - | ApN (_, xs) | RecN (_, xs) -> Vector.iter ~f xs - | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args - | _ -> () +let simp_concat xs = + if Vector.length xs = 1 then Vector.get xs 0 + else + let args = + if + Vector.for_all xs ~f:(function + | ApN (Concat, _) -> false + | _ -> true ) + then xs + else + Vector.concat + (Vector.fold_right xs ~init:[] ~f:(fun x s -> + match x with + | ApN (Concat, args) -> args :: s + | x -> Vector.of_array [|x|] :: s )) + in + ApN (Concat, args) -let fold e ~init:s ~f = - match e with - | Ap1 (_, x) -> f x s - | Ap2 (_, x, y) -> f y (f x s) - | Ap3 (_, x, y, z) -> f z (f y (f x s)) - | ApN (_, xs) | RecN (_, xs) -> - Vector.fold ~f:(fun s x -> f x s) xs ~init:s - | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s - | _ -> s +let simp_splat byt siz = + match siz with + | Integer {data} when Z.equal Z.zero data -> simp_concat Vector.empty + | _ -> Ap2 (Splat, byt, siz) + +let simp_memory siz arr = Ap2 (Memory, siz, arr) + +(* records *) + +let simp_record elts = ApN (Record, elts) +let simp_select idx rcd = Ap1 (Select idx, rcd) +let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt) + +let rec_app key = + let memo_id = Hashtbl.create key in + let dummy = null in + Staged.stage + @@ fun ~id op elt_thks -> + match Hashtbl.find memo_id id with + | None -> + (* Add placeholder to prevent computing [elts] in calls to [rec_app] + from [elt_thks] for recursive occurrences of [id]. *) + let elta = Array.create ~len:(Vector.length elt_thks) dummy in + let elts = Vector.of_array elta in + Hashtbl.set memo_id ~key:id ~data:elts ; + Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ; + RecN (op, elts) |> check invariant + | Some elts -> + (* Do not check invariant as invariant will be checked above after the + thunks are forced, before which invariant-checking may spuriously + fail. Note that it is important that the value constructed here + shares the array in the memo table, so that the update after + forcing the recursive thunks also updates this value. *) + RecN (op, elts) + +(* dispatching for normalization and invariant checking *) let norm1 op x = ( match op with @@ -927,63 +904,43 @@ let normN op xs = (match op with Concat -> simp_concat xs | Record -> simp_record xs) |> check invariant -let addN args = simp_add args |> check invariant -let mulN args = simp_mul args |> check invariant -let concat xs = normN Concat (Vector.of_array xs) -let splat ~byt ~siz = norm2 Splat byt siz -let memory ~siz ~arr = norm2 Memory siz arr +(* exposed interface *) + +let extract ?(unsigned = false) ~bits term = + norm1 (Extract {unsigned; bits}) term + +let convert ?(unsigned = false) ~dst ~src term = + norm1 (Convert {unsigned; dst; src}) term + let eq = norm2 Eq let dq = norm2 Dq let lt = norm2 Lt let le = norm2 Le let ord = norm2 Ord let uno = norm2 Uno -let neg = simp_negate -let add = simp_add2 -let sub = simp_sub -let mul = simp_mul2 +let neg e = simp_negate e |> check invariant +let add e f = simp_add2 e f |> check invariant +let addN args = simp_add args |> check invariant +let sub e f = simp_sub e f |> check invariant +let mul e f = simp_mul2 e f |> check invariant +let mulN args = simp_mul args |> check invariant let div = norm2 Div let rem = norm2 Rem let and_ = norm2 And let or_ = norm2 Or +let not_ e = simp_not e |> check invariant let xor = norm2 Xor -let not_ = simp_not let shl = norm2 Shl let lshr = norm2 Lshr let ashr = norm2 Ashr let conditional ~cnd ~thn ~els = norm3 Conditional cnd thn els +let splat ~byt ~siz = norm2 Splat byt siz +let memory ~siz ~arr = norm2 Memory siz arr +let concat xs = normN Concat (Vector.of_array xs) let record elts = normN Record elts let select ~rcd ~idx = norm1 (Select idx) rcd let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt -let rec_app key = - let memo_id = Hashtbl.create key in - let dummy = null in - Staged.stage - @@ fun ~id op elt_thks -> - match Hashtbl.find memo_id id with - | None -> - (* Add placeholder to prevent computing [elts] in calls to [rec_app] - from [elt_thks] for recursive occurrences of [id]. *) - let elta = Array.create ~len:(Vector.length elt_thks) dummy in - let elts = Vector.of_array elta in - Hashtbl.set memo_id ~key:id ~data:elts ; - Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ; - RecN (op, elts) |> check invariant - | Some elts -> - (* Do not check invariant as invariant will be checked above after the - thunks are forced, before which invariant-checking may spuriously - fail. Note that it is important that the value constructed here - shares the array in the memo table, so that the update after - forcing the recursive thunks also updates this value. *) - RecN (op, elts) - -let extract ?(unsigned = false) ~bits term = - norm1 (Extract {unsigned; bits}) term - -let convert ?(unsigned = false) ~dst ~src term = - norm1 (Convert {unsigned; dst; src}) term - let size_of t = Option.bind (Typ.prim_bit_size_of t) ~f:(fun n -> if n % 8 = 0 then Some (integer (Z.of_int (n / 8))) else None ) @@ -1062,8 +1019,51 @@ let rename sub e = | Var _ as v -> Some (Var.Subst.apply sub v) | _ -> None ) +(** Traverse *) + +let iter e ~f = + match e with + | Ap1 (_, x) -> f x + | Ap2 (_, x, y) -> f x ; f y + | Ap3 (_, x, y, z) -> f x ; f y ; f z + | ApN (_, xs) | RecN (_, xs) -> Vector.iter ~f xs + | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args + | _ -> () + +let fold e ~init:s ~f = + match e with + | Ap1 (_, x) -> f x s + | Ap2 (_, x, y) -> f y (f x s) + | Ap3 (_, x, y, z) -> f z (f y (f x s)) + | ApN (_, xs) | RecN (_, xs) -> + Vector.fold ~f:(fun s x -> f x s) xs ~init:s + | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s + | _ -> s + +let fold_terms e ~init ~f = + let fold_terms_ fold_terms_ e s = + let s = + match e with + | Ap1 (_, x) -> fold_terms_ x s + | Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s) + | Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) + | ApN (_, xs) | RecN (_, xs) -> + Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s + | Add args | Mul args -> + Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) + | _ -> s + in + f s e + in + fix fold_terms_ (fun _ s -> s) e init + +let fold_vars e ~init ~f = + fold_terms e ~init ~f:(fun z -> function + | Var _ as v -> f z (v :> Var.t) | _ -> z ) + (** Query *) +let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty let is_true = function Integer {data} -> Z.is_true data | _ -> false let is_false = function Integer {data} -> Z.is_false data | _ -> false @@ -1082,7 +1082,7 @@ let classify = function | Add _ | Mul _ -> `Interpreted | Ap2 ((Eq | Dq), _, _) -> `Simplified | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> `Uninterpreted - | RecN _ | Var _ | Nondet _ | Label _ | Integer _ | Float _ -> `Atomic + | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> `Atomic let solve e f = [%Trace.call fun {pf} -> pf "%a@ %a" pp e pp f] diff --git a/sledge/src/llair/term.mli b/sledge/src/llair/term.mli index ec37fcf74..f91714cd0 100644 --- a/sledge/src/llair/term.mli +++ b/sledge/src/llair/term.mli @@ -8,13 +8,7 @@ (** Terms Pure (heap-independent) terms are complex arithmetic, bitwise-logical, - etc. operations over literal values and variables. - - Terms for operations that are uninterpreted in the analyzer are - represented in curried form, where [App] is an application of a function - symbol to an argument. This is done to simplify the definition of - 'subterm' and make it explicit. The specific constructor functions - indicate and check the expected arity of the function symbols. *) + etc. operations over literal values and variables. *) type comparator_witness @@ -32,8 +26,6 @@ type op1 = [@@deriving compare, equal, hash, sexp] type op2 = - | Splat (** Iterated concatenation of a single byte *) - | Memory (** Size-tagged byte-array *) | Eq (** Equal test *) | Dq (** Disequal test *) | Lt (** Less-than test *) @@ -48,6 +40,8 @@ type op2 = | Shl (** Shift left, bitwise *) | Lshr (** Logical shift right, bitwise *) | Ashr (** Arithmetic shift right, bitwise *) + | Splat (** Iterated concatenation of a single byte *) + | Memory (** Size-tagged byte-array *) | Update of int (** Constant record with updated index *) [@@deriving compare, equal, hash, sexp] @@ -59,42 +53,38 @@ type opN = | Record (** Record (array / struct) constant *) [@@deriving compare, equal, hash, sexp] -type recN = - | Record - (** Record constant that may recursively refer to itself - (transitively) from its args. NOTE: represented by cyclic values. *) +type recN = Record (** Recursive record (array / struct) constant *) [@@deriving compare, equal, hash, sexp] type qset = (t, comparator_witness) Qset.t and t = private - | Add of qset (** Addition *) - | Mul of qset (** Multiplication *) + | Add of qset (** Sum of terms with rational coefficients *) + | Mul of qset (** Product of terms with rational exponents *) | Var of {id: int; name: string} (** Local variable / virtual register *) + | Ap1 of op1 * t (** Unary application *) + | Ap2 of op2 * t * t (** Binary application *) + | Ap3 of op3 * t * t * t (** Ternary application *) + | ApN of opN * t vector (** N-ary application *) + | RecN of recN * t vector + (** Recursive n-ary application, may recursively refer to itself + (transitively) from its args. NOTE: represented by cyclic values. *) + | Integer of {data: Z.t} (** Integer constant *) + | Float of {data: string} (** Floating-point constant *) | Nondet of {msg: string} (** Anonymous local variable with arbitrary value, representing non-deterministic approximation of value described by [msg] *) | Label of {parent: string; name: string} (** Address of named code block within parent function *) - | Ap1 of op1 * t - | Ap2 of op2 * t * t - | Ap3 of op3 * t * t * t - | ApN of opN * t vector - | RecN of recN * t vector - | Integer of {data: Z.t} - (** Integer constant, or if [typ] is a [Pointer], null pointer value - that never refers to an object *) - | Float of {data: string} (** Floating-point constant *) [@@deriving compare, equal, hash, sexp] val comparator : (t, comparator_witness) Comparator.t - -type term = t - -val pp_full : ?is_x:(term -> bool) -> t pp +val pp_full : ?is_x:(t -> bool) -> t pp val pp : t pp val invariant : t -> unit +type term = t + (** Term.Var is re-exported as Var *) module Var : sig type t = private term [@@deriving compare, equal, hash, sexp] @@ -140,65 +130,86 @@ end (** Construct *) +(* variables *) val var : Var.t -> t + +(* constants *) val nondet : string -> t val label : parent:string -> name:string -> t +val null : t +val bool : bool -> t val true_ : t val false_ : t -val null : t +val integer : Z.t -> t val zero : t val one : t val minus_one : t -val splat : byt:t -> siz:t -> t -val memory : siz:t -> arr:t -> t -val concat : t array -> t -val bool : bool -> t -val integer : Z.t -> t val rational : Q.t -> t val float : string -> t + +(* type conversions *) +val extract : ?unsigned:bool -> bits:int -> t -> t +val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t + +(* comparisons *) val eq : t -> t -> t val dq : t -> t -> t val lt : t -> t -> t val le : t -> t -> t val ord : t -> t -> t val uno : t -> t -> t + +(* arithmetic *) val neg : t -> t val add : t -> t -> t val sub : t -> t -> t val mul : t -> t -> t val div : t -> t -> t val rem : t -> t -> t + +(* boolean / bitwise *) val and_ : t -> t -> t val or_ : t -> t -> t -val xor : t -> t -> t val not_ : t -> t + +(* bitwise *) +val xor : t -> t -> t val shl : t -> t -> t val lshr : t -> t -> t val ashr : t -> t -> t + +(* if-then-else *) val conditional : cnd:t -> thn:t -> els:t -> t + +(* memory contents *) +val splat : byt:t -> siz:t -> t +val memory : siz:t -> arr:t -> t +val concat : t array -> t + +(* records (struct / array values) *) val record : t vector -> t val select : rcd:t -> idx:int -> t val update : rcd:t -> idx:int -> elt:t -> t -val extract : ?unsigned:bool -> bits:int -> t -> t -val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t -val size_of : Typ.t -> t option +(* recursive n-ary application *) val rec_app : (module Hashtbl.Key with type t = 'id) -> (id:'id -> recN -> t lazy_t vector -> t) Staged.t -(** Access *) - -val iter : t -> f:(t -> unit) -> unit -val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a -val fold_terms : t -> init:'a -> f:('a -> t -> 'a) -> 'a -val fold : t -> init:'a -> f:(t -> 'a -> 'a) -> 'a +val size_of : Typ.t -> t option (** Transform *) val map : t -> f:(t -> t) -> t val rename : Var.Subst.t -> t -> t +(** Traverse *) + +val iter : t -> f:(t -> unit) -> unit +val fold : t -> init:'a -> f:(t -> 'a -> 'a) -> 'a +val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a +val fold_terms : t -> init:'a -> f:('a -> t -> 'a) -> 'a + (** Query *) val fv : t -> Var.Set.t