[sledge] Add Arithmetic.solve_zero_eq

Reviewed By: jvillard

Differential Revision: D24532343

fbshipit-source-id: d7d2f6fd2
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent f007b774f4
commit e4749098b2

@ -7,9 +7,13 @@
(** Arithmetic terms *) (** Arithmetic terms *)
open Ses.Var_intf
include Arithmetic_intf include Arithmetic_intf
module Representation (Trm : INDETERMINATE) = struct module Representation
(Var : VAR)
(Trm : INDETERMINATE with type var := Var.t) =
struct
module Prod = struct module Prod = struct
include Multiset.Make include Multiset.Make
(Int) (Int)
@ -64,6 +68,16 @@ module Representation (Trm : INDETERMINATE) = struct
(** [get_trm m] is [Some x] iff [equal m (of_ x 1)] *) (** [get_trm m] is [Some x] iff [equal m (of_ x 1)] *)
let get_trm mono = let get_trm mono =
match Prod.only_elt mono with Some (trm, 1) -> Some trm | _ -> None match Prod.only_elt mono with Some (trm, 1) -> Some trm | _ -> None
(* traverse *)
let trms mono =
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
(* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
let fv p = Var.Set.of_iter (vars p)
end end
module Sum = struct module Sum = struct
@ -99,6 +113,8 @@ module Representation (Trm : INDETERMINATE) = struct
(Sum.pp "@ + " pp_coeff_mono) (Sum.pp "@ + " pp_coeff_mono)
poly poly
let pp = ppx (fun _ -> None)
let mono_invariant mono = let mono_invariant mono =
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
Prod.iter mono ~f:(fun base power -> Prod.iter mono ~f:(fun base power ->
@ -268,14 +284,72 @@ module Representation (Trm : INDETERMINATE) = struct
(* traverse *) (* traverse *)
let iter poly = let monos poly =
Iter.from_iter (fun f -> Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono))
Sum.iter poly ~f:(fun mono _ ->
Prod.iter mono ~f:(fun trm _ -> f trm) ) ) let trms poly = Iter.flat_map ~f:Mono.trms (monos poly)
type product = Prod.t type product = Prod.t
let fold_factors = Prod.fold let fold_factors = Prod.fold
let fold_monomials = Sum.fold let fold_monomials = Sum.fold
(* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* solve *)
let exists_fv_in vs poly = Iter.exists ~f:(Var.Set.mem vs) (vars poly)
(** [solve_for_mono r c m p] solves [0 = r + (c×m) + p] as [m = q]
([Some (m, q)]) such that [r + (c×m) + p = m - q] *)
let solve_for_mono rejected_poly coeff mono poly =
if Mono.equal_one mono || exists_fv_in (Mono.fv mono) poly then None
else
Some
( Sum.of_ mono Q.one
, mulc (Q.inv (Q.neg coeff)) (Sum.union rejected_poly poly) )
(** [solve_poly r p] solves [0 = r + p] as [m = q] ([Some (m, q)]) such
that [r + p = m - q] *)
let rec solve_poly rejected poly =
[%trace]
~call:(fun {pf} -> pf "0 = (%a) + (%a)" pp rejected pp poly)
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (v, q) ->
Format.fprintf fs "%a ↦ %a" pp v pp q ))
s )
@@ fun () ->
let* mono, coeff, poly = Sum.pop_min_elt poly in
match solve_for_mono rejected coeff mono poly with
| Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly
(* solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%trace]
~call:(fun {pf} ->
pf "0 = %a%a" Trm.pp e (Option.pp " for %a" Trm.pp) for_ )
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (c, r) ->
Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ;
match (for_, s) with
| Some f, Some (c, _) -> assert (equal (trm f) c)
| _ -> () )
@@ fun () ->
let* a = Embed.get_arith e in
match for_ with
| None -> solve_poly Sum.empty a
| Some for_ -> (
let* for_poly = Embed.get_arith for_ in
match get_mono for_poly with
| Some m ->
let* c, p = Sum.find_and_remove m a in
solve_for_mono Sum.empty c m p
| _ -> None )
end end
end end

@ -7,9 +7,10 @@
(** Arithmetic terms *) (** Arithmetic terms *)
open Ses.Var_intf
include module type of Arithmetic_intf include module type of Arithmetic_intf
module Representation (Indeterminate : INDETERMINATE) : module Representation
REPRESENTATION (Var : VAR)
with type var := Indeterminate.var (Indeterminate : INDETERMINATE with type var := Var.t) :
with type trm := Indeterminate.trm REPRESENTATION with type var := Var.t with type trm := Indeterminate.trm

@ -61,8 +61,15 @@ module type S = sig
(** Traverse *) (** Traverse *)
val iter : t -> trm iter val trms : t -> trm iter
(** [iter a] enumerates the indeterminate terms appearing in [a] *) (** [trms a] enumerates the indeterminate terms appearing in [a] *)
(** Solve *)
val solve_zero_eq : ?for_:trm -> trm -> (t * t) option
(** [solve_zero_eq d] is [Some (e, f)] if [0 = d] can be equivalently
expressed as [e = f] for some monomial subterm [e] of [d]. If [for_]
is passed, then the subterm [e] must be [for_]. *)
(**/**) (**/**)
@ -79,6 +86,8 @@ module type INDETERMINATE = sig
type var type var
val ppx : var Var.strength -> trm pp val ppx : var Var.strength -> trm pp
val pp : trm pp
val vars : trm -> var iter
end end
(** An embedding of arithmetic terms [t] into indeterminates [trm], (** An embedding of arithmetic terms [t] into indeterminates [trm],

@ -40,12 +40,13 @@ end
and Arith0 : and Arith0 :
(Arithmetic.REPRESENTATION with type var := Var.t with type trm := Trm.t) = (Arithmetic.REPRESENTATION with type var := Var.t with type trm := Trm.t) =
Arithmetic.Representation (struct Arithmetic.Representation
type trm = Trm.t [@@deriving compare, equal, sexp] (Var)
type var = Var.t (struct
include Trm
let ppx = Trm.ppx type trm = t [@@deriving compare, equal, sexp]
end) end)
and Arith : and Arith :
(Arithmetic.S (Arithmetic.S
@ -89,6 +90,7 @@ and Trm : sig
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
val ppx : Var.t Var.strength -> t pp val ppx : Var.t Var.strength -> t pp
val pp : t pp
val _Var : int -> string -> t val _Var : int -> string -> t
val _Z : Z.t -> t val _Z : Z.t -> t
val _Q : Q.t -> t val _Q : Q.t -> t
@ -106,6 +108,7 @@ and Trm : sig
val sub : t -> t -> t val sub : t -> t -> t
val seq_size_exn : t -> t val seq_size_exn : t -> t
val seq_size : t -> t option val seq_size : t -> t option
val vars : t -> Var.t iter
end = struct end = struct
type t = type t =
| Var of {id: int; name: string} | Var of {id: int; name: string}
@ -198,12 +201,12 @@ end = struct
| Trm _ | Const _ -> assert false ) | Trm _ | Const _ -> assert false )
| _ -> () | _ -> ()
(* destructors *) (** Destruct *)
let get_z = function Z z -> Some z | _ -> None let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
(* constructors *) (** Construct *)
let _Var id name = Var {id; name} |> check invariant let _Var id name = Var {id; name} |> check invariant
@ -361,6 +364,26 @@ end = struct
| Some c -> c | Some c -> c
| None -> Apply (f, es) ) | None -> Apply (f, es) )
|> check invariant |> check invariant
(** Traverse *)
let rec iter_vars e ~f =
match e with
| Var _ as v -> f (Var.of_ v)
| Z _ | Q _ | Ancestor _ -> ()
| Splat x | Select {rcd= x} -> iter_vars ~f x
| Sized {seq= x; siz= y} | Update {rcd= x; elt= y} ->
iter_vars ~f x ;
iter_vars ~f y
| Extract {seq= x; off= y; len= z} ->
iter_vars ~f x ;
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Record xs | Apply (_, xs) ->
Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a)
let vars e = Iter.from_labelled_iter (iter_vars e)
end end
type arith = Arith.t type arith = Arith.t
@ -424,22 +447,3 @@ let rec map_vars e ~f =
| Record xs -> mapN (map_vars ~f) e _Record xs | Record xs -> mapN (map_vars ~f) e _Record xs
| Ancestor _ -> e | Ancestor _ -> e
| Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs | Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs
(** Traverse *)
let rec iter_vars e ~f =
match e with
| Var _ as v -> f (Var.of_ v)
| Z _ | Q _ | Ancestor _ -> ()
| Splat x | Select {rcd= x} -> iter_vars ~f x
| Sized {seq= x; siz= y} | Update {rcd= x; elt= y} ->
iter_vars ~f x ;
iter_vars ~f y
| Extract {seq= x; off= y; len= z} ->
iter_vars ~f x ;
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Record xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.iter a)
let vars e = Iter.from_labelled_iter (iter_vars e)

Loading…
Cancel
Save