You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

214 lines
6.2 KiB

(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
module Var = struct
include Ses.Term.Var
(** Variable renaming substitutions *)
module Subst = struct
type nonrec t = t Map.t [@@deriving compare, equal, sexp_of]
type x = {sub: t; dom: Set.t; rng: Set.t}
let t_of_sexp = Map.t_of_sexp t_of_sexp
let invariant s =
let@ () = Invariant.invariant [%here] s [%sexp_of: t] in
let domain, range =
Map.fold s ~init:(Set.empty, Set.empty)
~f:(fun ~key ~data (domain, range) ->
(* substs are injective *)
assert (not (Set.mem range data)) ;
(Set.add domain key, Set.add range data) )
in
assert (Set.disjoint domain range)
let pp = Map.pp pp pp
let empty = Map.empty
let is_empty = Map.is_empty
let freshen vs ~wrt =
let dom = Set.inter wrt vs in
( if Set.is_empty dom then
({sub= empty; dom= Set.empty; rng= Set.empty}, wrt)
else
let wrt = Set.union wrt vs in
let sub, rng, wrt =
Set.fold dom ~init:(empty, Set.empty, wrt)
~f:(fun (sub, rng, wrt) x ->
let x', wrt = fresh (name x) ~wrt in
let sub = Map.add_exn sub ~key:x ~data:x' in
let rng = Set.add rng x' in
(sub, rng, wrt) )
in
({sub; dom; rng}, wrt) )
|> check (fun ({sub; _}, _) -> invariant sub)
let fold sub ~init ~f =
Map.fold sub ~init ~f:(fun ~key ~data s -> f key data s)
let domain sub =
Map.fold sub ~init:Set.empty ~f:(fun ~key ~data:_ domain ->
Set.add domain key )
let range sub =
Map.fold sub ~init:Set.empty ~f:(fun ~key:_ ~data range ->
Set.add range data )
let invert sub =
Map.fold sub ~init:empty ~f:(fun ~key ~data sub' ->
Map.add_exn sub' ~key:data ~data:key )
|> check invariant
let restrict sub vs =
Map.fold sub ~init:{sub; dom= Set.empty; rng= Set.empty}
~f:(fun ~key ~data z ->
if Set.mem vs key then
{z with dom= Set.add z.dom key; rng= Set.add z.rng data}
else (
assert (
(* all substs are injective, so the current mapping is the
only one that can cause [data] to be in [rng] *)
(not (Set.mem (range (Map.remove sub key)) data))
|| violates invariant sub ) ;
{z with sub= Map.remove z.sub key} ) )
|> check (fun {sub; dom; rng} ->
assert (Set.equal dom (domain sub)) ;
assert (Set.equal rng (range sub)) )
let apply sub v = Map.find sub v |> Option.value ~default:v
end
end
module Term = struct
include (
Ses.Term : module type of Ses.Term with module Var := Ses.Term.Var )
let ite = conditional
let rename s e = rename (Var.Subst.apply s) e
end
module Formula = struct
include Term
let inject b = b
let project e = Some e
let tt = true_
let ff = false_
let cond ~cnd ~pos ~neg = conditional ~cnd ~thn:pos ~els:neg
end
module Context = struct
include Ses.Equality
let and_formula = and_term
let normalizef = normalize
let rename x sub = rename x (Var.Subst.apply sub)
module Subst = struct
include Subst
let substf = subst
end
(* Replay debugging *)
type call =
| Normalize of t * Term.t
| And_formula of Var.Set.t * Formula.t * t
| And_ of Var.Set.t * t * t
| OrN of Var.Set.t * t list
| Rename of t * Var.Subst.t
| Apply_subst of Var.Set.t * Subst.t * t
| Solve_for_vars of Var.Set.t list * t
[@@deriving sexp]
let replay c =
match call_of_sexp (Sexp.of_string c) with
| Normalize (r, e) -> normalize r e |> ignore
| And_formula (us, e, r) -> and_formula us e r |> ignore
| And_ (us, r, s) -> and_ us r s |> ignore
| OrN (us, rs) -> orN us rs |> ignore
| Rename (r, s) -> rename r s |> ignore
| Apply_subst (us, s, r) -> apply_subst us s r |> ignore
| Solve_for_vars (vss, r) -> solve_for_vars vss r |> ignore
(* Debug wrappers *)
let report ~name ~elapsed ~aggregate ~count =
Format.eprintf "%15s time: %12.3f ms %12.3f ms %12d calls@." name
elapsed aggregate count
let dump_threshold = ref 1000.
let wrap tmr f call =
let f () =
Timer.start tmr ;
let r = f () in
Timer.stop_report tmr (fun ~name ~elapsed ~aggregate ~count ->
report ~name ~elapsed ~aggregate ~count ;
if Float.(elapsed > !dump_threshold) then (
dump_threshold := 2. *. !dump_threshold ;
Format.eprintf "@\n%a@\n@." Sexp.pp_hum (sexp_of_call (call ()))
) ) ;
r
in
if not [%debug] then f ()
else
try f () with exn -> raise_s ([%sexp_of: exn * call] (exn, call ()))
let normalize_tmr = Timer.create "normalize" ~at_exit:report
let and_formula_tmr = Timer.create "and_formula" ~at_exit:report
let and_tmr = Timer.create "and_" ~at_exit:report
let orN_tmr = Timer.create "orN" ~at_exit:report
let rename_tmr = Timer.create "rename" ~at_exit:report
let apply_subst_tmr = Timer.create "apply_subst" ~at_exit:report
let solve_for_vars_tmr = Timer.create "solve_for_vars" ~at_exit:report
let normalize r e =
wrap normalize_tmr (fun () -> normalize r e) (fun () -> Normalize (r, e))
let and_formula us e r =
wrap and_formula_tmr
(fun () -> and_formula us e r)
(fun () -> And_formula (us, e, r))
let and_ us r s =
wrap and_tmr (fun () -> and_ us r s) (fun () -> And_ (us, r, s))
let orN us rs = wrap orN_tmr (fun () -> orN us rs) (fun () -> OrN (us, rs))
let rename r s =
wrap rename_tmr (fun () -> rename r s) (fun () -> Rename (r, s))
let apply_subst us s r =
wrap apply_subst_tmr
(fun () -> apply_subst us s r)
(fun () -> Apply_subst (us, s, r))
let solve_for_vars vss r =
wrap solve_for_vars_tmr
(fun () -> solve_for_vars vss r)
(fun () -> Solve_for_vars (vss, r))
end
(*
* Convert from Llair
*)
module Term_of_Llair = struct
let exp = Ses.Term.of_exp
end
module Formula_of_Llair = struct
let exp = Term_of_Llair.exp
end
module Var_of_Llair = struct
let reg = Ses.Var.of_reg
let regs = Ses.Var.Set.of_regs
end