[sledge] Cleanup and tracing in Trm and Arith

Summary:
No functional change. Mostly reorder definitions. Add some
tracing. Some documentation improvement. Remove a few unneeded
identifiers in patterns.

Reviewed By: jvillard

Differential Revision: D25883712

fbshipit-source-id: 1449a7f54
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 9945cc9495
commit b6ffeb179f

@ -211,6 +211,10 @@ struct
redundant representations, singleton polynomials are flattened. *) redundant representations, singleton polynomials are flattened. *)
let of_trm : ?power:int -> trm -> t = let of_trm : ?power:int -> trm -> t =
fun ?(power = 1) base -> fun ?(power = 1) base ->
[%trace]
~call:(fun {pf} -> pf "@ %a^%i" Trm.pp base power)
~retn:(fun {pf} (c, m) -> pf "%a×%a" Q.pp c Mono.pp m)
@@ fun () ->
match Embed.get_arith base with match Embed.get_arith base with
| Some poly -> ( | Some poly -> (
match Sum.classify poly with match Sum.classify poly with
@ -230,6 +234,10 @@ struct
polynomials are multiplied by their coefficients directly. *) polynomials are multiplied by their coefficients directly. *)
let to_poly : t -> Poly.t = let to_poly : t -> Poly.t =
fun (coeff, mono) -> fun (coeff, mono) ->
[%trace]
~call:(fun {pf} -> pf "@ %a×%a" Q.pp coeff Mono.pp mono)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
( match Mono.get_trm mono with ( match Mono.get_trm mono with
| Some trm -> ( | Some trm -> (
match Embed.get_arith trm with match Embed.get_arith trm with
@ -283,6 +291,13 @@ struct
Sum.partition_map poly ~f:(fun _ coeff -> Sum.partition_map poly ~f:(fun _ coeff ->
if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) ) if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) )
(* traverse *)
let monos poly =
Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono))
let trms poly = Iter.flat_map ~f:Mono.trms (monos poly)
let map poly ~f = let map poly ~f =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp poly) ~call:(fun {pf} -> pf "@ %a" pp poly)
@ -300,13 +315,6 @@ struct
Sum.union p p' ) Sum.union p p' )
|> check invariant |> check invariant
(* traverse *)
let monos poly =
Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono))
let trms poly = Iter.flat_map ~f:Mono.trms (monos poly)
(* query *) (* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p) let vars p = Iter.flat_map ~f:Trm.vars (trms p)

@ -33,7 +33,7 @@ module type S = sig
val classify : t -> kind val classify : t -> kind
(** [classify a] is [Trm x] iff [get_trm a] is [Some x], [Const q] iff (** [classify a] is [Trm x] iff [get_trm a] is [Some x], [Const q] iff
[get_const a] is [Some q], [Interpreted] if the principal operation of [get_const a] is [Some q], [Interpreted] if the principal operation of
[a] is interpreted, and [Uninterpreted] otherwise *) [a] is interpreted, and [Uninterpreted] otherwise. *)
(** Construct compound terms *) (** Construct compound terms *)
@ -56,15 +56,15 @@ module type S = sig
(** [partition_sign a] is [(p, n)] such that [a] = [p - n] and all (** [partition_sign a] is [(p, n)] such that [a] = [p - n] and all
coefficients in [p] and [n] are non-negative. *) coefficients in [p] and [n] are non-negative. *)
val map : t -> f:(trm -> trm) -> t
(** [map ~f a] is [a] with each maximal non-interpreted subterm
transformed by [f]. *)
(** Traverse *) (** Traverse *)
val trms : t -> trm iter val trms : t -> trm iter
(** [trms a] enumerates the indeterminate terms appearing in [a] *) (** [trms a] enumerates the indeterminate terms appearing in [a] *)
val map : t -> f:(trm -> trm) -> t
(** [map ~f a] is [a] with each maximal non-interpreted subterm
transformed by [f]. *)
(** Solve *) (** Solve *)
val solve_zero_eq : ?for_:trm -> trm -> (t * t) option val solve_zero_eq : ?for_:trm -> trm -> (t * t) option

@ -443,6 +443,28 @@ let concat elts = _Concat elts
let apply sym args = _Apply sym args let apply sym args = _Apply sym args
(** Traverse *)
let trms = function
| Var _ | Z _ | Q _ -> Iter.empty
| Arith a -> Arith.trms a
| Splat x -> Iter.(cons x empty)
| Sized {seq; siz} -> Iter.(cons seq (cons siz empty))
| Extract {seq; off; len} -> Iter.(cons seq (cons off (cons len empty)))
| Concat xs | Apply (_, xs) -> Iter.of_array xs
let is_atomic = function
| Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> true
| Arith _ | Splat _ | Sized _ | Extract _ | Concat _ | Apply _ -> false
let rec atoms e =
if is_atomic e then Iter.return e else Iter.flat_map ~f:atoms (trms e)
(** Query *)
let fv e = Var.Set.of_iter (vars e)
let rec height e = 1 + Iter.fold ~f:(fun x -> max (height x)) (trms e) 0
(** Transform *) (** Transform *)
let rec map_vars e ~f = let rec map_vars e ~f =
@ -470,26 +492,3 @@ let map e ~f =
map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e _Concat xs | Concat xs -> mapN f e _Concat xs
| Apply (g, xs) -> mapN f e (_Apply g) xs | Apply (g, xs) -> mapN f e (_Apply g) xs
(** Traverse *)
let trms = function
| Var _ | Z _ | Q _ -> Iter.empty
| Arith a -> Arith.trms a
| Splat x -> Iter.(cons x empty)
| Sized {seq= x; siz= y} -> Iter.(cons x (cons y empty))
| Extract {seq= x; off= y; len= z} ->
Iter.(cons x (cons y (cons z empty)))
| Concat xs | Apply (_, xs) -> Iter.of_array xs
let is_atomic = function
| Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> true
| Arith _ | Splat _ | Sized _ | Extract _ | Concat _ | Apply _ -> false
let rec atoms e =
if is_atomic e then Iter.return e else Iter.flat_map ~f:atoms (trms e)
(** Query *)
let fv e = Var.Set.of_iter (vars e)
let rec height e = 1 + Iter.fold ~f:(fun x -> max (height x)) (trms e) 0

@ -91,26 +91,30 @@ val apply : Funsym.t -> t array -> t
val get_z : t -> Z.t option val get_z : t -> Z.t option
val get_q : t -> Q.t option val get_q : t -> Q.t option
(** Transform *)
val map_vars : t -> f:(Var.t -> Var.t) -> t
val map : t -> f:(t -> t) -> t
(** Query *) (** Query *)
val seq_size_exn : t -> t val seq_size_exn : t -> t
val seq_size : t -> t option val seq_size : t -> t option
val is_atomic : t -> bool val is_atomic : t -> bool
val height : t -> int val height : t -> int
val fv : t -> Var.Set.t
(** Traverse *) (** Traverse *)
val trms : t -> t iter
(** The immediate subterms of a term. *)
val vars : t -> Var.t iter val vars : t -> Var.t iter
(** The variables that occur in a term. *) (** The variables that occur in a term. *)
val fv : t -> Var.Set.t
val trms : t -> t iter
(** The immediate subterms. *)
val atoms : t -> t iter val atoms : t -> t iter
(** The atomic reflexive-transitive subterms of a term. *) (** The atomic reflexive-transitive subterms. *)
(** Transform *)
val map_vars : t -> f:(Var.t -> Var.t) -> t
(** Map over the {!vars}. *)
val map : t -> f:(t -> t) -> t
(** Map over the {!trms}. *)

Loading…
Cancel
Save