diff --git a/sledge/src/llair/exp.ml b/sledge/src/llair/exp.ml index e11783ae3..62ccb7b3d 100644 --- a/sledge/src/llair/exp.ml +++ b/sledge/src/llair/exp.ml @@ -7,476 +7,596 @@ (** Expressions *) -type t = - | Var of {name: string; typ: Typ.t; loc: Loc.t} - | Global of {name: string; init: t option; typ: Typ.t; loc: Loc.t} - | Nondet of {typ: Typ.t; loc: Loc.t; msg: string} - | Label of {parent: string; name: string; loc: Loc.t} - | Null of {typ: Typ.t} - | App of {op: t; arg: t; loc: Loc.t} - | AppN of {op: t; args: t vector; loc: Loc.t} - (* NOTE: may be cyclic *) - | PtrFld of {fld: int} - | PtrIdx - | PrjFld of {fld: int} - | PrjIdx - | UpdFld of {fld: int} - | UpdIdx - | Integer of {data: Z.t; typ: Typ.t} - | Float of {data: string; typ: Typ.t} - | Array of {typ: Typ.t} - | Struct of {typ: Typ.t} - | Cast of {typ: Typ.t} - | Conv of {signed: bool; typ: Typ.t} - | Select - (* binary: comparison *) - | Eq - | Ne - | Gt - | Ge - | Lt - | Le - | Ugt - | Uge - | Ult - | Ule - | Ord - | Uno - (* binary: boolean / bitwise *) - | And - | Or - | Xor - | Shl - | LShr - | AShr - (* binary: arithmetic *) - | Add - | Sub - | Mul - | Div - | UDiv - | Rem - | URem -[@@deriving compare, sexp] - -let equal = [%compare.equal: t] - -let uncurry exp = - let rec uncurry_ args op = - match op with - | App {op; arg} -> uncurry_ (arg :: args) op - | AppN {op; args} -> (op, Vector.to_list args) - | _ -> (op, args) - in - uncurry_ [] exp - - -let rec fmt ff exp = - let pf fmt = - Format.pp_open_box ff 2 ; - Format.kfprintf (fun ff -> Format.pp_close_box ff ()) ff fmt - in - match[@warning "p"] uncurry exp with - | Var {name}, [] -> pf "%%%s" name - | Global {name}, [] -> - pf "@%s%t" name (fun ff -> - let demangled = Llvm.demangle name in - if not (String.is_empty demangled || String.equal name demangled) - then Format.fprintf ff "“%s”" demangled ) - | Nondet {msg}, [] -> pf "nondet \"%s\"" msg - | Label {name}, [] -> pf "%s" name - | Null _, [] -> pf "null" - | Integer {data}, [] -> pf "%a" Z.pp_print data - | Float {data}, [] -> pf "%s" data - | PtrFld {fld}, [ptr] -> pf "%a ⊕ %i" fmt ptr fld - | PtrIdx, [arr; idx] -> pf "%a ⊕ %a" fmt arr fmt idx - | PrjFld {fld}, [ptr] -> pf "%a[%i]" fmt ptr fld - | PrjIdx, [arr; idx] -> pf "%a[%a]" fmt arr fmt idx - | UpdFld {fld}, [agg; elt] -> - pf "{%a@ @[| %i → %a@]}" fmt agg fld fmt elt - | UpdIdx, [agg; elt; idx] -> pf "[%a | %a → %a]" fmt agg fmt idx fmt elt - | Array _, elts -> pf "[%a]" (list_fmt ",@ " fmt) elts - | Struct _, elts -> pf "{%a}" (list_fmt ",@ " fmt) elts - | Cast {typ}, [arg] -> pf "(@[(%a)@ %a@])" Typ.fmt typ fmt arg - | Conv {typ}, [arg] -> pf "(@[%c%a>@ %a@])" '<' Typ.fmt typ fmt arg - | Select, [cnd; thn; els] -> pf "(%a@ ? %a@ : %a)" fmt cnd fmt thn fmt els - | Eq, [x; y] -> pf "(%a@ = %a)" fmt x fmt y - | Ne, [x; y] -> pf "(%a@ != %a)" fmt x fmt y - | Gt, [x; y] -> pf "(%a@ > %a)" fmt x fmt y - | Ge, [x; y] -> pf "(%a@ >= %a)" fmt x fmt y - | Lt, [x; y] -> pf "(%a@ < %a)" fmt x fmt y - | Le, [x; y] -> pf "(%a@ <= %a)" fmt x fmt y - | Ugt, [x; y] -> pf "(%a@ u> %a)" fmt x fmt y - | Uge, [x; y] -> pf "(%a@ u>= %a)" fmt x fmt y - | Ult, [x; y] -> pf "(%a@ u< %a)" fmt x fmt y - | Ule, [x; y] -> pf "(%a@ u<= %a)" fmt x fmt y - | Ord, [x; y] -> pf "(%a@ ord %a)" fmt x fmt y - | Uno, [x; y] -> pf "(%a@ uno %a)" fmt x fmt y - | And, [x; y] -> pf "(%a@ && %a)" fmt x fmt y - | Or, [x; y] -> pf "(%a@ || %a)" fmt x fmt y - | Xor, [x; y] -> pf "(%a@ ^ %a)" fmt x fmt y - | Shl, [x; y] -> pf "(%a@ << %a)" fmt x fmt y - | LShr, [x; y] -> pf "(%a@ >> %a)" fmt x fmt y - | AShr, [x; y] -> pf "(%a@ >>a %a)" fmt x fmt y - | Add, [x; y] -> pf "(%a@ + %a)" fmt x fmt y - | Sub, [x; y] -> pf "(%a@ - %a)" fmt x fmt y - | Mul, [x; y] -> pf "(%a@ * %a)" fmt x fmt y - | Div, [x; y] -> pf "(%a@ / %a)" fmt x fmt y - | UDiv, [x; y] -> pf "(%a@ u/ %a)" fmt x fmt y - | Rem, [x; y] -> pf "(%a@ %% %a)" fmt x fmt y - | URem, [x; y] -> pf "(%a@ u%% %a)" fmt x fmt y - - -(** Queries *) - -let rec typ_of : t -> Typ.t = function[@warning "p"] - | Var {typ} - |Global {typ} - |Nondet {typ} - |Null {typ} - |Integer {typ} - |Float {typ} - |Array {typ} - |Struct {typ} - |App {op= Cast {typ} | Conv {typ}} -> - typ - | Label _ -> Typ.i8p - | App {op= PtrFld {fld}; arg} -> ( - match[@warning "p"] typ_of arg - with Pointer {elt= Tuple {elts} | Struct {elts}} -> - Typ.mkPointer ~elt:(Vector.get elts fld) ) - | App {op= App {op= PtrIdx; arg}} -> ( - match[@warning "p"] typ_of arg with Pointer {elt= Array {elt}} -> - Typ.mkPointer ~elt ) - | App {op= PrjFld {fld}; arg} -> ( - match[@warning "p"] typ_of arg with - | Tuple {elts} | Struct {elts} -> Vector.get elts fld ) - | App {op= App {op= PrjIdx; arg}} -> ( - match[@warning "p"] typ_of arg with Array {elt} -> elt ) - | App {op= App {op= UpdFld _; arg}} - |App {op= App {op= App {op= UpdIdx; arg}}} -> - typ_of arg - | App - { op= - App - { op= - ( Eq | Ne | Gt | Ge | Lt | Le | Ugt | Uge | Ult | Ule | Ord - | Uno ) } } -> - Typ.i1 - | App - { op= - App - { op= - ( And | Or | Xor | Shl | LShr | AShr | Add | Sub | Mul | Div - | UDiv | Rem | URem ) - ; arg } } - |App {op= App {op= App {op= Select}}; arg} -> - typ_of arg - | AppN {op} -> typ_of op - - -let valid_fld fld elts = 0 <= fld && fld < Vector.length elts - -(** Re-exported modules *) - -(* Variables are the expressions constructed by [Var] *) -module Var = struct - module T = struct - type nonrec t = t [@@deriving compare, sexp] - - let equal = equal - end - - include T - include Comparator.Make (T) - - let fmt = fmt - - let mk ?(loc = Loc.none) name typ = Var {name; typ; loc} +module T0 = struct + type t = + | Var of {id: int; name: string} + | Nondet of {msg: string} + | Label of {parent: string; name: string} + | App of {op: t; arg: t} + (* pointer and memory constants and operations *) + | Null + | Splat + | Memory + | Concat + (* numeric constants *) + | Integer of {data: Z.t} + | Float of {data: string} + (* binary: comparison *) + | Eq + | Dq + | Gt + | Ge + | Lt + | Le + | Ugt + | Uge + | Ult + | Ule + | Ord + | Uno + (* binary: arithmetic, numeric and pointer *) + | Add + | Sub + | Mul + | Div + | Udiv + | Rem + | Urem + (* binary: boolean / bitwise *) + | And + | Or + | Xor + | Shl + | Lshr + | Ashr + (* ternary: conditional *) + | Conditional + (* array/struct constants and operations *) + | Record + | Select + | Update + | Struct_rec of {elts: t vector} (** NOTE: may be cyclic *) + (* unary: conversion *) + | Convert of {signed: bool; dst: Typ.t; src: Typ.t} + [@@deriving compare, hash, sexp] + + let equal = [%compare.equal: t] +end - let name = function[@warning "p"] Var {name} -> name +module T = struct + include T0 + include Comparator.Make (T0) + + let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = + let rec fix_f seen e = + match e with + | Struct_rec _ -> + if List.mem ~equal:( == ) seen e then f bot e + else f (fix_f (e :: seen)) e + | _ -> f (fix_f seen) e + in + let rec fix_f_seen_nil e = + match e with + | Struct_rec _ -> f (fix_f [e]) e + | _ -> f fix_f_seen_nil e + in + fix_f_seen_nil e - let typ = function[@warning "p"] Var {typ} -> typ + let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) + = + fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z - let loc = function[@warning "p"] Var {loc} -> loc + let uncurry = + let rec uncurry_ args = function + | App {op; arg} -> uncurry_ (arg :: args) op + | op -> (op, args) + in + uncurry_ [] + + let pp fs exp = + let pp_ pp fs exp = + let pf fmt = + Format.pp_open_box fs 2 ; + Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt + in + match exp with + | Var {name; id= 0} -> pf "%%%s" name + | Var {name; id} -> pf "%%%s_%d" name id + | Nondet {msg} -> pf "nondet \"%s\"" msg + | Label {name} -> pf "%s" name + | Null -> pf "null" + | Splat -> pf "^" + | Memory -> pf "⟨_,_⟩" + | App {op= App {op= Memory; arg= siz}; arg= bytes} -> + pf "@<1>⟨%a,%a@<1>⟩" pp siz pp bytes + | Concat -> pf "^" + | Integer {data} -> pf "%a" Z.pp_print data + | Float {data} -> pf "%s" data + | Eq -> pf "=" + | Dq -> pf "!=" + | Gt -> pf ">" + | Ge -> pf ">=" + | Lt -> pf "<" + | Le -> pf "<=" + | Ugt -> pf "u>" + | Uge -> pf "u>=" + | Ult -> pf "u<" + | Ule -> pf "u<=" + | Ord -> pf "ord" + | Uno -> pf "uno" + | Add -> pf "+" + | Sub -> pf "-" + | Mul -> pf "*" + | Div -> pf "/" + | Udiv -> pf "udiv" + | Rem -> pf "rem" + | Urem -> pf "urem" + | And -> pf "&&" + | Or -> pf "||" + | Xor -> pf "xor" + | App {op= App {op= Xor; arg}; arg= Integer {data}} + when Z.equal Z.minus_one data -> + pf "¬%a" pp arg + | App {op= App {op= Xor; arg= Integer {data}}; arg} + when Z.equal Z.minus_one data -> + pf "¬%a" pp arg + | Shl -> pf "shl" + | Lshr -> pf "lshr" + | Ashr -> pf "ashr" + | Conditional -> pf "(_?_:_)" + | App + {op= App {op= App {op= Conditional; arg= cnd}; arg= thn}; arg= els} + -> + pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els + | Select -> pf "_[_]" + | App {op= App {op= Select; arg= rcd}; arg= idx} -> + pf "%a[%a]" pp rcd pp idx + | Update -> pf "[_|_→_]" + | App {op= App {op= App {op= Update; arg= rcd}; arg= elt}; arg= idx} + -> + pf "[%a@ @[| %a → %a@]]" pp rcd pp idx pp elt + | Record -> pf "{_}" + | App {op; arg} -> ( + match uncurry exp with + | Record, elts -> pf "{@[%a@]}" (List.pp ",@ " pp) elts + | op, [x; y] -> pf "(%a@ %a %a)" pp x pp op pp y + | _ -> pf "(%a@ %a)" pp op pp arg ) + | Struct_rec {elts} -> pf "{|%a|}" (Vector.pp ",@ " pp) elts + | Convert {dst; src} -> pf "(%a)(%a)" Typ.pp dst Typ.pp src + in + fix_flip pp_ (fun _ _ -> ()) fs exp end -(* Globals are the expressions constructed by [Global] *) -module Global = struct - type init = t +include T - module T = struct - type nonrec t = t [@@deriving compare, sexp] +type exp = t - let equal = equal - - let hash = Hashtbl.hash - end +(** Invariant *) +let invariant ?(partial = false) e = + Invariant.invariant [%here] e [%sexp_of: t] + @@ fun () -> + let op, args = uncurry e in + let assert_arity arity = + let nargs = List.length args in + assert (nargs = arity || (partial && nargs < arity)) + in + match op with + | Var _ | Nondet _ | Label _ | Null | Integer _ | Float _ -> + assert_arity 0 + | Convert {dst; src} -> + assert (Typ.convertible src dst) ; + assert_arity 1 + | Splat | Memory | Concat | Eq | Dq | Gt | Ge | Lt | Le | Ugt | Uge + |Ult | Ule | Ord | Uno | Add | Sub | Mul | Div | Udiv | Rem | Urem + |And | Or | Xor | Shl | Lshr | Ashr | Select -> + assert_arity 2 + | Conditional | Update -> assert_arity 3 + | Record -> assert (partial || not (List.is_empty args)) + | Struct_rec {elts} -> + assert (not (Vector.is_empty elts)) ; + assert_arity 0 + | App _ -> fail "uncurry cannot return App" () + +(** Variables are the expressions constructed by [Var] *) +module Var = struct include T - include Comparator.Make (T) - - let fmt_defn ff g = - let[@warning "p"] (Global {init; typ}) = g in - let[@warning "p"] (Typ.Pointer {elt= typ}) = typ in - Format.fprintf ff "@[<2>%a %a%a@]" Typ.fmt typ fmt g - (option_fmt " =@ @[%a@]" fmt) - init - - let fmt = fmt + type var = t - let mk ?init ?(loc = Loc.none) name typ = - assert ( - Option.for_all init ~f:(fun exp -> - Typ.equal typ (Typ.mkPointer ~elt:(typ_of exp)) ) ) ; - Global {name; init; typ; loc} + module Set = struct + include ( + Set : + module type of Set with type ('elt, 'cmp) t := ('elt, 'cmp) Set.t ) + type t = Set.M(T).t [@@deriving compare, sexp] - let of_exp e = match e with Global _ -> Some e | _ -> None - - let name = function[@warning "p"] Global {name} -> name - - let typ = function[@warning "p"] Global {typ} -> typ + let pp vs = Set.pp T.pp vs + let empty = Set.empty (module T) + let of_vector = Set.of_vector (module T) + end - let loc = function[@warning "p"] Global {loc} -> loc + let invariant x = + Invariant.invariant [%here] x [%sexp_of: t] + @@ fun () -> match x with Var _ -> invariant x | _ -> assert false + + let id = function Var {id} -> id | x -> violates invariant x + let name = function Var {name} -> name | x -> violates invariant x + + let of_exp = function + | Var _ as v -> Some (v |> check invariant) + | _ -> None + + let program name = Var {id= 0; name} |> check invariant + + let fresh name ~(wrt : Set.t) = + let max = match Set.max_elt wrt with None -> 0 | Some max -> id max in + let x' = Var {name; id= max + 1} in + (x', Set.add wrt x') + + (** Variable renaming substitutions *) + module Subst = struct + type t = T.t Map.M(T).t [@@deriving compare, sexp] + + let invariant s = + Invariant.invariant [%here] s [%sexp_of: t] + @@ fun () -> + let domain, range = + Map.fold s ~init:(Set.empty, Set.empty) + ~f:(fun ~key ~data (domain, range) -> + assert (not (Set.mem range data)) ; + (Set.add domain key, Set.add range data) ) + in + assert (Set.disjoint domain range) + + let pp fs s = + Format.fprintf fs "@[<1>[%a]@]" + (List.pp ",@ " (fun fs (k, v) -> + Format.fprintf fs "@[[%a ↦ %a]@]" T.pp k T.pp v )) + (Map.to_alist s) + + let empty = Map.empty (module T) + let is_empty = Map.is_empty + + let freshen vs ~wrt = + let xs = Set.inter wrt vs in + let wrt = Set.union wrt vs in + Set.fold xs ~init:(empty, wrt) ~f:(fun (sub, wrt) x -> + let x', wrt = fresh (name x) ~wrt in + let sub = Map.add_exn sub ~key:x ~data:x' in + (sub, wrt) ) + |> fst |> check invariant + + let extend sub ~replace ~with_ = + ( match Map.add sub ~key:replace ~data:with_ with + | `Duplicate -> sub + | `Ok sub -> + Map.map_preserving_phys_equal sub ~f:(fun v -> + if equal v replace then with_ else v ) ) + |> check invariant + + let invert sub = + Map.fold sub ~init:empty ~f:(fun ~key ~data sub' -> + Map.add_exn sub' ~key:data ~data:key ) + |> check invariant + + let exclude sub vs = + Set.fold vs ~init:sub ~f:Map.remove |> check invariant + + let domain sub = + Map.fold sub ~init:Set.empty ~f:(fun ~key ~data:_ domain -> + Set.add domain key ) + + let range sub = + Map.fold sub ~init:Set.empty ~f:(fun ~key:_ ~data range -> + Set.add range data ) + + let apply sub v = try Map.find_exn sub v with Caml.Not_found -> v + + let apply_set sub vs = + Map.fold sub ~init:vs ~f:(fun ~key ~data vs -> + let vs' = Set.remove vs key in + if Set.to_tree vs' == Set.to_tree vs then vs + else ( + assert (not (Set.equal vs' vs)) ; + Set.add vs' data ) ) + + let close_set sub vs = + Map.fold sub ~init:vs ~f:(fun ~key:_ ~data vs -> Set.add vs data) + end end -(** Constructors *) - -let locate loc exp = - match exp with - | Var {name; typ} -> Var {name; typ; loc} - | Global {name; init; typ} -> Global {name; init; typ; loc} - | Nondet {typ; msg} -> Nondet {typ; loc; msg} - | Label {parent; name} -> Label {parent; name; loc} - | App {op; arg} -> App {op; arg; loc} - | AppN {op; args} -> AppN {op; args; loc} - | _ -> exp - - -let mkApp1 op arg = App {op; arg; loc= Loc.none} - -let mkApp2 op x y = mkApp1 (mkApp1 op x) y - -let mkApp3 op x y z = mkApp1 (mkApp1 (mkApp1 op x) y) z - -let mkAppN op args = AppN {op; args; loc= Loc.none} - -let mkVar = Fn.id - -let mkGlobal = Fn.id - -let mkNondet (typ : Typ.t) msg = - assert (match typ with Function _ -> false | _ -> true) ; - Nondet {typ; loc= Loc.none; msg} - - -let mkLabel ~parent ~name = Label {parent; name; loc= Loc.none} - -let mkNull (typ : Typ.t) = - assert (match typ with Opaque _ | Function _ -> false | _ -> true) ; - Null {typ} - - -let mkPtrFld ~ptr ~fld = - assert ( - match typ_of ptr with - | Pointer {elt= Tuple {elts} | Struct {elts}} -> valid_fld fld elts - | _ -> false ) ; - mkApp1 (PtrFld {fld}) ptr - - -let mkPtrIdx ~ptr ~idx = - assert ( - match (typ_of ptr, typ_of idx) with - | Pointer {elt= Array _}, Integer _ -> true - | _ -> false ) ; - mkApp2 PtrIdx ptr idx - - -let mkPrjFld ~agg ~fld = - assert ( - match typ_of agg with - | Tuple {elts} | Struct {elts} -> valid_fld fld elts - | _ -> false ) ; - mkApp1 (PrjFld {fld}) agg - - -let mkPrjIdx ~arr ~idx = - assert ( - match (typ_of arr, typ_of idx) with - | Array _, Integer _ -> true - | _ -> false ) ; - mkApp2 PrjIdx arr idx - - -let mkUpdFld ~agg ~elt ~fld = - assert ( - match typ_of agg with - | Tuple {elts} | Struct {elts} -> - valid_fld fld elts && Typ.equal (Vector.get elts fld) (typ_of elt) - | _ -> false ) ; - mkApp2 (UpdFld {fld}) agg elt - - -let mkUpdIdx ~arr ~elt ~idx = - assert ( - match (typ_of arr, typ_of idx) with - | Array {elt= typ}, Integer _ -> Typ.equal typ (typ_of elt) - | _ -> false ) ; - mkApp3 UpdIdx arr elt idx - - -let mkInteger data (typ : Typ.t) = - assert ( - let in_range num bits = - let lb = Z.(-(if bits = 1 then ~$1 else ~$1 lsl Int.(bits - 1))) - and ub = Z.(~$1 lsl bits) in - Z.(leq lb num && lt num ub) +let fold_exps e ~init ~f = + let fold_exps_ fold_exps_ e z = + let z = + match e with + | App {op; arg} -> fold_exps_ op (fold_exps_ arg z) + | Struct_rec {elts} -> + Vector.fold elts ~init:z ~f:(fun z elt -> fold_exps_ elt z) + | _ -> z in - match typ with Integer {bits} -> in_range data bits | _ -> false ) ; - Integer {data; typ} - - -let mkBool b = mkInteger (Z.of_int (Bool.to_int b)) Typ.i1 - -let mkFloat data (typ : Typ.t) = - assert (match typ with Float _ -> true | _ -> false) ; - Float {data; typ} - - -let mkArray elts (typ : Typ.t) = - assert ( - match typ with - | Array {elt= elt_typ; len} -> - Vector.for_all elts ~f:(fun elt -> Typ.equal (typ_of elt) elt_typ) - && Vector.length elts = len - | _ -> false ) ; - mkAppN (Array {typ}) elts - - -let mkStruct elts (typ : Typ.t) = - assert ( - match typ with - | Tuple {elts= elt_typs} | Struct {elts= elt_typs} -> - Vector.for_all2_exn elts elt_typs ~f:(fun elt elt_typ -> - Typ.equal (typ_of elt) elt_typ ) - | _ -> false ) ; - mkAppN (Struct {typ}) elts - - -let mkStruct_rec key = - let memo_id = Hashtbl.create key () in - let dummy = Null {typ= Typ.mkBytes} in - let mkStruct_ ~id elt_thks typ = - match Hashtbl.find memo_id id with - | None -> - (* Add placeholder to prevent computing [elts] in calls to - [mkStruct_rec] from [elts] for recursive occurrences of [id]. *) - let elta = Array.create ~len:(Vector.length elt_thks) dummy in - let elts = Vector.of_array elta in - Hashtbl.set memo_id ~key:id ~data:elts ; - Vector.iteri elt_thks ~f:(fun i elt_thk -> - elta.(i) <- Lazy.force elt_thk ) ; - mkStruct elts typ - | Some elts -> - (* Do not call [mkStruct] as types will be checked by the call to - [mkStruct] above after the thunks are forced, before which - type-checking may spuriously fail. Note that it is important that - the value constructed here shares the array in the memo table, so - that the update after forcing the recursive thunks also updates - this value. *) - mkAppN (Struct {typ}) elts + f z e in - Staged.stage mkStruct_ - - -let mkCast exp typ = - assert (Typ.compatible (typ_of exp) typ) ; - mkApp1 (Cast {typ}) exp - - -let mkConv exp ?(signed = false) typ = - assert (Typ.compatible (typ_of exp) typ) ; - mkApp1 (Conv {signed; typ}) exp - - -let mkSelect ~cnd ~thn ~els = - assert ( - match (typ_of cnd, typ_of thn, typ_of els) with - | Integer {bits= 1}, s, t -> Typ.equal s t - | Array {elt= Integer {bits= 1}; len= m}, (Array {len= n} as s), t -> - m = n && Typ.equal s t - | _ -> false ) ; - mkApp3 Select cnd thn els - - -let binop op x y = - assertf - (let typ = typ_of x in - match (op, typ) with - | ( (Eq | Ne | Gt | Ge | Lt | Le | Ugt | Uge | Ult | Ule) - , (Integer _ | Float _ | Pointer _) ) - |(Add | Sub | Mul | Div | Rem), (Integer _ | Float _) - |(And | Or | Xor | Shl | LShr | AShr | UDiv | URem), Integer _ - |(Ord | Uno), Float _ -> - Typ.equal typ (typ_of y) - | _ -> false) - "ill-typed: %a" fmt (mkApp2 op x y) () ; - mkApp2 op x y - - -let mkEq = binop Eq - -let mkNe = binop Ne - -let mkGt = binop Gt - -let mkGe = binop Ge - -let mkLt = binop Lt - -let mkLe = binop Le - -let mkUgt = binop Ugt - -let mkUge = binop Uge - -let mkUlt = binop Ult - -let mkUle = binop Ule - -let mkAnd = binop And - -let mkOr = binop Or - -let mkXor = binop Xor - -let mkShl = binop Shl - -let mkLShr = binop LShr - -let mkAShr = binop AShr - -let mkAdd = binop Add - -let mkSub = binop Sub - -let mkMul = binop Mul - -let mkDiv = binop Div - -let mkUDiv = binop UDiv - -let mkRem = binop Rem - -let mkURem = binop URem - -let mkOrd = binop Ord + fix fold_exps_ (fun _ z -> z) e init + +let fold_vars e ~init ~f = + fold_exps e ~init ~f:(fun z -> function + | Var _ as v -> f z (v :> Var.t) | _ -> z ) + +let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty + +(** Construct *) + +let var x = x +let nondet msg = Nondet {msg} |> check invariant +let label ~parent ~name = Label {parent; name} |> check invariant +let null = Null |> check invariant +let integer data = Integer {data} |> check invariant +let bool b = integer (Z.of_int (Bool.to_int b)) +let float data = Float {data} |> check invariant + +let simp_convert signed (dst : Typ.t) (src : Typ.t) arg = + match (signed, dst, src, arg) with + | _, Integer {bits}, _, Integer {data} when Z.numbits data <= bits -> + integer data + | false, Integer {bits= m}, Integer {bits= n}, _ when m >= n -> arg + | _ -> App {op= Convert {signed; dst; src}; arg} + +let rec simp_eq x y = + match (x, y) with + (* i = j ==> i=j *) + | Integer {data= i}, Integer {data= j} -> bool (Z.equal i j) + (* e+i = j ==> e = j-i *) + | ( App {op= App {op= Add; arg= e}; arg= Integer {data= i}} + , Integer {data= j} ) -> + simp_eq e (integer (Z.sub j i)) + (* e = e ==> 1 *) + | _ when equal x y -> bool true + | _ -> App {op= App {op= Eq; arg= x}; arg= y} + +let simp_dq x y = + match (x, y) with + (* i != j ==> i!=j *) + | Integer {data= i}, Integer {data= j} -> bool (not (Z.equal i j)) + (* e = e ==> 0 *) + | _ when equal x y -> bool false + | _ -> App {op= App {op= Dq; arg= x}; arg= y} + +let simp_gt x y = + match (x, y) with + (* i > j ==> i>j *) + | Integer {data= i}, Integer {data= j} -> bool (Z.gt i j) + | _ -> App {op= App {op= Gt; arg= x}; arg= y} + +let simp_ge x y = + match (x, y) with + (* i >= j ==> i>=j *) + | Integer {data= i}, Integer {data= j} -> bool (Z.geq i j) + | _ -> App {op= App {op= Ge; arg= x}; arg= y} + +let simp_lt x y = + match (x, y) with + (* i < j ==> i bool (Z.lt i j) + | _ -> App {op= App {op= Lt; arg= x}; arg= y} + +let simp_le x y = + match (x, y) with + (* i <= j ==> i<=j *) + | Integer {data= i}, Integer {data= j} -> bool (Z.leq i j) + | _ -> App {op= App {op= Le; arg= x}; arg= y} + +let rec simp_add x y = + match (x, y) with + (* i + j ==> i+j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.add i j) + (* i + e ==> e + i *) + | Integer _, _ -> simp_add y x + (* e + 0 ==> e *) + | _, Integer {data} when Z.equal Z.zero data -> x + (* (e+i) + j ==> e+(i+j) *) + | App {op= App {op= Add; arg}; arg= Integer {data= i}}, Integer {data= j} + -> + simp_add arg (integer (Z.add i j)) + (* (i-e) + j ==> (i+j)-e *) + | App {op= App {op= Sub; arg= Integer {data= i}}; arg}, Integer {data= j} + -> + simp_sub (integer (Z.add i j)) arg + | _ -> App {op= App {op= Add; arg= x}; arg= y} + +and simp_sub x y = + match (x, y) with + (* i - j ==> i-j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.sub i j) + (* e - i ==> e + (-i) *) + | _, Integer {data} -> simp_add x (integer (Z.neg data)) + (* e - e ==> 0 *) + | _ when equal x y -> integer Z.zero + | _ -> App {op= App {op= Sub; arg= x}; arg= y} + +let simp_mul x y = + match (x, y) with + (* i * j ==> i*j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.mul i j) + (* e * 1 ==> e *) + | (Integer {data}, e | e, Integer {data}) when Z.equal Z.one data -> e + | _ -> App {op= App {op= Mul; arg= x}; arg= y} + +let simp_div x y = + match (x, y) with + (* i / j ==> i/j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.div i j) + | _ -> App {op= App {op= Div; arg= x}; arg= y} + +let simp_rem x y = + match (x, y) with + (* i % j ==> i%j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.( mod ) i j) + | _ -> App {op= App {op= Rem; arg= x}; arg= y} + +let simp_and x y = + match (x, y) with + (* i && j ==> i logand j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.logand i j) + (* e && 1 ==> e *) + | (Integer {data}, e | e, Integer {data}) when Z.equal Z.one data -> e + (* e && 0 ==> 0 *) + | ((Integer {data} as z), _ | _, (Integer {data} as z)) + when Z.equal Z.zero data -> + z + | _ -> App {op= App {op= And; arg= x}; arg= y} + +let simp_or x y = + match (x, y) with + (* i || j ==> i logor j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.logor i j) + (* e || 1 ==> e *) + | (Integer {data}, _ | _, Integer {data}) when Z.equal Z.one data -> + integer Z.one + (* e || 0 ==> e *) + | (Integer {data}, e | e, Integer {data}) when Z.equal Z.zero data -> e + | _ -> App {op= App {op= Or; arg= x}; arg= y} + +let simp_xor x y = + match (x, y) with + (* i xor j ==> i logxor j *) + | Integer {data= i}, Integer {data= j} -> integer (Z.logxor i j) + (* ¬(x=y) ==> x!=y *) + | App {op= App {op= Eq; arg= x}; arg= y}, Integer {data} + |Integer {data}, App {op= App {op= Eq; arg= x}; arg= y} + when Z.equal Z.minus_one data -> + simp_dq x y + (* ¬(x!=y) ==> x=y *) + | App {op= App {op= Dq; arg= x}; arg= y}, Integer {data} + |Integer {data}, App {op= App {op= Dq; arg= x}; arg= y} + when Z.equal Z.minus_one data -> + simp_eq x y + | _ -> App {op= App {op= Xor; arg= x}; arg= y} + +let app1 ?(partial = false) op arg = + ( match (op, arg) with + | Convert {signed; dst; src}, x -> simp_convert signed dst src x + | App {op= Eq; arg= x}, y -> simp_eq x y + | App {op= Dq; arg= x}, y -> simp_dq x y + | App {op= Gt; arg= x}, y -> simp_gt x y + | App {op= Ge; arg= x}, y -> simp_ge x y + | App {op= Lt; arg= x}, y -> simp_lt x y + | App {op= Le; arg= x}, y -> simp_le x y + | App {op= Add; arg= x}, y -> simp_add x y + | App {op= Sub; arg= x}, y -> simp_sub x y + | App {op= Mul; arg= x}, y -> simp_mul x y + | App {op= Div; arg= x}, y -> simp_div x y + | App {op= Rem; arg= x}, y -> simp_rem x y + | App {op= And; arg= x}, y -> simp_and x y + | App {op= Or; arg= x}, y -> simp_or x y + | App {op= Xor; arg= x}, y -> simp_xor x y + | _ -> App {op; arg} ) + |> check (invariant ~partial) + +let app2 op x y = app1 (app1 ~partial:true op x) y +let app3 op x y z = app1 (app1 ~partial:true (app1 ~partial:true op x) y) z +let appN op xs = List.fold xs ~init:op ~f:app1 +let splat ~byt ~siz = app2 Splat byt siz +let memory ~siz ~arr = app2 Memory siz arr +let concat = app2 Concat +let eq = app2 Eq +let dq = app2 Dq +let gt = app2 Gt +let ge = app2 Ge +let lt = app2 Lt +let le = app2 Le +let ugt = app2 Ugt +let uge = app2 Uge +let ult = app2 Ult +let ule = app2 Ule +let ord = app2 Ord +let uno = app2 Uno +let add = app2 Add +let sub = app2 Sub +let mul = app2 Mul +let div = app2 Div +let udiv = app2 Udiv +let rem = app2 Rem +let urem = app2 Urem +let and_ = app2 And +let or_ = app2 Or +let xor = app2 Xor +let shl = app2 Shl +let lshr = app2 Lshr +let ashr = app2 Ashr +let conditional ~cnd ~thn ~els = app3 Conditional cnd thn els +let record elts = appN Record elts |> check invariant +let select ~rcd ~idx = app2 Select rcd idx +let update ~rcd ~elt ~idx = app3 Update rcd elt idx + +let struct_rec key = + let memo_id = Hashtbl.create key in + let dummy = Null in + Staged.stage + @@ fun ~id elt_thks -> + match Hashtbl.find memo_id id with + | None -> + (* Add placeholder to prevent computing [elts] in calls to + [struct_rec] from [elt_thks] for recursive occurrences of [id]. *) + let elta = Array.create ~len:(Vector.length elt_thks) dummy in + let elts = Vector.of_array elta in + Hashtbl.set memo_id ~key:id ~data:elts ; + Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ; + Struct_rec {elts} |> check invariant + | Some elts -> + (* Do not check invariant as invariant will be checked above after the + thunks are forced, before which invariant-checking may spuriously + fail. Note that it is important that the value constructed here + shares the array in the memo table, so that the update after + forcing the recursive thunks also updates this value. *) + Struct_rec {elts} + +let convert ?(signed = false) ~dst ~src exp = + app1 (Convert {signed; dst; src}) exp + +(** Access *) + +let fold e ~init:z ~f = + match e with + | App {op; arg; _} -> + let z = f z op in + let z = f z arg in + z + | _ -> z + +let fold_map e ~init:z ~f = + match e with + | App {op; arg} -> + let z, op' = f z op in + let z, arg' = f z arg in + if op' == op && arg' == arg then (z, e) + else (z, app1 ~partial:true op' arg') + | _ -> (z, e) + +let map e ~f = + match e with + | App {op; arg} -> + let op' = f op in + let arg' = f arg in + if op' == op && arg' == arg then e else app1 ~partial:true op' arg' + | _ -> e + +(** Update *) + +let rename e sub = + let rec rename_ e sub = + match e with + | Var _ -> Var.Subst.apply sub e + | _ -> map e ~f:(fun f -> rename_ f sub) + in + rename_ e sub |> check (invariant ~partial:true) -let mkUno = binop Uno +(** Query *) -(** Queries *) +let is_true = function Integer {data} -> Z.equal Z.one data | _ -> false +let is_false = function Integer {data} -> Z.equal Z.zero data | _ -> false -let typ = typ_of +let rec is_constant = function + | Var _ | Nondet _ -> false + | App {op; arg} -> is_constant arg && is_constant op + | _ -> true diff --git a/sledge/src/llair/exp.mli b/sledge/src/llair/exp.mli index 7520d7793..390199392 100644 --- a/sledge/src/llair/exp.mli +++ b/sledge/src/llair/exp.mli @@ -10,48 +10,34 @@ Pure (heap-independent) expressions are complex arithmetic, bitwise-logical, etc. operations over literal values and registers. - Expressions are represented in curried form, where the only recursive + Expressions are represented in curried form, where the only† recursive constructor is [App], which is an application of a function symbol to an argument. This is done to simplify the definition of 'subexpression' and make it explicit, which is a significant help for treating equality between expressions using congruence closure. The specific constructor - functions indicate and check the expected arity and types of the - function symbols. *) + functions indicate and check the expected arity of the function symbols. + + [†] [Struct_rec] is also a recursive constructor, but its values are + treated as atomic since, as they are recursive, doing otherwise would + require inductive reasoning. *) type t = private - | Var of {name: string; typ: Typ.t; loc: Loc.t} - (** Local variable / virtual register *) - | Global of {name: string; init: t option; typ: Typ.t; loc: Loc.t} - (** Global variable, with initalizer *) - | Nondet of {typ: Typ.t; loc: Loc.t; msg: string} + | Var of {id: int; name: string} (** Local variable / virtual register *) + | Nondet of {msg: string} (** Anonymous local variable with arbitrary value, representing - non-deterministic approximation of value described by [msg]. *) - | Label of {parent: string; name: string; loc: Loc.t} + non-deterministic approximation of value described by [msg] *) + | Label of {parent: string; name: string} (** Address of named code block within parent function *) - | Null of {typ: Typ.t} - (** Pointer value that never refers to an object *) - | App of {op: t; arg: t; loc: Loc.t} + | App of {op: t; arg: t} (** Application of function symbol to argument, curried *) - | AppN of {op: t; args: t vector; loc: Loc.t} - (** Application of function symbol to arguments. NOTE: may be cyclic - when [op] is [Struct]. *) - | PtrFld of {fld: int} (** Pointer to a field of a struct *) - | PtrIdx (** Pointer to an index of an array *) - | PrjFld of {fld: int} (** Project a field from a constant struct *) - | PrjIdx (** Project an index from a constant array *) - | UpdFld of {fld: int} (** Constant struct with updated field *) - | UpdIdx (** Constant array with updated index *) - | Integer of {data: Z.t; typ: Typ.t} (** Integer constant *) - | Float of {data: string; typ: Typ.t} (** Floating-point constant *) - | Array of {typ: Typ.t} (** Array constant *) - | Struct of {typ: Typ.t} (** Struct constant *) - | Cast of {typ: Typ.t} (** Cast to specified type, invertible *) - | Conv of {signed: bool; typ: Typ.t} - (** Convert to specified type, possibly with loss of information *) - | Select (** Conditional *) - (* binary: comparison *) + | Null (** Pointer value that never refers to an object *) + | Splat (** Iterated concatenation of a single byte *) + | Memory (** Size-tagged byte-array *) + | Concat (** Byte-array concatenation *) + | Integer of {data: Z.t} (** Integer constant *) + | Float of {data: string} (** Floating-point constant *) | Eq (** Equal test *) - | Ne (** Not-equal test *) + | Dq (** Disequal test *) | Gt (** Greater-than test *) | Ge (** Greater-than-or-equal test *) | Lt (** Less-than test *) @@ -62,194 +48,151 @@ type t = private | Ule (** Unordered or less-than-or-equal test *) | Ord (** Ordered test (neither arg is nan) *) | Uno (** Unordered test (some arg is nan) *) - (* binary: boolean / bitwise *) - | And (** Conjunction *) - | Or (** Disjunction *) - | Xor (** Exclusive-or / Boolean disequality *) - | Shl (** Shift left *) - | LShr (** Logical shift right *) - | AShr (** Arithmetic shift right *) - (* binary: arithmetic *) | Add (** Addition *) | Sub (** Subtraction *) | Mul (** Multiplication *) | Div (** Division *) - | UDiv (** Unsigned division *) + | Udiv (** Unsigned division *) | Rem (** Remainder of division *) - | URem (** Remainder of unsigned division *) - -val compare : t -> t -> int + | Urem (** Remainder of unsigned division *) + | And (** Conjunction *) + | Or (** Disjunction *) + | Xor (** Exclusive-or / Boolean disequality *) + | Shl (** Shift left *) + | Lshr (** Logical shift right *) + | Ashr (** Arithmetic shift right *) + | Conditional (** If-then-else *) + | Record (** Record (array / struct) constant *) + | Select (** Select an index from a record *) + | Update (** Constant record with updated index *) + | Struct_rec of {elts: t vector} + (** Struct constant that may recursively refer to itself + (transitively) from [elts]. NOTE: represented by cyclic values. *) + | Convert of {signed: bool; dst: Typ.t; src: Typ.t} + (** Convert between specified types, possibly with loss of information *) +[@@deriving compare, hash, sexp] + +type exp = t + +include Comparator.S with type t := t val equal : t -> t -> bool +val pp : t pp +val invariant : ?partial:bool -> t -> unit -val t_of_sexp : Sexp.t -> t - -val sexp_of_t : t -> Sexp.t - -val fmt : t fmt - -(** Re-exported modules *) - +(** Exp.Var is re-exported as Var *) module Var : sig - type nonrec t = private t + type t = private exp [@@deriving compare, hash, sexp] + type var = t include Comparator.S with type t := t - val compare : t -> t -> int - - val equal : t -> t -> bool - - val t_of_sexp : Sexp.t -> t - - val sexp_of_t : t -> Sexp.t - - val fmt : t fmt - - val mk : ?loc:Loc.t -> string -> Typ.t -> t - - val name : t -> string - - val typ : t -> Typ.t - - val loc : t -> Loc.t -end - -module Global : sig - type init = t + module Set : sig + type t = (var, comparator_witness) Set.t [@@deriving compare, sexp] - type nonrec t = private t - - include Comparator.S with type t := t - - val compare : t -> t -> int + val pp : t pp + val empty : t + val of_vector : var vector -> t + end val equal : t -> t -> bool + val pp : t pp - val t_of_sexp : Sexp.t -> t - - val sexp_of_t : t -> Sexp.t - - val hash : t -> int - - val fmt : t fmt - - val fmt_defn : t fmt - - val mk : ?init:init -> ?loc:Loc.t -> string -> Typ.t -> t - - val of_exp : init -> t option + include Invariant.S with type t := t + val of_exp : exp -> t option + val program : string -> t + val fresh : string -> wrt:Set.t -> t * Set.t + val id : t -> int val name : t -> string - val typ : t -> Typ.t - - val loc : t -> Loc.t + module Subst : sig + type t [@@deriving compare, sexp] + + val pp : t pp + val empty : t + val freshen : Set.t -> wrt:Set.t -> t + val extend : t -> replace:var -> with_:var -> t + val invert : t -> t + val exclude : t -> Set.t -> t + val is_empty : t -> bool + val domain : t -> Set.t + val range : t -> Set.t + val apply_set : t -> Set.t -> Set.t + val close_set : t -> Set.t -> Set.t + end end -(** Constructors *) - -val mkVar : Var.t -> t - -val mkGlobal : Global.t -> t - -val mkNondet : Typ.t -> string -> t - -val mkLabel : parent:string -> name:string -> t - -val mkNull : Typ.t -> t - -val mkPtrFld : ptr:t -> fld:int -> t - -val mkPtrIdx : ptr:t -> idx:t -> t - -val mkPrjFld : agg:t -> fld:int -> t - -val mkPrjIdx : arr:t -> idx:t -> t - -val mkUpdFld : agg:t -> elt:t -> fld:int -> t - -val mkUpdIdx : arr:t -> elt:t -> idx:t -> t - -val mkBool : bool -> t - -val mkInteger : Z.t -> Typ.t -> t - -val mkFloat : string -> Typ.t -> t - -val mkArray : t vector -> Typ.t -> t - -val mkStruct : t vector -> Typ.t -> t - -val mkStruct_rec : - (module Hashtbl.Key_plain with type t = 'id) - -> (id:'id -> t lazy_t vector -> Typ.t -> t) Staged.t -(** [mkStruct_rec Id id element_thunks typ] constructs a possibly-cyclic - [Struct] value. Cycles are detected using [Id]. The caller of - [mkStruct_rec Id] must ensure that a single unstaging of [mkStruct_rec - Id] is used for each complete cyclic value. Also, the caller must ensure - that recursive calls to [mkStruct_rec Id] provide [id] values that - uniquely identify at least one point on each cycle. Failure to obey - these requirements will lead to stack overflow. *) - -val mkCast : t -> Typ.t -> t - -val mkConv : t -> ?signed:bool -> Typ.t -> t - -val mkSelect : cnd:t -> thn:t -> els:t -> t - -val mkEq : t -> t -> t - -val mkNe : t -> t -> t - -val mkGt : t -> t -> t - -val mkGe : t -> t -> t - -val mkLt : t -> t -> t - -val mkLe : t -> t -> t - -val mkUgt : t -> t -> t - -val mkUge : t -> t -> t - -val mkUlt : t -> t -> t - -val mkUle : t -> t -> t - -val mkOrd : t -> t -> t - -val mkUno : t -> t -> t - -val mkAnd : t -> t -> t - -val mkOr : t -> t -> t - -val mkXor : t -> t -> t - -val mkShl : t -> t -> t - -val mkLShr : t -> t -> t - -val mkAShr : t -> t -> t - -val mkAdd : t -> t -> t - -val mkSub : t -> t -> t - -val mkMul : t -> t -> t - -val mkDiv : t -> t -> t - -val mkUDiv : t -> t -> t - -val mkRem : t -> t -> t - -val mkURem : t -> t -> t - -val locate : Loc.t -> t -> t -(** Update the debug location *) - -(** Queries *) - -val typ : t -> Typ.t +(** Construct *) + +val var : Var.t -> t +val nondet : string -> t +val label : parent:string -> name:string -> t +val null : t +val splat : byt:t -> siz:t -> t +val memory : siz:t -> arr:t -> t +val concat : t -> t -> t +val bool : bool -> t +val integer : Z.t -> t +val float : string -> t +val eq : t -> t -> t +val dq : t -> t -> t +val gt : t -> t -> t +val ge : t -> t -> t +val lt : t -> t -> t +val le : t -> t -> t +val ugt : t -> t -> t +val uge : t -> t -> t +val ult : t -> t -> t +val ule : t -> t -> t +val ord : t -> t -> t +val uno : t -> t -> t +val add : t -> t -> t +val sub : t -> t -> t +val mul : t -> t -> t +val div : t -> t -> t +val udiv : t -> t -> t +val rem : t -> t -> t +val urem : t -> t -> t +val and_ : t -> t -> t +val or_ : t -> t -> t +val xor : t -> t -> t +val shl : t -> t -> t +val lshr : t -> t -> t +val ashr : t -> t -> t +val conditional : cnd:t -> thn:t -> els:t -> t +val record : t list -> t +val select : rcd:t -> idx:t -> t +val update : rcd:t -> elt:t -> idx:t -> t + +val struct_rec : + (module Hashtbl.Key with type t = 'id) + -> (id:'id -> t lazy_t vector -> t) Staged.t +(** [struct_rec Id id element_thunks] constructs a possibly-cyclic [Struct] + value. Cycles are detected using [Id]. The caller of [struct_rec Id] + must ensure that a single unstaging of [struct_rec Id] is used for each + complete cyclic value. Also, the caller must ensure that recursive calls + to [struct_rec Id] provide [id] values that uniquely identify at least + one point on each cycle. Failure to obey these requirements will lead to + stack overflow. *) + +val convert : ?signed:bool -> dst:Typ.t -> src:Typ.t -> t -> t + +(** Access *) + +val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a +val fold_exps : t -> init:'a -> f:('a -> t -> 'a) -> 'a +val fold : t -> init:'a -> f:('a -> t -> 'a) -> 'a +val fold_map : t -> init:'a -> f:('a -> t -> 'a * t) -> 'a * t +val map : t -> f:(t -> t) -> t + +(** Update *) + +val rename : t -> Var.Subst.t -> t + +(** Query *) + +val fv : t -> Var.Set.t +val is_true : t -> bool +val is_false : t -> bool +val is_constant : t -> bool