Theory HOL-Library.Multiset
section ‹(Finite) Multisets›
theory Multiset
  imports Cancellation
begin
subsection ‹The type of multisets›
typedef 'a multiset = ‹{f :: 'a ⇒ nat. finite {x. f x > 0}}›
  morphisms count Abs_multiset
proof
  show ‹(λx. 0::nat) ∈ {f. finite {x. f x > 0}}›
    by simp
qed
setup_lifting type_definition_multiset
lemma count_Abs_multiset:
  ‹count (Abs_multiset f) = f› if ‹finite {x. f x > 0}›
  by (rule Abs_multiset_inverse) (simp add: that)
lemma multiset_eq_iff: "M = N ⟷ (∀a. count M a = count N a)"
  by (simp only: count_inject [symmetric] fun_eq_iff)
lemma multiset_eqI: "(⋀x. count A x = count B x) ⟹ A = B"
  using multiset_eq_iff by auto
text ‹Preservation of the representing set \<^term>‹multiset›.›
lemma diff_preserves_multiset:
  ‹finite {x. 0 < M x - N x}› if ‹finite {x. 0 < M x}› for M N :: ‹'a ⇒ nat›
  using that by (rule rev_finite_subset) auto
lemma filter_preserves_multiset:
  ‹finite {x. 0 < (if P x then M x else 0)}› if ‹finite {x. 0 < M x}› for M N :: ‹'a ⇒ nat›
  using that by (rule rev_finite_subset) auto
lemmas in_multiset = diff_preserves_multiset filter_preserves_multiset
subsection ‹Representing multisets›
text ‹Multiset enumeration›
instantiation multiset :: (type) cancel_comm_monoid_add
begin
lift_definition zero_multiset :: ‹'a multiset›
  is ‹λa. 0›
  by simp
abbreviation empty_mset :: ‹'a multiset› (‹{#}›)
  where ‹empty_mset ≡ 0›
lift_definition plus_multiset :: ‹'a multiset ⇒ 'a multiset ⇒ 'a multiset›
  is ‹λM N a. M a + N a›
  by simp
lift_definition minus_multiset :: ‹'a multiset ⇒ 'a multiset ⇒ 'a multiset›
  is ‹λM N a. M a - N a›
  by (rule diff_preserves_multiset)
instance
  by (standard; transfer) (simp_all add: fun_eq_iff)
end
context
begin
qualified definition is_empty :: "'a multiset ⇒ bool" where
  [code_abbrev]: "is_empty A ⟷ A = {#}"
end
lemma add_mset_in_multiset:
  ‹finite {x. 0 < (if x = a then Suc (M x) else M x)}›
  if ‹finite {x. 0 < M x}›
  using that by (simp add: flip: insert_Collect)
lift_definition add_mset :: "'a ⇒ 'a multiset ⇒ 'a multiset" is
  "λa M b. if b = a then Suc (M b) else M b"
by (rule add_mset_in_multiset)
syntax
  "_multiset" :: "args ⇒ 'a multiset"  (‹(‹indent=2 notation=‹mixfix multiset enumeration››{#_#})›)
syntax_consts
  "_multiset" ⇌ add_mset
translations
  "{#x, xs#}" == "CONST add_mset x {#xs#}"
  "{#x#}" == "CONST add_mset x {#}"
lemma count_empty [simp]: "count {#} a = 0"
  by (simp add: zero_multiset.rep_eq)
lemma count_add_mset [simp]:
  "count (add_mset b A) a = (if b = a then Suc (count A a) else count A a)"
  by (simp add: add_mset.rep_eq)
lemma count_single: "count {#b#} a = (if b = a then 1 else 0)"
  by simp
lemma
  add_mset_not_empty [simp]: ‹add_mset a A ≠ {#}› and
  empty_not_add_mset [simp]: "{#} ≠ add_mset a A"
  by (auto simp: multiset_eq_iff)
lemma add_mset_add_mset_same_iff [simp]:
  "add_mset a A = add_mset a B ⟷ A = B"
  by (auto simp: multiset_eq_iff)
lemma add_mset_commute:
  "add_mset x (add_mset y M) = add_mset y (add_mset x M)"
  by (auto simp: multiset_eq_iff)
subsection ‹Basic operations›
subsubsection ‹Conversion to set and membership›
definition set_mset :: ‹'a multiset ⇒ 'a set›
  where ‹set_mset M = {x. count M x > 0}›
abbreviation member_mset :: ‹'a ⇒ 'a multiset ⇒ bool›
  where ‹member_mset a M ≡ a ∈ set_mset M›
notation
  member_mset  (‹'(∈#')›) and
  member_mset  (‹(‹notation=‹infix ∈#››_/ ∈# _)› [50, 51] 50)
notation  (ASCII)
  member_mset  (‹'(:#')›) and
  member_mset  (‹(‹notation=‹infix :#››_/ :# _)› [50, 51] 50)
abbreviation not_member_mset :: ‹'a ⇒ 'a multiset ⇒ bool›
  where ‹not_member_mset a M ≡ a ∉ set_mset M›
notation
  not_member_mset  (‹'(∉#')›) and
  not_member_mset  (‹(‹notation=‹infix ∉#››_/ ∉# _)› [50, 51] 50)
notation  (ASCII)
  not_member_mset  (‹'(~:#')›) and
  not_member_mset  (‹(‹notation=‹infix ~:#››_/ ~:# _)› [50, 51] 50)
context
begin
qualified abbreviation Ball :: "'a multiset ⇒ ('a ⇒ bool) ⇒ bool"
  where "Ball M ≡ Set.Ball (set_mset M)"
qualified abbreviation Bex :: "'a multiset ⇒ ('a ⇒ bool) ⇒ bool"
  where "Bex M ≡ Set.Bex (set_mset M)"
end
syntax
  "_MBall"       :: "pttrn ⇒ 'a set ⇒ bool ⇒ bool"
    (‹(‹indent=3 notation=‹binder ∀››∀_∈#_./ _)› [0, 0, 10] 10)
  "_MBex"        :: "pttrn ⇒ 'a set ⇒ bool ⇒ bool"
    (‹(‹indent=3 notation=‹binder ∃››∃_∈#_./ _)› [0, 0, 10] 10)
syntax  (ASCII)
  "_MBall"       :: "pttrn ⇒ 'a set ⇒ bool ⇒ bool"
    (‹(‹indent=3 notation=‹binder ∀››∀_:#_./ _)› [0, 0, 10] 10)
  "_MBex"        :: "pttrn ⇒ 'a set ⇒ bool ⇒ bool"
    (‹(‹indent=3 notation=‹binder ∃››∃_:#_./ _)› [0, 0, 10] 10)
syntax_consts
  "_MBall" ⇌ Multiset.Ball and
  "_MBex" ⇌ Multiset.Bex
translations
  "∀x∈#A. P" ⇌ "CONST Multiset.Ball A (λx. P)"
  "∃x∈#A. P" ⇌ "CONST Multiset.Bex A (λx. P)"
typed_print_translation ‹
 [(\<^const_syntax>‹Multiset.Ball›, Syntax_Trans.preserve_binder_abs2_tr' \<^syntax_const>‹_MBall›),
  (\<^const_syntax>‹Multiset.Bex›, Syntax_Trans.preserve_binder_abs2_tr' \<^syntax_const>‹_MBex›)]
› 
lemma count_eq_zero_iff:
  "count M x = 0 ⟷ x ∉# M"
  by (auto simp add: set_mset_def)
lemma not_in_iff:
  "x ∉# M ⟷ count M x = 0"
  by (auto simp add: count_eq_zero_iff)
lemma count_greater_zero_iff [simp]:
  "count M x > 0 ⟷ x ∈# M"
  by (auto simp add: set_mset_def)
lemma count_inI:
  assumes "count M x = 0 ⟹ False"
  shows "x ∈# M"
proof (rule ccontr)
  assume "x ∉# M"
  with assms show False by (simp add: not_in_iff)
qed
lemma in_countE:
  assumes "x ∈# M"
  obtains n where "count M x = Suc n"
proof -
  from assms have "count M x > 0" by simp
  then obtain n where "count M x = Suc n"
    using gr0_conv_Suc by blast
  with that show thesis .
qed
lemma count_greater_eq_Suc_zero_iff [simp]:
  "count M x ≥ Suc 0 ⟷ x ∈# M"
  by (simp add: Suc_le_eq)
lemma count_greater_eq_one_iff [simp]:
  "count M x ≥ 1 ⟷ x ∈# M"
  by simp
lemma set_mset_empty [simp]:
  "set_mset {#} = {}"
  by (simp add: set_mset_def)
lemma set_mset_single:
  "set_mset {#b#} = {b}"
  by (simp add: set_mset_def)
lemma set_mset_eq_empty_iff [simp]:
  "set_mset M = {} ⟷ M = {#}"
  by (auto simp add: multiset_eq_iff count_eq_zero_iff)
lemma finite_set_mset [iff]:
  "finite (set_mset M)"
  using count [of M] by simp
lemma set_mset_add_mset_insert [simp]: ‹set_mset (add_mset a A) = insert a (set_mset A)›
  by (auto simp flip: count_greater_eq_Suc_zero_iff split: if_splits)
lemma multiset_nonemptyE [elim]:
  assumes "A ≠ {#}"
  obtains x where "x ∈# A"
proof -
  have "∃x. x ∈# A" by (rule ccontr) (insert assms, auto)
  with that show ?thesis by blast
qed
lemma count_gt_imp_in_mset: "count M x > n ⟹ x ∈# M"
  using count_greater_zero_iff by fastforce
subsubsection ‹Union›
lemma count_union [simp]:
  "count (M + N) a = count M a + count N a"
  by (simp add: plus_multiset.rep_eq)
lemma set_mset_union [simp]:
  "set_mset (M + N) = set_mset M ∪ set_mset N"
  by (simp only: set_eq_iff count_greater_zero_iff [symmetric] count_union) simp
lemma union_mset_add_mset_left [simp]:
  "add_mset a A + B = add_mset a (A + B)"
  by (auto simp: multiset_eq_iff)
lemma union_mset_add_mset_right [simp]:
  "A + add_mset a B = add_mset a (A + B)"
  by (auto simp: multiset_eq_iff)
lemma add_mset_add_single: ‹add_mset a A = A + {#a#}›
  by (subst union_mset_add_mset_right, subst add.comm_neutral) standard
subsubsection ‹Difference›
instance multiset :: (type) comm_monoid_diff
  by standard (transfer; simp add: fun_eq_iff)
lemma count_diff [simp]:
  "count (M - N) a = count M a - count N a"
  by (simp add: minus_multiset.rep_eq)
lemma add_mset_diff_bothsides:
  ‹add_mset a M - add_mset a A = M - A›
  by (auto simp: multiset_eq_iff)
lemma in_diff_count:
  "a ∈# M - N ⟷ count N a < count M a"
  by (simp add: set_mset_def)
lemma count_in_diffI:
  assumes "⋀n. count N x = n + count M x ⟹ False"
  shows "x ∈# M - N"
proof (rule ccontr)
  assume "x ∉# M - N"
  then have "count N x = (count N x - count M x) + count M x"
    by (simp add: in_diff_count not_less)
  with assms show False by auto
qed
lemma in_diff_countE:
  assumes "x ∈# M - N"
  obtains n where "count M x = Suc n + count N x"
proof -
  from assms have "count M x - count N x > 0" by (simp add: in_diff_count)
  then have "count M x > count N x" by simp
  then obtain n where "count M x = Suc n + count N x"
    using less_iff_Suc_add by auto
  with that show thesis .
qed
lemma in_diffD:
  assumes "a ∈# M - N"
  shows "a ∈# M"
proof -
  have "0 ≤ count N a" by simp
  also from assms have "count N a < count M a"
    by (simp add: in_diff_count)
  finally show ?thesis by simp
qed
lemma set_mset_diff:
  "set_mset (M - N) = {a. count N a < count M a}"
  by (simp add: set_mset_def)
lemma diff_empty [simp]: "M - {#} = M ∧ {#} - M = {#}"
  by rule (fact Groups.diff_zero, fact Groups.zero_diff)
lemma diff_cancel: "A - A = {#}"
  by (fact Groups.diff_cancel)
lemma diff_union_cancelR: "M + N - N = (M::'a multiset)"
  by (fact add_diff_cancel_right')
lemma diff_union_cancelL: "N + M - N = (M::'a multiset)"
  by (fact add_diff_cancel_left')
lemma diff_right_commute:
  fixes M N Q :: "'a multiset"
  shows "M - N - Q = M - Q - N"
  by (fact diff_right_commute)
lemma diff_add:
  fixes M N Q :: "'a multiset"
  shows "M - (N + Q) = M - N - Q"
  by (rule sym) (fact diff_diff_add)
lemma insert_DiffM [simp]: "x ∈# M ⟹ add_mset x (M - {#x#}) = M"
  by (clarsimp simp: multiset_eq_iff)
lemma insert_DiffM2: "x ∈# M ⟹ (M - {#x#}) + {#x#} = M"
  by simp
lemma diff_union_swap: "a ≠ b ⟹ add_mset b (M - {#a#}) = add_mset b M - {#a#}"
  by (auto simp add: multiset_eq_iff)
lemma diff_add_mset_swap [simp]: "b ∉# A ⟹ add_mset b M - A = add_mset b (M - A)"
  by (auto simp add: multiset_eq_iff simp: not_in_iff)
lemma diff_union_swap2 [simp]: "y ∈# M ⟹ add_mset x M - {#y#} = add_mset x (M - {#y#})"
  by (metis add_mset_diff_bothsides diff_union_swap diff_zero insert_DiffM)
lemma diff_diff_add_mset [simp]: "(M::'a multiset) - N - P = M - (N + P)"
  by (rule diff_diff_add)
lemma diff_union_single_conv:
  "a ∈# J ⟹ I + J - {#a#} = I + (J - {#a#})"
  by (simp add: multiset_eq_iff Suc_le_eq)
lemma mset_add [elim?]:
  assumes "a ∈# A"
  obtains B where "A = add_mset a B"
proof -
  from assms have "A = add_mset a (A - {#a#})"
    by simp
  with that show thesis .
qed
lemma union_iff:
  "a ∈# A + B ⟷ a ∈# A ∨ a ∈# B"
  by auto
lemma count_minus_inter_lt_count_minus_inter_iff:
  "count (M2 - M1) y < count (M1 - M2) y ⟷ y ∈# M1 - M2"
  by (meson count_greater_zero_iff gr_implies_not_zero in_diff_count leI order.strict_trans2
      order_less_asym)
lemma minus_inter_eq_minus_inter_iff:
  "(M1 - M2) = (M2 - M1) ⟷ set_mset (M1 - M2) = set_mset (M2 - M1)"
  by (metis add.commute count_diff count_eq_zero_iff diff_add_zero in_diff_countE multiset_eq_iff)
subsubsection ‹Min and Max›
abbreviation Min_mset :: "'a::linorder multiset ⇒ 'a" where
"Min_mset m ≡ Min (set_mset m)"
abbreviation Max_mset :: "'a::linorder multiset ⇒ 'a" where
"Max_mset m ≡ Max (set_mset m)"
lemma
  Min_in_mset: "M ≠ {#} ⟹ Min_mset M ∈# M" and
  Max_in_mset: "M ≠ {#} ⟹ Max_mset M ∈# M"
  by simp+
subsubsection ‹Equality of multisets›
lemma single_eq_single [simp]: "{#a#} = {#b#} ⟷ a = b"
  by (auto simp add: multiset_eq_iff)
lemma union_eq_empty [iff]: "M + N = {#} ⟷ M = {#} ∧ N = {#}"
  by (auto simp add: multiset_eq_iff)
lemma empty_eq_union [iff]: "{#} = M + N ⟷ M = {#} ∧ N = {#}"
  by (auto simp add: multiset_eq_iff)
lemma multi_self_add_other_not_self [simp]: "M = add_mset x M ⟷ False"
  by (auto simp add: multiset_eq_iff)
lemma add_mset_remove_trivial [simp]: ‹add_mset x M - {#x#} = M›
  by (auto simp: multiset_eq_iff)
lemma diff_single_trivial: "¬ x ∈# M ⟹ M - {#x#} = M"
  by (auto simp add: multiset_eq_iff not_in_iff)
lemma diff_single_eq_union: "x ∈# M ⟹ M - {#x#} = N ⟷ M = add_mset x N"
  by auto
lemma union_single_eq_diff: "add_mset x M = N ⟹ M = N - {#x#}"
  unfolding add_mset_add_single[of _ M] by (fact add_implies_diff)
lemma union_single_eq_member: "add_mset x M = N ⟹ x ∈# N"
  by auto
lemma add_mset_remove_trivial_If:
  "add_mset a (N - {#a#}) = (if a ∈# N then N else add_mset a N)"
  by (simp add: diff_single_trivial)
lemma add_mset_remove_trivial_eq: ‹N = add_mset a (N - {#a#}) ⟷ a ∈# N›
  by (auto simp: add_mset_remove_trivial_If)
lemma union_is_single:
  "M + N = {#a#} ⟷ M = {#a#} ∧ N = {#} ∨ M = {#} ∧ N = {#a#}"
  (is "?lhs = ?rhs")
proof
  show ?lhs if ?rhs using that by auto
  show ?rhs if ?lhs
    by (metis Multiset.diff_cancel add.commute add_diff_cancel_left' diff_add_zero diff_single_trivial insert_DiffM that)
qed
lemma single_is_union: "{#a#} = M + N ⟷ {#a#} = M ∧ N = {#} ∨ M = {#} ∧ {#a#} = N"
  by (auto simp add: eq_commute [of "{#a#}" "M + N"] union_is_single)
lemma add_eq_conv_diff:
  "add_mset a M = add_mset b N ⟷ M = N ∧ a = b ∨ M = add_mset b (N - {#a#}) ∧ N = add_mset a (M - {#b#})"
  (is "?lhs ⟷ ?rhs")
proof
  show ?lhs if ?rhs
    using that
    by (auto simp add: add_mset_commute[of a b])
  show ?rhs if ?lhs
  proof (cases "a = b")
    case True with ‹?lhs› show ?thesis by simp
  next
    case False
    from ‹?lhs› have "a ∈# add_mset b N" by (rule union_single_eq_member)
    with False have "a ∈# N" by auto
    moreover from ‹?lhs› have "M = add_mset b N - {#a#}" by (rule union_single_eq_diff)
    moreover note False
    ultimately show ?thesis by (auto simp add: diff_right_commute [of _ "{#a#}"])
  qed
qed
lemma add_mset_eq_single [iff]: "add_mset b M = {#a#} ⟷ b = a ∧ M = {#}"
  by (auto simp: add_eq_conv_diff)
lemma single_eq_add_mset [iff]: "{#a#} = add_mset b M ⟷ b = a ∧ M = {#}"
  by (auto simp: add_eq_conv_diff)
lemma insert_noteq_member:
  assumes BC: "add_mset b B = add_mset c C"
   and bnotc: "b ≠ c"
  shows "c ∈# B"
proof -
  have "c ∈# add_mset c C" by simp
  have nc: "¬ c ∈# {#b#}" using bnotc by simp
  then have "c ∈# add_mset b B" using BC by simp
  then show "c ∈# B" using nc by simp
qed
lemma add_eq_conv_ex:
  "(add_mset a M = add_mset b N) =
    (M = N ∧ a = b ∨ (∃K. M = add_mset b K ∧ N = add_mset a K))"
  by (auto simp add: add_eq_conv_diff)
lemma multi_member_split: "x ∈# M ⟹ ∃A. M = add_mset x A"
  by (rule exI [where x = "M - {#x#}"]) simp
lemma multiset_add_sub_el_shuffle:
  assumes "c ∈# B"
    and "b ≠ c"
  shows "add_mset b (B - {#c#}) = add_mset b B - {#c#}"
proof -
  from ‹c ∈# B› obtain A where B: "B = add_mset c A"
    by (blast dest: multi_member_split)
  have "add_mset b A = add_mset c (add_mset b A) - {#c#}" by simp
  then have "add_mset b A = add_mset b (add_mset c A) - {#c#}"
    by (simp add: ‹b ≠ c›)
  then show ?thesis using B by simp
qed
lemma add_mset_eq_singleton_iff[iff]:
  "add_mset x M = {#y#} ⟷ M = {#} ∧ x = y"
  by auto
subsubsection ‹Pointwise ordering induced by count›
definition subseteq_mset :: "'a multiset ⇒ 'a multiset ⇒ bool"  (infix ‹⊆#› 50)
  where "A ⊆# B ⟷ (∀a. count A a ≤ count B a)"
definition subset_mset :: "'a multiset ⇒ 'a multiset ⇒ bool" (infix ‹⊂#› 50)
  where "A ⊂# B ⟷ A ⊆# B ∧ A ≠ B"
abbreviation (input) supseteq_mset :: "'a multiset ⇒ 'a multiset ⇒ bool"  (infix ‹⊇#› 50)
  where "supseteq_mset A B ≡ B ⊆# A"
abbreviation (input) supset_mset :: "'a multiset ⇒ 'a multiset ⇒ bool"  (infix ‹⊃#› 50)
  where "supset_mset A B ≡ B ⊂# A"
notation (input)
  subseteq_mset  (infix ‹≤#› 50) and
  supseteq_mset  (infix ‹≥#› 50)
notation (ASCII)
  subseteq_mset  (infix ‹<=#› 50) and
  subset_mset  (infix ‹<#› 50) and
  supseteq_mset  (infix ‹>=#› 50) and
  supset_mset  (infix ‹>#› 50)
global_interpretation subset_mset: ordering ‹(⊆#)› ‹(⊂#)›
  by standard (auto simp add: subset_mset_def subseteq_mset_def multiset_eq_iff intro: order.trans order.antisym)