From b6ffeb179fbf0aedbaa6030544c3140aeae7e32e Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Tue, 2 Feb 2021 04:34:06 -0800 Subject: [PATCH] [sledge] Cleanup and tracing in Trm and Arith Summary: No functional change. Mostly reorder definitions. Add some tracing. Some documentation improvement. Remove a few unneeded identifiers in patterns. Reviewed By: jvillard Differential Revision: D25883712 fbshipit-source-id: 1449a7f54 --- sledge/src/fol/arithmetic.ml | 22 ++++++++++----- sledge/src/fol/arithmetic_intf.ml | 10 +++---- sledge/src/fol/trm.ml | 45 +++++++++++++++---------------- sledge/src/fol/trm.mli | 24 ++++++++++------- 4 files changed, 56 insertions(+), 45 deletions(-) diff --git a/sledge/src/fol/arithmetic.ml b/sledge/src/fol/arithmetic.ml index 461c0a296..9b03436a3 100644 --- a/sledge/src/fol/arithmetic.ml +++ b/sledge/src/fol/arithmetic.ml @@ -211,6 +211,10 @@ struct redundant representations, singleton polynomials are flattened. *) let of_trm : ?power:int -> trm -> t = fun ?(power = 1) base -> + [%trace] + ~call:(fun {pf} -> pf "@ %a^%i" Trm.pp base power) + ~retn:(fun {pf} (c, m) -> pf "%a×%a" Q.pp c Mono.pp m) + @@ fun () -> match Embed.get_arith base with | Some poly -> ( match Sum.classify poly with @@ -230,6 +234,10 @@ struct polynomials are multiplied by their coefficients directly. *) let to_poly : t -> Poly.t = fun (coeff, mono) -> + [%trace] + ~call:(fun {pf} -> pf "@ %a×%a" Q.pp coeff Mono.pp mono) + ~retn:(fun {pf} -> pf "%a" pp) + @@ fun () -> ( match Mono.get_trm mono with | Some trm -> ( match Embed.get_arith trm with @@ -283,6 +291,13 @@ struct Sum.partition_map poly ~f:(fun _ coeff -> if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) ) + (* traverse *) + + let monos poly = + Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono)) + + let trms poly = Iter.flat_map ~f:Mono.trms (monos poly) + let map poly ~f = [%trace] ~call:(fun {pf} -> pf "@ %a" pp poly) @@ -300,13 +315,6 @@ struct Sum.union p p' ) |> check invariant - (* traverse *) - - let monos poly = - Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono)) - - let trms poly = Iter.flat_map ~f:Mono.trms (monos poly) - (* query *) let vars p = Iter.flat_map ~f:Trm.vars (trms p) diff --git a/sledge/src/fol/arithmetic_intf.ml b/sledge/src/fol/arithmetic_intf.ml index fe99f475e..0c768d13a 100644 --- a/sledge/src/fol/arithmetic_intf.ml +++ b/sledge/src/fol/arithmetic_intf.ml @@ -33,7 +33,7 @@ module type S = sig val classify : t -> kind (** [classify a] is [Trm x] iff [get_trm a] is [Some x], [Const q] iff [get_const a] is [Some q], [Interpreted] if the principal operation of - [a] is interpreted, and [Uninterpreted] otherwise *) + [a] is interpreted, and [Uninterpreted] otherwise. *) (** Construct compound terms *) @@ -56,15 +56,15 @@ module type S = sig (** [partition_sign a] is [(p, n)] such that [a] = [p - n] and all coefficients in [p] and [n] are non-negative. *) - val map : t -> f:(trm -> trm) -> t - (** [map ~f a] is [a] with each maximal non-interpreted subterm - transformed by [f]. *) - (** Traverse *) val trms : t -> trm iter (** [trms a] enumerates the indeterminate terms appearing in [a] *) + val map : t -> f:(trm -> trm) -> t + (** [map ~f a] is [a] with each maximal non-interpreted subterm + transformed by [f]. *) + (** Solve *) val solve_zero_eq : ?for_:trm -> trm -> (t * t) option diff --git a/sledge/src/fol/trm.ml b/sledge/src/fol/trm.ml index b623b2f4e..ebcea49d7 100644 --- a/sledge/src/fol/trm.ml +++ b/sledge/src/fol/trm.ml @@ -443,6 +443,28 @@ let concat elts = _Concat elts let apply sym args = _Apply sym args +(** Traverse *) + +let trms = function + | Var _ | Z _ | Q _ -> Iter.empty + | Arith a -> Arith.trms a + | Splat x -> Iter.(cons x empty) + | Sized {seq; siz} -> Iter.(cons seq (cons siz empty)) + | Extract {seq; off; len} -> Iter.(cons seq (cons off (cons len empty))) + | Concat xs | Apply (_, xs) -> Iter.of_array xs + +let is_atomic = function + | Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> true + | Arith _ | Splat _ | Sized _ | Extract _ | Concat _ | Apply _ -> false + +let rec atoms e = + if is_atomic e then Iter.return e else Iter.flat_map ~f:atoms (trms e) + +(** Query *) + +let fv e = Var.Set.of_iter (vars e) +let rec height e = 1 + Iter.fold ~f:(fun x -> max (height x)) (trms e) 0 + (** Transform *) let rec map_vars e ~f = @@ -470,26 +492,3 @@ let map e ~f = map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len | Concat xs -> mapN f e _Concat xs | Apply (g, xs) -> mapN f e (_Apply g) xs - -(** Traverse *) - -let trms = function - | Var _ | Z _ | Q _ -> Iter.empty - | Arith a -> Arith.trms a - | Splat x -> Iter.(cons x empty) - | Sized {seq= x; siz= y} -> Iter.(cons x (cons y empty)) - | Extract {seq= x; off= y; len= z} -> - Iter.(cons x (cons y (cons z empty))) - | Concat xs | Apply (_, xs) -> Iter.of_array xs - -let is_atomic = function - | Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> true - | Arith _ | Splat _ | Sized _ | Extract _ | Concat _ | Apply _ -> false - -let rec atoms e = - if is_atomic e then Iter.return e else Iter.flat_map ~f:atoms (trms e) - -(** Query *) - -let fv e = Var.Set.of_iter (vars e) -let rec height e = 1 + Iter.fold ~f:(fun x -> max (height x)) (trms e) 0 diff --git a/sledge/src/fol/trm.mli b/sledge/src/fol/trm.mli index 92b602f65..7a91149b9 100644 --- a/sledge/src/fol/trm.mli +++ b/sledge/src/fol/trm.mli @@ -91,26 +91,30 @@ val apply : Funsym.t -> t array -> t val get_z : t -> Z.t option val get_q : t -> Q.t option -(** Transform *) - -val map_vars : t -> f:(Var.t -> Var.t) -> t -val map : t -> f:(t -> t) -> t - (** Query *) val seq_size_exn : t -> t val seq_size : t -> t option val is_atomic : t -> bool val height : t -> int -val fv : t -> Var.Set.t (** Traverse *) -val trms : t -> t iter -(** The immediate subterms of a term. *) - val vars : t -> Var.t iter (** The variables that occur in a term. *) +val fv : t -> Var.Set.t + +val trms : t -> t iter +(** The immediate subterms. *) + val atoms : t -> t iter -(** The atomic reflexive-transitive subterms of a term. *) +(** The atomic reflexive-transitive subterms. *) + +(** Transform *) + +val map_vars : t -> f:(Var.t -> Var.t) -> t +(** Map over the {!vars}. *) + +val map : t -> f:(t -> t) -> t +(** Map over the {!trms}. *)