[sledge] Fix Arith.map on noninterpreted polynomials

Summary:
Currently when Arith.map encounters a noninterpreted polynomial, it
does not descend to its subterms. This diff fixes this bug, and
documents that the `map` and `trms` operations agree on the set of
subterms.

Reviewed By: jvillard

Differential Revision: D25883725

fbshipit-source-id: 1dc06f1fc
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent af58e744a2
commit a4f8ecfb2b

@ -180,15 +180,6 @@ struct
| _ -> Uninterpreted ) | _ -> Uninterpreted )
| `Many -> Interpreted | `Many -> Interpreted
let is_noninterpreted poly =
match Sum.only_elt poly with
| Some (mono, _) -> (
match Prod.classify mono with
| `Zero -> false
| `One (_, n) -> n <> 1
| `Many -> true )
| None -> false
let get_const poly = let get_const poly =
match Sum.classify poly with match Sum.classify poly with
| `Zero -> Some Q.zero | `Zero -> Some Q.zero
@ -303,21 +294,40 @@ struct
| Some mono -> Mono.trms mono | Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly) | None -> Iter.map ~f:trm_of_mono (monos poly)
(* map over [trms] *)
let map poly ~f = let map poly ~f =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp poly) ~call:(fun {pf} -> pf "@ %a" pp poly)
~retn:(fun {pf} -> pf "%a" pp) ~retn:(fun {pf} -> pf "%a" pp)
@@ fun () -> @@ fun () ->
( if is_noninterpreted poly then poly ( match get_mono poly with
| Some mono ->
let mono', (coeff, mono_delta) =
Prod.fold mono
(mono, (Q.one, Mono.one))
~f:(fun base power (mono', delta) ->
let base' = f base in
if base' == base then (mono', delta)
else
( Prod.remove base mono'
, CM.mul delta (CM.of_trm ~power base') ) )
in
if mono' == mono then poly
else CM.to_poly (coeff, Mono.mul mono' mono_delta)
| None ->
let poly', delta =
Sum.fold poly (poly, Sum.empty)
~f:(fun mono coeff (poly', delta) ->
if Mono.equal_one mono then (poly', delta)
else else
let p, p' =
Sum.fold poly (poly, Sum.empty) ~f:(fun mono coeff (p, p') ->
let e = trm_of_mono mono in let e = trm_of_mono mono in
let e' = f e in let e' = f e in
if e == e' then (p, p') if e == e' then (poly', delta)
else (Sum.remove mono p, Sum.union p' (mulc coeff (trm e'))) ) else
( Sum.remove mono poly'
, Sum.union delta (mulc coeff (trm e')) ) )
in in
Sum.union p p' ) Sum.union poly' delta )
|> check invariant |> check invariant
(* query *) (* query *)

@ -68,8 +68,7 @@ module type S = sig
of factors [X] for each [j]. *) of factors [X] for each [j]. *)
val map : t -> f:(trm -> trm) -> t val map : t -> f:(trm -> trm) -> t
(** [map ~f a] is [a] with each maximal non-interpreted subterm (** Map over the {!trms}. *)
transformed by [f]. *)
(** Solve *) (** Solve *)

Loading…
Cancel
Save