From 21f3287e42469a36cdfd37da9c9b1f36ad6c173d Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Tue, 27 Oct 2020 12:16:57 -0700 Subject: [PATCH] [sledge] Do not expose the internal Trm interface Reviewed By: ngorogiannis Differential Revision: D24532364 fbshipit-source-id: 95d9bbe4e --- sledge/src/fml.ml | 4 +-- sledge/src/fol.ml | 68 ++++++++++++++++++++++++++-------------------- sledge/src/trm.ml | 36 ++++++++++++++++++++++++ sledge/src/trm.mli | 58 ++++++++++++++++++++++++++++----------- 4 files changed, 118 insertions(+), 48 deletions(-) diff --git a/sledge/src/fml.ml b/sledge/src/fml.ml index 658c741b1..55ff908d8 100644 --- a/sledge/src/fml.ml +++ b/sledge/src/fml.ml @@ -110,13 +110,13 @@ let _Eq x y = let pos = length_common_prefix in let a = Array.sub ~pos ~len:(m - length_common) a in let b = Array.sub ~pos ~len:(n - length_common) b in - _Eq (Trm._Concat a) (Trm._Concat b) + _Eq (Trm.concat a) (Trm.concat b) | (Sized _ | Extract _ | Concat _), (Sized _ | Extract _ | Concat _) -> sort_eq x y (* x = α ==> ⟨x,|α|⟩ = α *) | x, ((Sized _ | Extract _ | Concat _) as a) |((Sized _ | Extract _ | Concat _) as a), x -> - _Eq (Trm._Sized x (Trm.seq_size_exn a)) a + _Eq (Trm.sized ~siz:(Trm.seq_size_exn a) ~seq:x) a | _ -> sort_eq x y let map_pos_neg f e cons ~pos ~neg = diff --git a/sledge/src/fol.ml b/sledge/src/fol.ml index 134db085c..30408d31b 100644 --- a/sledge/src/fol.ml +++ b/sledge/src/fol.ml @@ -5,7 +5,6 @@ * LICENSE file in the root directory of this source tree. *) -open Trm open Fml type var = Var.t @@ -70,7 +69,7 @@ let rec map_cnd : (fml -> 'a -> 'a -> 'a) -> (trm -> 'a) -> cnd -> 'a = let embed_into_cnd : exp -> cnd = function | #cnd as c -> c (* p ==> (p ? 1 : 0) *) - | `Fml fml -> `Ite (fml, `Trm one, `Trm zero) + | `Fml fml -> `Ite (fml, `Trm Trm.one, `Trm Trm.zero) (** Project out a formula that is embedded into a conditional term. @@ -78,7 +77,8 @@ let embed_into_cnd : exp -> cnd = function that [project_out_fml (embed_into_cnd (`Fml f)) = Some f]. *) let project_out_fml : cnd -> fml option = function (* (p ? 1 : 0) ==> p *) - | `Ite (cnd, `Trm one', `Trm zero') when one == one' && zero == zero' -> + | `Ite (cnd, `Trm one', `Trm zero') + when Trm.one == one' && Trm.zero == zero' -> Some cnd | _ -> None @@ -215,35 +215,41 @@ module Term = struct (* arithmetic *) - let zero = `Trm zero - let one = `Trm one - let integer z = `Trm (_Z z) - let rational q = `Trm (_Q q) - let neg = ap1t @@ fun x -> _Arith Arith.(neg (trm x)) + let zero = `Trm Trm.zero + let one = `Trm Trm.one + let integer z = `Trm (Trm.integer z) + let rational q = `Trm (Trm.rational q) + let neg = ap1t Trm.neg let add = ap2t Trm.add let sub = ap2t Trm.sub - let mulq q = ap1t @@ fun x -> _Arith Arith.(mulc q (trm x)) - let mul = ap2t @@ fun x y -> _Arith (Arith.mul x y) - let div = ap2t @@ fun x y -> _Arith (Arith.div x y) - let pow x i = (ap1t @@ fun x -> _Arith (Arith.pow x i)) x + let mulq q = ap1t (Trm.mulq q) + let mul = ap2t Trm.mul + let div = ap2t Trm.div + let pow x i = (ap1t @@ fun x -> Trm.pow x i) x (* sequences *) - let splat = ap1t _Splat - let sized ~seq ~siz = ap2t _Sized seq siz - let extract ~seq ~off ~len = ap3t _Extract seq off len - let concat elts = apNt _Concat elts + let splat = ap1t Trm.splat + let sized ~seq ~siz = ap2t (fun seq siz -> Trm.sized ~seq ~siz) seq siz + + let extract ~seq ~off ~len = + ap3t (fun seq off len -> Trm.extract ~seq ~off ~len) seq off len + + let concat elts = apNt Trm.concat elts (* records *) - let select ~rcd ~idx = ap1t (_Select idx) rcd - let update ~rcd ~idx ~elt = ap2t (_Update idx) rcd elt - let record elts = apNt _Record elts - let ancestor i = `Trm (_Ancestor i) + let select ~rcd ~idx = ap1t (fun rcd -> Trm.select ~rcd ~idx) rcd + + let update ~rcd ~idx ~elt = + ap2t (fun rcd elt -> Trm.update ~rcd ~idx ~elt) rcd elt + + let record elts = apNt Trm.record elts + let ancestor i = `Trm (Trm.ancestor i) (* uninterpreted *) - let apply sym args = apNt (_Apply sym) args + let apply sym args = apNt (Trm.apply sym) args (* if-then-else *) @@ -251,21 +257,23 @@ module Term = struct (** Destruct *) - let d_int = function `Trm (Z z) -> Some z | _ -> None + let d_int e = match (e : t) with `Trm (Z z) -> Some z | _ -> None - let get_const = function + let get_const e = + match (e : t) with | `Trm (Z z) -> Some (Q.of_z z) | `Trm (Q q) -> Some q | _ -> None (** Access *) - let split_const = function + let split_const e = + match (e : t) with | `Trm (Z z) -> (zero, Q.of_z z) | `Trm (Q q) -> (zero, q) | `Trm (Arith a) -> - let a_c, c = Arith.split_const a in - (`Trm (_Arith a_c), c) + let a_c, c = Trm.Arith.split_const a in + (`Trm (Trm.arith a_c), c) | e -> (e, Q.zero) (** Traverse *) @@ -464,10 +472,10 @@ let vs_to_ses : Var.Set.t -> Ses.Var.Set.t = fun vs -> Ses.Var.Set.of_iter (Iter.map ~f:v_to_ses (Var.Set.to_iter vs)) let rec arith_to_ses poly = - Arith.fold_monomials poly Ses.Term.zero ~f:(fun mono coeff e -> + Trm.Arith.fold_monomials poly Ses.Term.zero ~f:(fun mono coeff e -> Ses.Term.add e (Ses.Term.mulq coeff - (Arith.fold_factors mono Ses.Term.one ~f:(fun trm pow f -> + (Trm.Arith.fold_factors mono Ses.Term.one ~f:(fun trm pow f -> let rec exp b i = assert (i > 0) ; if i = 1 then b else Ses.Term.mul b (exp f (i - 1)) @@ -544,8 +552,8 @@ let v_of_ses : Ses.Var.t -> var = let vs_of_ses : Ses.Var.Set.t -> Var.Set.t = fun vs -> Var.Set.of_iter (Iter.map ~f:v_of_ses (Ses.Var.Set.to_iter vs)) -let uap1 f = ap1t (fun x -> _Apply f [|x|]) -let uap2 f = ap2t (fun x y -> _Apply f [|x; y|]) +let uap1 f = ap1t (fun x -> Trm.apply f [|x|]) +let uap2 f = ap2t (fun x y -> Trm.apply f [|x; y|]) let litN p = apNf (_Lit p) let rec uap_tt f a = uap1 f (of_ses a) diff --git a/sledge/src/trm.ml b/sledge/src/trm.ml index 836457e2a..eaea18654 100644 --- a/sledge/src/trm.ml +++ b/sledge/src/trm.ml @@ -367,8 +367,44 @@ type arith = Arith.t include Trm +(** Construct *) + +(* variables *) + +let var v = (v : Var.t :> t) + +(* arithmetic *) + let zero = _Z Z.zero let one = _Z Z.one +let integer z = _Z z +let rational q = _Q q +let neg x = _Arith Arith.(neg (trm x)) +let add = Trm.add +let sub = Trm.sub +let mulq q x = _Arith Arith.(mulc q (trm x)) +let mul x y = _Arith (Arith.mul x y) +let div x y = _Arith (Arith.div x y) +let pow x i = _Arith (Arith.pow x i) +let arith = _Arith + +(* sequences *) + +let splat = _Splat +let sized ~seq ~siz = _Sized seq siz +let extract ~seq ~off ~len = _Extract seq off len +let concat elts = _Concat elts + +(* records *) + +let select ~rcd ~idx = _Select idx rcd +let update ~rcd ~idx ~elt = _Update idx rcd elt +let record elts = _Record elts +let ancestor i = _Ancestor i + +(* uninterpreted *) + +let apply sym args = _Apply sym args (** Transform *) diff --git a/sledge/src/trm.mli b/sledge/src/trm.mli index d6f29b74a..7410b5e43 100644 --- a/sledge/src/trm.mli +++ b/sledge/src/trm.mli @@ -42,24 +42,50 @@ module Arith : Arithmetic.S with type var := Var.t with type trm := t with type t = arith val ppx : Var.t Var.strength -> t pp -val _Var : int -> string -> t -val _Z : Z.t -> t -val _Q : Q.t -> t -val _Arith : Arith.t -> t -val _Splat : t -> t -val _Sized : t -> t -> t -val _Extract : t -> t -> t -> t -val _Concat : t array -> t -val _Select : int -> t -> t -val _Update : int -> t -> t -> t -val _Record : t array -> t -val _Ancestor : int -> t -val _Apply : Ses.Funsym.t -> t array -> t + +(** Construct *) + +(* variables *) +val var : Var.t -> t + +(* arithmetic *) +val zero : t +val one : t +val integer : Z.t -> t +val rational : Q.t -> t +val neg : t -> t val add : t -> t -> t val sub : t -> t -> t +val mulq : Q.t -> t -> t +val mul : t -> t -> t +val div : t -> t -> t +val pow : t -> int -> t +val arith : Arith.t -> t + +(* sequences (of flexible size) *) +val splat : t -> t +val sized : seq:t -> siz:t -> t +val extract : seq:t -> off:t -> len:t -> t +val concat : t array -> t + +(* records (with fixed indices) *) +val select : rcd:t -> idx:int -> t +val update : rcd:t -> idx:int -> elt:t -> t +val record : t array -> t +val ancestor : int -> t + +(* uninterpreted *) +val apply : Ses.Funsym.t -> t array -> t + +(** Transform *) + +val map_vars : t -> f:(Var.t -> Var.t) -> t + +(** Query *) + val seq_size_exn : t -> t val seq_size : t -> t option + +(** Traverse *) + val vars : t -> Var.t iter -val zero : t -val one : t -val map_vars : t -> f:(Var.t -> Var.t) -> t