[sledge] Add Shostak canonizer for aggregate theory

Reviewed By: ngorogiannis

Differential Revision: D19286631

fbshipit-source-id: 0ee4b164c
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent 7bb1ec073a
commit 06fcb210c9

@ -27,6 +27,28 @@ module Infix = struct
end
let concat_map x ~f = v (Array.concat_map (a x) ~f:(fun y -> a (f y)))
let map_adjacent ~f dummy xs_v =
let xs0 = a xs_v in
let copy_xs = lazy (Array.copy xs0) in
let n = Array.length xs0 - 1 in
let rec map_adjacent_ i xs =
if i < n then
let xs =
match f xs.(i) xs.(i + 1) with
| None -> xs
| Some x ->
let xs = Lazy.force copy_xs in
xs.(i) <- dummy ;
xs.(i + 1) <- x ;
xs
in
map_adjacent_ (i + 1) xs
else if phys_equal xs xs0 then xs
else Array.filter xs ~f:(fun x -> not (phys_equal dummy x))
in
v (map_adjacent_ 0 xs0)
let create ~len x = v (Array.create ~len x)
let empty = v [||]

@ -126,6 +126,9 @@ val fold_right : 'a t -> f:('a -> 'b -> 'b) -> init:'b -> 'b
val concat_map : 'a t -> f:('a -> 'b t) -> 'b t
(* val concat_mapi : 'a t -> f:(int -> 'a -> 'b t) -> 'b t *)
val map_adjacent : f:('a -> 'a -> 'a option) -> 'a -> 'a t -> 'a t
(* val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t *)
(* val partitioni_tf : 'a t -> f:(int -> 'a -> bool) -> 'a t * 'a t *)
(* val cartesian_product : 'a t -> 'b t -> ('a * 'b) t *)

@ -780,29 +780,144 @@ let simp_ashr x y =
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> Ap2 (Ashr, x, y)
(* aggregate sizes *)
let rec agg_size_exn = function
| Ap2 (Memory, n, _) | Ap3 (Extract, _, _, n) -> n
| ApN (Concat, a0U) ->
Vector.fold a0U ~init:zero ~f:(fun a0I aJ ->
simp_add2 a0I (agg_size_exn aJ) )
| _ -> invalid_arg "agg_size_exn"
let agg_size e = try Some (agg_size_exn e) with Invalid_argument _ -> None
(* memory *)
let simp_concat xs =
if Vector.length xs = 1 then Vector.get xs 0
let empty_agg = ApN (Concat, Vector.of_array [||])
let simp_splat byt = Ap1 (Splat, byt)
let simp_memory siz arr =
(* ⟨n,α⟩ ==> α when n ≡ |α| *)
match agg_size arr with
| Some n when equal siz n -> arr
| _ -> Ap2 (Memory, siz, arr)
type pcmp = Lt | Eq | Gt | Unknown
let partial_compare x y : pcmp =
match simp_sub x y with
| Integer {data} -> (
match Int.sign (Z.sign data) with Neg -> Lt | Zero -> Eq | Pos -> Gt )
| _ -> Unknown
let partial_ge x y =
match partial_compare x y with Gt | Eq -> true | Lt | Unknown -> false
let rec simp_extract agg off len =
[%Trace.call fun {pf} -> pf "%a" pp (Ap3 (Extract, agg, off, len))]
;
(* _[_,0) ==> ⟨⟩ *)
( if equal len zero then empty_agg
else
let args =
if
Vector.for_all xs ~f:(function
| ApN (Concat, _) -> false
| _ -> true )
then xs
else
Vector.concat
(Vector.fold_right xs ~init:[] ~f:(fun x s ->
match x with
| ApN (Concat, args) -> args :: s
| x -> Vector.of_array [|x|] :: s ))
in
ApN (Concat, args)
let o_l = simp_add2 off len in
match agg with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Ap3 (Extract, a, m, k) when partial_ge k o_l ->
simp_extract a (simp_add2 m off) len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Ap2 (Memory, n, (Ap1 (Splat, _) as e)) when partial_ge n o_l ->
simp_memory len e
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Ap2 (Memory, n, _) when equal off zero && equal n len -> agg
(* (α₀^…^αᵢ^…^αⱼ^…) [0+n₀+…+nᵢ₋₁, nᵢ+…+nⱼ) ==> αᵢ^…^αⱼ where nₓ ≡ |αₓ| *)
| ApN (Concat, na1N) ->
let n = Vector.length na1N in
(* invariant: oI = ∑ᵥ₌₀ⁱ⁻¹ nᵥ *)
let rec find_off oI i =
[%Trace.call fun {pf} -> pf "o_0^%i = %a" i pp oI]
;
( if i = n then Ap3 (Extract, agg, off, len)
else
match Vector.get na1N i with
| Ap2 (Memory, nI, _) | Ap3 (Extract, _, _, nI) -> (
match (oI, off) with
| Integer {data= y}, Integer {data= z} when Z.gt y z ->
Ap3 (Extract, agg, off, len)
| _ when not (equal oI off) ->
find_off (simp_add2 oI nI) (i + 1)
| _ ->
(* invariant: lIJ = ∑ᵥ₌ᵢʲ⁻¹ nᵥ *)
let rec find_len lIJ j =
[%Trace.call fun {pf} -> pf "l_%i^%i = %a" i j pp lIJ]
;
( if j = n then find_off (simp_add2 oI nI) (i + 1)
else
match Vector.get na1N j with
| Ap2 (Memory, nJ, _) | Ap3 (Extract, _, _, nJ) -> (
let lIJ = simp_add2 lIJ nJ in
match (lIJ, len) with
| Integer {data= y}, Integer {data= z}
when Z.gt y z ->
Ap3 (Extract, agg, off, len)
| _ when not (equal lIJ len) ->
find_len lIJ (j + 1)
| _ ->
let naIJ =
Vector.sub ~pos:i ~len:(j - i + 1) na1N
in
simp_concat naIJ )
| _ -> violates invariant agg )
|>
[%Trace.retn fun {pf} -> pf "%a" pp]
in
find_len zero i )
| _ -> violates invariant agg )
|>
[%Trace.retn fun {pf} -> pf "%a" pp]
in
find_off zero 0
(* α[o,l) *)
| _ -> Ap3 (Extract, agg, off, len) )
|>
[%Trace.retn fun {pf} -> pf "%a" pp]
let simp_splat byt = Ap1 (Splat, byt)
let simp_memory siz arr = Ap2 (Memory, siz, arr)
let simp_extract agg off len = Ap3 (Extract, agg, off, len)
and simp_concat xs =
[%Trace.call fun {pf} -> pf "%a" pp (ApN (Concat, xs))]
;
(* (α^(β^γ)) ==> (α^β^γ) *)
let flatten xs =
let exists_sub_Concat =
Vector.exists ~f:(function ApN (Concat, _) -> true | _ -> false)
in
let concat_sub_Concat xs =
Vector.concat
(Vector.fold_right xs ~init:[] ~f:(fun x s ->
match x with
| ApN (Concat, ys) -> ys :: s
| x -> Vector.of_array [|x|] :: s ))
in
if exists_sub_Concat xs then concat_sub_Concat xs else xs
in
let simp_adjacent e f =
match (e, f) with
(* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *)
| ( Ap3 (Extract, (Ap2 (Memory, n, _) as na), o, k)
, Ap3 (Extract, na', o_k, l) )
when equal na na'
&& equal o_k (simp_add2 o k)
&& partial_ge n (simp_add2 o_k l) ->
Some (simp_extract na o (simp_add2 k l))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Ap2 (Memory, m, (Ap1 (Splat, _) as a)), Ap2 (Memory, n, a')
when equal a a' ->
Some (simp_memory (simp_add2 m n) a)
| _ -> None
in
let xs = flatten xs in
let xs = Vector.map_adjacent empty_agg xs ~f:simp_adjacent in
(if Vector.length xs = 1 then Vector.get xs 0 else ApN (Concat, xs))
|>
[%Trace.retn fun {pf} -> pf "%a" pp]
(* records *)

Loading…
Cancel
Save