diff --git a/sledge/src/fol.ml b/sledge/src/fol.ml index 168a230f6..f1c322f77 100644 --- a/sledge/src/fol.ml +++ b/sledge/src/fol.ml @@ -103,7 +103,9 @@ and Trm : sig val _Record : trm array -> trm val _Ancestor : int -> trm val _Apply : Funsym.t -> trm array -> trm + val add : trm -> trm -> trm val sub : trm -> trm -> trm + val seq_size_exn : trm -> trm end = struct type var = Var.t @@ -186,6 +188,8 @@ end = struct elts ) elts] + let pp = ppx (fun _ -> None) + let invariant e = let@ () = Invariant.invariant [%here] e [%sexp_of: trm] in match e with @@ -224,11 +228,132 @@ end = struct | Compound -> Arith a ) |> check invariant + let add x y = _Arith Arith.(add (trm x) (trm y)) let sub x y = _Arith Arith.(sub (trm x) (trm y)) let _Splat x = Splat x |> check invariant - let _Sized seq siz = Sized {seq; siz} |> check invariant - let _Extract seq off len = Extract {seq; off; len} |> check invariant - let _Concat es = Concat es |> check invariant + + let seq_size_exn = + let invalid = Invalid_argument "seq_size_exn" in + let rec seq_size_exn = function + | Sized {siz= n} | Extract {len= n} -> n + | Concat a0U -> + Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero + | _ -> raise invalid + in + seq_size_exn + + let seq_size e = + try Some (seq_size_exn e) with Invalid_argument _ -> None + + let _Sized seq siz = + ( match seq_size seq with + (* ⟨n,α⟩ ==> α when n ≡ |α| *) + | Some n when equal_trm siz n -> seq + | _ -> Sized {seq; siz} ) + |> check invariant + + let partial_compare x y = + match sub x y with + | Z z -> Some (Int.sign (Z.sign z)) + | Q q -> Some (Int.sign (Q.sign q)) + | _ -> None + + let partial_ge x y = + match partial_compare x y with Some (Pos | Zero) -> true | _ -> false + + let empty_seq = Concat [||] + + let rec _Extract seq off len = + [%trace] + ~call:(fun {pf} -> pf "%a" pp (Extract {seq; off; len})) + ~retn:(fun {pf} -> pf "%a" pp) + @@ fun () -> + (* _[_,0) ==> ⟨⟩ *) + ( if equal_trm len zero then empty_seq + else + let o_l = add off len in + match seq with + (* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *) + | Extract {seq= a; off= m; len= k} when partial_ge k o_l -> + _Extract a (add m off) len + (* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *) + | Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l -> + _Sized e len + (* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *) + | Sized {siz= n} when equal_trm off zero && equal_trm n len -> seq + (* For (α₀^α₁)[o,l) there are 3 cases: + * + * ⟨...⟩^⟨...⟩ + * [,) + * o < o+l ≤ |α₀| : (α₀^α₁)[o,l) ==> α₀[o,l) ^ α₁[0,0) + * + * ⟨...⟩^⟨...⟩ + * [ , ) + * o ≤ |α₀| < o+l : (α₀^α₁)[o,l) ==> α₀[o,|α₀|-o) ^ α₁[0,l-(|α₀|-o)) + * + * ⟨...⟩^⟨...⟩ + * [,) + * |α₀| ≤ o : (α₀^α₁)[o,l) ==> α₀[o,0) ^ α₁[o-|α₀|,l) + * + * So in general: + * + * (α₀^α₁)[o,l) ==> α₀[o,l₀) ^ α₁[o₁,l-l₀) + * where l₀ = max 0 (min l |α₀|-o) + * o₁ = max 0 o-|α₀| + *) + | Concat na1N -> ( + match len with + | Z l -> + Array.fold_map_until na1N (l, off) + ~f:(fun naI (l, oI) -> + if Z.equal Z.zero l then + `Continue (_Extract naI oI zero, (l, oI)) + else + let nI = seq_size_exn naI in + let oI_nI = sub oI nI in + match oI_nI with + | Z z -> + let oJ = if Z.sign z <= 0 then zero else oI_nI in + let lI = Z.(max zero (min l (neg z))) in + let l = Z.(l - lI) in + `Continue (_Extract naI oI (_Z lI), (l, oJ)) + | _ -> `Stop (Extract {seq; off; len}) ) + ~finish:(fun (e1N, _) -> _Concat e1N) + | _ -> Extract {seq; off; len} ) + (* α[o,l) *) + | _ -> Extract {seq; off; len} ) + |> check invariant + + and _Concat xs = + [%trace] + ~call:(fun {pf} -> pf "%a" pp (Concat xs)) + ~retn:(fun {pf} -> pf "%a" pp) + @@ fun () -> + (* (α^(β^γ)^δ) ==> (α^β^γ^δ) *) + let flatten xs = + if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then + Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs + else xs + in + let simp_adjacent e f = + match (e, f) with + (* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *) + | ( Extract {seq= Sized {siz= n} as na; off= o; len= k} + , Extract {seq= na'; off= o_k; len= l} ) + when equal_trm na na' + && equal_trm o_k (add o k) + && partial_ge n (add o_k l) -> + Some (_Extract na o (add k l)) + (* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *) + | Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'} + when equal_trm a a' -> + Some (_Sized a (add m n)) + | _ -> None + in + let xs = flatten xs in + let xs = Array.reduce_adjacent ~f:simp_adjacent xs in + (if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant + let _Select idx rcd = Select {idx; rcd} |> check invariant let _Update idx rcd elt = Update {idx; rcd; elt} |> check invariant let _Record es = Record es |> check invariant @@ -275,15 +400,54 @@ module Fml = struct if x == zero then _Eq0 y else if y == zero then _Eq0 x else + let sort_eq x y = + match Sign.of_int (compare_trm x y) with + | Neg -> _Eq x y + | Zero -> tt + | Pos -> _Eq y x + in match (x, y) with (* x = y ==> 0 = x - y when x = y is an arithmetic equality *) | (Z _ | Q _ | Arith _), _ | _, (Z _ | Q _ | Arith _) -> _Eq0 (sub x y) - | _ -> ( - match Sign.of_int (compare_trm x y) with - | Neg -> _Eq x y - | Zero -> tt - | Pos -> _Eq y x ) + (* α^β^δ = α^γ^δ ==> β = γ *) + | Concat a, Concat b -> + let m = Array.length a in + let n = Array.length b in + let l = min m n in + let length_common_prefix = + let rec find_lcp i = + if i < l && equal_trm a.(i) b.(i) then find_lcp (i + 1) else i + in + find_lcp 0 + in + if length_common_prefix = l then tt + else + let length_common_suffix = + let rec find_lcs i = + if equal_trm a.(m - 1 - i) b.(n - 1 - i) then + find_lcs (i + 1) + else i + in + find_lcs 0 + in + let length_common = + length_common_prefix + length_common_suffix + in + if length_common = 0 then sort_eq x y + else + let pos = length_common_prefix in + let a = Array.sub ~pos ~len:(m - length_common) a in + let b = Array.sub ~pos ~len:(n - length_common) b in + _Eq (_Concat a) (_Concat b) + | (Sized _ | Extract _ | Concat _), (Sized _ | Extract _ | Concat _) + -> + sort_eq x y + (* x = α ==> ⟨x,|α|⟩ = α *) + | x, ((Sized _ | Extract _ | Concat _) as a) + |((Sized _ | Extract _ | Concat _) as a), x -> + _Eq (_Sized x (seq_size_exn a)) a + | _ -> sort_eq x y end module Fmls = Prop.Fmls @@ -662,7 +826,7 @@ module Term = struct let integer z = `Trm (_Z z) let rational q = `Trm (_Q q) let neg = ap1t @@ fun x -> _Arith Arith.(neg (trm x)) - let add = ap2t @@ fun x y -> _Arith Arith.(add (trm x) (trm y)) + let add = ap2t Trm.add let sub = ap2t Trm.sub let mulq q = ap1t @@ fun x -> _Arith Arith.(mulc q (trm x)) let mul = ap2t @@ fun x y -> _Arith (Arith.mul x y)