diff --git a/sledge/src/arithmetic.ml b/sledge/src/arithmetic.ml index 52652712f..9a6a2988f 100644 --- a/sledge/src/arithmetic.ml +++ b/sledge/src/arithmetic.ml @@ -7,9 +7,13 @@ (** Arithmetic terms *) +open Ses.Var_intf include Arithmetic_intf -module Representation (Trm : INDETERMINATE) = struct +module Representation + (Var : VAR) + (Trm : INDETERMINATE with type var := Var.t) = +struct module Prod = struct include Multiset.Make (Int) @@ -64,6 +68,16 @@ module Representation (Trm : INDETERMINATE) = struct (** [get_trm m] is [Some x] iff [equal m (of_ x 1)] *) let get_trm mono = match Prod.only_elt mono with Some (trm, 1) -> Some trm | _ -> None + + (* traverse *) + + let trms mono = + Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm)) + + (* query *) + + let vars p = Iter.flat_map ~f:Trm.vars (trms p) + let fv p = Var.Set.of_iter (vars p) end module Sum = struct @@ -99,6 +113,8 @@ module Representation (Trm : INDETERMINATE) = struct (Sum.pp "@ + " pp_coeff_mono) poly + let pp = ppx (fun _ -> None) + let mono_invariant mono = let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in Prod.iter mono ~f:(fun base power -> @@ -268,14 +284,72 @@ module Representation (Trm : INDETERMINATE) = struct (* traverse *) - let iter poly = - Iter.from_iter (fun f -> - Sum.iter poly ~f:(fun mono _ -> - Prod.iter mono ~f:(fun trm _ -> f trm) ) ) + let monos poly = + Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono)) + + let trms poly = Iter.flat_map ~f:Mono.trms (monos poly) type product = Prod.t let fold_factors = Prod.fold let fold_monomials = Sum.fold + + (* query *) + + let vars p = Iter.flat_map ~f:Trm.vars (trms p) + + (* solve *) + + let exists_fv_in vs poly = Iter.exists ~f:(Var.Set.mem vs) (vars poly) + + (** [solve_for_mono r c m p] solves [0 = r + (c×m) + p] as [m = q] + ([Some (m, q)]) such that [r + (c×m) + p = m - q] *) + let solve_for_mono rejected_poly coeff mono poly = + if Mono.equal_one mono || exists_fv_in (Mono.fv mono) poly then None + else + Some + ( Sum.of_ mono Q.one + , mulc (Q.inv (Q.neg coeff)) (Sum.union rejected_poly poly) ) + + (** [solve_poly r p] solves [0 = r + p] as [m = q] ([Some (m, q)]) such + that [r + p = m - q] *) + let rec solve_poly rejected poly = + [%trace] + ~call:(fun {pf} -> pf "0 = (%a) + (%a)" pp rejected pp poly) + ~retn:(fun {pf} s -> + pf "%a" + (Option.pp "%a" (fun fs (v, q) -> + Format.fprintf fs "%a ↦ %a" pp v pp q )) + s ) + @@ fun () -> + let* mono, coeff, poly = Sum.pop_min_elt poly in + match solve_for_mono rejected coeff mono poly with + | Some _ as soln -> soln + | None -> solve_poly (Sum.add mono coeff rejected) poly + + (* solve [0 = e] *) + let solve_zero_eq ?for_ e = + [%trace] + ~call:(fun {pf} -> + pf "0 = %a%a" Trm.pp e (Option.pp " for %a" Trm.pp) for_ ) + ~retn:(fun {pf} s -> + pf "%a" + (Option.pp "%a" (fun fs (c, r) -> + Format.fprintf fs "%a ↦ %a" pp c pp r )) + s ; + match (for_, s) with + | Some f, Some (c, _) -> assert (equal (trm f) c) + | _ -> () ) + @@ fun () -> + let* a = Embed.get_arith e in + match for_ with + | None -> solve_poly Sum.empty a + | Some for_ -> ( + let* for_poly = Embed.get_arith for_ in + match get_mono for_poly with + | Some m -> + let* c, p = Sum.find_and_remove m a in + solve_for_mono Sum.empty c m p + | _ -> None ) end end diff --git a/sledge/src/arithmetic.mli b/sledge/src/arithmetic.mli index 96bf41b65..fdf3d92c1 100644 --- a/sledge/src/arithmetic.mli +++ b/sledge/src/arithmetic.mli @@ -7,9 +7,10 @@ (** Arithmetic terms *) +open Ses.Var_intf include module type of Arithmetic_intf -module Representation (Indeterminate : INDETERMINATE) : - REPRESENTATION - with type var := Indeterminate.var - with type trm := Indeterminate.trm +module Representation + (Var : VAR) + (Indeterminate : INDETERMINATE with type var := Var.t) : + REPRESENTATION with type var := Var.t with type trm := Indeterminate.trm diff --git a/sledge/src/arithmetic_intf.ml b/sledge/src/arithmetic_intf.ml index c306102c4..69f321a32 100644 --- a/sledge/src/arithmetic_intf.ml +++ b/sledge/src/arithmetic_intf.ml @@ -61,8 +61,15 @@ module type S = sig (** Traverse *) - val iter : t -> trm iter - (** [iter a] enumerates the indeterminate terms appearing in [a] *) + val trms : t -> trm iter + (** [trms a] enumerates the indeterminate terms appearing in [a] *) + + (** Solve *) + + val solve_zero_eq : ?for_:trm -> trm -> (t * t) option + (** [solve_zero_eq d] is [Some (e, f)] if [0 = d] can be equivalently + expressed as [e = f] for some monomial subterm [e] of [d]. If [for_] + is passed, then the subterm [e] must be [for_]. *) (**/**) @@ -79,6 +86,8 @@ module type INDETERMINATE = sig type var val ppx : var Var.strength -> trm pp + val pp : trm pp + val vars : trm -> var iter end (** An embedding of arithmetic terms [t] into indeterminates [trm], diff --git a/sledge/src/trm.ml b/sledge/src/trm.ml index eaea18654..ac94572aa 100644 --- a/sledge/src/trm.ml +++ b/sledge/src/trm.ml @@ -40,12 +40,13 @@ end and Arith0 : (Arithmetic.REPRESENTATION with type var := Var.t with type trm := Trm.t) = -Arithmetic.Representation (struct - type trm = Trm.t [@@deriving compare, equal, sexp] - type var = Var.t + Arithmetic.Representation + (Var) + (struct + include Trm - let ppx = Trm.ppx -end) + type trm = t [@@deriving compare, equal, sexp] + end) and Arith : (Arithmetic.S @@ -89,6 +90,7 @@ and Trm : sig [@@deriving compare, equal, sexp] val ppx : Var.t Var.strength -> t pp + val pp : t pp val _Var : int -> string -> t val _Z : Z.t -> t val _Q : Q.t -> t @@ -106,6 +108,7 @@ and Trm : sig val sub : t -> t -> t val seq_size_exn : t -> t val seq_size : t -> t option + val vars : t -> Var.t iter end = struct type t = | Var of {id: int; name: string} @@ -198,12 +201,12 @@ end = struct | Trm _ | Const _ -> assert false ) | _ -> () - (* destructors *) + (** Destruct *) let get_z = function Z z -> Some z | _ -> None let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None - (* constructors *) + (** Construct *) let _Var id name = Var {id; name} |> check invariant @@ -361,6 +364,26 @@ end = struct | Some c -> c | None -> Apply (f, es) ) |> check invariant + + (** Traverse *) + + let rec iter_vars e ~f = + match e with + | Var _ as v -> f (Var.of_ v) + | Z _ | Q _ | Ancestor _ -> () + | Splat x | Select {rcd= x} -> iter_vars ~f x + | Sized {seq= x; siz= y} | Update {rcd= x; elt= y} -> + iter_vars ~f x ; + iter_vars ~f y + | Extract {seq= x; off= y; len= z} -> + iter_vars ~f x ; + iter_vars ~f y ; + iter_vars ~f z + | Concat xs | Record xs | Apply (_, xs) -> + Array.iter ~f:(iter_vars ~f) xs + | Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a) + + let vars e = Iter.from_labelled_iter (iter_vars e) end type arith = Arith.t @@ -424,22 +447,3 @@ let rec map_vars e ~f = | Record xs -> mapN (map_vars ~f) e _Record xs | Ancestor _ -> e | Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs - -(** Traverse *) - -let rec iter_vars e ~f = - match e with - | Var _ as v -> f (Var.of_ v) - | Z _ | Q _ | Ancestor _ -> () - | Splat x | Select {rcd= x} -> iter_vars ~f x - | Sized {seq= x; siz= y} | Update {rcd= x; elt= y} -> - iter_vars ~f x ; - iter_vars ~f y - | Extract {seq= x; off= y; len= z} -> - iter_vars ~f x ; - iter_vars ~f y ; - iter_vars ~f z - | Concat xs | Record xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs - | Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.iter a) - -let vars e = Iter.from_labelled_iter (iter_vars e)