File ‹~~/src/Tools/eqsubst.ML›
signature EQSUBST =
sig
  type match =
    ((indexname * (sort * typ)) list 
      * (indexname * (typ * term)) list) 
    * (string * typ) list 
    * (string * typ) list 
    * term 
  type searchinfo =
    Proof.context
    * int 
    * Zipper.T 
  datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
  val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
  val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
  val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
  
  val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
  val eqsubst_asm_tac': Proof.context ->
    (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
  val eqsubst_tac: Proof.context ->
    int list -> 
    thm list -> int -> tactic
  val eqsubst_tac': Proof.context ->
    (searchinfo -> term -> match Seq.seq) 
    -> thm 
    -> int 
    -> thm 
    -> thm Seq.seq 
  
  val valid_match_start: Zipper.T -> bool
  val search_lr_all: Zipper.T -> Zipper.T Seq.seq
  val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
  val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
  val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
  val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
end;
structure EqSubst: EQSUBST =
struct
fun prep_meta_eq ctxt =
  Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
fun unfix_frees frees =
   fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
type match =
  ((indexname * (sort * typ)) list 
   * (indexname * (typ * term)) list) 
  * (string * typ) list 
  * (string * typ) list 
  * term; 
type searchinfo =
  Proof.context
  * int 
  * Zipper.T; 
datatype 'a skipseq =
  SkipMore of int |
  SkipSeq of 'a Seq.seq Seq.seq;
fun skipto_skipseq m s =
  let
    fun skip_occs n sq =
      (case Seq.pull sq of
        NONE => SkipMore n
      | SOME (h, t) =>
        (case Seq.pull h of
          NONE => skip_occs n t
        | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
  in skip_occs m s end;
fun mk_foo_match mkuptermfunc Ts t =
  let
    val ty = Term.type_of t
    val bigtype = rev (map snd Ts) ---> ty
    fun mk_foo 0 t = t
      | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
    val num_of_bnds = length Ts
    
    val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
  in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
fun mk_fake_bound_name n = ":b_" ^ n;
fun fakefree_badbounds Ts t =
  let val (FakeTs, Ts, newnames) =
    fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
      let
        val newname = singleton (Name.variant_list usednames) n
      in
        ((mk_fake_bound_name newname, ty) :: FakeTs,
          (newname, ty) :: Ts,
          newname :: usednames)
      end) Ts ([], [], [])
  in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
fun prep_zipper_match z =
  let
    val t = Zipper.trm z
    val c = Zipper.ctxt z
    val Ts = Zipper.C.nty_ctxt c
    val (FakeTs', Ts', t') = fakefree_badbounds Ts t
    val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
  in
    (t', (FakeTs', Ts', absterm))
  end;
fun clean_unify ctxt ix (a as (pat, tgt)) =
  let
    
    val pat_ty = Term.type_of pat;
    val tgt_ty = Term.type_of tgt;
    
    val typs_unify =
      SOME (Sign.typ_unify (Proof_Context.theory_of ctxt) (pat_ty, tgt_ty) (Vartab.empty, ix))
        handle Type.TUNIFY => NONE;
  in
    (case typs_unify of
      SOME (typinsttab, ix2) =>
        let
          
          fun mk_insts env =
            (Vartab.dest (Envir.type_env env),
             Vartab.dest (Envir.term_env env));
          val initenv =
            Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
          val useq = Unify.smash_unifiers (Context.Proof ctxt) [a] initenv
            handle ListPair.UnequalLengths => Seq.empty
              | Term.TERM _ => Seq.empty;
          fun clean_unify' useq () =
            (case (Seq.pull useq) of
               NONE => NONE
             | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
            handle ListPair.UnequalLengths => NONE
              | Term.TERM _ => NONE;
        in
          (Seq.make (clean_unify' useq))
        end
    | NONE => Seq.empty)
  end;
fun clean_unify_z ctxt maxidx pat z =
  let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
    Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
      (clean_unify ctxt maxidx (t, pat))
  end;
fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
  | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
  | bot_left_leaf_of x = x;
fun valid_match_start z =
  (case bot_left_leaf_of (Zipper.trm z) of
    Var _ => false
  | _ => true);
val search_lr_all = ZipperSearch.all_bl_ur;
fun search_lr_valid validf =
  let
    fun sf_valid_td_lr z =
      let val here = if validf z then [Zipper.Here z] else [] in
        (case Zipper.trm z of
          _ $ _ =>
            [Zipper.LookIn (Zipper.move_down_left z)] @ here @
            [Zipper.LookIn (Zipper.move_down_right z)]
        | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
        | _ => here)
      end;
  in Zipper.lzy_search sf_valid_td_lr end;
fun search_bt_valid validf =
  let
    fun sf_valid_td_lr z =
      let val here = if validf z then [Zipper.Here z] else [] in
        (case Zipper.trm z of
          _ $ _ =>
            [Zipper.LookIn (Zipper.move_down_left z),
             Zipper.LookIn (Zipper.move_down_right z)] @ here
        | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
        | _ => here)
      end;
  in Zipper.lzy_search sf_valid_td_lr end;
fun searchf_unify_gen f (ctxt, maxidx, z) lhs =
  Seq.map (clean_unify_z ctxt maxidx lhs) (Zipper.limit_apply f z);
val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
  RW_Inst.rw ctxt m rule conclthm
  |> unfix_frees cfvs
  |> Conv.fconv_rule Drule.beta_eta_conversion
  |> (fn r => resolve_tac ctxt [r] i st);
fun prep_concl_subst ctxt i gth =
  let
    val th = Thm.incr_indexes 1 gth;
    val tgt_term = Thm.prop_of th;
    val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
    val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
    val conclterm = Logic.strip_imp_concl fixedbody;
    val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
    val maxidx = Thm.maxidx_of th;
    val ft =
      (Zipper.move_down_right 
       o Zipper.move_down_left 
       o Zipper.mktop
       o Thm.prop_of) conclthm
  in
    ((cfvs, conclthm), (ctxt, maxidx, ft))
  end;
fun eqsubst_tac' ctxt searchf instepthm i st =
  let
    val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
    val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
    fun rewrite_with_thm r =
      let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
        searchf searchinfo lhs
        |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
      end;
  in stepthms |> Seq.maps rewrite_with_thm end;
fun skip_first_occs_search occ srchf sinfo lhs =
  (case skipto_skipseq occ (srchf sinfo lhs) of
    SkipMore _ => Seq.empty
  | SkipSeq ss => Seq.flat ss);
fun eqsubst_tac ctxt occs thms =
  SELECT_GOAL
    let
      val thmseq = Seq.of_list thms;
      fun apply_occ_tac occ st =
        thmseq |> Seq.maps (fn r =>
          eqsubst_tac' ctxt
            (skip_first_occs_search occ searchf_lr_unify_valid) r
            (Thm.nprems_of st) st);
      val sorted_occs = Library.sort (rev_order o int_ord) occs;
    in Seq.EVERY (map apply_occ_tac sorted_occs) #> Seq.maps distinct_subgoals_tac end;
fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
  let
    val st2 = Thm.rotate_rule (j - 1) i st; 
    val preelimrule =
      RW_Inst.rw ctxt m rule pth
      |> (Seq.hd o prune_params_tac ctxt)
      |> Thm.permute_prems 0 ~1 
      |> unfix_frees cfvs 
      |> Conv.fconv_rule Drule.beta_eta_conversion; 
  in
    
    Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
      (dresolve_tac ctxt [preelimrule] i st2)
  end;
fun prep_subst_in_asm ctxt i gth j =
  let
    val th = Thm.incr_indexes 1 gth;
    val tgt_term = Thm.prop_of th;
    val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
    val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
    val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
    val asm_nprems = length (Logic.strip_imp_prems asmt);
    val pth = Thm.trivial ((Thm.cterm_of ctxt) asmt);
    val maxidx = Thm.maxidx_of th;
    val ft =
      (Zipper.move_down_right 
         o Zipper.mktop
         o Thm.prop_of) pth
  in ((cfvs, j, asm_nprems, pth), (ctxt, maxidx, ft)) end;
fun prep_subst_in_asms ctxt i gth =
  map (prep_subst_in_asm ctxt i gth)
    ((fn l => Library.upto (1, length l))
      (Logic.prems_of_goal (Thm.prop_of gth) i));
fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
  let
    val asmpreps = prep_subst_in_asms ctxt i st;
    val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
    fun rewrite_with_thm r =
      let
        val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
        fun occ_search occ [] = Seq.empty
          | occ_search occ ((asminfo, searchinfo)::moreasms) =
              (case searchf searchinfo occ lhs of
                SkipMore i => occ_search i moreasms
              | SkipSeq ss =>
                  Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
                    (occ_search 1 moreasms)) 
      in
        occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
      end;
  in stepthms |> Seq.maps rewrite_with_thm end;
fun skip_first_asm_occs_search searchf sinfo occ lhs =
  skipto_skipseq occ (searchf sinfo lhs);
fun eqsubst_asm_tac ctxt occs thms =
  SELECT_GOAL
    let
      val thmseq = Seq.of_list thms;
      fun apply_occ_tac occ st =
        thmseq |> Seq.maps (fn r =>
          eqsubst_asm_tac' ctxt
            (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
            (Thm.nprems_of st) st);
      val sorted_occs = Library.sort (rev_order o int_ord) occs;
    in Seq.EVERY (map apply_occ_tac sorted_occs) #> Seq.maps distinct_subgoals_tac end;
val _ =
  Theory.setup
    (Method.setup \<^binding>‹subst›
      (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
        Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
          SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
      "single-step substitution");
end;