@ -502,8 +502,16 @@ module Prod = struct
assert ( match term with Integer _ | Rational _ -> false | _ -> true ) ;
Qset . add prod term Q . one
let singleton term = add term empty
let of_ term = add term empty
let union = Qset . union
let to_term prod =
match Qset . pop prod with
| None -> one
| Some ( factor , power , prod' )
when Qset . is_empty prod' && Q . equal Q . one power ->
factor
| _ -> Mul prod
end
let rec simp_add_ es poly =
@ -547,21 +555,21 @@ and simp_mul2 e f =
| Rational { data = c } , x | x , Rational { data = c } ->
Sum . to_term ( Sum . of_ ~ coeff : c x )
(* ( ∏ᵤ₌₀ⁱ xᵤ ) × ( ∏ᵥ₌ᵢ₊₁ⁿ xᵥ ) ==> ∏ⱼ₌₀ⁿ xⱼ *)
| Mul xs1 , Mul xs2 -> Mul ( Prod . union xs1 xs2 )
| Mul xs1 , Mul xs2 -> Prod . to_term ( Prod . union xs1 xs2 )
(* ( ∏ᵢ xᵢ ) × ( ∑ᵤ cᵤ × ∏ⱼ yᵤⱼ ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *)
| ( Mul prod as m ) , Add sum | Add sum , ( Mul prod as m ) ->
Sum . to_term
( Sum . map sum ~ f : ( function
| Mul args -> Mul ( Prod . union prod args )
| Mul args -> Prod . to_term ( Prod . union prod args )
| ( Integer _ | Rational _ ) as c -> simp_mul2 c m
| mono -> Mul ( Prod . add mono prod ) ) )
| mono -> Prod . to_term ( Prod . add mono prod ) ) )
(* x₀ × ( ∏ᵢ₌₁ⁿ xᵢ ) ==> ∏ᵢ₌₀ⁿ xᵢ *)
| Mul xs1 , x | x , Mul xs1 -> Mul ( Prod . add x xs1 )
| Mul xs1 , x | x , Mul xs1 -> Prod . to_term ( Prod . add x xs1 )
(* e × ( ∑ᵤ cᵤ × ∏ⱼ yᵤⱼ ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *)
| Add args , e | e , Add args ->
simp_add_ ( Sum . map ~ f : ( fun m -> simp_mul2 e m ) args ) zero
(* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *)
| _ -> Mul ( Prod . add e ( Prod . singleton f ) )
| _ -> Prod . to_term ( Prod . add e ( Prod . of_ f ) )
let rec simp_div x y =
match ( x , y ) with
@ -1222,6 +1230,11 @@ let fold_terms e ~init ~f =
let iter_vars e ~ f =
iter_terms e ~ f : ( function Var _ as v -> f ( v :> Var . t ) | _ -> () )
let exists_vars e ~ f =
with_return ( fun { return } ->
iter_vars e ~ f : ( fun v -> if f v then return true ) ;
false )
let fold_vars e ~ init ~ f =
fold_terms e ~ init ~ f : ( fun s -> function
| Var _ as v -> f s ( v :> Var . t ) | _ -> s )
@ -1248,42 +1261,69 @@ let height e =
(* * Solve *)
let find_for ? for_ args =
let exists_var args ~ f =
with_return ( fun { return } ->
Qset . iter args ~ f : ( fun arg _ ->
iter_vars arg ~ f : ( fun v -> if f v then return true ) ) ;
false )
in
let remove_if_non_occuring rejected args c q =
let args = Qset . remove args c in
let fv_c = fv c in
if exists_var ~ f : ( Var . Set . mem fv_c ) args then None
else Some ( c , q , Qset . union rejected args )
in
let rec find_for_ rejected args =
let * c , q = Qset . min_elt args in
remove_if_non_occuring rejected args c q
| > Option . or_else ~ f : ( fun () ->
find_for_ ( Qset . add rejected c q ) ( Qset . remove args c ) )
let exists_fv_in vs qset =
Qset . exists qset ~ f : ( fun e _ -> exists_vars e ~ f : ( Var . Set . mem vs ) )
let exists_fv_in4 vs w x y z =
exists_fv_in vs w | | exists_fv_in vs x | | exists_fv_in vs y
| | exists_fv_in vs z
(* solve [0 = rejected_sum + ( coeff × prod ) + sum] *)
let solve_for_factor rejected_sum coeff prod sum =
let rec find_factor rejected_prod prod =
let * factor , power , prod = Qset . pop_min_elt prod in
if
( not ( Q . equal Q . one power ) )
| | exists_fv_in4 ( fv factor ) rejected_sum rejected_prod prod sum
then find_factor ( Qset . add rejected_prod factor power ) prod
else Some ( factor , Qset . union rejected_prod prod )
in
match for_ with
| Some c ->
let q = Qset . count args c in
if Q . equal Q . zero q then None
else remove_if_non_occuring Qset . empty args c q
| None -> find_for_ Qset . empty args
let + factor , prod = find_factor Qset . empty prod in
(* solve [0 = rejected_sum + ( coeff × factor × prod ) + sum] yielding
[ factor = ( rejected_sum + sum ) / ( - coeff × prod ) ] * )
( factor
, div
( Sum . to_term ( Qset . union rejected_sum sum ) )
( mul ( rational ( Q . neg coeff ) ) ( Prod . to_term prod ) ) )
(* solve [0 = rejected_sum + ( coeff × mono ) + sum] *)
let solve_for_mono rejected_sum coeff mono sum =
match mono with
| Mul prod -> solve_for_factor rejected_sum coeff prod sum
| _ ->
if exists_fv_in ( fv mono ) sum then None
else
Some
( mono
, Sum . to_term
( Sum . mul_const
( Q . inv ( Q . neg coeff ) )
( Qset . union rejected_sum sum ) ) )
(* solve [0 = rejected + sum] *)
let rec solve_sum rejected_sum sum =
let * mono , coeff , sum = Qset . pop_min_elt sum in
solve_for_mono rejected_sum coeff mono sum
| > Option . or_else ~ f : ( fun () ->
solve_sum ( Qset . add rejected_sum mono coeff ) sum )
let rec solve_div = function
(* [n / d = t] ==> [n = d × t] *)
| Some ( Ap2 ( Div , num , den ) , trm ) -> solve_div ( Some ( num , mul den trm ) )
| o -> o
(* solve [0 = e] *)
let solve_zero_eq ? for_ e =
[ % Trace . call fun { pf } -> pf " %a%a " pp e ( Option . pp " for %a " pp ) for_ ]
[ % Trace . call fun { pf } -> pf " 0 = %a%a" pp e ( Option . pp " for %a " pp ) for_ ]
;
( match e with
| Add args ->
let + c , q , args = find_for ? for_ args in
let n = Sum . to_term ( Qset . remove args c ) in
let d = rational ( Q . neg q ) in
let r = div n d in
( c , r )
| Add sum ->
( match for_ with
| None -> solve_sum Qset . empty sum
| Some mono ->
let * coeff , sum = Qset . find_and_remove sum mono in
solve_for_mono Qset . empty coeff mono sum )
| > solve_div
| _ -> None )
| >
[ % Trace . retn fun { pf } s ->
@ -1292,5 +1332,7 @@ let solve_zero_eq ?for_ e =
Format . fprintf fs " %a ↦ %a " pp c pp r ) )
s ;
match ( for_ , s ) with
| Some ( Mul prod ) , Some ( var , _ ) ->
assert ( not ( Q . equal Q . zero ( Qset . count prod var ) ) )
| Some f , Some ( c , _ ) -> assert ( equal f c )
| _ -> () ]