[sledge] NFC minor cleanups

Reviewed By: jvillard

Differential Revision: D17665255

fbshipit-source-id: 0f18e5777
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent 8ee0c67d1f
commit 3003a8e646

@ -248,7 +248,7 @@ module Make (Dom : Domain_sig.Dom) = struct
| None -> [%Trace.info "queue empty"] ; ()
end
let used_globals : exec_opts -> Reg.reg -> Reg.Set.t =
let used_globals : exec_opts -> Reg.t -> Reg.Set.t =
fun opts fn ->
[%Trace.call fun {pf} -> pf "%a" Reg.pp fn]
;

@ -14,8 +14,7 @@ let report_fmt_thunk = Fn.flip pp
let empty = Reg.Set.empty
let init globals =
[%Trace.info
"pgm globals: {%a}" (Vector.pp ", " Llair_.Global.pp) globals] ;
[%Trace.info "pgm globals: {%a}" (Vector.pp ", " Global.pp) globals] ;
empty
let join l r = Some (Set.union l r)

@ -186,10 +186,6 @@ and pp_record fs elts =
elts]
[@@warning "-9"]
type exp = t
let pp_exp = pp
(** Invariant *)
let valid_idx idx elts = 0 <= idx && idx < Vector.length elts
@ -288,6 +284,10 @@ and typ_of exp =
let typ = typ_of
type exp = t
let pp_exp = pp
(** Registers are the expressions constructed by [Reg] *)
module Reg = struct
include T
@ -371,31 +371,14 @@ module Reg = struct
|> check invariant
end
(** Access *)
let fold_exps e ~init ~f =
let fold_exps_ fold_exps_ e z =
let z =
match e.desc with
| Ap1 (_, _, x) -> fold_exps_ x z
| Ap2 (_, _, x, y) -> fold_exps_ y (fold_exps_ x z)
| Ap3 (_, _, w, x, y) -> fold_exps_ w (fold_exps_ y (fold_exps_ x z))
| ApN (_, _, xs) ->
Vector.fold xs ~init:z ~f:(fun z elt -> fold_exps_ elt z)
| _ -> z
in
f z e
in
fix fold_exps_ (fun _ z -> z) e init
let fold_regs e ~init ~f =
fold_exps e ~init ~f:(fun z x ->
match x.desc with Reg _ -> f z (x :> Reg.t) | _ -> z )
(** Construct *)
(* registers *)
let reg x = x
(* constants *)
let nondet typ msg =
{desc= Nondet {msg; typ}; term= Term.nondet msg} |> check invariant
@ -412,6 +395,8 @@ let bool b = integer Typ.bool (Z.of_bool b)
let float typ data =
{desc= Float {data; typ}; term= Term.float data} |> check invariant
(* type conversions *)
let convert ?(unsigned = false) ~dst ~src exp =
( if (not unsigned) && Typ.equal dst src then exp
else
@ -419,9 +404,7 @@ let convert ?(unsigned = false) ~dst ~src exp =
; term= Term.convert ~unsigned ~dst ~src exp.term } )
|> check invariant
let select typ rcd idx =
{desc= Ap1 (Select idx, typ, rcd); term= Term.select ~rcd:rcd.term ~idx}
|> check invariant
(* comparisons *)
let unsigned typ op x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
@ -477,6 +460,8 @@ let uno typ x y =
{desc= Ap2 (Uno, typ, x, y); term= Term.uno x.term y.term}
|> check invariant
(* arithmetic *)
let add typ x y =
{desc= Ap2 (Add, typ, x, y); term= Term.add x.term y.term}
|> check invariant
@ -505,6 +490,8 @@ let urem typ x y =
{desc= Ap2 (Urem, typ, x, y); term= unsigned typ Term.rem x.term y.term}
|> check invariant
(* boolean / bitwise *)
let and_ typ x y =
{desc= Ap2 (And, typ, x, y); term= Term.and_ x.term y.term}
|> check invariant
@ -513,6 +500,8 @@ let or_ typ x y =
{desc= Ap2 (Or, typ, x, y); term= Term.or_ x.term y.term}
|> check invariant
(* bitwise *)
let xor typ x y =
{desc= Ap2 (Xor, typ, x, y); term= Term.xor x.term y.term}
|> check invariant
@ -529,21 +518,29 @@ let ashr typ x y =
{desc= Ap2 (Ashr, typ, x, y); term= Term.ashr x.term y.term}
|> check invariant
let update typ ~rcd idx ~elt =
{ desc= Ap2 (Update idx, typ, rcd, elt)
; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term }
|> check invariant
(* if-then-else *)
let conditional typ ~cnd ~thn ~els =
{ desc= Ap3 (Conditional, typ, cnd, thn, els)
; term= Term.conditional ~cnd:cnd.term ~thn:thn.term ~els:els.term }
|> check invariant
(* records (struct / array values) *)
let record typ elts =
{ desc= ApN (Record, typ, elts)
; term= Term.record (Vector.map ~f:(fun elt -> elt.term) elts) }
|> check invariant
let select typ rcd idx =
{desc= Ap1 (Select idx, typ, rcd); term= Term.select ~rcd:rcd.term ~idx}
|> check invariant
let update typ ~rcd idx ~elt =
{ desc= Ap2 (Update idx, typ, rcd, elt)
; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term }
|> check invariant
let struct_rec key =
let memo_id = Hashtbl.create key in
let rec_app = (Staged.unstage (Term.rec_app key)) Term.Record in
@ -569,6 +566,27 @@ let struct_rec key =
forcing the recursive thunks also updates this value. *)
{desc= ApN (Struct_rec, typ, elts); term= rec_app ~id Vector.empty}
(** Traverse *)
let fold_exps e ~init ~f =
let fold_exps_ fold_exps_ e z =
let z =
match e.desc with
| Ap1 (_, _, x) -> fold_exps_ x z
| Ap2 (_, _, x, y) -> fold_exps_ y (fold_exps_ x z)
| Ap3 (_, _, w, x, y) -> fold_exps_ w (fold_exps_ y (fold_exps_ x z))
| ApN (_, _, xs) ->
Vector.fold xs ~init:z ~f:(fun z elt -> fold_exps_ elt z)
| _ -> z
in
f z e
in
fix fold_exps_ (fun _ z -> z) e init
let fold_regs e ~init ~f =
fold_exps e ~init ~f:(fun z x ->
match x.desc with Reg _ -> f z (x :> Reg.t) | _ -> z )
(** Query *)
let is_true e =

@ -125,13 +125,21 @@ end
(** Construct *)
(* registers *)
val reg : Reg.t -> t
(* constants *)
val nondet : Typ.t -> string -> t
val label : parent:string -> name:string -> t
val null : t
val bool : bool -> t
val integer : Typ.t -> Z.t -> t
val float : Typ.t -> string -> t
(* type conversions *)
val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t
(* comparisons *)
val eq : Typ.t -> t -> t -> t
val dq : Typ.t -> t -> t -> t
val gt : Typ.t -> t -> t -> t
@ -144,6 +152,8 @@ val ult : Typ.t -> t -> t -> t
val ule : Typ.t -> t -> t -> t
val ord : Typ.t -> t -> t -> t
val uno : Typ.t -> t -> t -> t
(* arithmetic *)
val add : Typ.t -> t -> t -> t
val sub : Typ.t -> t -> t -> t
val mul : Typ.t -> t -> t -> t
@ -151,13 +161,21 @@ val div : Typ.t -> t -> t -> t
val rem : Typ.t -> t -> t -> t
val udiv : Typ.t -> t -> t -> t
val urem : Typ.t -> t -> t -> t
(* boolean / bitwise *)
val and_ : Typ.t -> t -> t -> t
val or_ : Typ.t -> t -> t -> t
(* bitwise *)
val xor : Typ.t -> t -> t -> t
val shl : Typ.t -> t -> t -> t
val lshr : Typ.t -> t -> t -> t
val ashr : Typ.t -> t -> t -> t
(* if-then-else *)
val conditional : Typ.t -> cnd:t -> thn:t -> els:t -> t
(* records (struct / array values) *)
val record : Typ.t -> t vector -> t
val select : Typ.t -> t -> int -> t
val update : Typ.t -> rcd:t -> int -> elt:t -> t
@ -173,15 +191,13 @@ val struct_rec :
one point on each cycle. Failure to obey these requirements will lead to
stack overflow. *)
val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t
(** Access *)
(** Traverse *)
val fold_regs : t -> init:'a -> f:('a -> Reg.t -> 'a) -> 'a
(** Query *)
val term : t -> Term.t
val typ : t -> Typ.t
val is_true : t -> bool
val is_false : t -> bool
val typ : t -> Typ.t

@ -714,7 +714,7 @@ let exception_typs =
the PHIs of [dst] translated to a move. *)
let xlate_jump :
x
-> ?reg_exps:(Reg.reg * Exp.t) list
-> ?reg_exps:(Reg.t * Exp.t) list
-> Llvm.llvalue
-> Llvm.llbasicblock
-> Loc.t
@ -859,7 +859,7 @@ let xlate_instr :
| ConstantExpr -> (
match Llvm.constexpr_opcode maybe_llfunc with
| BitCast -> Llvm.operand maybe_llfunc 0
| IntToPtr -> todo "maybe handle calls with inttoptr" ()
| IntToPtr -> todo "calls with inttoptr" ()
| _ ->
fail "Unknown value in a call instruction %a" pp_llvalue
maybe_llfunc () )

@ -7,6 +7,8 @@
(** Terms *)
[@@@warning "+9"]
module Z = struct
include Z
@ -19,63 +21,48 @@ module rec T : sig
type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp]
type op1 =
(* conversion *)
| Extract of {unsigned: bool; bits: int}
| Convert of {unsigned: bool; dst: Typ.t; src: Typ.t}
(* array/struct *)
| Select of int
[@@deriving compare, equal, hash, sexp]
type op2 =
(* memory *)
| Splat
| Memory
(* comparison *)
| Eq
| Dq
| Lt
| Le
| Ord
| Uno
(* arithmetic *)
| Div
| Rem
(* boolean / bitwise *)
| And
| Or
| Xor
| Shl
| Lshr
| Ashr
(* array/struct *)
| Splat
| Memory
| Update of int
[@@deriving compare, equal, hash, sexp]
type op3 = (* if-then-else *)
| Conditional
[@@deriving compare, equal, hash, sexp]
type op3 = Conditional [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type recN = Record [@@deriving compare, equal, hash, sexp]
type t =
(* nary arithmetic *)
| Add of qset
| Mul of qset
(* nullary *)
| Var of {id: int; name: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
(* application *)
| Ap1 of op1 * t
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t vector
(* recursive application *)
| RecN of recN * t vector (** NOTE: cyclic *)
(* numeric constants *)
| Integer of {data: Z.t}
| Float of {data: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
[@@deriving compare, equal, hash, sexp]
(* Note: solve (and invariant) requires Qset.min_elt to return a
@ -100,8 +87,6 @@ and T0 : sig
[@@deriving compare, equal, hash, sexp]
type op2 =
| Splat
| Memory
| Eq
| Dq
| Lt
@ -116,6 +101,8 @@ and T0 : sig
| Shl
| Lshr
| Ashr
| Splat
| Memory
| Update of int
[@@deriving compare, equal, hash, sexp]
@ -127,8 +114,6 @@ and T0 : sig
| Add of qset
| Mul of qset
| Var of {id: int; name: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
| Ap1 of op1 * t
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
@ -136,6 +121,8 @@ and T0 : sig
| RecN of recN * t vector
| Integer of {data: Z.t}
| Float of {data: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
[@@deriving compare, equal, hash, sexp]
end = struct
type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp]
@ -147,8 +134,6 @@ end = struct
[@@deriving compare, equal, hash, sexp]
type op2 =
| Splat
| Memory
| Eq
| Dq
| Lt
@ -163,6 +148,8 @@ end = struct
| Shl
| Lshr
| Ashr
| Splat
| Memory
| Update of int
[@@deriving compare, equal, hash, sexp]
@ -174,8 +161,6 @@ end = struct
| Add of qset
| Mul of qset
| Var of {id: int; name: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
| Ap1 of op1 * t
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
@ -183,6 +168,8 @@ end = struct
| RecN of recN * t vector
| Integer of {data: Z.t}
| Float of {data: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
[@@deriving compare, equal, hash, sexp]
end
@ -226,13 +213,18 @@ let rec pp ?is_x fs term =
Trace.pp_styled (get_var_style var) "%%%s" fs name
| Var {name; id} as var ->
Trace.pp_styled (get_var_style var) "%%%s_%d" fs name id
| Nondet {msg} -> pf "nondet \"%s\"" msg
| Label {name} -> pf "%s" name
| Ap2 (Splat, byt, siz) -> pf "%a^%a" pp byt pp siz
| Ap2 (Memory, siz, arr) -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp arr
| ApN (Concat, args) -> pf "%a" (Vector.pp "@,^" pp) args
| Integer {data} -> Trace.pp_styled `Magenta "%a" fs Z.pp data
| Float {data} -> pf "%s" data
| Nondet {msg} -> pf "nondet \"%s\"" msg
| Label {name} -> pf "%s" name
| Ap1 (Extract {unsigned; bits}, arg) ->
pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg
| Ap1 (Convert {dst; unsigned= true; src= Integer {bits}}, arg) ->
pf "((%a)(u%i)@ %a)" Typ.pp dst bits pp arg
| Ap1 (Convert {unsigned= true; dst= Integer {bits}; src}, arg) ->
pf "((u%i)(%a)@ %a)" bits Typ.pp src pp arg
| Ap1 (Convert {dst; src}, arg) ->
pf "((%a)(%a)@ %a)" Typ.pp dst Typ.pp src pp arg
| Ap2 (Eq, x, y) -> pf "(%a@ = %a)" pp x pp y
| Ap2 (Dq, x, y) -> pf "(%a@ @<2>≠ %a)" pp x pp y
| Ap2 (Lt, x, y) -> pf "(%a@ < %a)" pp x pp y
@ -266,21 +258,17 @@ let rec pp ?is_x fs term =
| Ap2 (Ashr, x, y) -> pf "(%a@ ashr %a)" pp x pp y
| Ap3 (Conditional, cnd, thn, els) ->
pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els
| Ap2 (Splat, byt, siz) -> pf "%a^%a" pp byt pp siz
| Ap2 (Memory, siz, arr) -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp arr
| ApN (Concat, args) -> pf "%a" (Vector.pp "@,^" pp) args
| ApN (Record, elts) -> pf "{%a}" pp_record elts
| RecN (Record, elts) -> pf "{|%a|}" (Vector.pp ",@ " pp) elts
| Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx
| Ap2 (Update idx, rcd, elt) ->
pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt
| ApN (Record, elts) -> pf "{%a}" pp_record elts
| RecN (Record, elts) -> pf "{|%a|}" (Vector.pp ",@ " pp) elts
| Ap1 (Extract {unsigned; bits}, arg) ->
pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg
| Ap1 (Convert {dst; unsigned= true; src= Integer {bits}}, arg) ->
pf "((%a)(u%i)@ %a)" Typ.pp dst bits pp arg
| Ap1 (Convert {unsigned= true; dst= Integer {bits}; src}, arg) ->
pf "((u%i)(%a)@ %a)" bits Typ.pp src pp arg
| Ap1 (Convert {dst; src}, arg) ->
pf "((%a)(%a)@ %a)" Typ.pp dst Typ.pp src pp arg
in
fix_flip pp_ (fun _ _ -> ()) fs term
[@@warning "-9"]
and pp_record fs elts =
[%Trace.fprintf
@ -356,28 +344,15 @@ let invariant e =
Invariant.invariant [%here] e [%sexp_of: t]
@@ fun () ->
match e with
| Var _ | Nondet _ | Label _ | Integer _ | Float _ -> ()
| Ap1 (Extract _, _) -> ()
| Ap1 (Convert {dst; src}, _) -> assert (Typ.convertible src dst)
| Add _ -> assert_polynomial e |> Fn.id
| Mul _ -> assert_monomial e |> Fn.id
| Ap2
( ( Eq | Dq | Lt | Le | Ord | Uno | Div | Rem | And | Or | Xor | Shl
| Lshr | Ashr )
, _
, _ ) ->
()
| ApN (Concat, args) -> assert (Vector.length args <> 1)
| Ap2 (Splat, _, siz) -> (
match siz with
| Integer {data} -> assert (not (Z.equal Z.zero data))
| _ -> () )
| Ap2 (Memory, _, _) -> ()
| Ap1 (Select _, _) -> ()
| Ap3 (Conditional, _, _, _) -> ()
| Ap2 (Update _, _, _) -> ()
| Ap2 (Splat, _, Integer {data}) -> assert (not (Z.equal Z.zero data))
| ApN (Concat, mems) -> assert (Vector.length mems <> 1)
| ApN (Record, elts) | RecN (Record, elts) ->
assert (not (Vector.is_empty elts))
| Ap1 (Convert {dst; src}, _) -> assert (Typ.convertible src dst)
| _ -> ()
[@@warning "-9"]
(** Variables are the terms constructed by [Var] *)
module Var = struct
@ -415,8 +390,8 @@ module Var = struct
Invariant.invariant [%here] x [%sexp_of: t]
@@ fun () -> match x with Var _ -> invariant x | _ -> assert false
let id = function Var {id} -> id | x -> violates invariant x
let name = function Var {name} -> name | x -> violates invariant x
let id = function Var v -> v.id | x -> violates invariant x
let name = function Var v -> v.name | x -> violates invariant x
let of_term = function
| Var _ as v -> Some (v |> check invariant)
@ -490,34 +465,14 @@ module Var = struct
end
end
let fold_terms e ~init ~f =
let fold_terms_ fold_terms_ e s =
let s =
match e with
| Ap1 (_, x) -> fold_terms_ x s
| Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s)
| Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s))
| ApN (_, xs) | RecN (_, xs) ->
Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s
| Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s)
| _ -> s
in
f s e
in
fix fold_terms_ (fun _ s -> s) e init
(** Construct *)
let fold_vars e ~init ~f =
fold_terms e ~init ~f:(fun z -> function
| Var _ as v -> f z (v :> Var.t) | _ -> z )
(* variables *)
let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty
let var x = x
(** Construct *)
(* constants *)
let var x = x
let nondet msg = Nondet {msg} |> check invariant
let label ~parent ~name = Label {parent; name} |> check invariant
let integer data = Integer {data} |> check invariant
let null = integer Z.zero
let zero = integer Z.zero
@ -527,6 +482,10 @@ let bool b = integer (Z.of_bool b)
let true_ = bool true
let false_ = bool false
let float data = Float {data} |> check invariant
let nondet msg = Nondet {msg} |> check invariant
let label ~parent ~name = Label {parent; name} |> check invariant
(* type conversions *)
let simp_extract ~unsigned bits arg =
match arg with
@ -541,47 +500,7 @@ let simp_convert ~unsigned dst src arg =
integer (Z.extract ~unsigned (min m n) data)
| _ -> Ap1 (Convert {unsigned; dst; src}, arg)
let simp_record elts = ApN (Record, elts)
let simp_select idx rcd = Ap1 (Select idx, rcd)
let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt)
let simp_concat xs =
if Vector.length xs = 1 then Vector.get xs 0
else
let args =
if
Vector.for_all xs ~f:(function
| ApN (Concat, _) -> false
| _ -> true )
then xs
else
Vector.concat
(Vector.fold_right xs ~init:[] ~f:(fun x s ->
match x with
| ApN (Concat, args) -> args :: s
| x -> Vector.of_array [|x|] :: s ))
in
ApN (Concat, args)
let simp_splat byt siz =
match siz with
| Integer {data} when Z.equal Z.zero data -> simp_concat Vector.empty
| _ -> Ap2 (Splat, byt, siz)
let simp_memory siz arr = Ap2 (Memory, siz, arr)
let simp_lt x y =
match (x, y) with
| Integer {data= i}, Integer {data= j} -> bool (Z.lt i j)
| _ -> Ap2 (Lt, x, y)
let simp_le x y =
match (x, y) with
| Integer {data= i}, Integer {data= j} -> bool (Z.leq i j)
| _ -> Ap2 (Le, x, y)
let simp_ord x y = Ap2 (Ord, x, y)
let simp_uno x y = Ap2 (Uno, x, y)
(* arithmetic *)
let sum_mul_const const sum =
assert (not (Q.equal Q.zero const)) ;
@ -731,6 +650,8 @@ let simp_sub x y =
(* x - y ==> x + (-1 * y) *)
| _ -> simp_add2 x (simp_negate y)
(* if-then-else *)
let simp_cond cnd thn els =
match cnd with
(* ¬(true ? t : e) ==> t *)
@ -739,6 +660,17 @@ let simp_cond cnd thn els =
| Integer {data} when Z.is_false data -> els
| _ -> Ap3 (Conditional, cnd, thn, els)
(* boolean / bitwise *)
let rec is_boolean = function
| Ap1 ((Extract {bits= 1; _} | Convert {dst= Integer {bits= 1; _}; _}), _)
|Ap2 ((Eq | Dq | Lt | Le), _, _) ->
true
| Ap2 ((Div | Rem | And | Or | Xor | Shl | Lshr | Ashr), x, y)
|Ap3 (Conditional, _, x, y) ->
is_boolean x || is_boolean y
| _ -> false
let rec simp_and x y =
match (x, y) with
(* i && j *)
@ -773,42 +705,22 @@ let rec simp_or x y =
| _ when equal x y -> x
| _ -> Ap2 (Or, x, y)
let rec is_boolean = function
| Ap1 ((Extract {bits= 1} | Convert {dst= Integer {bits= 1}}), _)
|Ap2 ((Eq | Dq | Lt | Le), _, _) ->
true
| Ap2 ((Div | Rem | And | Or | Xor | Shl | Lshr | Ashr), x, y)
|Ap3 (Conditional, _, x, y) ->
is_boolean x || is_boolean y
| _ -> false
(* comparison *)
let rec simp_not term =
match term with
(* ¬(x = y) ==> x ≠ y *)
| Ap2 (Eq, x, y) -> simp_dq x y
(* ¬(x ≠ y) ==> x = y *)
| Ap2 (Dq, x, y) -> simp_eq x y
(* ¬(x < y) ==> y <= x *)
| Ap2 (Lt, x, y) -> simp_le y x
(* ¬(x <= y) ==> y < x *)
| Ap2 (Le, x, y) -> simp_lt y x
(* ¬(x ≠ nan ∧ y ≠ nan) ==> x = nan y = nan *)
| Ap2 (Ord, x, y) -> simp_uno x y
(* ¬(x = nan y = nan) ==> x ≠ nan ∧ y ≠ nan *)
| Ap2 (Uno, x, y) -> simp_ord x y
(* ¬(a ∧ b) ==> ¬a ¬b *)
| Ap2 (And, x, y) -> simp_or (simp_not x) (simp_not y)
(* ¬(a b) ==> ¬a ∧ ¬b *)
| Ap2 (Or, x, y) -> simp_and (simp_not x) (simp_not y)
(* ¬(c ? t : e) ==> c ? ¬t : ¬e *)
| Ap3 (Conditional, cnd, thn, els) ->
simp_cond cnd (simp_not thn) (simp_not els)
(* ¬i ==> -i-1 *)
| Integer {data} -> integer (Z.lognot data)
(* ¬e ==> true xor e *)
| e -> Ap2 (Xor, true_, e)
let simp_lt x y =
match (x, y) with
| Integer {data= i}, Integer {data= j} -> bool (Z.lt i j)
| _ -> Ap2 (Lt, x, y)
and simp_eq x y =
let simp_le x y =
match (x, y) with
| Integer {data= i}, Integer {data= j} -> bool (Z.leq i j)
| _ -> Ap2 (Le, x, y)
let simp_ord x y = Ap2 (Ord, x, y)
let simp_uno x y = Ap2 (Uno, x, y)
let rec simp_eq x y =
match (x, y) with
(* i = j *)
| Integer {data= i}, Integer {data= j} -> bool (Z.equal i j)
@ -835,6 +747,35 @@ and simp_dq x y =
| Ap2 (Eq, x, y) -> Ap2 (Dq, x, y)
| b -> simp_not b )
(* negation-normal form *)
and simp_not term =
match term with
(* ¬(x = y) ==> x ≠ y *)
| Ap2 (Eq, x, y) -> simp_dq x y
(* ¬(x ≠ y) ==> x = y *)
| Ap2 (Dq, x, y) -> simp_eq x y
(* ¬(x < y) ==> y <= x *)
| Ap2 (Lt, x, y) -> simp_le y x
(* ¬(x <= y) ==> y < x *)
| Ap2 (Le, x, y) -> simp_lt y x
(* ¬(x ≠ nan ∧ y ≠ nan) ==> x = nan y = nan *)
| Ap2 (Ord, x, y) -> simp_uno x y
(* ¬(x = nan y = nan) ==> x ≠ nan ∧ y ≠ nan *)
| Ap2 (Uno, x, y) -> simp_ord x y
(* ¬(a ∧ b) ==> ¬a ¬b *)
| Ap2 (And, x, y) -> simp_or (simp_not x) (simp_not y)
(* ¬(a b) ==> ¬a ∧ ¬b *)
| Ap2 (Or, x, y) -> simp_and (simp_not x) (simp_not y)
(* ¬(c ? t : e) ==> c ? ¬t : ¬e *)
| Ap3 (Conditional, cnd, thn, els) ->
simp_cond cnd (simp_not thn) (simp_not els)
(* ¬i ==> -i-1 *)
| Integer {data} -> integer (Z.lognot data)
(* ¬e ==> true xor e *)
| e -> Ap2 (Xor, true_, e)
(* bitwise *)
let simp_xor x y =
match (x, y) with
(* i xor j *)
@ -871,26 +812,62 @@ let simp_ashr x y =
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> Ap2 (Ashr, x, y)
(** Access *)
(* memory *)
let iter e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x ; f y
| Ap3 (_, x, y, z) -> f x ; f y ; f z
| ApN (_, xs) | RecN (_, xs) -> Vector.iter ~f xs
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| _ -> ()
let simp_concat xs =
if Vector.length xs = 1 then Vector.get xs 0
else
let args =
if
Vector.for_all xs ~f:(function
| ApN (Concat, _) -> false
| _ -> true )
then xs
else
Vector.concat
(Vector.fold_right xs ~init:[] ~f:(fun x s ->
match x with
| ApN (Concat, args) -> args :: s
| x -> Vector.of_array [|x|] :: s ))
in
ApN (Concat, args)
let fold e ~init:s ~f =
match e with
| Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) | RecN (_, xs) ->
Vector.fold ~f:(fun s x -> f x s) xs ~init:s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| _ -> s
let simp_splat byt siz =
match siz with
| Integer {data} when Z.equal Z.zero data -> simp_concat Vector.empty
| _ -> Ap2 (Splat, byt, siz)
let simp_memory siz arr = Ap2 (Memory, siz, arr)
(* records *)
let simp_record elts = ApN (Record, elts)
let simp_select idx rcd = Ap1 (Select idx, rcd)
let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt)
let rec_app key =
let memo_id = Hashtbl.create key in
let dummy = null in
Staged.stage
@@ fun ~id op elt_thks ->
match Hashtbl.find memo_id id with
| None ->
(* Add placeholder to prevent computing [elts] in calls to [rec_app]
from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(Vector.length elt_thks) dummy in
let elts = Vector.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ;
Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
RecN (op, elts) |> check invariant
| Some elts ->
(* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *)
RecN (op, elts)
(* dispatching for normalization and invariant checking *)
let norm1 op x =
( match op with
@ -927,63 +904,43 @@ let normN op xs =
(match op with Concat -> simp_concat xs | Record -> simp_record xs)
|> check invariant
let addN args = simp_add args |> check invariant
let mulN args = simp_mul args |> check invariant
let concat xs = normN Concat (Vector.of_array xs)
let splat ~byt ~siz = norm2 Splat byt siz
let memory ~siz ~arr = norm2 Memory siz arr
(* exposed interface *)
let extract ?(unsigned = false) ~bits term =
norm1 (Extract {unsigned; bits}) term
let convert ?(unsigned = false) ~dst ~src term =
norm1 (Convert {unsigned; dst; src}) term
let eq = norm2 Eq
let dq = norm2 Dq
let lt = norm2 Lt
let le = norm2 Le
let ord = norm2 Ord
let uno = norm2 Uno
let neg = simp_negate
let add = simp_add2
let sub = simp_sub
let mul = simp_mul2
let neg e = simp_negate e |> check invariant
let add e f = simp_add2 e f |> check invariant
let addN args = simp_add args |> check invariant
let sub e f = simp_sub e f |> check invariant
let mul e f = simp_mul2 e f |> check invariant
let mulN args = simp_mul args |> check invariant
let div = norm2 Div
let rem = norm2 Rem
let and_ = norm2 And
let or_ = norm2 Or
let not_ e = simp_not e |> check invariant
let xor = norm2 Xor
let not_ = simp_not
let shl = norm2 Shl
let lshr = norm2 Lshr
let ashr = norm2 Ashr
let conditional ~cnd ~thn ~els = norm3 Conditional cnd thn els
let splat ~byt ~siz = norm2 Splat byt siz
let memory ~siz ~arr = norm2 Memory siz arr
let concat xs = normN Concat (Vector.of_array xs)
let record elts = normN Record elts
let select ~rcd ~idx = norm1 (Select idx) rcd
let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt
let rec_app key =
let memo_id = Hashtbl.create key in
let dummy = null in
Staged.stage
@@ fun ~id op elt_thks ->
match Hashtbl.find memo_id id with
| None ->
(* Add placeholder to prevent computing [elts] in calls to [rec_app]
from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(Vector.length elt_thks) dummy in
let elts = Vector.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ;
Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
RecN (op, elts) |> check invariant
| Some elts ->
(* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *)
RecN (op, elts)
let extract ?(unsigned = false) ~bits term =
norm1 (Extract {unsigned; bits}) term
let convert ?(unsigned = false) ~dst ~src term =
norm1 (Convert {unsigned; dst; src}) term
let size_of t =
Option.bind (Typ.prim_bit_size_of t) ~f:(fun n ->
if n % 8 = 0 then Some (integer (Z.of_int (n / 8))) else None )
@ -1062,8 +1019,51 @@ let rename sub e =
| Var _ as v -> Some (Var.Subst.apply sub v)
| _ -> None )
(** Traverse *)
let iter e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x ; f y
| Ap3 (_, x, y, z) -> f x ; f y ; f z
| ApN (_, xs) | RecN (_, xs) -> Vector.iter ~f xs
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| _ -> ()
let fold e ~init:s ~f =
match e with
| Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) | RecN (_, xs) ->
Vector.fold ~f:(fun s x -> f x s) xs ~init:s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| _ -> s
let fold_terms e ~init ~f =
let fold_terms_ fold_terms_ e s =
let s =
match e with
| Ap1 (_, x) -> fold_terms_ x s
| Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s)
| Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s))
| ApN (_, xs) | RecN (_, xs) ->
Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s
| Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s)
| _ -> s
in
f s e
in
fix fold_terms_ (fun _ s -> s) e init
let fold_vars e ~init ~f =
fold_terms e ~init ~f:(fun z -> function
| Var _ as v -> f z (v :> Var.t) | _ -> z )
(** Query *)
let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty
let is_true = function Integer {data} -> Z.is_true data | _ -> false
let is_false = function Integer {data} -> Z.is_false data | _ -> false
@ -1082,7 +1082,7 @@ let classify = function
| Add _ | Mul _ -> `Interpreted
| Ap2 ((Eq | Dq), _, _) -> `Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> `Uninterpreted
| RecN _ | Var _ | Nondet _ | Label _ | Integer _ | Float _ -> `Atomic
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> `Atomic
let solve e f =
[%Trace.call fun {pf} -> pf "%a@ %a" pp e pp f]

@ -8,13 +8,7 @@
(** Terms
Pure (heap-independent) terms are complex arithmetic, bitwise-logical,
etc. operations over literal values and variables.
Terms for operations that are uninterpreted in the analyzer are
represented in curried form, where [App] is an application of a function
symbol to an argument. This is done to simplify the definition of
'subterm' and make it explicit. The specific constructor functions
indicate and check the expected arity of the function symbols. *)
etc. operations over literal values and variables. *)
type comparator_witness
@ -32,8 +26,6 @@ type op1 =
[@@deriving compare, equal, hash, sexp]
type op2 =
| Splat (** Iterated concatenation of a single byte *)
| Memory (** Size-tagged byte-array *)
| Eq (** Equal test *)
| Dq (** Disequal test *)
| Lt (** Less-than test *)
@ -48,6 +40,8 @@ type op2 =
| Shl (** Shift left, bitwise *)
| Lshr (** Logical shift right, bitwise *)
| Ashr (** Arithmetic shift right, bitwise *)
| Splat (** Iterated concatenation of a single byte *)
| Memory (** Size-tagged byte-array *)
| Update of int (** Constant record with updated index *)
[@@deriving compare, equal, hash, sexp]
@ -59,42 +53,38 @@ type opN =
| Record (** Record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
type recN =
| Record
(** Record constant that may recursively refer to itself
(transitively) from its args. NOTE: represented by cyclic values. *)
type recN = Record (** Recursive record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
type qset = (t, comparator_witness) Qset.t
and t = private
| Add of qset (** Addition *)
| Mul of qset (** Multiplication *)
| Add of qset (** Sum of terms with rational coefficients *)
| Mul of qset (** Product of terms with rational exponents *)
| Var of {id: int; name: string} (** Local variable / virtual register *)
| Ap1 of op1 * t (** Unary application *)
| Ap2 of op2 * t * t (** Binary application *)
| Ap3 of op3 * t * t * t (** Ternary application *)
| ApN of opN * t vector (** N-ary application *)
| RecN of recN * t vector
(** Recursive n-ary application, may recursively refer to itself
(transitively) from its args. NOTE: represented by cyclic values. *)
| Integer of {data: Z.t} (** Integer constant *)
| Float of {data: string} (** Floating-point constant *)
| Nondet of {msg: string}
(** Anonymous local variable with arbitrary value, representing
non-deterministic approximation of value described by [msg] *)
| Label of {parent: string; name: string}
(** Address of named code block within parent function *)
| Ap1 of op1 * t
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t vector
| RecN of recN * t vector
| Integer of {data: Z.t}
(** Integer constant, or if [typ] is a [Pointer], null pointer value
that never refers to an object *)
| Float of {data: string} (** Floating-point constant *)
[@@deriving compare, equal, hash, sexp]
val comparator : (t, comparator_witness) Comparator.t
type term = t
val pp_full : ?is_x:(term -> bool) -> t pp
val pp_full : ?is_x:(t -> bool) -> t pp
val pp : t pp
val invariant : t -> unit
type term = t
(** Term.Var is re-exported as Var *)
module Var : sig
type t = private term [@@deriving compare, equal, hash, sexp]
@ -140,65 +130,86 @@ end
(** Construct *)
(* variables *)
val var : Var.t -> t
(* constants *)
val nondet : string -> t
val label : parent:string -> name:string -> t
val null : t
val bool : bool -> t
val true_ : t
val false_ : t
val null : t
val integer : Z.t -> t
val zero : t
val one : t
val minus_one : t
val splat : byt:t -> siz:t -> t
val memory : siz:t -> arr:t -> t
val concat : t array -> t
val bool : bool -> t
val integer : Z.t -> t
val rational : Q.t -> t
val float : string -> t
(* type conversions *)
val extract : ?unsigned:bool -> bits:int -> t -> t
val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t
(* comparisons *)
val eq : t -> t -> t
val dq : t -> t -> t
val lt : t -> t -> t
val le : t -> t -> t
val ord : t -> t -> t
val uno : t -> t -> t
(* arithmetic *)
val neg : t -> t
val add : t -> t -> t
val sub : t -> t -> t
val mul : t -> t -> t
val div : t -> t -> t
val rem : t -> t -> t
(* boolean / bitwise *)
val and_ : t -> t -> t
val or_ : t -> t -> t
val xor : t -> t -> t
val not_ : t -> t
(* bitwise *)
val xor : t -> t -> t
val shl : t -> t -> t
val lshr : t -> t -> t
val ashr : t -> t -> t
(* if-then-else *)
val conditional : cnd:t -> thn:t -> els:t -> t
(* memory contents *)
val splat : byt:t -> siz:t -> t
val memory : siz:t -> arr:t -> t
val concat : t array -> t
(* records (struct / array values) *)
val record : t vector -> t
val select : rcd:t -> idx:int -> t
val update : rcd:t -> idx:int -> elt:t -> t
val extract : ?unsigned:bool -> bits:int -> t -> t
val convert : ?unsigned:bool -> dst:Typ.t -> src:Typ.t -> t -> t
val size_of : Typ.t -> t option
(* recursive n-ary application *)
val rec_app :
(module Hashtbl.Key with type t = 'id)
-> (id:'id -> recN -> t lazy_t vector -> t) Staged.t
(** Access *)
val iter : t -> f:(t -> unit) -> unit
val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a
val fold_terms : t -> init:'a -> f:('a -> t -> 'a) -> 'a
val fold : t -> init:'a -> f:(t -> 'a -> 'a) -> 'a
val size_of : Typ.t -> t option
(** Transform *)
val map : t -> f:(t -> t) -> t
val rename : Var.Subst.t -> t -> t
(** Traverse *)
val iter : t -> f:(t -> unit) -> unit
val fold : t -> init:'a -> f:(t -> 'a -> 'a) -> 'a
val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a
val fold_terms : t -> init:'a -> f:('a -> t -> 'a) -> 'a
(** Query *)
val fv : t -> Var.Set.t

Loading…
Cancel
Save