(*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *)

(* Proofs about llvm to llair translation *)

open HolKernel boolLib bossLib Parse;
open listTheory arithmeticTheory pred_setTheory finite_mapTheory wordsTheory integer_wordTheory;
open rich_listTheory pathTheory;
open settingsTheory miscTheory memory_modelTheory;
open llvmTheory llvm_propTheory llvm_liveTheory llairTheory llair_propTheory llvm_to_llairTheory;

new_theory "llvm_to_llair_prop";

set_grammar_ancestry ["llvm", "llair", "llvm_to_llair", "llvm_live"];

numLib.prefer_num ();

Inductive v_rel:
  (∀w. v_rel (FlatV (PtrV w)) (FlatV (IntV (w2i w) llair$pointer_size))) ∧
  (∀w. v_rel (FlatV (W1V w)) (FlatV (IntV (w2i w) 1))) ∧
  (∀w. v_rel (FlatV (W8V w)) (FlatV (IntV (w2i w) 8))) ∧
  (∀w. v_rel (FlatV (W32V w)) (FlatV (IntV (w2i w) 32))) ∧
  (∀w. v_rel (FlatV (W64V w)) (FlatV (IntV (w2i w) 64))) ∧
  (∀vs1 vs2.
    list_rel v_rel vs1 vs2
    ⇒
    v_rel (AggV vs1) (AggV vs2))
End

(* Define when an LLVM state is related to a llair one. Parameterised over a
 * relation on program counters, which should be generated by the
 * transformation. It is not trivial because the translation cuts up blocks at
 * function calls and adds blocks for removing phi nodes.
 *
 * Also parameterised on a map for locals relating LLVM registers to llair
 * expressions that compute the value in that register. This corresponds to part
 * of the translation's state.
 *)
Definition state_rel_def:
  state_rel prog pc_rel emap (s:llvm$state) (s':llair$state) ⇔
    pc_rel s.ip s'.bp ∧
    (* Live LLVM registers are mapped and have a related value in the emap
     * (after evaluating) *)
    (∀r. r ∈ live prog s.ip ⇒
      ∃v v' e.
        v_rel v.value v' ∧
        flookup s.locals r = Some v ∧
        flookup emap r = Some e ∧ eval_exp s' e v') ∧
    erase_tags s.heap = s'.heap ∧
    s'.status = get_observation prog s
End

Theorem v_rel_bytes:
  ∀v v'. v_rel v v' ⇒ llvm_value_to_bytes v = llair_value_to_bytes v'
Proof
  ho_match_mp_tac v_rel_ind >>
  rw [v_rel_cases, llvm_value_to_bytes_def, llair_value_to_bytes_def] >>
  rw [value_to_bytes_def, llvmTheory.unconvert_value_def, w2n_i2n,
      llairTheory.unconvert_value_def, llairTheory.pointer_size_def,
      llvmTheory.pointer_size_def] >>
  pop_assum mp_tac >>
  qid_spec_tac `vs1` >>
  Induct_on `vs2` >> rw [] >> rw []
QED

Theorem translate_constant_correct_lem:
  (∀c s prog pc_rel emap s' (g : glob_var |-> β # word64).
   state_rel prog pc_rel emap s s'
   ⇒
   ∃v'. eval_exp s' (translate_const c) v' ∧ v_rel (eval_const g c) v') ∧
  (∀(cs : (ty # const) list) s prog pc_rel emap s' (g : glob_var |-> β # word64).
   state_rel prog pc_rel emap s s'
   ⇒
   ∃v'. list_rel (eval_exp s') (map (translate_const o snd) cs) v' ∧ list_rel v_rel (map (eval_const g o snd) cs) v') ∧
  (∀(tc : ty # const) s prog pc_rel emap s' (g : glob_var |-> β # word64).
   state_rel prog pc_rel emap s s'
   ⇒
   ∃v'. eval_exp s' (translate_const (snd tc)) v' ∧ v_rel (eval_const g (snd tc)) v')
Proof
  ho_match_mp_tac const_induction >> rw [translate_const_def] >>
  simp [Once eval_exp_cases, eval_const_def]
  >- (
    Cases_on `s` >> simp [eval_const_def, translate_size_def, v_rel_cases] >>
    metis_tac [truncate_2comp_i2w_w2i, dimindex_1, dimindex_8, dimindex_32, dimindex_64])
  >- (
    simp [v_rel_cases, PULL_EXISTS, MAP_MAP_o] >>
    fs [combinTheory.o_DEF, pairTheory.LAMBDA_PROD] >>
    metis_tac [])
  >- (
    simp [v_rel_cases, PULL_EXISTS, MAP_MAP_o] >>
    fs [combinTheory.o_DEF, pairTheory.LAMBDA_PROD] >>
    metis_tac [])
  >- cheat
  >- cheat
  >- cheat
  >- cheat
QED

Theorem translate_constant_correct:
  ∀c s prog pc_rel emap s' g.
   state_rel prog pc_rel emap s s'
   ⇒
   ∃v'. eval_exp s' (translate_const c) v' ∧ v_rel (eval_const g c) v'
Proof
  metis_tac [translate_constant_correct_lem]
QED

Theorem translate_arg_correct:
  ∀s a v prog pc_rel emap s'.
  state_rel prog pc_rel emap s s' ∧
  eval s a = Some v ∧
  arg_to_regs a ⊆ live prog s.ip
  ⇒
  ∃v'. eval_exp s' (translate_arg emap a) v' ∧ v_rel v.value v'
Proof
  Cases_on `a` >> rw [eval_def, translate_arg_def] >> rw []
  >- metis_tac [translate_constant_correct] >>
  CASE_TAC >> fs [PULL_EXISTS, state_rel_def, arg_to_regs_def] >>
  res_tac >> rfs [] >> metis_tac []
QED

Theorem is_allocated_state_rel:
  ∀prog pc_rel emap s1 s1'.
    state_rel prog pc_rel emap s1 s1'
    ⇒
    (∀i. is_allocated i s1.heap ⇔ is_allocated i s1'.heap)
Proof
  rw [state_rel_def, is_allocated_def, erase_tags_def] >>
  pop_assum mp_tac >> pop_assum (mp_tac o GSYM) >> rw []
QED

Theorem restricted_i2w_11:
  ∀i (w:'a word). INT_MIN (:'a) ≤ i ∧ i ≤ INT_MAX (:'a) ⇒ (i2w i : 'a word) = i2w (w2i w) ⇒ i = w2i w
Proof
  rw [i2w_def]
  >- (
    Cases_on `n2w (Num (-i)) = INT_MINw` >>
    rw [w2i_neg, w2i_INT_MINw] >>
    fs [word_L_def] >>
    `?j. 0 ≤ j ∧ i = -j` by intLib.COOPER_TAC >>
    rw [] >>
    fs [] >>
    `INT_MIN (:'a) < dimword (:'a)` by metis_tac [INT_MIN_LT_DIMWORD] >>
    `Num j MOD dimword (:'a) = Num j`
    by (irule LESS_MOD >> intLib.COOPER_TAC) >>
    fs []
    >- intLib.COOPER_TAC
    >- (
      `Num j < INT_MIN (:'a)` by intLib.COOPER_TAC >>
      fs [w2i_n2w_pos, integerTheory.INT_OF_NUM]))
  >- (
    fs [GSYM INT_MAX, INT_MAX_def] >>
    `Num i < INT_MIN (:'a)` by intLib.COOPER_TAC >>
    rw [w2i_n2w_pos, integerTheory.INT_OF_NUM] >>
    intLib.COOPER_TAC)
QED

Theorem translate_extract_correct:
  ∀prog pc_rel emap s1 s1' a v v1' e1' cs ns result.
    state_rel prog pc_rel emap s1 s1' ∧
    map (λci. signed_v_to_num (eval_const s1.globals ci)) cs = map Some ns ∧
    extract_value v ns = Some result ∧
    eval_exp s1' e1' v1' ∧
    v_rel v v1'
    ⇒
    ∃v2'.
      eval_exp s1' (foldl (λe c. Select e (translate_const c)) e1' cs) v2' ∧
      v_rel result v2'
Proof
  Induct_on `cs` >> rw [] >> fs [extract_value_def]
  >- metis_tac [] >>
  first_x_assum irule >>
  Cases_on `ns` >> fs [] >>
  qmatch_goalsub_rename_tac `translate_const c` >>
  `?v2'. eval_exp s1' (translate_const c) v2' ∧ v_rel (eval_const s1.globals c) v2'`
  by metis_tac [translate_constant_correct] >>
  Cases_on `v` >> fs [extract_value_def] >>
  qpat_x_assum `v_rel (AggV _) _` mp_tac >>
  simp [Once v_rel_cases] >> rw [] >>
  simp [Once eval_exp_cases, PULL_EXISTS] >>
  fs [LIST_REL_EL_EQN] >>
  qmatch_assum_rename_tac `_ = map Some is` >>
  Cases_on `eval_const s1.globals c` >> fs [signed_v_to_num_def, signed_v_to_int_def] >> rw [] >>
  `?i. v2' = FlatV i` by fs [v_rel_cases] >> fs [] >>
  qmatch_assum_rename_tac `option_join _ = Some x` >>
  `?size. i = IntV (&x) size` suffices_by metis_tac [] >> rw [] >>
  qpat_x_assum `v_rel _ _` mp_tac >>
  simp [v_rel_cases] >> rw [] >> fs [signed_v_to_int_def] >> rw [] >>
  intLib.COOPER_TAC
QED

Theorem translate_update_correct:
  ∀prog pc_rel emap s1 s1' a v1 v1' v2 v2' e2 e2' e1' cs ns result.
    state_rel prog pc_rel emap s1 s1' ∧
    map (λci. signed_v_to_num (eval_const s1.globals ci)) cs = map Some ns ∧
    insert_value v1 v2 ns = Some result ∧
    eval_exp s1' e1' v1' ∧
    v_rel v1 v1' ∧
    eval_exp s1' e2' v2' ∧
    v_rel v2 v2'
    ⇒
    ∃v3'.
      eval_exp s1' (translate_updatevalue e1' e2' cs) v3' ∧
      v_rel result v3'
Proof
  Induct_on `cs` >> rw [] >> fs [insert_value_def, translate_updatevalue_def]
  >- metis_tac [] >>
  simp [Once eval_exp_cases, PULL_EXISTS] >>
  Cases_on `ns` >> fs [] >>
  Cases_on `v1` >> fs [insert_value_def] >>
  rename [`insert_value (el x _) _ ns`] >>
  Cases_on `insert_value (el x l) v2 ns` >> fs [] >> rw [] >>
  qpat_x_assum `v_rel (AggV _) _` mp_tac >> simp [Once v_rel_cases] >> rw [] >>
  simp [v_rel_cases] >>
  qmatch_goalsub_rename_tac `translate_const c` >>
  qexists_tac `vs2` >> simp [] >>
  `?v4'. eval_exp s1' (translate_const c) v4' ∧ v_rel (eval_const s1.globals c) v4'`
  by metis_tac [translate_constant_correct] >>
  `?idx_size. v4' = FlatV (IntV (&x) idx_size)`
  by (
    pop_assum mp_tac >> simp [Once v_rel_cases] >>
    rw [] >> fs [signed_v_to_num_def, signed_v_to_int_def] >>
    intLib.COOPER_TAC) >>
  first_x_assum drule >>
  disch_then drule >>
  disch_then drule >>
  disch_then (qspecl_then [`el x vs2`, `v2'`, `e2'`, `Select e1' (translate_const c)`] mp_tac) >>
  simp [Once eval_exp_cases] >>
  metis_tac [EVERY2_LUPDATE_same, LIST_REL_LENGTH, LIST_REL_EL_EQN]
QED

(*
Theorem translate_instr_to_exp_correct:
  ∀emap instr r t s1 s1' s2 prog pc_rel.
    classify_instr instr = Exp r t ∧
    state_rel prog pc_rel emap s1 s1' ∧
    get_instr prog s1.ip instr ∧
    step_instr prog s1 instr s2 ⇒
    ∃v pv.
      eval_exp s1' (translate_instr_to_exp emap instr) v ∧
      flookup s2.locals r = Some pv ∧ v_rel pv.value v
Proof
  recInduct translate_instr_to_exp_ind >>
  simp [translate_instr_to_exp_def, classify_instr_def] >>
  conj_tac
  >- ( (* Sub *)
    rw [step_instr_cases, Once eval_exp_cases, do_sub_def, PULL_EXISTS] >>
    simp [llvmTheory.inc_pc_def, update_result_def, FLOOKUP_UPDATE] >>
    simp [v_rel_cases, PULL_EXISTS] >>
    first_x_assum (mp_then.mp_then mp_then.Any mp_tac translate_arg_correct) >>
    disch_then drule >>
    first_x_assum (mp_then.mp_then mp_then.Any mp_tac translate_arg_correct) >>
    disch_then drule >>
    drule get_instr_live >> simp [uses_def] >> strip_tac >>
    BasicProvers.EVERY_CASE_TAC >> fs [translate_ty_def, translate_size_def] >>
    rfs [v_rel_cases] >>
    pairarg_tac >> fs [] >>
    fs [pairTheory.PAIR_MAP, wordsTheory.FST_ADD_WITH_CARRY] >>
    qmatch_goalsub_abbrev_tac `eval_exp _ _ (FlatV (IntV i1 _))` >> strip_tac >>
    qmatch_goalsub_abbrev_tac `eval_exp _ _ (FlatV (IntV i2 _))` >> strip_tac >>
    qexists_tac `i1` >> qexists_tac `i2` >> simp [] >>
    unabbrev_all_tac >>
    rw []
    >- (
      irule restricted_i2w_11 >> simp [word_sub_i2w] >>
      `dimindex (:1) = 1` by rw [] >>
      drule truncate_2comp_i2w_w2i >>
      rw [word_sub_i2w] >>
      metis_tac [w2i_ge, w2i_le, SIMP_CONV (srw_ss()) [] ``INT_MIN (:1)``,
                 SIMP_CONV (srw_ss()) [] ``INT_MAX (:1)``])
    >- (
      irule restricted_i2w_11 >> simp [word_sub_i2w] >>
      `dimindex (:8) = 8` by rw [] >>
      drule truncate_2comp_i2w_w2i >>
      rw [word_sub_i2w] >>
      metis_tac [w2i_ge, w2i_le, SIMP_CONV (srw_ss()) [] ``INT_MIN (:8)``,
                 SIMP_CONV (srw_ss()) [] ``INT_MAX (:8)``])
    >- (
      irule restricted_i2w_11 >> simp [word_sub_i2w] >>
      `dimindex (:32) = 32` by rw [] >>
      drule truncate_2comp_i2w_w2i >>
      rw [word_sub_i2w] >>
      metis_tac [w2i_ge, w2i_le, SIMP_CONV (srw_ss()) [] ``INT_MIN (:32)``,
                 SIMP_CONV (srw_ss()) [] ``INT_MAX (:32)``])
    >- (
      irule restricted_i2w_11 >> simp [word_sub_i2w] >>
      `dimindex (:64) = 64` by rw [] >>
      drule truncate_2comp_i2w_w2i >>
      rw [word_sub_i2w] >>
      metis_tac [w2i_ge, w2i_le, SIMP_CONV (srw_ss()) [] ``INT_MIN (:64)``,
                 SIMP_CONV (srw_ss()) [] ``INT_MAX (:64)``])) >>
  conj_tac
  >- ( (* Extractvalue *)
    rw [step_instr_cases] >>
    simp [llvmTheory.inc_pc_def, update_result_def, FLOOKUP_UPDATE] >>
    metis_tac [uses_def, get_instr_live, translate_arg_correct, translate_extract_correct]) >>
  conj_tac
  >- ( (* Updatevalue *)
    rw [step_instr_cases] >>
    simp [llvmTheory.inc_pc_def, update_result_def, FLOOKUP_UPDATE] >>
    drule get_instr_live >> simp [uses_def] >>
    metis_tac [get_instr_live, translate_arg_correct, translate_update_correct]) >>
  cheat
QED

Triviality eval_exp_help:
  (s1 with heap := h).locals = s1.locals
Proof
  rw []
QED

Theorem erase_tags_set_bytes:
  ∀p v l h. erase_tags (set_bytes p v l h) = set_bytes () v l (erase_tags h)
Proof
  Induct_on `v` >> rw [set_bytes_def] >>
  irule (METIS_PROVE [] ``x = y ⇒ f a b c x = f a b c y``) >>
  rw [erase_tags_def]
QED

Theorem translate_instr_to_inst_correct:
  ∀prog pc_rel emap instr s1 s1' s2.
    classify_instr instr = Non_exp ∧
    state_rel prog pc_rel emap s1 s1' ∧
    get_instr prog s1.ip instr ∧
    step_instr prog s1 instr s2 ⇒
    ∃s2'.
      step_inst s1' (translate_instr_to_inst emap instr) s2' ∧
      state_rel prog pc_rel emap s2 s2'

Proof

  rw [step_instr_cases] >>
  fs [classify_instr_def, translate_instr_to_inst_def]
  >- ( (* Load *)
    cheat)
  >- ( (* Store *)
    simp [step_inst_cases, PULL_EXISTS] >>
    drule get_instr_live >> rw [uses_def] >>
    drule translate_arg_correct >> disch_then drule >> disch_then drule >>
    qpat_x_assum `eval _ _ = Some _` mp_tac >>
    drule translate_arg_correct >> disch_then drule >> disch_then drule >>
    rw [] >>
    qpat_x_assum `v_rel (FlatV _) _` mp_tac >> simp [Once v_rel_cases] >> rw [] >>
    HINT_EXISTS_TAC >> rw [] >>
    qexists_tac `freeable` >> rw [] >>
    HINT_EXISTS_TAC >> rw []
    >- metis_tac [v_rel_bytes]
    >- (
      fs [w2n_i2n, pointer_size_def] >>
      metis_tac [v_rel_bytes, is_allocated_state_rel, ADD_COMM]) >>
    fs [state_rel_def] >>
    rw []
    >- cheat
    >- (
      fs [llvmTheory.inc_pc_def] >>
      `r ∈ live prog s1.ip`
      by (
        drule live_gen_kill >>
        rw [next_ips_def, assigns_def, uses_def, inc_pc_def]) >>
      first_x_assum drule >> rw [] >>
      metis_tac [eval_exp_ignores, eval_exp_help])
    >- (
      rw [llvmTheory.inc_pc_def, w2n_i2n, pointer_size_def, erase_tags_set_bytes] >>
      metis_tac[v_rel_bytes]))
  >- cheat
  >- cheat
  >- cheat
QED


    simp [step_inst_cases, PULL_EXISTS] >>
    Cases_on `r` >> simp [translate_reg_def] >>
    drule get_instr_live >> rw [uses_def] >>
    drule translate_arg_correct >> disch_then drule >> disch_then drule >>
    simp [Once v_rel_cases] >> rw [] >>
    qexists_tac `IntV (w2i w) pointer_size` >> rw [] >>
    qexists_tac `freeable` >> rw []
    >- (fs [w2n_i2n, pointer_size_def] >> metis_tac [is_allocated_state_rel]) >>
    fs [state_rel_def] >> rw []
    >- cheat
    >- (
      fs [llvmTheory.inc_pc_def, update_results_def, update_result_def] >>
      rw [] >> fs [FLOOKUP_UPDATE] >> rw []
      >- (
        cheat)
      >- (
        `r ∈ live prog s1.ip`
        by (
          drule live_gen_kill >>
          rw [next_ips_def, assigns_def, uses_def, inc_pc_def]) >>
        first_x_assum drule >> rw [] >>
        qexists_tac `v` >>
        qexists_tac `v'` >>
        qexists_tac `e` >>
        rw []
        metis_tac [eval_exp_ignores, eval_exp_help])


    >- fs [update_results_def, llvmTheory.inc_pc_def, update_result_def]

*)

Definition translate_trace_def:
  (translate_trace types Tau = Tau ) ∧
  (translate_trace types (W gv bytes) = W (translate_glob_var gv (types gv)) bytes)
End

Theorem multi_step_to_step_block:
  ∀prog s1 s1' tr s2.
    state_rel prog pc_rel emap s1 s1' ∧
    multi_step prog s1 tr s2
    ⇒
    ∃s2' b.
      get_block (translate_prog prog) s1'.bp b ∧
      step_block (translate_prog prog) s1' b.cmnd (map (translate_trace types) tr) b.term s2' ∧
      state_rel prog pc_rel emap s2 s2'
Proof
  cheat
QED

Theorem trans_trace_not_tau:
  ∀types. (λx. x ≠ Tau) ∘ translate_trace types = (λx. x ≠ Tau)
Proof
  rw [FUN_EQ_THM] >> eq_tac >> rw [translate_trace_def] >>
  Cases_on `x` >> fs [translate_trace_def]
QED

Theorem translate_prog_correct_lem1:
  ∀path.
    okpath (multi_step prog) path ∧ finite path
    ⇒
    ∀s1'.
    state_rel prog pc_rel emap (first path) s1'
    ⇒
    ∃path'.
      finite path' ∧
      okpath (step (translate_prog prog)) path' ∧
      first path' = s1' ∧
      labels path' = LMAP (map (translate_trace types)) (labels path) ∧
      state_rel prog pc_rel emap (last path) (last path')
Proof
  ho_match_mp_tac finite_okpath_ind >> rw []
  >- (qexists_tac `stopped_at s1'` >> rw []) >>
  drule multi_step_to_step_block >> disch_then drule >>
  disch_then (qspec_then `types` mp_tac) >> rw [] >>
  first_x_assum drule >> rw [] >>
  qexists_tac `pcons s1' (map (translate_trace types) r) path'` >> rw [] >>
  simp [step_cases] >> qexists_tac `b` >> simp [] >>
  fs [state_rel_def] >> simp [get_observation_def] >>
  fs [Once multi_step_cases, last_step_def] >> rw [] >>
  metis_tac [get_instr_func, exit_no_step]
QED

Theorem translate_prog_correct:
  ∀prog s1 s1'.
    state_rel prog pc_rel emap s1 s1'
    ⇒
    image (I ## map (translate_trace types)) (multi_step_sem prog s1) = sem (translate_prog prog) s1'
Proof
  rw [sem_def, multi_step_sem_def, EXTENSION] >> eq_tac >> rw []
  >- (
    drule translate_prog_correct_lem1 >> disch_then drule >> disch_then drule >>
    disch_then (qspec_then `types` mp_tac) >> rw [] >>
    qexists_tac `path'` >> rw [] >>
    fs [IN_DEF, observation_prefixes_cases, toList_some] >> rw [] >>
    rfs [lmap_fromList] >>
    rw [GSYM MAP_FLAT, FILTER_MAP, trans_trace_not_tau]
    >- fs [state_rel_def]
    >- fs [state_rel_def] >>
    qexists_tac `map (translate_trace types) l2'` >>
    simp [GSYM MAP_FLAT, FILTER_MAP, trans_trace_not_tau] >>
    `INJ (translate_trace types) (set l2' ∪ set (flat l2)) UNIV`
    by (
      simp [INJ_DEF] >> rpt gen_tac >>
      Cases_on `x` >> Cases_on `y` >> simp [translate_trace_def] >>
      Cases_on `a` >> Cases_on `a'` >> simp [translate_glob_var_def]) >>
    fs [INJ_MAP_EQ_IFF, inj_map_prefix_iff] >> rw [] >>
    fs [state_rel_def])
  >- cheat
QED

export_theory ();