@ -89,13 +89,13 @@ module LinArith : sig
val subst_variables : t -> f : ( Var . t -> subst_target ) -> t
val subst_variables : t -> f : ( Var . t -> subst_target ) -> t
end = struct
end = struct
(* * invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *)
(* * invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *)
type t = Q . t Var . Map . t * Q . t [@@ deriving compare ]
type t = Q . t * Q . t Var . Map . t [@@ deriving compare , equal ]
let yojson_of_t ( vs, c ) = ` List [ Var . Map . yojson_of_t Q . yojson_of_t vs ; Q . yojson_of_t c ]
let yojson_of_t ( c, vs ) = ` List [ Var . Map . yojson_of_t Q . yojson_of_t vs ; Q . yojson_of_t c ]
type subst_target = QSubst of Q . t | VarSubst of Var . t | LinSubst of t
type subst_target = QSubst of Q . t | VarSubst of Var . t | LinSubst of t
let pp pp_var fmt ( vs, c ) =
let pp pp_var fmt ( c, vs ) =
if Var . Map . is_empty vs then Q . pp_print fmt c
if Var . Map . is_empty vs then Q . pp_print fmt c
else
else
let pp_c fmt c =
let pp_c fmt c =
@ -118,30 +118,30 @@ end = struct
F . fprintf fmt " @[<h>%a%a@] " pp_vs vs pp_c c
F . fprintf fmt " @[<h>%a%a@] " pp_vs vs pp_c c
let add ( vs1 , c1 ) ( vs2 , c2 ) =
let add ( c1 , vs1 ) ( c2 , vs2 ) =
( Var . Map . union
( Q . add c1 c2
, Var . Map . union
( fun _ v c1 c2 ->
( fun _ v c1 c2 ->
let c = Q . add c1 c2 in
let c = Q . add c1 c2 in
if Q . is_zero c then None else Some c )
if Q . is_zero c then None else Some c )
vs1 vs2
vs1 vs2 )
, Q . add c1 c2 )
let minus ( vs, c ) = ( Var . Map . map ( fun c -> Q . neg c ) vs , Q . neg c )
let minus ( c, vs ) = ( Q . neg c , Var . Map . map ( fun c -> Q . neg c ) vs )
let subtract l1 l2 = add l1 ( minus l2 )
let subtract l1 l2 = add l1 ( minus l2 )
let zero = ( Var. Map . empty , Q . zero )
let zero = ( Q. zero , Var. Map . empty )
let is_zero ( vs, c ) = Q . is_zero c && Var . Map . is_empty vs
let is_zero ( c, vs ) = Q . is_zero c && Var . Map . is_empty vs
let mult q ( ( vs, c ) as l ) =
let mult q ( ( c, vs ) as l ) =
if Q . is_zero q then (* needed for correction: coeffs cannot be zero *) zero
if Q . is_zero q then (* needed for correction: coeffs cannot be zero *) zero
else if Q . is_one q then (* purely an optimisation *) l
else if Q . is_one q then (* purely an optimisation *) l
else ( Var. Map . map ( fun c -> Q . mul q c ) vs , Q . mul q c )
else ( Q. mul q c , Var. Map . map ( fun c -> Q . mul q c ) vs )
let solve_eq_zero ( vs, c ) =
let solve_eq_zero ( c, vs ) =
match Var . Map . min_binding_opt vs with
match Var . Map . min_binding_opt vs with
| None ->
| None ->
if Q . is_zero c then Sat None else Unsat
if Q . is_zero c then Sat None else Unsat
@ -154,18 +154,18 @@ end = struct
vs Var . Map . empty
vs Var . Map . empty
in
in
let c' = Q . div c d in
let c' = Q . div c d in
Sat ( Some ( x , ( vs', c ') ) )
Sat ( Some ( x , ( c', vs ') ) )
let solve_eq l1 l2 = solve_eq_zero ( subtract l1 l2 )
let solve_eq l1 l2 = solve_eq_zero ( subtract l1 l2 )
let of_var v = ( Var. Map . singleton v Q . one , Q . zero )
let of_var v = ( Q. zero , Var. Map . singleton v Q . one )
let of_q q = ( Var . Map . empty , q )
let of_q q = ( q , Var . Map . empty )
let get_as_const ( vs, c ) = if Var . Map . is_empty vs then Some c else None
let get_as_const ( c, vs ) = if Var . Map . is_empty vs then Some c else None
let get_as_var ( vs, c ) =
let get_as_var ( c, vs ) =
if Q . is_zero c then
if Q . is_zero c then
match Var . Map . is_singleton_or_more vs with
match Var . Map . is_singleton_or_more vs with
| Singleton ( x , cx ) when Q . is_one cx ->
| Singleton ( x , cx ) when Q . is_one cx ->
@ -175,20 +175,20 @@ end = struct
else None
else None
let has_var x ( vs , _ ) = Var . Map . mem x vs
let has_var x ( _ , vs ) = Var . Map . mem x vs
let subst x y ( ( vs, c ) as l ) =
let subst x y ( ( c, vs ) as l ) =
match Var . Map . find_opt x vs with
match Var . Map . find_opt x vs with
| None ->
| None ->
l
l
| Some cx ->
| Some cx ->
let vs' = Var . Map . remove x vs | > Var . Map . add y cx in
let vs' = Var . Map . remove x vs | > Var . Map . add y cx in
( vs', c )
( c, vs' )
let of_subst_target = function QSubst q -> of_q q | VarSubst v -> of_var v | LinSubst l -> l
let of_subst_target = function QSubst q -> of_q q | VarSubst v -> of_var v | LinSubst l -> l
let fold_subst_variables ( ( vs_foreign, c ) as l0 ) ~ init ~ f =
let fold_subst_variables ( ( c, vs_foreign) as l0 ) ~ init ~ f =
let changed = ref false in
let changed = ref false in
let acc_f , l' =
let acc_f , l' =
Var . Map . fold
Var . Map . fold
@ -197,7 +197,7 @@ end = struct
( match op with VarSubst v when Var . equal v v_foreign -> () | _ -> changed := true ) ;
( match op with VarSubst v when Var . equal v v_foreign -> () | _ -> changed := true ) ;
( acc_f , add ( mult q0 ( of_subst_target op ) ) l ) )
( acc_f , add ( mult q0 ( of_subst_target op ) ) l ) )
vs_foreign
vs_foreign
( init , ( Var . Map . empty , c ) )
( init , ( c , Var . Map . empty ) )
in
in
let l' = if ! changed then l' else l0 in
let l' = if ! changed then l' else l0 in
( acc_f , l' )
( acc_f , l' )
@ -205,7 +205,7 @@ end = struct
let subst_variables l ~ f = fold_subst_variables l ~ init : () ~ f : ( fun () v -> ( () , f v ) ) | > snd
let subst_variables l ~ f = fold_subst_variables l ~ init : () ~ f : ( fun () v -> ( () , f v ) ) | > snd
let get_variables ( vs , _ ) = Var . Map . to_seq vs | > Seq . map fst
let get_variables ( _ , vs ) = Var . Map . to_seq vs | > Seq . map fst
end
end
type subst_target = LinArith . subst_target =
type subst_target = LinArith . subst_target =