* Copyright (c) Facebook, Inc. and its affiliates.
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
(** Terms *)
[@@@warning "+9"]
type op1 =
| Signed of {bits: int}
| Unsigned of {bits: int}
| Convert of {src: Llair.Typ.t; dst: Llair.Typ.t}
| Splat
| Select of int
[@@deriving compare, equal, hash, sexp]
type op2 =
| Eq
| Dq
| Lt
| Le
| Ord
| Uno
| Div
| Rem
| Xor
| Shl
| Lshr
| Ashr
| Sized
| Update of int
[@@deriving compare, equal, hash, sexp]
type op3 = Conditional | Extract [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
module rec Set : sig
include NS.Set.S with type elt := T.t
val t_of_sexp : Sexp.t -> t
end = struct
include NS.Set.Make (T)
include Provide_of_sexp (T)
and Qset : sig
include NS.Qset.S with type elt := T.t
val t_of_sexp : Sexp.t -> t
end = struct
include NS.Qset.Make (T)
let t_of_sexp = t_of_sexp T.t_of_sexp
and T : sig
type set = Set.t [@@deriving compare, equal, sexp]
type qset = Qset.t [@@deriving compare, equal, sexp]
type t =
| Var of {id: int; name: string}
| Ap1 of op1 * t
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t iarray
| And of set
| Or of set
| Add of qset
| Mul of qset
| Label of {parent: string; name: string}
| Float of {data: string}
| Integer of {data: Z.t}
| Rational of {data: Q.t}
| RecRecord of int
[@@deriving compare, equal, sexp]
end = struct
type set = Set.t [@@deriving compare, equal, sexp]
type qset = Qset.t [@@deriving compare, equal, sexp]
type t =
| Var of {id: int; name: string}
| Ap1 of op1 * t
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t iarray
| And of set
| Or of set
| Add of qset
| Mul of qset
| Label of {parent: string; name: string}
| Float of {data: string}
| Integer of {data: Z.t}
| Rational of {data: Q.t}
| RecRecord of int
[@@deriving compare, equal, sexp]
(* Note: solve (and invariant) requires Qset.min_elt to return a
non-coefficient, so Integer and Rational terms must compare higher than
any valid monomial *)
let compare x y =
if x == y then 0
match (x, y) with
| Var {id= i; name= _}, Var {id= j; name= _} when i > 0 && j > 0 ->
Int.compare i j
| _ -> compare x y
let equal x y =
x == y
match (x, y) with
| Var {id= i; name= _}, Var {id= j; name= _} when i > 0 && j > 0 ->
Int.equal i j
| _ -> equal x y
include T
module Map = struct
include Map.Make (T)
include Provide_of_sexp (T)
let rec ppx strength fs term =
let rec pp fs term =
let pf fmt =
Format.pp_open_box fs 2 ;
Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt
match term with
| Var {name; id= -1} -> Trace.pp_styled `Bold "%@%s" fs name
| Var {name; id= 0} -> Trace.pp_styled `Bold "%%%s" fs name
| Var {name; id} -> (
match strength term with
| None -> pf "%%%s_%d" name id
| Some `Universal -> Trace.pp_styled `Bold "%%%s_%d" fs name id
| Some `Existential -> Trace.pp_styled `Cyan "%%%s_%d" fs name id
| Some `Anonymous -> Trace.pp_styled `Cyan "_" fs )
| Integer {data} -> Trace.pp_styled `Magenta "%a" fs Z.pp data
| Rational {data} -> Trace.pp_styled `Magenta "%a" fs Q.pp data
| Float {data} -> pf "%s" data
| Label {name} -> pf "%s" name
| Ap1 (Signed {bits}, arg) -> pf "((s%i)@ %a)" bits pp arg
| Ap1 (Unsigned {bits}, arg) -> pf "((u%i)@ %a)" bits pp arg
| Ap1 (Convert {src; dst}, arg) ->
pf "((%a)(%a)@ %a)" Llair.Typ.pp dst Llair.Typ.pp src pp arg
| Ap2 (Eq, x, y) -> pf "(%a@ = %a)" pp x pp y
| Ap2 (Dq, x, y) -> pf "(%a@ @<2>≠ %a)" pp x pp y
| Ap2 (Lt, x, y) -> pf "(%a@ < %a)" pp x pp y
| Ap2 (Le, x, y) -> pf "(%a@ @<2>≤ %a)" pp x pp y
| Ap2 (Ord, x, y) -> pf "(%a@ ord %a)" pp x pp y
| Ap2 (Uno, x, y) -> pf "(%a@ uno %a)" pp x pp y
| Add args ->
let pp_poly_term fs (monomial, coefficient) =
match monomial with
| Integer {data} when Z.equal Z.one data -> Q.pp fs coefficient
| _ when Q.equal Q.one coefficient -> pp fs monomial
| _ ->
Format.fprintf fs "%a @<1>× %a" Q.pp coefficient pp monomial
pf "(%a)" (Qset.pp "@ + " pp_poly_term) args
| Mul args ->
let pp_mono_term fs (factor, exponent) =
if Q.equal Q.one exponent then pp fs factor
else Format.fprintf fs "%a^%a" pp factor Q.pp exponent
pf "(%a)" (Qset.pp "@ @<2>× " pp_mono_term) args
| Ap2 (Div, x, y) -> pf "(%a@ / %a)" pp x pp y
| Ap2 (Rem, x, y) -> pf "(%a@ rem %a)" pp x pp y
| And xs -> pf "(@[%a@])" (Set.pp ~sep:" &&@ " pp) xs
| Or xs -> pf "(@[%a@])" (Set.pp ~sep:" ||@ " pp) xs
| Ap2 (Xor, x, Integer {data}) when Z.is_true data -> pf "¬%a" pp x
| Ap2 (Xor, Integer {data}, x) when Z.is_true data -> pf "¬%a" pp x
| Ap2 (Xor, x, y) -> pf "(%a@ xor %a)" pp x pp y
| Ap2 (Shl, x, y) -> pf "(%a@ shl %a)" pp x pp y
| Ap2 (Lshr, x, y) -> pf "(%a@ lshr %a)" pp x pp y
| Ap2 (Ashr, x, y) -> pf "(%a@ ashr %a)" pp x pp y
| Ap3 (Conditional, cnd, thn, els) ->
pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els
| Ap3 (Extract, seq, off, len) -> pf "%a[%a,%a)" pp seq pp off pp len
| Ap1 (Splat, byt) -> pf "%a^" pp byt
| Ap2 (Sized, siz, arr) -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp arr
| ApN (Concat, args) when IArray.is_empty args -> pf "@<2>⟨⟩"
| ApN (Concat, args) -> pf "(%a)" (IArray.pp "@,^" pp) args
| ApN (Record, elts) -> pf "{%a}" (pp_record strength) elts
| Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx
| Ap2 (Update idx, rcd, elt) ->
pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt
| RecRecord i -> pf "(rec_record %i)" i
pp fs term
[@@warning "-9"]
and pp_record strength fs elts =
fs "%a"
(fun fs elts ->
String.init (IArray.length elts) ~f:(fun i ->
match IArray.get elts i with
| Integer {data} -> Char.of_int_exn (Z.to_int data)
| _ -> raise (Invalid_argument "not a string") )
| s -> Format.fprintf fs "@[<h>%s@]" (String.escaped s)
| exception _ ->
Format.fprintf fs "@[<h>%a@]"
(IArray.pp ",@ " (ppx strength))
elts )
let pp = ppx (fun _ -> None)
let pp_t = pp
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
(** Invariant *)
let assert_conjunction = function
| And cs ->
Set.iter cs ~f:(fun c ->
assert (match c with And _ -> false | _ -> true) )
| _ -> assert false
let assert_disjunction = function
| Or cs ->
Set.iter cs ~f:(fun c ->
assert (match c with Or _ -> false | _ -> true) )
| _ -> assert false
(* an indeterminate (factor of a monomial) is any
non-Add/Mul/Integer/Rational term *)
let assert_indeterminate = function
| Integer _ | Rational _ | Add _ | Mul _ -> assert false
| _ -> assert true
(* a monomial is a power product of factors, e.g.
* ∏ᵢ xᵢ^nᵢ
* for (non-constant) indeterminants xᵢ and positive integer exponents nᵢ
let assert_monomial mono =
match mono with
| Mul args ->
Qset.iter args ~f:(fun factor exponent ->
assert (Z.equal (Q.den exponent) Z.one) ;
assert (Q.sign exponent > 0) ;
assert_indeterminate factor |> Fn.id )
| _ -> assert_indeterminate mono |> Fn.id
(* a polynomial term is a monomial multiplied by a non-zero coefficient
* c × ∏ᵢ xᵢ
let assert_poly_term mono coeff =
assert (Q.is_real coeff) ;
assert (Q.sign coeff <> 0) ;
match mono with
| Integer {data} -> assert (Z.equal Z.one data)
| Mul args ->
( match Qset.min_elt args with
| None | Some ((Integer _ | Rational _), _) -> assert false
| Some (_, n) -> assert (Qset.length args > 1 || not (Q.equal Q.one n))
) ;
assert_monomial mono |> Fn.id
| _ -> assert_monomial mono |> Fn.id
(* a polynomial is a linear combination of monomials, e.g.
* ∑ᵢ cᵢ × ∏ⱼ xᵢⱼ
* for non-zero constant coefficients cᵢ
* and monomials ∏ⱼ xᵢⱼ, one of which may be the empty product 1
let assert_polynomial poly =
match poly with
| Add args ->
( match Qset.min_elt args with
| None | Some ((Integer _ | Rational _), _) -> assert false
| Some (_, k) -> assert (Qset.length args > 1 || not (Q.equal Q.one k))
) ;
Qset.iter args ~f:(fun m c -> assert_poly_term m c |> Fn.id)
| _ -> assert false
(* sequence args of Extract and Concat must be sequence terms, in
particular, not variables *)
let rec assert_sequence = function
| Ap2 (Sized, _, _) -> ()
| Ap3 (Extract, a, _, _) -> assert_sequence a
| ApN (Concat, a0N) ->
assert (IArray.length a0N <> 1) ;
IArray.iter ~f:assert_sequence a0N
| _ -> assert false
let invariant e =
let@ () = Invariant.invariant [%here] e [%sexp_of: t] in
match e with
| And _ -> assert_conjunction e |> Fn.id
| Or _ -> assert_disjunction e |> Fn.id
| Add _ -> assert_polynomial e |> Fn.id
| Mul _ -> assert_monomial e |> Fn.id
| Ap2 (Sized, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
assert_sequence e
| ApN (Record, elts) -> assert (not (IArray.is_empty elts))
| Ap1 (Convert {src= Integer _; dst= Integer _}, _) -> assert false
| Ap1 (Convert {src; dst}, _) ->
assert (Llair.Typ.convertible src dst) ;
assert (
(Llair.Typ.equivalent src dst)
(* avoid redundant representations *) )
| Rational {data} ->
assert (Q.is_real data) ;
assert (not (Z.equal Z.one (Q.den data)))
| _ -> ()
[@@warning "-9"]
(** Construct *)
(* variables *)
let var x = x
(* constants *)
let integer data = Integer {data} |> check invariant
let rational data =
( if Z.equal Z.one (Q.den data) then Integer {data= Q.num data}
else Rational {data} )
|> check invariant
let zero = integer Z.zero
let one = integer Z.one
let minus_one = integer Z.minus_one
let bool b = integer (Z.of_bool b)
let true_ = bool true
let false_ = bool false
let float data = Float {data} |> check invariant
let label ~parent ~name = Label {parent; name} |> check invariant
(* type conversions *)
let simp_signed bits arg =
match arg with
| Integer {data} -> integer (Z.signed_extract data 0 bits)
| _ -> Ap1 (Signed {bits}, arg)
let simp_unsigned bits arg =
match arg with
| Integer {data} -> integer (Z.extract data 0 bits)
| _ -> Ap1 (Unsigned {bits}, arg)
let simp_convert src dst arg = Ap1 (Convert {src; dst}, arg)
(* arithmetic *)
(* Sums of polynomial terms represented by multisets. A sum ∑ᵢ cᵢ × Xᵢ of
monomials Xᵢ with coefficients cᵢ is represented by a multiset where the
elements are Xᵢ with multiplicities cᵢ. A constant is treated as the
coefficient of the empty monomial, which is the unit of multiplication 1. *)
module Sum = struct
let empty = Qset.empty
let add coeff term sum =
assert (not (Q.equal Q.zero coeff)) ;
match term with
| Integer {data} when Z.equal Z.zero data -> sum
| Integer {data} -> Qset.add sum one Q.(coeff * of_z data)
| Rational {data} -> Qset.add sum one Q.(coeff * data)
| _ -> Qset.add sum term coeff
let of_ ?(coeff = Q.one) term = add coeff term empty
let map sum ~f =
Qset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum)
let mul_const const sum =
assert (not (Q.equal Q.zero const)) ;
if Q.equal Q.one const then sum
else Qset.map_counts ~f:(fun _ -> Q.mul const) sum
let to_term sum =
match Qset.classify sum with
| `Zero -> zero
| `One (arg, q) -> (
match arg with
| Integer {data} ->
assert (Z.equal Z.one data) ;
rational q
| _ when Q.equal Q.one q -> arg
| _ -> Add sum )
| `Many -> Add sum
(* Products of indeterminants represented by multisets. A product ∏ᵢ xᵢ^nᵢ
of indeterminates xᵢ is represented by a multiset where the elements are
xᵢ and the multiplicities are the exponents nᵢ. *)
module Prod = struct
let empty = Qset.empty
let add term prod =
assert (match term with Integer _ | Rational _ -> false | _ -> true) ;
Qset.add prod term Q.one
let of_ term = add term empty
let union = Qset.union
let to_term prod =
match Qset.pop prod with
| None -> one
| Some (factor, power, prod')
when Qset.is_empty prod' && Q.equal Q.one power ->
| _ -> Mul prod
let rec simp_add_ es poly =
(* (coeff × term) + poly *)
let f term coeff poly =
match (term, poly) with
(* (0 × e) + s ==> 0 (optim) *)
| _ when Q.equal Q.zero coeff -> poly
(* (c × 0) + s ==> s (optim) *)
| Integer {data}, _ when Z.equal Z.zero data -> poly
(* (c × cᵢ) + cⱼ ==> c×cᵢ+cⱼ *)
| Integer {data= i}, Integer {data= j} ->
rational Q.((coeff * of_z i) + of_z j)
| Rational {data= i}, Rational {data= j} -> rational Q.((coeff * i) + j)
(* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *)
| Add args, _ -> simp_add_ (Sum.mul_const coeff args) poly
(* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ cᵢ × Xᵢ *)
| _, Add args -> Sum.to_term (Sum.add coeff term args)
(* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *)
| _ -> Sum.to_term (Sum.add coeff term (Sum.of_ poly))
Qset.fold ~f es ~init:poly
and simp_mul2 e f =
match (e, f) with
(* c₁ × c₂ ==> c₁×c₂ *)
| Integer {data= i}, Integer {data= j} -> integer (Z.mul i j)
| Rational {data= i}, Rational {data= j} -> rational (Q.mul i j)
(* 0 × f ==> 0 *)
| Integer {data}, _ when Z.equal Z.zero data -> e
(* e × 0 ==> 0 *)
| _, Integer {data} when Z.equal Z.zero data -> f
(* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ yᵤⱼ *)
| Integer {data}, Add args | Add args, Integer {data} ->
Sum.to_term (Sum.mul_const (Q.of_z data) args)
| Rational {data}, Add args | Add args, Rational {data} ->
Sum.to_term (Sum.mul_const data args)
(* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *)
| Integer {data= c}, x | x, Integer {data= c} ->
Sum.to_term (Sum.of_ ~coeff:(Q.of_z c) x)
| Rational {data= c}, x | x, Rational {data= c} ->
Sum.to_term (Sum.of_ ~coeff:c x)
(* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *)
| Mul xs1, Mul xs2 -> Prod.to_term (Prod.union xs1 xs2)
(* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *)
| (Mul prod as m), Add sum | Add sum, (Mul prod as m) ->
(Sum.map sum ~f:(function
| Mul args -> Prod.to_term (Prod.union prod args)
| (Integer _ | Rational _) as c -> simp_mul2 c m
| mono -> Prod.to_term (Prod.add mono prod) ))
(* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *)
| Mul xs1, x | x, Mul xs1 -> Prod.to_term (Prod.add x xs1)
(* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *)
| Add args, e | e, Add args ->
simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero
(* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *)
| _ -> Prod.to_term (Prod.add e (Prod.of_ f))
let simp_div x y =
match (x, y) with
(* e / 0 ==> e / 0 *)
| _, Integer {data} when Z.equal Z.zero data -> Ap2 (Div, x, y)
(* e / 1 ==> e *)
| e, Integer {data} when Z.equal Z.one data -> e
(* e / -1 ==> -1×e *)
| e, (Integer {data} as c) when Z.equal Z.minus_one data -> simp_mul2 e c
(* i / j ==> i/j *)
| Integer {data= i}, Integer {data= j} -> integer (Z.div i j)
| Rational {data= i}, Rational {data= j} -> rational (Q.div i j)
(* (∑ᵢ cᵢ × Xᵢ) / z ==> ∑ᵢ cᵢ/z × Xᵢ *)
| Add args, Integer {data} ->
Sum.to_term (Sum.mul_const Q.(inv (of_z data)) args)
| Add args, Rational {data} ->
Sum.to_term (Sum.mul_const Q.(inv data) args)
(* x / n ==> 1/n × x *)
| _, Integer {data} -> Sum.to_term (Sum.of_ ~coeff:Q.(inv (of_z data)) x)
| _, Rational {data} -> Sum.to_term (Sum.of_ ~coeff:Q.(inv data) x)
(* x / y *)
| _ -> Ap2 (Div, x, y)
let simp_rem x y =
match (x, y) with
(* i % j *)
| Integer {data= i}, Integer {data= j} when not (Z.equal Z.zero j) ->
integer (Z.rem i j)
(* (n/d) % i ==> (n / d) % i *)
| Rational {data= q}, Integer {data= i} when not (Z.equal Z.zero i) ->
integer (Z.rem (Z.div q.num q.den) i)
(* e % 1 ==> 0 *)
| _, Integer {data} when Z.equal Z.one data -> zero
| _ -> Ap2 (Rem, x, y)
let simp_add es = simp_add_ es zero
let simp_add2 e f = simp_add_ (Sum.of_ e) f
let simp_negate x = simp_mul2 minus_one x
let simp_sub x y =
match (x, y) with
(* i - j *)
| Integer {data= i}, Integer {data= j} -> integer (Z.sub i j)
| Rational {data= i}, Rational {data= j} -> rational (Q.sub i j)
(* x - y ==> x + (-1 * y) *)
| _ -> simp_add2 x (simp_negate y)
let simp_mul es =
(* (bas ^ pwr) × term *)
let rec mul_pwr bas pwr term =
if Q.equal Q.zero pwr then term
else mul_pwr bas Q.(pwr - one) (simp_mul2 bas term)
Qset.fold es ~init:one ~f:(fun bas pwr term ->
if Q.sign pwr >= 0 then mul_pwr bas pwr term
else simp_div term (mul_pwr bas (Q.neg pwr) one) )
(* if-then-else *)
let simp_cond cnd thn els =
match cnd with
(* ¬(true ? t : e) ==> t *)
| Integer {data} when Z.is_true data -> thn
(* ¬(false ? t : e) ==> e *)
| Integer {data} when Z.is_false data -> els
| _ -> Ap3 (Conditional, cnd, thn, els)
(* boolean / bitwise *)
let rec is_boolean = function
| Ap1 ((Unsigned {bits= 1} | Convert {dst= Integer {bits= 1; _}; _}), _)
|Ap2 ((Eq | Dq | Lt | Le), _, _) ->
| Ap2 ((Div | Rem | Xor | Shl | Lshr | Ashr), x, y)
|Ap3 (Conditional, _, x, y) ->
is_boolean x || is_boolean y
| And xs | Or xs -> Set.for_all ~f:is_boolean xs
| _ -> false
let rec simp_and2 x y =
match (x, y) with
(* i && j *)
| Integer {data= i}, Integer {data= j} -> integer (Z.logand i j)
(* e && true ==> e *)
| (Integer {data}, e | e, Integer {data}) when Z.is_true data -> e
(* e && false ==> false *)
| ((Integer {data} as f), _ | _, (Integer {data} as f))
when Z.is_false data ->
(* e && (c ? t : f) ==> (c ? e && t : e && f) *)
| e, Ap3 (Conditional, c, t, f) | Ap3 (Conditional, c, t, f), e ->
simp_cond c (simp_and2 e t) (simp_and2 e f)
(* e && e ==> e *)
| _ when equal x y -> x
| _ ->
let add s = function And cs -> Set.union s cs | c -> Set.add s c in
And (add (add Set.empty x) y)
let simp_and xs = Set.fold xs ~init:true_ ~f:simp_and2
let rec simp_or2 x y =
match (x, y) with
(* i || j *)
| Integer {data= i}, Integer {data= j} -> integer (Z.logor i j)
(* e || true ==> true *)
| ((Integer {data} as t), _ | _, (Integer {data} as t))
when Z.is_true data ->
(* e || false ==> e *)
| (Integer {data}, e | e, Integer {data}) when Z.is_false data -> e
(* e || (c ? t : f) ==> (c ? e || t : e || f) *)
| e, Ap3 (Conditional, c, t, f) | Ap3 (Conditional, c, t, f), e ->
simp_cond c (simp_or2 e t) (simp_or2 e f)
(* e || e ==> e *)
| _ when equal x y -> x
| _ ->
let add s = function Or cs -> Set.union s cs | c -> Set.add s c in
Or (add (add Set.empty x) y)
let simp_or xs = Set.fold xs ~init:false_ ~f:simp_or2
(* sequence sizes *)
let rec seq_size_exn = function
| Ap2 (Sized, n, _) | Ap3 (Extract, _, _, n) -> n
| ApN (Concat, a0U) ->
IArray.fold a0U ~init:zero ~f:(fun a0I aJ ->
simp_add2 a0I (seq_size_exn aJ) )
| _ -> invalid_arg "seq_size_exn"
let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None
(* sequences *)
let empty_seq = ApN (Concat, IArray.of_array [||])
let simp_splat byt = Ap1 (Splat, byt)
let simp_sized siz arr =
(* ⟨n,α⟩ ==> α when n ≡ |α| *)
match seq_size arr with
| Some n when equal siz n -> arr
| _ -> Ap2 (Sized, siz, arr)
type pcmp = Lt | Eq | Gt | Unknown
let partial_compare x y : pcmp =
match simp_sub x y with
| Integer {data} -> (
match Int.sign (Z.sign data) with Neg -> Lt | Zero -> Eq | Pos -> Gt )
| Rational {data} -> (
match Int.sign (Q.sign data) with Neg -> Lt | Zero -> Eq | Pos -> Gt )
| _ -> Unknown
let partial_ge x y =
match partial_compare x y with Gt | Eq -> true | Lt | Unknown -> false
let rec simp_extract seq off len =
[%Trace.call fun {pf} -> pf "%a" pp (Ap3 (Extract, seq, off, len))]
(* _[_,0) ==> ⟨⟩ *)
( if equal len zero then empty_seq
let o_l = simp_add2 off len in
match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Ap3 (Extract, a, m, k) when partial_ge k o_l ->
simp_extract a (simp_add2 m off) len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Ap2 (Sized, n, (Ap1 (Splat, _) as e)) when partial_ge n o_l ->
simp_sized len e
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Ap2 (Sized, n, _) when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases:
* ⟨...⟩^⟨...⟩
* [,)
* o < o+l ≤ |α₀| : (α₀^α₁)[o,l) ==> α₀[o,l) ^ α₁[0,0)
* ⟨...⟩^⟨...⟩
* [ , )
* o ≤ |α₀| < o+l : (α₀^α₁)[o,l) ==> α₀[o,|α₀|-o) ^ α₁[0,l-(|α₀|-o))
* ⟨...⟩^⟨...⟩
* [,)
* |α₀| ≤ o : (α₀^α₁)[o,l) ==> α₀[o,0) ^ α₁[o-|α₀|,l)
* So in general:
* (α₀^α₁)[o,l) ==> α₀[o,l₀) ^ α₁[o₁,l-l₀)
* where l₀ = max 0 (min l |α₀|-o)
* o₁ = max 0 o-|α₀|
| ApN (Concat, na1N) -> (
match len with
| Integer {data= l} ->
IArray.fold_map_until na1N ~init:(l, off)
~f:(fun (l, oI) naI ->
let nI = seq_size_exn naI in
if Z.equal Z.zero l then
Continue ((l, oI), simp_extract naI oI zero)
let oI_nI = simp_sub oI nI in
match oI_nI with
| Integer {data} ->
let oJ = if Z.sign data <= 0 then zero else oI_nI in
let lI = Z.(max zero (min l (neg data))) in
let l = Z.(l - lI) in
Continue ((l, oJ), simp_extract naI oI (integer lI))
| _ -> Stop (Ap3 (Extract, seq, off, len)) )
~finish:(fun (_, e1N) -> simp_concat e1N)
| _ -> Ap3 (Extract, seq, off, len) )
(* α[o,l) *)
| _ -> Ap3 (Extract, seq, off, len) )
[%Trace.retn fun {pf} -> pf "%a" pp]
and simp_concat xs =
[%Trace.call fun {pf} -> pf "%a" pp (ApN (Concat, xs))]
(* (α^(β^γ)^δ) ==> (α^β^γ^δ) *)
let flatten xs =
let exists_sub_Concat =
IArray.exists ~f:(function ApN (Concat, _) -> true | _ -> false)
let concat_sub_Concat xs =
(IArray.fold_right xs ~init:[] ~f:(fun x s ->
match x with
| ApN (Concat, ys) -> ys :: s
| x -> IArray.of_array [|x|] :: s ))
if exists_sub_Concat xs then concat_sub_Concat xs else xs
let simp_adjacent e f =
match (e, f) with
(* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *)
| ( Ap3 (Extract, (Ap2 (Sized, n, _) as na), o, k)
, Ap3 (Extract, na', o_k, l) )
when equal na na'
&& equal o_k (simp_add2 o k)
&& partial_ge n (simp_add2 o_k l) ->
Some (simp_extract na o (simp_add2 k l))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Ap2 (Sized, m, (Ap1 (Splat, _) as a)), Ap2 (Sized, n, a')
when equal a a' ->
Some (simp_sized (simp_add2 m n) a)
| _ -> None
let xs = flatten xs in
let xs = IArray.combine_adjacent ~f:simp_adjacent xs in
(if IArray.length xs = 1 then IArray.get xs 0 else ApN (Concat, xs))
[%Trace.retn fun {pf} -> pf "%a" pp]
(* comparison *)
let simp_lt x y =
match (x, y) with
| Integer {data= i}, Integer {data= j} -> bool (Z.lt i j)
| Rational {data= i}, Rational {data= j} -> bool (Q.lt i j)
| _ -> Ap2 (Lt, x, y)
let simp_le x y =
match (x, y) with
| Integer {data= i}, Integer {data= j} -> bool (Z.leq i j)
| Rational {data= i}, Rational {data= j} -> bool (Q.leq i j)
| _ -> Ap2 (Le, x, y)
let simp_ord x y = Ap2 (Ord, x, y)
let simp_uno x y = Ap2 (Uno, x, y)
let rec simp_eq x y =
match Ordering.of_int (compare x y) with
| Equal -> None
| Less -> Some (x, y)
| Greater -> Some (y, x)
(* e = e ==> true *)
| None -> bool true
| Some (x, y) -> (
match (x, y) with
(* i = j ==> false when i ≠ j *)
| Integer _, Integer _ | Rational _, Rational _ -> bool false
(* b = false ==> ¬b *)
| b, Integer {data} when Z.is_false data && is_boolean b -> simp_not b
(* b = true ==> b *)
| b, Integer {data} when Z.is_true data && is_boolean b -> b
(* e = (c ? t : f) ==> (c ? e = t : e = f) *)
| e, Ap3 (Conditional, c, t, f) | Ap3 (Conditional, c, t, f), e ->
simp_cond c (simp_eq e t) (simp_eq e f)
(* α^β^δ = α^γ^δ ==> β = γ *)
| ApN (Concat, a), ApN (Concat, b) ->
let m = IArray.length a in
let n = IArray.length b in
let length_common_prefix =
let rec find_lcp i =
if equal (IArray.get a i) (IArray.get b i) then find_lcp (i + 1)
else i
find_lcp 0
let length_common_suffix =
let rec find_lcs i =
if equal (IArray.get a (m - 1 - i)) (IArray.get b (n - 1 - i))
then find_lcs (i + 1)
else i
find_lcs 0
let length_common = length_common_prefix + length_common_suffix in
if length_common = 0 then Ap2 (Eq, x, y)
let pos = length_common_prefix in
let a = IArray.sub ~pos ~len:(m - length_common) a in
let b = IArray.sub ~pos ~len:(n - length_common) b in
simp_eq (simp_concat a) (simp_concat b)
| ( (Ap2 (Sized, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _))
, (Ap2 (Sized, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _)) ) ->
Ap2 (Eq, x, y)
(* x = α ==> ⟨x,|α|⟩ = α *)
| ( x
, ((Ap2 (Sized, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _)) as a)
|( ((Ap2 (Sized, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _)) as a)
, x ) ->
simp_eq (Ap2 (Sized, seq_size_exn a, x)) a
| x, y -> Ap2 (Eq, x, y) )
and simp_dq x y =
match (x, y) with
(* e ≠ (c ? t : f) ==> (c ? e ≠ t : e ≠ f) *)
| e, Ap3 (Conditional, c, t, f) | Ap3 (Conditional, c, t, f), e ->
simp_cond c (simp_dq e t) (simp_dq e f)
| _ -> (
match simp_eq x y with
| Ap2 (Eq, x, y) -> Ap2 (Dq, x, y)
| b -> simp_not b )
(* negation-normal form *)
and simp_not term =
match term with
(* ¬(x = y) ==> x ≠ y *)
| Ap2 (Eq, x, y) -> simp_dq x y
(* ¬(x ≠ y) ==> x = y *)
| Ap2 (Dq, x, y) -> simp_eq x y
(* ¬(x < y) ==> y <= x *)
| Ap2 (Lt, x, y) -> simp_le y x
(* ¬(x <= y) ==> y < x *)
| Ap2 (Le, x, y) -> simp_lt y x
(* ¬(x ≠ nan ∧ y ≠ nan) ==> x = nan ∨ y = nan *)
| Ap2 (Ord, x, y) -> simp_uno x y
(* ¬(x = nan ∨ y = nan) ==> x ≠ nan ∧ y ≠ nan *)
| Ap2 (Uno, x, y) -> simp_ord x y
(* ¬(a ∧ b) ==> ¬a ∨ ¬b *)
| And xs -> simp_or (Set.map ~f:simp_not xs)
(* ¬(a ∨ b) ==> ¬a ∧ ¬b *)
| Or xs -> simp_and (Set.map ~f:simp_not xs)
(* ¬¬e ==> e *)
| Ap2 (Xor, Integer {data}, e) when Z.is_true data -> e
| Ap2 (Xor, e, Integer {data}) when Z.is_true data -> e
(* ¬(c ? t : e) ==> c ? ¬t : ¬e *)
| Ap3 (Conditional, cnd, thn, els) ->
simp_cond cnd (simp_not thn) (simp_not els)
(* ¬i ==> -i-1 *)
| Integer {data} -> integer (Z.lognot data)
(* ¬e ==> true xor e *)
| e -> Ap2 (Xor, true_, e)
(* bitwise *)
let simp_xor x y =
match (x, y) with
(* i xor j *)
| Integer {data= i}, Integer {data= j} -> integer (Z.logxor i j)
(* true xor b ==> ¬b *)
| Integer {data}, b when Z.is_true data && is_boolean b -> simp_not b
| b, Integer {data} when Z.is_true data && is_boolean b -> simp_not b
(* e xor e ==> 0 *)
| _ when equal x y -> zero
| _ -> Ap2 (Xor, x, y)
let simp_shl x y =
match (x, y) with
(* i shl j *)
| Integer {data= i}, Integer {data= j} when Z.sign j >= 0 ->
integer (Z.shift_left i (Z.to_int j))
(* e shl 0 ==> e *)
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> Ap2 (Shl, x, y)
let simp_lshr x y =
match (x, y) with
(* i lshr j *)
| Integer {data= i}, Integer {data= j} when Z.sign j >= 0 ->
integer (Z.shift_right_trunc i (Z.to_int j))
(* e lshr 0 ==> e *)
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> Ap2 (Lshr, x, y)
let simp_ashr x y =
match (x, y) with
(* i ashr j *)
| Integer {data= i}, Integer {data= j} when Z.sign j >= 0 ->
integer (Z.shift_right i (Z.to_int j))
(* e ashr 0 ==> e *)
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> Ap2 (Ashr, x, y)
(* records *)
let simp_record elts = ApN (Record, elts)
let simp_select idx rcd = Ap1 (Select idx, rcd)
let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt)
let simp_rec_record i = RecRecord i
(* dispatching for normalization and invariant checking *)
let norm1 op x =
( match op with
| Signed {bits} -> simp_signed bits x
| Unsigned {bits} -> simp_unsigned bits x
| Convert {src; dst} -> simp_convert src dst x
| Splat -> simp_splat x
| Select idx -> simp_select idx x )
|> check invariant
let norm2 op x y =
( match op with
| Sized -> simp_sized x y
| Eq -> simp_eq x y
| Dq -> simp_dq x y
| Lt -> simp_lt x y
| Le -> simp_le x y
| Ord -> simp_ord x y
| Uno -> simp_uno x y
| Div -> simp_div x y
| Rem -> simp_rem x y
| Xor -> simp_xor x y
| Shl -> simp_shl x y
| Lshr -> simp_lshr x y
| Ashr -> simp_ashr x y
| Update idx -> simp_update idx x y )
|> check invariant
let norm3 op x y z =
( match op with
| Conditional -> simp_cond x y z
| Extract -> simp_extract x y z )
|> check invariant
let normN op xs =
(match op with Concat -> simp_concat xs | Record -> simp_record xs)
|> check invariant
(* exposed interface *)
let signed bits term = norm1 (Signed {bits}) term
let unsigned bits term = norm1 (Unsigned {bits}) term
let convert src ~to_:dst arg =
if Llair.Typ.equivalent src dst then arg
else norm1 (Convert {src; dst}) arg
let eq = norm2 Eq
let dq = norm2 Dq
let lt = norm2 Lt
let le = norm2 Le
let ord = norm2 Ord
let uno = norm2 Uno
let neg e = simp_negate e |> check invariant
let add e f = simp_add2 e f |> check invariant
let addN args = simp_add args |> check invariant
let sub e f = simp_sub e f |> check invariant
let mul e f = simp_mul2 e f |> check invariant
let mulq q e = mul (rational q) e
let mulN args = simp_mul args |> check invariant
let div = norm2 Div
let rem = norm2 Rem
let and_ e f = simp_and2 e f |> check invariant
let or_ e f = simp_or2 e f |> check invariant
let andN es = simp_and es |> check invariant
let orN es = simp_or es |> check invariant
let not_ e = simp_not e |> check invariant
let xor = norm2 Xor
let shl = norm2 Shl
let lshr = norm2 Lshr
let ashr = norm2 Ashr
let conditional ~cnd ~thn ~els = norm3 Conditional cnd thn els
let splat byt = norm1 Splat byt
let sized ~seq ~siz = norm2 Sized siz seq
let extract ~seq ~off ~len = norm3 Extract seq off len
let concat xs = normN Concat (IArray.of_array xs)
let record elts = normN Record elts
let select ~rcd ~idx = norm1 (Select idx) rcd
let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt
let rec_record i = simp_rec_record i |> check invariant
let rec binary mk x y = mk (of_exp x) (of_exp y)
and ubinary mk typ x y =
let unsigned typ = unsigned (Llair.Typ.bit_size_of typ) in
mk (unsigned typ (of_exp x)) (unsigned typ (of_exp y))
and of_exp e =
match (e : Llair.Exp.t) with
| Reg {name; global; typ= _} -> Var {name; id= (if global then -1 else 0)}
| Label {parent; name} -> label ~parent ~name
| Integer {data; typ= _} -> integer data
| Float {data; typ= _} -> float data
| Ap1 (Signed {bits}, _, x) -> signed bits (of_exp x)
| Ap1 (Unsigned {bits}, _, x) -> unsigned bits (of_exp x)
| Ap1 (Convert {src}, dst, exp) -> convert src ~to_:dst (of_exp exp)
| Ap2 (Eq, _, x, y) -> binary eq x y
| Ap2 (Dq, _, x, y) -> binary dq x y
| Ap2 (Gt, _, x, y) -> binary lt y x
| Ap2 (Ge, _, x, y) -> binary le y x
| Ap2 (Lt, _, x, y) -> binary lt x y
| Ap2 (Le, _, x, y) -> binary le x y
| Ap2 (Ugt, typ, x, y) -> ubinary lt typ y x
| Ap2 (Uge, typ, x, y) -> ubinary le typ y x
| Ap2 (Ult, typ, x, y) -> ubinary lt typ x y
| Ap2 (Ule, typ, x, y) -> ubinary le typ x y
| Ap2 (Ord, _, x, y) -> binary ord x y
| Ap2 (Uno, _, x, y) -> binary uno x y
| Ap2 (Add, _, x, y) -> binary add x y
| Ap2 (Sub, _, x, y) -> binary sub x y
| Ap2 (Mul, _, x, y) -> binary mul x y
| Ap2 (Div, _, x, y) -> binary div x y
| Ap2 (Rem, _, x, y) -> binary rem x y
| Ap2 (Udiv, typ, x, y) -> ubinary div typ x y
| Ap2 (Urem, typ, x, y) -> ubinary rem typ x y
| Ap2 (And, _, x, y) -> binary and_ x y
| Ap2 (Or, _, x, y) -> binary or_ x y
| Ap2 (Xor, _, x, y) -> binary xor x y
| Ap2 (Shl, _, x, y) -> binary shl x y
| Ap2 (Lshr, _, x, y) -> binary lshr x y
| Ap2 (Ashr, _, x, y) -> binary ashr x y
| Ap3 (Conditional, _, cnd, thn, els) ->
conditional ~cnd:(of_exp cnd) ~thn:(of_exp thn) ~els:(of_exp els)
| Ap1 (Splat, _, byt) -> splat (of_exp byt)
| ApN (Record, _, elts) -> record (IArray.map ~f:of_exp elts)
| Ap1 (Select idx, _, rcd) -> select ~rcd:(of_exp rcd) ~idx
| Ap2 (Update idx, _, rcd, elt) ->
update ~rcd:(of_exp rcd) ~idx ~elt:(of_exp elt)
| RecRecord (i, _) -> rec_record i
(** Variables are the terms constructed by [Var] *)
module Var = struct
include T
let pp = pp
type strength = t -> [`Universal | `Existential | `Anonymous] option
let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
match x with Var _ -> invariant x | _ -> assert false
let id = function Var v -> v.id | x -> violates invariant x
let name = function Var v -> v.name | x -> violates invariant x
let of_ = function Var _ as v -> v | _ -> invalid_arg "Var.of_"
let of_term = function
| Var _ as v -> Some (v |> check invariant)
| _ -> None
let of_reg r =
match of_term (of_exp (r : Llair.Reg.t :> Llair.Exp.t)) with
| Some v -> v
| _ -> violates Llair.Reg.invariant r
let program ~name ~global = Var {name; id= (if global then -1 else 0)}
let fresh name ~wrt =
let max = match Set.max_elt wrt with None -> 0 | Some max -> id max in
let x' = Var {name; id= max + 1} in
(x', Set.add wrt x')
let identified ~name ~id = Var {name; id}
module Map = Map
module Set = struct
include Set
let pp vs = Set.pp pp_t vs
let ppx strength vs = Set.pp (ppx strength) vs
let pp_xs fs xs =
if not (is_empty xs) then
Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
let of_regs =
Llair.Reg.Set.fold ~init:empty ~f:(fun s r -> add s (of_reg r))
(** Destruct *)
let d_int = function Integer {data} -> Some data | _ -> None
(** Access *)
let const_of = function Add poly -> Some (Qset.count poly one) | _ -> None
(** Transform *)
let map e ~f =
let map1 op ~f x =
let x' = f x in
if x' == x then e else norm1 op x'
let map2 op ~f x y =
let x' = f x in
let y' = f y in
if x' == x && y' == y then e else norm2 op x' y'
let map3 op ~f x y z =
let x' = f x in
let y' = f y in
let z' = f z in
if x' == x && y' == y && z' == z then e else norm3 op x' y' z'
let mapN op ~f xs =
let xs' = IArray.map_endo ~f xs in
if xs' == xs then e else normN op xs'
let map_set mk ~f args =
let args' = Set.map ~f args in
if args' == args then e else mk args'
let map_qset mk ~f args =
let args' = Qset.map ~f:(fun arg q -> (f arg, q)) args in
if args' == args then e else mk args'
match e with
| And args -> map_set andN ~f args
| Or args -> map_set orN ~f args
| Add args -> map_qset addN ~f args
| Mul args -> map_qset mulN ~f args
| Ap1 (op, x) -> map1 op ~f x
| Ap2 (op, x, y) -> map2 op ~f x y
| Ap3 (op, x, y, z) -> map3 op ~f x y z
| ApN (op, xs) -> mapN op ~f xs
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> e
let fold_map e ~init ~f =
let s = ref init in
let f x =
let s', x' = f !s x in
s := s' ;
let e' = map e ~f in
(!s, e')
let rec map_rec_pre e ~f =
match f e with Some e' -> e' | None -> map ~f:(map_rec_pre ~f) e
let rec fold_map_rec_pre e ~init:s ~f =
match f s e with
| Some (s, e') -> (s, e')
| None -> fold_map ~f:(fun s e -> fold_map_rec_pre ~f ~init:s e) ~init:s e
let disjuncts e =
let rec disjuncts_ e =
match e with
| Or es ->
let e0, e1N = Set.pop_exn es in
Set.fold e1N ~init:(disjuncts_ e0) ~f:(fun cs e ->
Set.union cs (disjuncts_ e) )
| Ap3 (Conditional, cnd, thn, els) ->
(Set.of_ (and_ (orN (disjuncts_ cnd)) (orN (disjuncts_ thn))))
(and_ (orN (disjuncts_ (not_ cnd))) (orN (disjuncts_ els)))
| _ -> Set.of_ e
Set.elements (disjuncts_ e)
let rename f e =
map_rec_pre e ~f:(function Var _ as v -> Some (f v) | _ -> None)
(** Traverse *)
let iter e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) ->
f x ;
f y
| Ap3 (_, x, y, z) ->
f x ;
f y ;
f z
| ApN (_, xs) -> IArray.iter ~f xs
| And args | Or args -> Set.iter ~f args
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> ()
let exists e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x || f y
| Ap3 (_, x, y, z) -> f x || f y || f z
| ApN (_, xs) -> IArray.exists ~f xs
| And args | Or args -> Set.exists ~f args
| Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ ->
let for_all e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x && f y
| Ap3 (_, x, y, z) -> f x && f y && f z
| ApN (_, xs) -> IArray.for_all ~f xs
| And args | Or args -> Set.for_all ~f args
| Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> true
let fold e ~init:s ~f =
match e with
| Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) -> IArray.fold ~f:(fun s x -> f x s) xs ~init:s
| And args | Or args -> Set.fold ~f:(fun s e -> f e s) args ~init:s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> s
let rec iter_terms e ~f =
( match e with
| Ap1 (_, x) -> iter_terms ~f x
| Ap2 (_, x, y) ->
iter_terms ~f x ;
iter_terms ~f y
| Ap3 (_, x, y, z) ->
iter_terms ~f x ;
iter_terms ~f y ;
iter_terms ~f z
| ApN (_, xs) -> IArray.iter ~f:(iter_terms ~f) xs
| And args | Or args -> Set.iter args ~f:(iter_terms ~f)
| Add args | Mul args ->
Qset.iter args ~f:(fun arg _ -> iter_terms ~f arg)
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> ()
) ;
f e
let rec fold_terms e ~init:s ~f =
let fold_terms f e s = fold_terms e ~init:s ~f in
let s =
match e with
| Ap1 (_, x) -> fold_terms f x s
| Ap2 (_, x, y) -> fold_terms f y (fold_terms f x s)
| Ap3 (_, x, y, z) -> fold_terms f z (fold_terms f y (fold_terms f x s))
| ApN (_, xs) -> IArray.fold ~f:(fun s x -> fold_terms f x s) xs ~init:s
| And args | Or args ->
Set.fold args ~init:s ~f:(fun s x -> fold_terms f x s)
| Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms f arg s)
| Var _ | Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> s
f s e
let iter_vars e ~f =
iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ())
let exists_vars e ~f =
with_return (fun {return} ->
iter_vars e ~f:(fun v -> if f v then return true) ;
false )
let fold_vars e ~init ~f =
fold_terms e ~init ~f:(fun s -> function
| Var _ as v -> f s (v :> Var.t) | _ -> s )
(** Query *)
let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty
let is_true = function Integer {data} -> Z.is_true data | _ -> false
let is_false = function Integer {data} -> Z.is_false data | _ -> false
let rec is_constant = function
| Var _ -> false
| Label _ | Float _ | Integer _ | Rational _ -> true
| a -> for_all ~f:is_constant a
let rec height = function
| Var _ -> 0
| Ap1 (_, a) -> 1 + height a
| Ap2 (_, a, b) -> 1 + max (height a) (height b)
| Ap3 (_, a, b, c) -> 1 + max (height a) (max (height b) (height c))
| ApN (_, v) -> 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height a))
| And bs | Or bs ->
1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height a))
| Add qs | Mul qs ->
1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height a))
| Label _ | Float _ | Integer _ | Rational _ | RecRecord _ -> 0
(** Solve *)
let exists_fv_in vs qset =
Qset.exists qset ~f:(fun e _ -> exists_vars e ~f:(Var.Set.mem vs))
(* solve [0 = rejected_sum + (coeff × mono) + sum] *)
let solve_for_mono rejected_sum coeff mono sum =
match mono with
| Integer _ -> None
| _ ->
if exists_fv_in (fv mono) sum then None
( mono
, Sum.to_term
(Q.inv (Q.neg coeff))
(Qset.union rejected_sum sum)) )
(* solve [0 = rejected + sum] *)
let rec solve_sum rejected_sum sum =
let* mono, coeff, sum = Qset.pop_min_elt sum in
match solve_for_mono rejected_sum coeff mono sum with
| Some _ as soln -> soln
| None -> solve_sum (Qset.add rejected_sum mono coeff) sum
(* solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%Trace.call fun {pf} -> pf "0 = %a%a" pp e (Option.pp " for %a" pp) for_]
( match e with
| Add sum -> (
match for_ with
| None -> solve_sum Qset.empty sum
| Some mono ->
let* coeff, sum = Qset.find_and_remove sum mono in
solve_for_mono Qset.empty coeff mono sum )
| _ -> None )
[%Trace.retn fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (c, r) ->
Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ;
match (for_, s) with
| Some f, Some (c, _) -> assert (equal f c)
| _ -> ()]