File ‹~~/src/Tools/IsaPlanner/rw_inst.ML›
signature RW_INST =
sig
  val rw: Proof.context ->
    ((indexname * (sort * typ)) list * 
     (indexname * (typ * term)) list) 
    * (string * typ) list 
    * (string * typ) list 
    * term -> 
    thm -> 
    thm -> 
    thm  
end;
structure RW_Inst: RW_INST =
struct
fun allify_conditions ctxt Ts th =
  let
    fun allify (x, T) t =
      Logic.all_const T $ Abs (x, T, Term.abstract_over (Free (x, T), t));
    val cTs = map (Thm.cterm_of ctxt o Free) Ts;
    val cterm_asms = map (Thm.cterm_of ctxt o fold_rev allify Ts) (Thm.prems_of th);
    val allifyied_asm_thms = map (Drule.forall_elim_list cTs o Thm.assume) cterm_asms;
  in (fold (curry op COMP) allifyied_asm_thms th, cterm_asms) end;
fun mk_abstractedrule ctxt TsFake Ts rule =
  let
    
    val ns =
      IsaND.variant_names ctxt (Thm.full_prop_of rule :: Thm.hyps_of rule) (map fst Ts);
    val (fromnames, tonames, Ts') =
      fold (fn (((faken, _), (n, ty)), n2) => fn (rnf, rnt, Ts'') =>
              (Thm.cterm_of ctxt (Free(faken,ty)) :: rnf,
               Thm.cterm_of ctxt (Free(n2,ty)) :: rnt,
               (n2,ty) :: Ts''))
            (TsFake ~~ Ts ~~ ns) ([], [], []);
    
    val rule' = rule
      |> Drule.forall_intr_list fromnames
      |> Drule.forall_elim_list tonames;
    
    val (uncond_rule, cprems) = allify_conditions ctxt (rev Ts') rule';
    
    val abstractions = rev (Ts' ~~ tonames);
    val abstract_rule =
      fold (fn ((n, ty), ct) => Thm.abstract_rule n ct)
        abstractions uncond_rule;
  in (cprems, abstract_rule) end;
fun mk_renamings ctxt tgt rule_inst =
  let
    val rule_conds = Thm.prems_of rule_inst;
    val (_, cond_vs) =
      fold (fn t => fn (tyvs, vs) =>
        (union (op =) (Misc_Legacy.term_tvars t) tyvs,
         union (op =) (map Term.dest_Var (Misc_Legacy.term_vars t)) vs)) rule_conds ([], []);
    val termvars = map Term.dest_Var (Misc_Legacy.term_vars tgt);
    val vars_to_fix = union (op =) termvars cond_vs;
    val ys = IsaND.variant_names ctxt (tgt :: rule_conds) (map (fst o fst) vars_to_fix);
  in map2 (fn (xi, T) => fn y => ((xi, T), Free (y, T))) vars_to_fix ys end;
fun new_tfree (tv as (ix,sort)) (pairs, used) =
  let val v = singleton (Name.variant_list used) (string_of_indexname ix)
  in ((ix,(sort,TFree(v,sort)))::pairs, v::used) end;
fun mk_fixtvar_tyinsts ignore_insts ts =
  let
    val ignore_ixs = map fst ignore_insts;
    val (tvars, tfrees) =
      fold_rev (fn t => fn (varixs, tfrees) =>
        (Misc_Legacy.add_term_tvars (t,varixs),
         Misc_Legacy.add_term_tfrees (t,tfrees))) ts ([], []);
    val unfixed_tvars = filter (fn (ix,s) => not (member (op =) ignore_ixs ix)) tvars;
    val (fixtyinsts, _) = fold_rev new_tfree unfixed_tvars ([], map fst tfrees)
  in (fixtyinsts, tfrees) end;
fun cross_inst insts =
  let
    fun instL (ix, (ty,t)) = map (fn (ix2,(ty2,t2)) =>
      (ix2, (ty2,Term.subst_vars ([], [(ix, t)]) t2)));
    fun cross_instL ([], l) = rev l
      | cross_instL ((ix, t) :: insts, l) =
          cross_instL (instL (ix, t) insts, (ix, t) :: (instL (ix, t) l));
  in cross_instL (insts, []) end;
fun cross_inst_typs insts =
  let
    fun instL (ix, (srt,ty)) =
      map (fn (ix2,(srt2,ty2)) => (ix2, (srt2,Term.typ_subst_TVars [(ix, ty)] ty2)));
    fun cross_instL ([], l) = rev l
      | cross_instL ((ix, t) :: insts, l) =
          cross_instL (instL (ix, t) insts, (ix, t) :: (instL (ix, t) l));
  in cross_instL (insts, []) end;
fun rw ctxt ((nonfixed_typinsts, unprepinsts), FakeTs, Ts, outerterm) rule target_thm =
  let
    
    val (fixtyinsts, othertfrees) = 
      mk_fixtvar_tyinsts nonfixed_typinsts
        [Thm.prop_of rule, Thm.prop_of target_thm];
    val typinsts = cross_inst_typs (nonfixed_typinsts @ fixtyinsts);
    
    val ctyp_insts = TVars.make (map (fn (ix, (s, ty)) => ((ix, s), Thm.ctyp_of ctxt ty)) typinsts);
    
    val tgt_th_tyinst = Thm.instantiate (ctyp_insts,Vars.empty) target_thm;
    val rule_tyinst =  Thm.instantiate (ctyp_insts,Vars.empty) rule;
    val term_typ_inst = map (fn (ix,(_,ty)) => (ix,ty)) typinsts;
    
    val outerterm_tyinst = Term.subst_TVars term_typ_inst outerterm;
    val FakeTs_tyinst = map (apsnd (Term.typ_subst_TVars term_typ_inst)) FakeTs;
    val Ts_tyinst = map (apsnd (Term.typ_subst_TVars term_typ_inst)) Ts;
    
    val insts_tyinst =
      fold_rev (fn (ix, (ty, t)) => fn insts_tyinst =>
        (ix, (Term.typ_subst_TVars term_typ_inst ty, Term.subst_TVars term_typ_inst t))
          :: insts_tyinst) unprepinsts [];
    
    val insts_tyinst_inst = cross_inst insts_tyinst;
    
    val cinsts_tyinst =
      Vars.make (map (fn (ix, (ty, t)) => ((ix, ty), Thm.cterm_of ctxt t)) insts_tyinst_inst);
    
    val rule_inst = rule_tyinst |> Thm.instantiate (TVars.empty, cinsts_tyinst);
    
    val renamings = mk_renamings ctxt (Thm.prop_of tgt_th_tyinst) rule_inst;
    val cterm_renamings = map (fn (x, y) => apply2 (Thm.cterm_of ctxt) (Var x, y)) renamings;
    
    val outerterm_inst =
      outerterm_tyinst
      |> Term.subst_Vars (map (fn (ix, (ty, t)) => (ix, t)) insts_tyinst_inst)
      |> Term.subst_Vars (map (fn ((ix, ty), t) => (ix, t)) renamings);
    val couter_inst = Thm.reflexive (Thm.cterm_of ctxt outerterm_inst);
    val (cprems, abstract_rule_inst) =
      rule_inst
      |> Thm.instantiate (TVars.empty, Vars.make (map (apfst (dest_Var o Thm.term_of)) cterm_renamings))
      |> mk_abstractedrule ctxt FakeTs_tyinst Ts_tyinst;
    val specific_tgt_rule =
      Conv.fconv_rule Drule.beta_eta_conversion
        (Thm.combination couter_inst abstract_rule_inst);
    
    val tgt_th_inst =
      tgt_th_tyinst
      |> Thm.instantiate (TVars.empty, cinsts_tyinst)
      |> Thm.instantiate (TVars.empty, Vars.make (map (apfst (dest_Var o Thm.term_of)) cterm_renamings));
    val (vars,frees_of_fixed_vars) = Library.split_list cterm_renamings;
  in
    Conv.fconv_rule Drule.beta_eta_conversion tgt_th_inst
    |> Thm.equal_elim specific_tgt_rule
    |> Drule.implies_intr_list cprems
    |> Drule.forall_intr_list frees_of_fixed_vars
    |> Drule.forall_elim_list vars
    |> Thm.varifyT_global' (TFrees.make_set othertfrees)
    |-> K Drule.zero_var_indexes
  end;
end;