From a4f8ecfb2bcaccb5bbaa5db7b03b0148fbc00e77 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Tue, 2 Feb 2021 04:34:31 -0800 Subject: [PATCH] [sledge] Fix Arith.map on noninterpreted polynomials Summary: Currently when Arith.map encounters a noninterpreted polynomial, it does not descend to its subterms. This diff fixes this bug, and documents that the `map` and `trms` operations agree on the set of subterms. Reviewed By: jvillard Differential Revision: D25883725 fbshipit-source-id: 1dc06f1fc --- sledge/src/fol/arithmetic.ml | 48 +++++++++++++++++++------------ sledge/src/fol/arithmetic_intf.ml | 3 +- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/sledge/src/fol/arithmetic.ml b/sledge/src/fol/arithmetic.ml index 5110eab19..532727e68 100644 --- a/sledge/src/fol/arithmetic.ml +++ b/sledge/src/fol/arithmetic.ml @@ -180,15 +180,6 @@ struct | _ -> Uninterpreted ) | `Many -> Interpreted - let is_noninterpreted poly = - match Sum.only_elt poly with - | Some (mono, _) -> ( - match Prod.classify mono with - | `Zero -> false - | `One (_, n) -> n <> 1 - | `Many -> true ) - | None -> false - let get_const poly = match Sum.classify poly with | `Zero -> Some Q.zero @@ -303,21 +294,40 @@ struct | Some mono -> Mono.trms mono | None -> Iter.map ~f:trm_of_mono (monos poly) + (* map over [trms] *) let map poly ~f = [%trace] ~call:(fun {pf} -> pf "@ %a" pp poly) ~retn:(fun {pf} -> pf "%a" pp) @@ fun () -> - ( if is_noninterpreted poly then poly - else - let p, p' = - Sum.fold poly (poly, Sum.empty) ~f:(fun mono coeff (p, p') -> - let e = trm_of_mono mono in - let e' = f e in - if e == e' then (p, p') - else (Sum.remove mono p, Sum.union p' (mulc coeff (trm e'))) ) - in - Sum.union p p' ) + ( match get_mono poly with + | Some mono -> + let mono', (coeff, mono_delta) = + Prod.fold mono + (mono, (Q.one, Mono.one)) + ~f:(fun base power (mono', delta) -> + let base' = f base in + if base' == base then (mono', delta) + else + ( Prod.remove base mono' + , CM.mul delta (CM.of_trm ~power base') ) ) + in + if mono' == mono then poly + else CM.to_poly (coeff, Mono.mul mono' mono_delta) + | None -> + let poly', delta = + Sum.fold poly (poly, Sum.empty) + ~f:(fun mono coeff (poly', delta) -> + if Mono.equal_one mono then (poly', delta) + else + let e = trm_of_mono mono in + let e' = f e in + if e == e' then (poly', delta) + else + ( Sum.remove mono poly' + , Sum.union delta (mulc coeff (trm e')) ) ) + in + Sum.union poly' delta ) |> check invariant (* query *) diff --git a/sledge/src/fol/arithmetic_intf.ml b/sledge/src/fol/arithmetic_intf.ml index 9bd81cc6a..1a092010a 100644 --- a/sledge/src/fol/arithmetic_intf.ml +++ b/sledge/src/fol/arithmetic_intf.ml @@ -68,8 +68,7 @@ module type S = sig of factors [Xⱼ] for each [j]. *) val map : t -> f:(trm -> trm) -> t - (** [map ~f a] is [a] with each maximal non-interpreted subterm - transformed by [f]. *) + (** Map over the {!trms}. *) (** Solve *)