[sledge] Fix Arith.map on noninterpreted polynomials

Summary:
Currently when Arith.map encounters a noninterpreted polynomial, it
does not descend to its subterms. This diff fixes this bug, and
documents that the `map` and `trms` operations agree on the set of
subterms.

Reviewed By: jvillard

Differential Revision: D25883725

fbshipit-source-id: 1dc06f1fc
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent af58e744a2
commit a4f8ecfb2b

@ -180,15 +180,6 @@ struct
| _ -> Uninterpreted )
| `Many -> Interpreted
let is_noninterpreted poly =
match Sum.only_elt poly with
| Some (mono, _) -> (
match Prod.classify mono with
| `Zero -> false
| `One (_, n) -> n <> 1
| `Many -> true )
| None -> false
let get_const poly =
match Sum.classify poly with
| `Zero -> Some Q.zero
@ -303,21 +294,40 @@ struct
| Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly)
(* map over [trms] *)
let map poly ~f =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp poly)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
( if is_noninterpreted poly then poly
else
let p, p' =
Sum.fold poly (poly, Sum.empty) ~f:(fun mono coeff (p, p') ->
let e = trm_of_mono mono in
let e' = f e in
if e == e' then (p, p')
else (Sum.remove mono p, Sum.union p' (mulc coeff (trm e'))) )
in
Sum.union p p' )
( match get_mono poly with
| Some mono ->
let mono', (coeff, mono_delta) =
Prod.fold mono
(mono, (Q.one, Mono.one))
~f:(fun base power (mono', delta) ->
let base' = f base in
if base' == base then (mono', delta)
else
( Prod.remove base mono'
, CM.mul delta (CM.of_trm ~power base') ) )
in
if mono' == mono then poly
else CM.to_poly (coeff, Mono.mul mono' mono_delta)
| None ->
let poly', delta =
Sum.fold poly (poly, Sum.empty)
~f:(fun mono coeff (poly', delta) ->
if Mono.equal_one mono then (poly', delta)
else
let e = trm_of_mono mono in
let e' = f e in
if e == e' then (poly', delta)
else
( Sum.remove mono poly'
, Sum.union delta (mulc coeff (trm e')) ) )
in
Sum.union poly' delta )
|> check invariant
(* query *)

@ -68,8 +68,7 @@ module type S = sig
of factors [X] for each [j]. *)
val map : t -> f:(trm -> trm) -> t
(** [map ~f a] is [a] with each maximal non-interpreted subterm
transformed by [f]. *)
(** Map over the {!trms}. *)
(** Solve *)

Loading…
Cancel
Save