From f250194538edc11d2270cf0e86d3b2c78cddcd48 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Tue, 3 Mar 2026 05:21:35 +0000 Subject: [PATCH 1/4] feat(Query): add query complexity framework with sorting lower bounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a monad-generic framework for proving upper and lower bounds on query complexity of algorithms, demonstrated on comparison sorting. Key components: - `Query.Basic`: WPMonad-based framework with `UpperBound`/`LowerBound` predicates - `QueryTree`: decision tree reification for lower bound proofs - `IsMonadicSort`: correctness specification for comparison sorts - Verified insertion sort (O(n²)) and merge sort (O(n log n)) implementations - `lowerBound_infinite`: any correct comparison sort on an infinite type requires ≥ ⌈log₂(n!)⌉ comparisons for every input size n Co-Authored-By: Claude Opus 4.6 --- Cslib.lean | 13 +- .../Algorithms/Lean/MergeSort/MergeSort.lean | 207 ------------- Cslib/Algorithms/Lean/Query/Basic.lean | 286 ++++++++++++++++++ Cslib/Algorithms/Lean/Query/LowerBound.lean | 50 +++ .../Algorithms/Lean/Query/MonadicExample.lean | 105 +++++++ Cslib/Algorithms/Lean/Query/QueryTree.lean | 194 ++++++++++++ .../Lean/Query/Sort/Insertion/Defs.lean | 32 ++ .../Lean/Query/Sort/Insertion/Lemmas.lean | 227 ++++++++++++++ .../Lean/Query/Sort/LowerBound.lean | 194 ++++++++++++ .../Lean/Query/Sort/Merge/Defs.lean | 44 +++ .../Lean/Query/Sort/Merge/Lemmas.lean | 247 +++++++++++++++ .../Lean/Query/Sort/MonadicSort.lean | 99 ++++++ .../Algorithms/Lean/Query/Sort/QueryTree.lean | 107 +++++++ Cslib/Algorithms/Lean/Query/UpperBound.lean | 74 +++++ Cslib/Algorithms/Lean/TimeM.lean | 142 --------- lakefile.toml | 1 + 16 files changed, 1671 insertions(+), 351 deletions(-) delete mode 100644 Cslib/Algorithms/Lean/MergeSort/MergeSort.lean create mode 100644 Cslib/Algorithms/Lean/Query/Basic.lean create mode 100644 Cslib/Algorithms/Lean/Query/LowerBound.lean create mode 100644 Cslib/Algorithms/Lean/Query/MonadicExample.lean create mode 100644 Cslib/Algorithms/Lean/Query/QueryTree.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/MonadicSort.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean create mode 100644 Cslib/Algorithms/Lean/Query/UpperBound.lean delete mode 100644 Cslib/Algorithms/Lean/TimeM.lean diff --git a/Cslib.lean b/Cslib.lean index a9d5ffc3e..d05d7c4be 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -1,7 +1,16 @@ module -- shake: keep-all -public import Cslib.Algorithms.Lean.MergeSort.MergeSort -public import Cslib.Algorithms.Lean.TimeM +public import Cslib.Algorithms.Lean.Query.Basic +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Defs +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Lemmas +public import Cslib.Algorithms.Lean.Query.LowerBound +public import Cslib.Algorithms.Lean.Query.Sort.LowerBound +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Defs +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Lemmas +public import Cslib.Algorithms.Lean.Query.Sort.MonadicSort +public import Cslib.Algorithms.Lean.Query.MonadicExample +public import Cslib.Algorithms.Lean.Query.QueryTree +public import Cslib.Algorithms.Lean.Query.UpperBound public import Cslib.Computability.Automata.Acceptors.Acceptor public import Cslib.Computability.Automata.Acceptors.OmegaAcceptor public import Cslib.Computability.Automata.DA.Basic diff --git a/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean b/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean deleted file mode 100644 index 8ba55d461..000000000 --- a/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean +++ /dev/null @@ -1,207 +0,0 @@ -/- -Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Sorrachai Yingchareonthawornhcai --/ - -module - -public import Cslib.Algorithms.Lean.TimeM -public import Mathlib.Data.Nat.Cast.Order.Ring -public import Mathlib.Data.Nat.Lattice -public import Mathlib.Data.Nat.Log - -@[expose] public section - -/-! -# MergeSort on a list - -In this file we introduce `merge` and `mergeSort` algorithms that returns a time monad -over the list `TimeM ℕ (List α)`. The time complexity of `mergeSort` is the number of comparisons. - --- -## Main results - -- `mergeSort_correct`: `mergeSort` permutes the list into a sorted one. -- `mergeSort_time`: The number of comparisons of `mergeSort` is at most `n*⌈log₂ n⌉`. - --/ - -set_option autoImplicit false - -namespace Cslib.Algorithms.Lean.TimeM - -variable {α : Type} [LinearOrder α] - -/-- Merges two lists into a single list, counting comparisons as time cost. -Returns a `TimeM ℕ (List α)` where the time represents the number of comparisons performed. -/ -def merge : List α → List α → TimeM ℕ (List α) - | [], ys => return ys - | xs, [] => return xs - | x::xs', y::ys' => do - ✓ let c := (x ≤ y : Bool) - if c then - let rest ← merge xs' (y::ys') - return (x :: rest) - else - let rest ← merge (x::xs') ys' - return (y :: rest) - -/-- Sorts a list using the merge sort algorithm, counting comparisons as time cost. -Returns a `TimeM ℕ (List α)` where the time represents the total number of comparisons. -/ -def mergeSort (xs : List α) : TimeM ℕ (List α) := do - if xs.length < 2 then return xs - else - let half := xs.length / 2 - let left := xs.take half - let right := xs.drop half - let sortedLeft ← mergeSort left - let sortedRight ← mergeSort right - merge sortedLeft sortedRight - -section Correctness - -open List - -/-- Our merge computes the one already in mathlib. -/ -@[simp, grind =] -theorem ret_merge (xs ys : List α) : ⟪merge xs ys⟫ = xs.merge ys := by - fun_induction merge with grind [nil_merge, merge_right, cons_merge_cons] - -/-- A list is sorted if it satisfies the `Pairwise (· ≤ ·)` predicate. -/ -abbrev IsSorted (l : List α) : Prop := List.Pairwise (· ≤ ·) l - -/-- `x` is a minimum element of list `l` if `x ≤ b` for all `b ∈ l`. -/ -abbrev MinOfList (x : α) (l : List α) : Prop := ∀ b ∈ l, x ≤ b - -@[grind →] -theorem mem_either_merge (xs ys : List α) (z : α) (hz : z ∈ ⟪merge xs ys⟫) : z ∈ xs ∨ z ∈ ys := by - grind [List.mem_merge] - -theorem min_all_merge (x : α) (xs ys : List α) (hxs : MinOfList x xs) (hys : MinOfList x ys) : - MinOfList x ⟪merge xs ys⟫ := by - grind - -theorem sorted_merge {l1 l2 : List α} (hxs : IsSorted l1) (hys : IsSorted l2) : - IsSorted ⟪merge l1 l2⟫ := by - grind [hxs.merge hys] - -theorem mergeSort_sorted (xs : List α) : IsSorted ⟪mergeSort xs⟫ := by - fun_induction mergeSort xs with - | case1 x => - rcases x with _ | ⟨a, _ | ⟨b, rest⟩⟩ <;> grind - | case2 _ _ _ _ _ ih2 ih1 => exact sorted_merge ih2 ih1 - -lemma merge_perm (l₁ l₂ : List α) : ⟪merge l₁ l₂⟫ ~ l₁ ++ l₂ := by - fun_induction merge with grind [List.merge_perm_append] - -theorem mergeSort_perm (xs : List α) : ⟪mergeSort xs⟫ ~ xs := by - fun_induction mergeSort xs with - | case1 => simp - | case2 x _ _ left right ih2 ih1 => - simp only [ret_bind] - calc - ⟪merge ⟪mergeSort left⟫ ⟪mergeSort right⟫⟫ ~ - ⟪mergeSort left⟫ ++ ⟪mergeSort right⟫ := by apply merge_perm - _ ~ left++right := Perm.append ih2 ih1 - _ ~ x := by simp only [take_append_drop, Perm.refl, left, right] - -/-- MergeSort is functionally correct. -/ -theorem mergeSort_correct (xs : List α) : IsSorted ⟪mergeSort xs⟫ ∧ ⟪mergeSort xs⟫ ~ xs := - ⟨mergeSort_sorted xs, mergeSort_perm xs⟩ - -end Correctness - -section TimeComplexity - -/-- Recurrence relation for the time complexity of merge sort. -For a list of length `n`, this counts the total number of comparisons: -- Base cases: 0 comparisons for lists of length 0 or 1 -- Recursive case: split the list, sort both halves, - then merge (which takes at most `n` comparisons) -/ -def timeMergeSortRec : ℕ → ℕ -| 0 => 0 -| 1 => 0 -| n@(_+2) => timeMergeSortRec (n/2) + timeMergeSortRec ((n-1)/2 + 1) + n - -open Nat (clog) - -/-- Key Lemma: ⌈log2 ⌈n/2⌉⌉ ≤ ⌈log2 n⌉ - 1 for n > 1 -/ -@[grind →] -lemma clog2_half_le (n : ℕ) (h : n > 1) : clog 2 ((n + 1) / 2) ≤ clog 2 n - 1 := by - grind [Nat.clog_of_one_lt one_lt_two h] - -/-- Same logic for the floor half: ⌈log2 ⌊n/2⌋⌉ ≤ ⌈log2 n⌉ - 1 -/ -@[grind →] -lemma clog2_floor_half_le (n : ℕ) (h : n > 1) : clog 2 (n / 2) ≤ clog 2 n - 1 := by - apply Nat.le_trans _ (clog2_half_le n h) - apply Nat.clog_monotone - grind - -private lemma some_algebra (n : ℕ) : - (n / 2 + 1) * clog 2 (n / 2 + 1) + ((n + 1) / 2 + 1) * clog 2 ((n + 1) / 2 + 1) + (n + 2) ≤ - (n + 2) * clog 2 (n + 2) := by - -- 1. Substitution: Let N = n_1 + 2 to clean up the expression - let N := n + 2 - have hN : N ≥ 2 := by omega - -- 2. Rewrite the terms using N - have t1 : n / 2 + 1 = N / 2 := by omega - have t2 : (n + 1) / 2 + 1 = (N + 1) / 2 := by omega - have t3 : n + 1 + 1 = N := by omega - let k := clog 2 N - have h_bound_l : clog 2 (N / 2) ≤ k - 1 := clog2_floor_half_le N hN - have h_bound_r : clog 2 ((N + 1) / 2) ≤ k - 1 := clog2_half_le N hN - have h_split : N / 2 + (N + 1) / 2 = N := by omega - grw [t1, t2, t3, h_bound_l, h_bound_r, ←Nat.add_mul, h_split] - exact Nat.le_refl (N * (k - 1) + N) - -/-- Upper bound function for merge sort time complexity: `T(n) = n * ⌈log₂ n⌉` -/ -abbrev T (n : ℕ) : ℕ := n * clog 2 n - -/-- Solve the recurrence -/ -theorem timeMergeSortRec_le (n : ℕ) : timeMergeSortRec n ≤ T n := by - fun_induction timeMergeSortRec with - | case1 => grind - | case2 => grind - | case3 n ih2 ih1 => - grw [ih1,ih2] - have := some_algebra n - grind [Nat.add_div_right] - -theorem merge_ret_length_eq_sum (xs ys : List α) : - ⟪merge xs ys⟫.length = xs.length + ys.length := by - simp - -@[simp] theorem mergeSort_same_length (xs : List α) : - ⟪mergeSort xs⟫.length = xs.length := by - fun_induction mergeSort - · simp - · grind [List.length_merge] - -@[simp] theorem merge_time (xs ys : List α) : (merge xs ys).time ≤ xs.length + ys.length := by - fun_induction merge with - | case3 => - grind - | _ => simp - -theorem mergeSort_time_le (xs : List α) : - (mergeSort xs).time ≤ timeMergeSortRec xs.length := by - fun_induction mergeSort with - | case1 => - grind - | case2 _ _ _ _ _ ih2 ih1 => - simp only [time_bind] - grw [merge_time] - simp only [mergeSort_same_length] - unfold timeMergeSortRec - grind - -/-- Time complexity of mergeSort -/ -theorem mergeSort_time (xs : List α) : - let n := xs.length - (mergeSort xs).time ≤ n * clog 2 n := by - grind [mergeSort_time_le, timeMergeSortRec_le] - -end TimeComplexity - -end Cslib.Algorithms.Lean.TimeM diff --git a/Cslib/Algorithms/Lean/Query/Basic.lean b/Cslib/Algorithms/Lean/Query/Basic.lean new file mode 100644 index 000000000..ef36586c8 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Basic.lean @@ -0,0 +1,286 @@ +/- +Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sorrachai Yingchareonthawornhcai, Eric Wieser, Kim Morrison +-/ +module + +import Cslib.Init +public import Std.Do.Triple.Basic +import Std.Tactic.Do + +/-! # Query Complexity Framework + +This file provides infrastructure for proving upper bounds on the number of *queries* +(comparisons, oracle calls, etc.) that an algorithm makes. + +## Approach + +We define a monad transformer `TimeT m` that wraps `StateT` with an internal tick counter. +Each call to `tick` increments this counter by 1. The predicate `Costs prog k` asserts that +`prog` increments the counter by at most `k`, expressed as a Hoare triple: +for any initial count `c`, if `prog` starts with count `≤ c`, it finishes with count `≤ c + k`. + +The `Costs` combinators (`pure`, `bind`, `bind_spec`, `ite`, `map`, `le`, etc.) form +a small algebra for composing cost bounds, mirroring the structure of monadic programs. + +## Why this works + +The key to soundness is that algorithms are written as **monad-generic** functions: +``` +def myAlgorithm [Monad m] (query : α → m β) (input : γ) : m δ := ... +``` +Because `myAlgorithm` is polymorphic in `m`, it cannot inspect or manipulate the tick counter +directly — it can only interact with it through `query`. When we specialize `m` to `TimeT` +and wrap `query` with a call to `tick`, every query invocation is faithfully counted. +The monad abstraction acts as an information barrier: the algorithm cannot distinguish +the instrumented monad from any other, so it cannot game the counter. + +See `Cslib.Algorithms.Lean.Query.UpperBound` for the `UpperBound` and `UpperBoundT` predicates that +package this specialization step, and for a discussion of the computability caveat. +-/ + +open Std.Do + +set_option mvcgen.warning false + +public section + +namespace Cslib.Query + +/-- Internal state for `TimeT`: tracks the number of ticks (queries) performed. -/ +structure TimeT.State where + /-- The current tick count. -/ + count : Nat + +/-- A monad transformer that adds tick-counting to any monad `m`. -/ +@[expose] def TimeT (m : Type → Type) (α : Type) := StateT TimeT.State m α + +/-- The tick-counting monad, specializing `TimeT` to `Id`. -/ +@[expose] def TimeM (α : Type) := TimeT Id α + +namespace TimeT + +/-- Wrap a `StateT TimeT.State m` computation as a `TimeT m` computation. -/ +@[expose, inline] def mk (x : StateT State m α) : TimeT m α := x + +/-- Unwrap a `TimeT m` computation to `StateT TimeT.State m`. -/ +@[expose, inline] def unfold (x : TimeT m α) : StateT State m α := x + +@[simp] theorem unfold_mk (x : StateT State m α) : (mk x).unfold = x := rfl +@[simp] theorem mk_unfold (x : TimeT m α) : mk x.unfold = x := rfl + +@[ext] theorem ext {x y : TimeT m α} (h : x.unfold = y.unfold) : x = y := h + +instance [Monad m] : Monad (TimeT m) where + pure a := mk (pure a) + bind x f := mk (x.unfold >>= fun a => (f a).unfold) + +instance [Monad m] [LawfulMonad m] : LawfulMonad (TimeT m) := + inferInstanceAs (LawfulMonad (StateT State m)) +instance [WP m ps] : Std.Do.WP (TimeT m) (.arg State ps) := + inferInstanceAs (Std.Do.WP (StateT State m) _) + +instance [Monad m] [WPMonad m ps] : WPMonad (TimeT m) (.arg State ps) := + inferInstanceAs (WPMonad (StateT State m) _) + +instance [Monad m] : MonadLift m (TimeT m) where + monadLift x := mk (StateT.lift x) + +/-- Run a `TimeT` computation, starting with tick count 0, + returning the result and the final tick count. -/ +def run [Monad m] (x : TimeT m α) : m (α × Nat) := do + let (a, s) ← x.unfold.run ⟨0⟩ + pure (a, s.count) + +/-- Run a `TimeT` computation, starting with tick count 0, discarding the tick count. -/ +def run' [Monad m] (x : TimeT m α) : m α := Prod.fst <$> x.unfold.run ⟨0⟩ + +/-- Increment the tick counter by 1. -/ +@[expose] def tick [Monad m] : TimeT m Unit := + mk (modify fun s => ⟨s.count + 1⟩) + +@[simp] theorem tick_unfold [Monad m] : + (tick : TimeT m Unit).unfold = modify fun s => ⟨s.count + 1⟩ := rfl + +/-- Instrument a pure function as a tick-counted query. + `counted f a` increments the tick counter by 1 and returns `f a`. -/ +@[expose] def counted [Monad m] (f : α → β) (a : α) : TimeT m β := do tick; pure (f a) + +/-- Instrument a monadic function as a tick-counted query. + `countedM f a` increments the tick counter by 1, then runs `f a` in the base monad. -/ +@[expose] def countedM [Monad m] (f : α → m β) (a : α) : TimeT m β := do + tick; MonadLift.monadLift (f a) + +/-- Assertion: the tick count is at most `k`. -/ +@[expose] def checkBound {ps : PostShape} (k : Nat) : + Assertion (.arg State ps) := + fun s => ⌜s.count ≤ k⌝ + +/-- `Costs prog k` asserts that `prog` uses at most `k` ticks. -/ +@[expose] def Costs {n : Type → Type} {ps : PostShape} [WP n ps] + (prog : TimeT n α) (k : Nat) : Prop := + ∀ c, ⦃checkBound c⦄ prog ⦃⇓ _ => checkBound (c + k)⦄ + +/-- Spec for `tick` with schematic postcondition. + To satisfy any postcondition `Q` after `tick`, + it suffices to have `Q` hold with count incremented by 1. -/ +@[spec] +theorem tick_spec [Monad n] [WPMonad n ps] {Q : PostCond Unit (.arg State ps)} : + ⦃fun s => Q.1 () ⟨s.count + 1⟩⦄ (tick : TimeT n Unit) ⦃Q⦄ := by + simp only [Triple.iff] + unfold tick + change _ ⊢ₛ (PredTrans.pushArg fun s => wp (pure ((), { count := s.count + 1 }) : n _)).apply Q + simp only [PredTrans.apply_pushArg, WP.pure]; exact .rfl + +/-- `tick` costs 1. -/ +public theorem tick_costs [Monad n] [WPMonad n ps] : Costs (tick : TimeT n Unit) 1 := by + intro c + mvcgen + simp_all [checkBound] + +/-- WP of `MonadLift.monadLift` through `TimeT`: passes through the tick state unchanged. -/ +@[simp, spec] +theorem wp_monadLift [Monad m] [WPMonad m ps] (x : m α) + (Q : PostCond α (.arg State ps)) : + wp⟦(MonadLift.monadLift x : TimeT m α)⟧ Q = fun s => wp⟦x⟧ (fun a => Q.1 a s, Q.2) := + Std.Do.WP.monadLift_StateT x Q + +/-- `pure` costs 0. -/ +public theorem Costs.pure [Monad n] [WPMonad n ps] (a : α) : + Costs (Pure.pure a : TimeT n α) 0 := by + intro c + exact Triple.pure a .rfl + +/-- Sequential composition: costs add. -/ +public theorem Costs.bind [Monad n] [WPMonad n ps] + {x : TimeT n α} {f : α → TimeT n β} + (hx : Costs x k₁) (hf : ∀ a, Costs (f a) k₂) : + Costs (x >>= f) (k₁ + k₂) := by + intro c + apply Triple.bind _ _ (hx c) (fun a => ?_) + have := hf a (c + k₁) + rwa [Nat.add_assoc] at this + +-- Upstreamed in https://github.com/leanprover/lean4/pull/12760 +private theorem ExceptConds.and_elim_left (x y : ExceptConds ps) : + (x ∧ₑ y).entails x := by + induction ps with + | pure => exact ⟨⟩ | arg _ _ ih => exact ih _ _ + | except _ _ ih => exact ⟨fun _ => SPred.and_elim_l, ih _ _⟩ + +-- Upstreamed in https://github.com/leanprover/lean4/pull/12760 +private theorem ExceptConds.and_elim_right (x y : ExceptConds ps) : + (x ∧ₑ y).entails y := by + induction ps with + | pure => exact ⟨⟩ | arg _ _ ih => exact ih _ _ + | except _ _ ih => exact ⟨fun _ => SPred.and_elim_r, ih _ _⟩ + +/-- Sequential composition with specification: when the continuation's cost + depends on a predicate established by the first computation. -/ +public theorem Costs.bind_spec [Monad n] [WPMonad n ps] + {x : TimeT n α} {f : α → TimeT n β} {P : α → Prop} + (hx_cost : Costs x k₁) (hx_spec : ⦃⌜True⌝⦄ x ⦃⇓a => ⌜P a⌝⦄) + (hf : ∀ a, P a → Costs (f a) k₂) : + Costs (x >>= f) (k₁ + k₂) := by + intro c + have hcombined := Triple.and _ (hx_cost c) hx_spec + apply Triple.bind _ _ + · apply SPred.entails.trans + (SPred.entails.trans (SPred.and_intro .rfl (SPred.pure_intro trivial)) hcombined) + · apply (wp x).mono + exact ⟨fun _ => .rfl, ExceptConds.and_elim_left _ _⟩ + · intro a + simp only [Triple] + apply SPred.pure_elim_r + intro ha + have := hf a ha (c + k₁) + rwa [Nat.add_assoc] at this + +/-- Branching: cost of either branch. -/ +public theorem Costs.ite [Monad n] [WPMonad n ps] + {t e : TimeT n α} (b : Bool) (ht : Costs t k) (he : Costs e k) : + Costs (if b then t else e) k := by + intro c; cases b + · exact he c + · exact ht c + +/-- Functorial map preserves cost (postcondition is result-independent). -/ +public theorem Costs.map [Monad n] [WPMonad n ps] + {x : TimeT n α} {f : α → β} (h : Costs x k) : + Costs (f <$> x) k := by + intro c; simp only [Triple, WP.map]; exact h c + +/-- Lifting from the base monad costs nothing, provided the computation doesn't throw. -/ +public theorem Costs.monadLift [Monad n] [WPMonad n ps] (a : n α) + (ha : ∀ (P : Prop), ⦃⌜P⌝⦄ a ⦃⇓_ => ⌜P⌝⦄) : + Costs (MonadLift.monadLift a : TimeT n α) 0 := by + intro c + apply SPred.entails.trans _ (Spec.monadLift_StateT a _) + simp only [checkBound, Nat.add_zero] + intro s + exact ha (s.count ≤ c) + +/-- Weakening: increase the bound. -/ +public theorem Costs.le [Monad n] [WPMonad n ps] + {prog : TimeT n α} (h : Costs prog k) (hle : k ≤ k') : + Costs prog k' := by + intro c + exact Triple.entails_wp_of_post (h c) (by + simp only [PostCond.entails_noThrow] + intro _ s + exact SPred.pure_mono (fun hs => Nat.le_trans hs (Nat.add_le_add_left hle c))) + +/-- `pure` costs at most `k`, for any `k`. -/ +public theorem Costs.pure_le [Monad n] [WPMonad n ps] (a : α) (k : Nat) : + Costs (Pure.pure a : TimeT n α) k := + Costs.le (Costs.pure a) (Nat.zero_le k) + +/-- Branching with different costs: bounded by `max`. -/ +public theorem Costs.ite_max [Monad n] [WPMonad n ps] + {t e : TimeT n α} (b : Bool) (ht : Costs t kt) (he : Costs e ke) : + Costs (if b then t else e) (max kt ke) := + Costs.ite b (Costs.le ht (Nat.le_max_left kt ke)) (Costs.le he (Nat.le_max_right kt ke)) + +/-- `counted f a` costs 1. -/ +public theorem counted_costs [Monad n] [WPMonad n ps] (f : α → β) (a : α) : + Costs (counted (m := n) f a) 1 := + Costs.bind tick_costs (fun _ => Costs.pure (f a)) + +/-- `countedM f a` costs 1, provided the underlying computation preserves propositions. -/ +public theorem countedM_costs [Monad n] [WPMonad n ps] (f : α → n β) (a : α) + (hf : ∀ (P : Prop), ⦃⌜P⌝⦄ f a ⦃⇓_ => ⌜P⌝⦄) : + Costs (countedM (m := n) f a) 1 := + Costs.bind tick_costs (fun _ => Costs.monadLift (f a) hf) + +end TimeT + +/-- A monadic function has a pure return: its output is determined by a pure function + of its input, regardless of monadic effects. -/ +@[expose] def PureReturn {ps : PostShape} [Monad m] [WPMonad m ps] + (f : α → m β) (f' : α → β) : Prop := + ∀ a, ⦃⌜True⌝⦄ f a ⦃⇓b => ⌜b = f' a⌝⦄ + +/-- `pure ∘ f'` has pure return `f'`. -/ +theorem PureReturn.pure {ps : PostShape} [Monad m] [WPMonad m ps] (f' : α → β) : + PureReturn (fun a => Pure.pure (f' a) : α → m β) f' := by + intro a; mvcgen + +/-- A function with a pure return is non-failing. -/ +theorem PureReturn.nonFailing {ps : PostShape} [Monad m] [WPMonad m ps] + {f : α → m β} {f' : α → β} (h : PureReturn f f') : + ∀ a, ⦃⌜True⌝⦄ f a ⦃⇓_ => ⌜True⌝⦄ := by + intro a + exact Triple.entails_wp_of_post (h a) (by + simp only [PostCond.entails_noThrow]; intro _; exact SPred.pure_mono (fun _ => trivial)) + +instance : Monad TimeM := inferInstanceAs (Monad (TimeT Id)) +instance : LawfulMonad TimeM := inferInstanceAs (LawfulMonad (TimeT Id)) +instance : Std.Do.WP TimeM (.arg TimeT.State .pure) := + inferInstanceAs (Std.Do.WP (TimeT Id) _) +instance : WPMonad TimeM (.arg TimeT.State .pure) := + inferInstanceAs (WPMonad (TimeT Id) _) + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/LowerBound.lean b/Cslib/Algorithms/Lean/Query/LowerBound.lean new file mode 100644 index 000000000..d04478534 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/LowerBound.lean @@ -0,0 +1,50 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Eric Wieser +-/ +module + +public import Cslib.Algorithms.Lean.Query.QueryTree + +/-! # LowerBound: Query Complexity Lower Bounds via Decision Trees + +`Query.LowerBound f size bound` asserts a worst-case lower bound on the number of queries +made by a monad-generic algorithm `f`. + +## The decision tree argument + +To prove lower bounds, we specialize the algorithm to the `QueryTree` monad (the free monad +over queries), which reifies the algorithm's query pattern as an explicit decision tree. +Each internal node corresponds to a query, and each leaf to a final result. + +The predicate `LowerBound f size bound` states: for every input size `n`, there exists +an input `x` with `size x = n` and an oracle such that the algorithm makes at least +`bound n` queries. This is the worst-case formulation — for each size we exhibit a hard +input, and choose the oracle after seeing the algorithm's strategy (the tree). +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- `LowerBound f size bound` asserts that for every input size `n`, there exists + an input `x` with `size x = n` and an oracle making `f` perform at least `bound n` queries. + + The algorithm `f` is specialized to `QueryTree`, reifying its query pattern + as a decision tree. The oracle determines which path through the tree is taken. + Unlike `UpperBound` (upper bounds), no parametricity assumption is needed: the tree + structure itself forces enough branching for correctness. + + Note: one could potentially relax `∀ n` to "for infinitely many n" for some applications. -/ +@[expose] def LowerBound + (f : ∀ {m : Type → Type} [Monad m], (α → m β) → γ → m δ) + (size : γ → Nat) (bound : Nat → Nat) : Prop := + ∀ n, ∃ x, size x = n ∧ ∃ (oracle : α → β), + (f QueryTree.ask x : QueryTree α β δ).queriesOn oracle ≥ bound n + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/MonadicExample.lean b/Cslib/Algorithms/Lean/Query/MonadicExample.lean new file mode 100644 index 000000000..0e419131a --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/MonadicExample.lean @@ -0,0 +1,105 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.UpperBound +import Std.Tactic.Do + +/-! # Monadic Map-Sum Demo + +A minimal demonstration of `UpperBoundT` with a non-trivial base monad. + +`mapSum f xs` applies the monadic function `f` to each element of `xs` +and accumulates the results into a `StateM Int` counter. + +Since `f` is called exactly once per element, `mapSum` runs in `xs.length` queries. +This is expressed via `UpperBoundT (n := StateM Int)`, which instruments the query with tick-counting +while preserving the `StateM Int` layer used for accumulation. +-/ + +open Std.Do Cslib.Query + +set_option mvcgen.warning false + +public section + +namespace Cslib.Query + +/-- Apply a monadic function to each element of a list and accumulate the results in `StateM Int`. + The function `f` is the query whose invocations we measure. -/ +@[expose] def mapSum [Monad m] [MonadLiftT (StateM Int) m] + (f : Int → m Int) : List Int → m Unit + | [] => pure () + | x :: xs => do + let y ← f x + (modify (· + y) : StateM Int Unit) + mapSum f xs + +/-- `mapSum` calls the function exactly once per element. -/ +public theorem mapSum_runsInT : + UpperBoundT (n := StateM Int) mapSum (fun xs => xs.length) := by + intro query hquery xs + induction xs with + | nil => exact TimeT.Costs.pure () + | cons x xs ih => + simp only [List.length]; rw [Nat.add_comm] + have ih : TimeT.Costs (mapSum query xs) xs.length := ih + exact TimeT.Costs.bind (hquery x) (fun y => by + have := TimeT.Costs.bind + (TimeT.Costs.monadLift (modify (· + y) : StateM Int Unit) (fun P => by mvcgen)) + (fun _ => ih) + rwa [Nat.zero_add] at this) + +/-- `mapSum` with a state-preserving monadic function accumulates `(xs.map g).sum`. + + The predicate family `pre c` captures "the Int state is c" within the + abstract postcondition shape `ps`. The hypotheses `hf` and `h_modify` + assert that `f` preserves this predicate and the lifted `modify` transitions it. -/ +public theorem mapSum_spec_general [Monad m] [MonadLiftT (StateM Int) m] [WPMonad m ps] + (f : Int → m Int) (g : Int → Int) + (pre : Int → Assertion ps) + (hf : ∀ x c, ⦃pre c⦄ f x ⦃⇓ y => pre c ∧ ⌜y = g x⌝⦄) + (h_modify : ∀ v c, ⦃pre c⦄ + (MonadLiftT.monadLift (modify (· + v) : StateM Int Unit) : m Unit) + ⦃⇓ _ => pre (c + v)⦄) + (xs : List Int) : + ∀ c, ⦃pre c⦄ mapSum f xs ⦃⇓ _ => pre (c + (xs.map g).sum)⦄ := by + induction xs with + | nil => + intro c; simp only [mapSum]; mvcgen; simp_all + | cons x xs ih => + intro c; dsimp only [mapSum]; mvcgen [hf, h_modify, ih] + subst_vars; simp only [List.map_cons, List.sum_cons, ← Int.add_assoc]; exact .rfl + +/-- `mapSum` with a pure function accumulates `(xs.map f).sum`. + Special case of `mapSum_spec_general`. -/ +public theorem mapSum_spec (f : Int → Int) (xs : List Int) : + ∀ c, ⦃fun n => ⌜n = c⌝⦄ + mapSum (m := StateM Int) (fun a => pure (f a)) xs + ⦃⇓ _ => fun n => ⌜n = c + (xs.map f).sum⌝⦄ := + mapSum_spec_general (m := StateM Int) (fun a => pure (f a)) f + (fun c => (fun n => ⌜n = c⌝ : Assertion _)) + (by intro x c; mvcgen) + (by intro v c; mvcgen; subst_vars; rfl) + xs + +/-- `mapSum` with a tick-instrumented pure function still accumulates `(xs.map f).sum`. + Special case of `mapSum_spec_general`. -/ +public theorem mapSum_spec_tick (f : Int → Int) (xs : List Int) : + ∀ c, ⦃fun _ => fun n => ⌜n = c⌝⦄ + mapSum (m := TimeT (StateM Int)) (TimeT.counted f) xs + ⦃⇓ _ => fun _ => fun n => ⌜n = c + (xs.map f).sum⌝⦄ := + mapSum_spec_general (m := TimeT (StateM Int)) (TimeT.counted f) f + (fun c => (fun _ => fun n => ⌜n = c⌝ : Assertion _)) + (by intro x c; simp only [TimeT.counted]; mvcgen) + (by intro v c; simp only [Triple]; mvcgen + simp only [TimeT.wp_monadLift, Std.Do.WP.modifyGet_StateT] + intro _ _ h; subst h; rfl) + xs + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/QueryTree.lean b/Cslib/Algorithms/Lean/Query/QueryTree.lean new file mode 100644 index 000000000..431a22d4e --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/QueryTree.lean @@ -0,0 +1,194 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Basic +public import Mathlib.Data.Nat.Log +public import Mathlib.Data.Fintype.Card +public import Mathlib.Data.Fintype.Sum + +/-! # QueryTree: Decision Trees for Query Complexity Lower Bounds + +`QueryTree Q R α` is a free monad specialized to a single query type: queries take +input `Q` and return `R`, with final results of type `α`. It reifies a monad-generic +algorithm's query pattern as an explicit decision tree. + +## Motivation + +The upper-bound framework (`UpperBound`) uses parametricity: a monad-generic algorithm +is specialized to `TimeM` (tick-counting monad) to count queries. This doesn't work +for lower bounds because a noncomputable algorithm could detect `TimeM` and cheat +(e.g., `if h : m = TimeM then else ...`). + +Instead, we specialize the algorithm to `QueryTree Q R`, which *is* the decision tree. +The tree structure directly records every query the algorithm makes, regardless of whether +the algorithm is computable. Combinatorial arguments on the tree (leaf counting, depth +bounds) then yield lower bounds. + +## Design + +`QueryTree Q R` is isomorphic to `FreeM (fun _ => Q)` restricted to operations returning `R`, +but is defined as a dedicated inductive to avoid universe issues with `FreeM`'s existential +`ι` parameter (which would require producing values of arbitrary types during evaluation). + +Note that the graph-theoretic depth of the tree can be strictly larger than +`sup_oracle queriesOn oracle`, because the same query may appear at the root and inside a +subtree, and a single oracle must give consistent answers. + +## OracleQueryTree + +`OracleQueryTree Q R oracle` is a type alias for `QueryTree Q R` that bakes a fixed oracle +into the type. This allows a `WPMonad` instance where `wp t = pure (t.eval oracle)`, +connecting `QueryTree` evaluation to the Hoare-triple framework used by `IsMonadicSort`. + +## Main Definitions + +- `QueryTree Q R α` — the decision tree type +- `QueryTree.ask` — the canonical single-query tree +- `QueryTree.eval` — evaluate with a specific oracle +- `QueryTree.queriesOn` — count queries along an oracle-determined path +- `OracleQueryTree Q R oracle` — type alias with `WPMonad` instance +-/ + +open Std.Do Cslib.Query + +public section + +namespace Cslib.Query + +/-- A decision tree over queries of type `Q → R`, with results of type `α`. + +This is the free monad specialized to a single fixed-type operation, used to reify +monad-generic algorithms as explicit trees for query complexity lower bounds. -/ +inductive QueryTree (Q : Type) (R : Type) (α : Type) where + /-- A completed computation returning value `a`. -/ + | pure (a : α) : QueryTree Q R α + /-- A query node: asks query `q`, then continues based on the response. -/ + | query (q : Q) (cont : R → QueryTree Q R α) : QueryTree Q R α + +namespace QueryTree + +variable {Q R α β γ : Type} + +/-- Lift a single query into the tree. -/ +@[expose] def ask (q : Q) : QueryTree Q R R := .query q .pure + +/-- Monadic bind for query trees. -/ +@[expose] protected def bind : QueryTree Q R α → (α → QueryTree Q R β) → QueryTree Q R β + | .pure a, f => f a + | .query q cont, f => .query q (fun r => (cont r).bind f) + +/-- Functorial map for query trees. -/ +@[expose] protected def map (f : α → β) : QueryTree Q R α → QueryTree Q R β + | .pure a => .pure (f a) + | .query q cont => .query q (fun r => (cont r).map f) + +protected theorem bind_pure : ∀ (x : QueryTree Q R α), x.bind .pure = x + | .pure _ => rfl + | .query _ cont => by simp [QueryTree.bind, QueryTree.bind_pure] + +protected theorem bind_assoc : + ∀ (x : QueryTree Q R α) (f : α → QueryTree Q R β) (g : β → QueryTree Q R γ), + (x.bind f).bind g = x.bind (fun a => (f a).bind g) + | .pure _, _, _ => rfl + | .query _ cont, f, g => by simp [QueryTree.bind, QueryTree.bind_assoc] + +protected theorem bind_pure_comp (f : α → β) : + ∀ (x : QueryTree Q R α), x.bind (.pure ∘ f) = x.map f + | .pure _ => rfl + | .query _ cont => by simp [QueryTree.bind, QueryTree.map, QueryTree.bind_pure_comp] + +protected theorem id_map : ∀ (x : QueryTree Q R α), x.map id = x + | .pure _ => rfl + | .query _ cont => by simp [QueryTree.map, QueryTree.id_map] + +instance : Monad (QueryTree Q R) where + pure := .pure + bind := .bind + +instance : LawfulMonad (QueryTree Q R) := LawfulMonad.mk' + (bind_pure_comp := fun _ _ => rfl) + (id_map := QueryTree.bind_pure) + (pure_bind := fun _ _ => rfl) + (bind_assoc := QueryTree.bind_assoc) + +-- Core operations + +/-- Evaluate a query tree with a specific oracle, returning the final result. -/ +@[expose] def eval (oracle : Q → R) : QueryTree Q R α → α + | .pure a => a + | .query q cont => eval oracle (cont (oracle q)) + +/-- Count the number of queries along the path determined by `oracle`. -/ +@[expose] def queriesOn (oracle : Q → R) : QueryTree Q R α → Nat + | .pure _ => 0 + | .query q cont => 1 + queriesOn oracle (cont (oracle q)) + +-- Simp lemmas + +@[simp] theorem eval_pure' (oracle : Q → R) (a : α) : + (QueryTree.pure a : QueryTree Q R α).eval oracle = a := rfl + +@[simp] theorem eval_query (oracle : Q → R) (q : Q) (cont : R → QueryTree Q R α) : + (QueryTree.query q cont).eval oracle = (cont (oracle q)).eval oracle := rfl + +@[simp] theorem eval_bind (oracle : Q → R) (t : QueryTree Q R α) (f : α → QueryTree Q R β) : + (t.bind f).eval oracle = (f (t.eval oracle)).eval oracle := by + induction t with + | pure a => rfl + | query q cont ih => exact ih (oracle q) + +@[simp] theorem queriesOn_pure' (oracle : Q → R) (a : α) : + (QueryTree.pure a : QueryTree Q R α).queriesOn oracle = 0 := rfl + +@[simp] theorem queriesOn_query (oracle : Q → R) (q : Q) (cont : R → QueryTree Q R α) : + (QueryTree.query q cont).queriesOn oracle = 1 + (cont (oracle q)).queriesOn oracle := rfl + +/-- Queries of `t.bind f` = queries of `t` + queries of the continuation. -/ +@[simp] theorem queriesOn_bind (oracle : Q → R) (t : QueryTree Q R α) (f : α → QueryTree Q R β) : + (t.bind f).queriesOn oracle = + t.queriesOn oracle + (f (t.eval oracle)).queriesOn oracle := by + induction t with + | pure a => simp [QueryTree.bind, queriesOn, eval] + | query q cont ih => simp only [QueryTree.bind, queriesOn_query, eval_query, ih (oracle q)]; omega + +@[simp] theorem queriesOn_ask (oracle : Q → R) (q : Q) : + (ask q : QueryTree Q R R).queriesOn oracle = 1 := rfl + +@[simp] theorem eval_ask (oracle : Q → R) (q : Q) : + (ask q : QueryTree Q R R).eval oracle = oracle q := rfl + + +end QueryTree + +-- ## OracleQueryTree: WPMonad instance for QueryTree with a fixed oracle + +/-- `OracleQueryTree Q R oracle` is `QueryTree Q R` with a fixed oracle baked into the type, + enabling a `WPMonad` instance where `wp t = pure (t.eval oracle)`. -/ +abbrev OracleQueryTree (Q R : Type) (_oracle : Q → R) := QueryTree Q R + +namespace OracleQueryTree + +variable {Q R : Type} {oracle : Q → R} + +instance instWP : WP (OracleQueryTree Q R oracle) .pure where + wp t := pure (t.eval oracle) + +instance instWPMonad : WPMonad (OracleQueryTree Q R oracle) .pure where + wp_pure a := by simp [wp, QueryTree.eval] + wp_bind x f := by + simp only [wp] + congr 1 + exact QueryTree.eval_bind oracle x f + +@[simp] theorem wp_eq (t : OracleQueryTree Q R oracle α) : + instWP.wp t = pure (t.eval oracle) := rfl + +end OracleQueryTree + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean new file mode 100644 index 000000000..d269dd427 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean @@ -0,0 +1,32 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Basic + +namespace Cslib.Query + +/-- Insert `x` into a sorted list, using the monadic comparator `cmp`. -/ +@[expose] public def orderedInsert [Monad m] (cmp : α × α → m Bool) (x : α) : + List α → m (List α) + | [] => pure [x] + | y :: ys => do + let lt ← cmp (x, y) + if lt then + pure (x :: y :: ys) + else + let rest ← orderedInsert cmp x ys + pure (y :: rest) + +/-- Sort a list using insertion sort with the monadic comparator `cmp`. -/ +@[expose] public def insertionSort [Monad m] (cmp : α × α → m Bool) : + List α → m (List α) + | [] => pure [] + | x :: xs => do + let sorted ← insertionSort cmp xs + orderedInsert cmp x sorted + +end Cslib.Query diff --git a/Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean new file mode 100644 index 000000000..832aaf03b --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean @@ -0,0 +1,227 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.UpperBound +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Defs +import Std.Tactic.Do +public import Mathlib.Data.List.Sort + +open Std.Do Cslib.Query TimeT + +set_option mvcgen.warning false + +public section + +namespace Cslib.Query + +/-- `orderedInsert` produces a permutation of `x :: xs`, for any non-failing monadic comparator. -/ +public theorem orderedInsert_perm {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓_ => ⌜True⌝⦄) + (x : α) (xs : List α) : + ⦃⌜True⌝⦄ orderedInsert cmp x xs ⦃⇓result => ⌜List.Perm result (x :: xs)⌝⦄ := by + induction xs with + | nil => + simp only [orderedInsert] + mvcgen + | cons y ys ih => + simp only [orderedInsert] + mvcgen [ih, hcmp] + · mpure_intro; exact (List.Perm.cons _ ‹_›).trans (List.Perm.swap _ _ _) + +/-- Variant of `orderedInsert_perm` with a permutation precondition: + if `sorted` is a permutation of `xs`, + then `orderedInsert` produces a permutation of `x :: xs`. -/ +private theorem orderedInsert_perm' {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓_ => ⌜True⌝⦄) + (x : α) (xs : List α) (sorted : List α) : + ⦃⌜List.Perm sorted xs⌝⦄ orderedInsert cmp x sorted + ⦃⇓ result => ⌜List.Perm result (x :: xs)⌝⦄ := by + simp only [Triple] + apply SPred.pure_elim' + intro hsorted + exact Triple.entails_wp_of_post (orderedInsert_perm cmp hcmp x sorted) (by + simp only [PostCond.entails_noThrow] + intro result + exact SPred.pure_mono fun hperm => hperm.trans (List.Perm.cons x hsorted)) + +/-- `insertionSort` produces a permutation of its input, for any non-failing monadic comparator. -/ +public theorem insertionSort_perm {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓_ => ⌜True⌝⦄) + (xs : List α) : + ⦃⌜True⌝⦄ insertionSort cmp xs ⦃⇓result => ⌜List.Perm result xs⌝⦄ := by + induction xs with + | nil => + simp only [insertionSort] + mvcgen + | cons x xs ih => + simp only [insertionSort] + have := orderedInsert_perm' cmp hcmp x xs + mvcgen [ih, this] + +/-- `orderedInsert` uses at most `xs.length` queries. -/ +public theorem orderedInsert_runsIn (x : α) : + UpperBound (fun cmp xs => orderedInsert cmp x xs) List.length := by + change ∀ (query : (α × α) → TimeM Bool), (∀ a, TimeT.Costs (query a) 1) → + ∀ xs, TimeT.Costs (orderedInsert query x xs) xs.length + intro query hquery xs + induction xs with + | nil => + simp only [orderedInsert] + exact Costs.pure _ + | cons y ys ih => + dsimp only [orderedInsert] + apply Costs.le + · exact Costs.bind (hquery (x, y)) fun lt => + Costs.ite lt (Costs.pure_le _ _) (Costs.map ih) + · simp only [List.length]; omega + +/-- `insertionSort` uses at most `xs.length ^ 2` queries. -/ +public theorem insertionSort_runsIn : + UpperBound (insertionSort (α := α)) (fun xs => xs.length ^ 2) := by + change ∀ (query : (α × α) → TimeM Bool), (∀ a, TimeT.Costs (query a) 1) → + ∀ xs, TimeT.Costs (insertionSort query xs) (xs.length ^ 2) + intro query hquery xs + induction xs with + | nil => + simp only [insertionSort] + exact Costs.pure _ + | cons x xs ih => + dsimp only [insertionSort] + apply Costs.le + · exact Costs.bind_spec ih + (insertionSort_perm query (fun p => SPred.pure_intro trivial) xs) + fun sorted hperm => by + have := orderedInsert_runsIn x query hquery sorted + rwa [List.Perm.length_eq hperm] at this + · simp only [List.length, Nat.pow_two]; have := Nat.mul_succ xs.length xs.length; grind + +/-- The monadic `orderedInsert` at `m := Id` agrees with `List.orderedInsert`. -/ +public theorem id_run_orderedInsert (r : α → α → Prop) [DecidableRel r] (x : α) (xs : List α) : + Id.run (orderedInsert (fun p => pure (decide (r p.1 p.2))) x xs) = + List.orderedInsert r x xs := by + induction xs with + | nil => simp [orderedInsert, Id.run_pure] + | cons y ys ih => + simp only [orderedInsert, Id.run_bind, Id.run_pure, List.orderedInsert_cons] + split <;> simp_all [decide_eq_true_eq] + +/-- The monadic `insertionSort` at `m := Id` agrees with `List.insertionSort`. -/ +public theorem id_run_insertionSort (r : α → α → Prop) [DecidableRel r] (xs : List α) : + Id.run (insertionSort (fun p => pure (decide (r p.1 p.2))) xs) = + List.insertionSort r xs := by + induction xs with + | nil => simp [insertionSort, Id.run_pure] + | cons x xs ih => + simp only [insertionSort, Id.run_bind, List.insertionSort_cons, ih] + exact id_run_orderedInsert r x (List.insertionSort r xs) + +-- Sorted results + +section Sorted + +variable (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + +/-- `orderedInsert` preserves sortedness and produces a permutation, for any monadic comparator + with a pure return reflecting `r`. This combined version is needed because the sortedness + proof in the recursive case requires knowing the result is a permutation of the input. -/ +private theorem orderedInsert_spec {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : PureReturn cmp (fun p => decide (r p.1 p.2))) + (x : α) (xs : List α) (hpw : List.Pairwise r xs) : + ⦃⌜True⌝⦄ orderedInsert cmp x xs + ⦃⇓result => ⌜List.Pairwise r result ∧ List.Perm result (x :: xs)⌝⦄ := by + induction xs with + | nil => + simp only [orderedInsert] + mvcgen [hcmp] + · mpure_intro; exact ⟨List.pairwise_singleton r x, .refl _⟩ + | cons y ys ih => + simp only [orderedInsert] + have ih' := ih hpw.of_cons + have hcmp' : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓b => ⌜b = decide (r p.1 p.2)⌝⦄ := hcmp + mvcgen [ih', hcmp'] + · mpure_intro + have hlt : r x y := by simp_all [decide_eq_true_eq] + exact ⟨List.pairwise_cons.mpr ⟨fun z hz => + match List.mem_cons.mp hz with + | Or.inl h => h ▸ hlt + | Or.inr h => _root_.trans hlt (List.rel_of_pairwise_cons hpw h), hpw⟩, .refl _⟩ + · mpure_intro + rename_i _ _ _ hrest + obtain ⟨hrest_pw, hrest_perm⟩ := hrest + have hlt : ¬ r x y := by simp_all [decide_eq_true_eq] + exact ⟨List.pairwise_cons.mpr ⟨fun z hz => + match List.mem_cons.mp (hrest_perm.mem_iff.mp hz) with + | Or.inl h => h ▸ (Std.Total.total y x).resolve_right hlt + | Or.inr h => List.rel_of_pairwise_cons hpw h, hrest_pw⟩, + (List.Perm.cons y hrest_perm).trans (List.Perm.swap x y ys)⟩ + +/-- `orderedInsert` preserves sortedness, for any monadic comparator with a pure return + reflecting `r`. -/ +public theorem orderedInsert_sorted {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : PureReturn cmp (fun p => decide (r p.1 p.2))) + (x : α) (xs : List α) : + ⦃⌜List.Pairwise r xs⌝⦄ orderedInsert cmp x xs + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := by + simp only [Triple] + apply SPred.pure_elim' + intro hpw + exact Triple.entails_wp_of_post (orderedInsert_spec r cmp hcmp x xs hpw) (by + simp only [PostCond.entails_noThrow]; intro _; exact SPred.pure_mono And.left) + +/-- `insertionSort` produces a sorted list, for any monadic comparator with a pure return + reflecting `r`. -/ +public theorem insertionSort_sorted {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : PureReturn cmp (fun p => decide (r p.1 p.2))) + (xs : List α) : + ⦃⌜True⌝⦄ insertionSort cmp xs + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := by + induction xs with + | nil => + simp only [insertionSort] + mvcgen + · mpure_intro; exact List.Pairwise.nil + | cons x xs ih => + simp only [insertionSort] + have hord := orderedInsert_sorted r cmp hcmp x + mvcgen [ih, hord] + +/-- At `m := Id`, `orderedInsert` preserves sortedness. -/ +public theorem orderedInsert_sorted_id + (x : α) (xs : List α) (h : List.Pairwise r xs) : + List.Pairwise r (Id.run (orderedInsert (fun p => pure (decide (r p.1 p.2))) x xs)) := by + have := orderedInsert_sorted r (m := Id) _ (PureReturn.pure _) x xs + simp only [Triple] at this + exact this h + +/-- At `m := Id`, `insertionSort` produces a sorted list. -/ +public theorem insertionSort_sorted_id (xs : List α) : + List.Pairwise r (Id.run (insertionSort (fun p => pure (decide (r p.1 p.2))) xs)) := by + have := insertionSort_sorted r (m := Id) _ (PureReturn.pure _) xs + simp only [Triple] at this + exact this trivial + +/-- At `m := TimeT n`, `orderedInsert` preserves sortedness (with a pure comparator). -/ +public theorem orderedInsert_sorted_timeT {ps : PostShape} [Monad n] [WPMonad n ps] + (x : α) (xs : List α) : + ⦃⌜List.Pairwise r xs⌝⦄ + orderedInsert (m := TimeT n) (fun p => pure (decide (r p.1 p.2))) x xs + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := + orderedInsert_sorted r _ (PureReturn.pure _) x xs + +/-- At `m := TimeT n`, `insertionSort` produces a sorted list (with a pure comparator). -/ +public theorem insertionSort_sorted_timeT {ps : PostShape} [Monad n] [WPMonad n ps] + (xs : List α) : + ⦃⌜True⌝⦄ + insertionSort (m := TimeT n) (fun p => pure (decide (r p.1 p.2))) xs + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := + insertionSort_sorted r _ (PureReturn.pure _) xs + +end Sorted + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean b/Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean new file mode 100644 index 000000000..0b4ba7219 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean @@ -0,0 +1,194 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Eric Wieser +-/ +module + +public import Cslib.Algorithms.Lean.Query.LowerBound +public import Cslib.Algorithms.Lean.Query.Sort.MonadicSort +public import Cslib.Algorithms.Lean.Query.Sort.QueryTree +public import Mathlib.Data.List.Sort +public import Mathlib.Data.Nat.Factorial.Basic +public import Mathlib.Data.Fintype.Perm +public import Mathlib.Data.List.FinRange +public import Mathlib.SetTheory.Cardinal.Order + +/-! # Comparison Sorting Lower Bound + +`IsMonadicSort.lowerBound_infinite`: any correct comparison sort on an infinite type +has query complexity at least `⌈log₂(n!)⌉` for every input size `n`. + +The proof constructs `n!` distinct total orders on `α` (one per permutation of `n` +embedded elements), shows they produce distinct sorted outputs via `queryTree_correct`, +and applies `QueryTree.exists_queriesOn_ge_clog`. +-/ + +open Std.Do Cslib.Query + +public section + +namespace Cslib.Query + +/-- Specializing `IsMonadicSort` to `OracleQueryTree` gives the `QueryTree` evaluation property + needed for lower bound proofs. For any total order `r`, evaluating the sort's decision tree + with `r`'s comparison function produces a sorted permutation. -/ +theorem IsMonadicSort.queryTree_correct + {sort : ∀ {m : Type → Type} [Monad m], (α × α → m Bool) → List α → m (List α)} + (h : IsMonadicSort sort) + (r : α → α → Prop) [DecidableRel r] [IsTrans α r] [Std.Total r] + (xs : List α) : + let oracle := fun p : α × α => decide (r p.1 p.2) + let result := (sort (fun p => QueryTree.ask p) xs : + QueryTree (α × α) Bool (List α)).eval oracle + result.Perm xs ∧ result.Pairwise r := by + constructor + · -- Permutation: specialize h.perm to OracleQueryTree + have := @h.perm (OracleQueryTree (α × α) Bool (fun p => decide (r p.1 p.2))) + .pure _ OracleQueryTree.instWPMonad + (fun p => QueryTree.ask p) + (fun p => by simp [Triple, OracleQueryTree.wp_eq]) + xs + exact this trivial + · -- Sortedness: specialize h.sorted to OracleQueryTree + have := @h.sorted r _ _ _ + (OracleQueryTree (α × α) Bool (fun p => decide (r p.1 p.2))) + .pure _ OracleQueryTree.instWPMonad + (fun p => QueryTree.ask p) + (fun p => by simp [Triple, OracleQueryTree.wp_eq]) + xs + exact this trivial + +open Classical in +/-- A total order on an infinite type `α` that orders `n` embedded elements + (via `Infinite.natEmbedding`) according to `σ⁻¹`, with embedded elements + preceding all others, and a well-ordering among non-embedded elements. -/ +private noncomputable def infinitePermOrder [Infinite α] (n : Nat) + (σ : Equiv.Perm (Fin n)) (a b : α) : Prop := + if ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a then + if hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b then + σ.symm ha.choose ≤ σ.symm hb.choose + else True + else + if _ : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b then False + else @LE.le α (IsWellOrder.linearOrder (α := α) WellOrderingRel).toLE a b + +private noncomputable instance infinitePermOrder.instDecidableRel [Infinite α] : + DecidableRel (infinitePermOrder (α := α) n σ) := Classical.decRel _ + +private theorem infinitePermOrder.choose_eq [Infinite α] {i : Fin n} + (h : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = (Infinite.natEmbedding α) i.val) : + h.choose = i := by + have := h.choose_spec + have := (Infinite.natEmbedding α).injective this + exact Fin.ext this + +private theorem infinitePermOrder.instIsTrans [Infinite α] : + IsTrans α (infinitePermOrder (α := α) n σ) where + trans a b c hab hbc := by + letI : LinearOrder α := IsWellOrder.linearOrder WellOrderingRel + unfold infinitePermOrder at * + by_cases ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a <;> + by_cases hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b <;> + by_cases hc : ∃ k : Fin n, (Infinite.natEmbedding α) k.val = c <;> + simp_all only [↓reduceDIte, not_exists, exists_false] <;> + exact le_trans ‹_› ‹_› + +private theorem infinitePermOrder.instTotal [Infinite α] : + Std.Total (infinitePermOrder (α := α) n σ) where + total a b := by + letI : LinearOrder α := IsWellOrder.linearOrder WellOrderingRel + unfold infinitePermOrder + by_cases ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a <;> + by_cases hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b <;> + simp_all [le_total] + +private theorem infinitePermOrder.instAntisymm [Infinite α] : + Std.Antisymm (infinitePermOrder (α := α) n σ) where + antisymm a b hab hba := by + letI : LinearOrder α := IsWellOrder.linearOrder WellOrderingRel + simp only [infinitePermOrder] at hab hba + by_cases ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a <;> + by_cases hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b <;> + simp_all only [↓reduceDIte, not_exists] + · calc a = (Infinite.natEmbedding α) ha.choose.val := ha.choose_spec.symm + _ = (Infinite.natEmbedding α) hb.choose.val := by + congr 1; exact congrArg Fin.val (σ.symm.injective (le_antisymm hab hba)) + _ = b := hb.choose_spec + · exact le_antisymm hab hba + +/-- `infinitePermOrder` restricted to embedded values matches `σ⁻¹(·) ≤ σ⁻¹(·)`. -/ +private theorem infinitePermOrder_on_embedded [Infinite α] {i j : Fin n} : + infinitePermOrder (α := α) n σ ((Infinite.natEmbedding α) i.val) + ((Infinite.natEmbedding α) j.val) ↔ σ.symm i ≤ σ.symm j := by + have hi : ∃ k : Fin n, (Infinite.natEmbedding α) k.val = (Infinite.natEmbedding α) i.val := + ⟨i, rfl⟩ + have hj : ∃ k : Fin n, (Infinite.natEmbedding α) k.val = (Infinite.natEmbedding α) j.val := + ⟨j, rfl⟩ + unfold infinitePermOrder + rw [dif_pos hi, dif_pos hj, infinitePermOrder.choose_eq hi, infinitePermOrder.choose_eq hj] + +/-- `map (ι ∘ Fin.val ∘ σ) (finRange n)` is pairwise sorted by `infinitePermOrder n σ`. -/ +private theorem pairwise_map_infinitePermOrder [Infinite α] (σ : Equiv.Perm (Fin n)) : + List.Pairwise (infinitePermOrder (α := α) n σ) + ((List.finRange n).map (fun i => (Infinite.natEmbedding α) (σ i).val)) := by + rw [List.pairwise_map] + exact (List.pairwise_le_finRange n).imp fun hab => by + simp only [infinitePermOrder_on_embedded, Equiv.symm_apply_apply] + exact hab + +/-- `map (ι ∘ Fin.val ∘ σ) (finRange n)` is a permutation of `map (ι ∘ Fin.val) (finRange n)`. -/ +private theorem map_perm_of_infinite_embedding [Infinite α] (σ : Equiv.Perm (Fin n)) : + ((List.finRange n).map (fun i => (Infinite.natEmbedding α) (σ i).val)).Perm + ((List.finRange n).map (fun i => (Infinite.natEmbedding α) i.val)) := by + rw [show (fun i => (Infinite.natEmbedding α) (σ i).val) = + (fun i => (Infinite.natEmbedding α) i.val) ∘ σ from rfl, ← List.map_map] + exact (Equiv.Perm.map_finRange_perm σ).map _ + +/-- Different permutations give different `map (ι ∘ Fin.val ∘ σ) (finRange n)`. -/ +private theorem map_infinite_embedding_injective [Infinite α] : + Function.Injective (fun σ : Equiv.Perm (Fin n) => + (List.finRange n).map (fun i => (Infinite.natEmbedding α) (σ i).val)) := by + intro σ τ h + exact Equiv.ext fun i => by + have := List.map_inj_left.mp h i (List.mem_finRange i) + exact Fin.val_injective ((Infinite.natEmbedding α).injective this) + +/-- Any correct comparison sort on an infinite type has query complexity at least `⌈log₂(n!)⌉` + for every input size `n`. -/ +theorem IsMonadicSort.lowerBound_infinite [Infinite α] + {sort : ∀ {m : Type → Type} [Monad m], (α × α → m Bool) → List α → m (List α)} + (h : IsMonadicSort sort) : + LowerBound sort List.length (fun n => Nat.clog 2 (Nat.factorial n)) := by + intro n + set ι := Infinite.natEmbedding α + refine ⟨(List.finRange n).map (fun i => ι i.val), by simp, ?_⟩ + set xs := (List.finRange n).map (fun i => ι i.val) + set tree := (sort (fun p => QueryTree.ask p) xs : QueryTree (α × α) Bool (List α)) + have hcard : Fintype.card (Equiv.Perm (Fin n)) = Nat.factorial n := by + rw [Fintype.card_perm, Fintype.card_fin] + let e := Fintype.equivFinOfCardEq hcard + let oracles : Fin (Nat.factorial n) → (α × α → Bool) := + fun i p => decide (infinitePermOrder n (e.symm i) p.1 p.2) + have h_inj : Function.Injective (fun i => tree.eval (oracles i)) := by + intro i j h_eval + suffices key : ∀ i, tree.eval (oracles i) = + (List.finRange n).map (fun k => ι ((e.symm i) k).val) by + have h_eval' : tree.eval (oracles i) = tree.eval (oracles j) := h_eval + rw [key, key] at h_eval' + exact e.symm.injective (map_infinite_embedding_injective h_eval') + intro i + set σ := e.symm i + letI := infinitePermOrder.instDecidableRel (α := α) (n := n) (σ := σ) + letI := infinitePermOrder.instIsTrans (α := α) (n := n) (σ := σ) + letI := infinitePermOrder.instTotal (α := α) (n := n) (σ := σ) + have hc := h.queryTree_correct (infinitePermOrder (α := α) n σ) xs + haveI := infinitePermOrder.instAntisymm (α := α) (n := n) (σ := σ) + exact hc.1.trans (map_perm_of_infinite_embedding σ).symm |>.eq_of_pairwise' + hc.2 (pairwise_map_infinitePermOrder σ) + obtain ⟨i, hi⟩ := QueryTree.exists_queriesOn_ge_clog tree oracles (Nat.factorial_pos n) h_inj + exact ⟨oracles i, hi⟩ + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean b/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean new file mode 100644 index 000000000..b0633d4e0 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean @@ -0,0 +1,44 @@ +/- +Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sorrachai Yingchareonthawornhcai, Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Basic + +namespace Cslib.Query + +/-- Monad-generic merge: merges two lists using a monadic comparator `cmp`. + At `m := Id`, this agrees with `List.merge` (see `id_run_merge`). -/ +@[expose] public def merge [Monad m] (cmp : α × α → m Bool) : + List α → List α → m (List α) + | [], ys => pure ys + | xs, [] => pure xs + | x :: xs, y :: ys => do + let le ← cmp (x, y) + if le then + let rest ← merge cmp xs (y :: ys) + pure (x :: rest) + else + let rest ← merge cmp (x :: xs) ys + pure (y :: rest) + +open List.MergeSort.Internal in +/-- Monad-generic merge sort: sorts a list using a monadic comparator `cmp`. + At `m := Id`, this agrees with `List.mergeSort` (see `id_run_mergeSort`). -/ +@[expose] public def mergeSort [Monad m] (cmp : α × α → m Bool) : + List α → m (List α) + | [] => pure [] + | [a] => pure [a] + | a :: b :: xs => + let lr := splitInTwo ⟨a :: b :: xs, rfl⟩ + have : lr.1.1.length < (a :: b :: xs).length := by simp [lr.1.2]; omega + have : lr.2.1.length < (a :: b :: xs).length := by simp [lr.2.2]; omega + do + let left ← mergeSort cmp lr.1.1 + let right ← mergeSort cmp lr.2.1 + merge cmp left right +termination_by xs => xs.length + +end Cslib.Query diff --git a/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean new file mode 100644 index 000000000..5aa72bdd6 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean @@ -0,0 +1,247 @@ +/- +Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sorrachai Yingchareonthawornhcai, Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.UpperBound +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Defs +import Std.Tactic.Do +import Mathlib.Tactic.Linarith +public import Mathlib.Data.List.Sort +public import Mathlib.Data.Nat.Log + +open Std.Do Cslib.Query TimeT + +set_option mvcgen.warning false + +namespace Cslib.Query + +/-- The monadic `merge` at `m := Id` agrees with `List.merge`. -/ +public theorem id_run_merge (le : α → α → Bool) (xs ys : List α) : + Id.run (merge (fun p => pure (le p.1 p.2)) xs ys) = List.merge xs ys le := by + fun_induction merge (m := Id) (fun p => pure (le p.1 p.2)) xs ys with + | case1 => simp [Id.run_pure] + | case2 xs => simp [Id.run_pure, List.merge_right] + | case3 x xs y ys ih_t ih_f => + simp only [Id.run_bind, Id.run_pure] at ih_t ih_f ⊢ + rw [List.cons_merge_cons] + split <;> simp_all + +-- Unlike `id_run_merge` above, we don't prove a conformance lemma +-- `id_run_mergeSort : Id.run (mergeSort ...) = List.mergeSort ...` +-- because Lean's `module` system does not expose equational lemmas +-- (e.g. `List.mergeSort.eq_3`) to downstream modules. +-- Instead, we validate our definition via specifications: +-- `merge_perm`, `mergeSort_perm`, and `mergeSort_runsIn`. + +/-- `merge` produces a permutation of `xs ++ ys`, for any non-failing monadic comparator. -/ +public theorem merge_perm {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓_ => ⌜True⌝⦄) + (xs ys : List α) : + ⦃⌜True⌝⦄ merge cmp xs ys ⦃⇓result => ⌜List.Perm result (xs ++ ys)⌝⦄ := by + fun_induction merge (m := m) cmp xs ys with + | case1 => mvcgen + | case2 xs => + mvcgen + · mpure_intro; simp [List.append_nil] + | case3 x xs y ys ih_t ih_f => + mvcgen [ih_t, ih_f, hcmp] + · mpure_intro; exact List.Perm.cons _ ‹_› + · mpure_intro + exact (List.Perm.cons _ ‹_›).trans + ((List.Perm.swap x y _).trans (List.Perm.cons x (List.perm_middle.symm))) + +/-- `mergeSort` produces a permutation of its input, for any non-failing monadic comparator. -/ +public theorem mergeSort_perm {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓_ => ⌜True⌝⦄) + (xs : List α) : + ⦃⌜True⌝⦄ mergeSort cmp xs ⦃⇓result => ⌜List.Perm result xs⌝⦄ := by + fun_induction mergeSort (m := m) cmp xs with + | case1 => mvcgen + | case2 => mvcgen + | case3 a b xs lr _ _ ih_left ih_right => + have hmerge := merge_perm cmp hcmp + mvcgen [ih_left, ih_right, hmerge] + · apply SPred.pure_mono + intro h_merge + rename_i _ _ _ h_left _ h_right _ + have hsplit := List.MergeSort.Internal.splitInTwo_fst_append_splitInTwo_snd + ⟨a :: b :: xs, rfl⟩ + exact h_merge.trans ((h_left.append h_right).trans (.of_eq hsplit)) + +/-- `merge` uses at most `xs.length + ys.length` queries. -/ +public theorem merge_costs (query : (α × α) → TimeM Bool) (hquery : ∀ a, Costs (query a) 1) + (xs ys : List α) : Costs (merge query xs ys) (xs.length + ys.length) := by + fun_induction merge (m := TimeM) query xs ys with + | case1 => exact Costs.pure_le _ _ + | case2 xs => exact Costs.pure_le _ _ + | case3 x xs y ys ih_t ih_f => + apply Costs.le + · exact Costs.bind (hquery (x, y)) fun le => + Costs.ite_max le (Costs.map ih_t) (Costs.map ih_f) + · simp only [List.length_cons]; omega + +/-- The key arithmetic inequality for the merge sort recurrence: +`⌈n/2⌉ * clog(⌈n/2⌉) + ⌊n/2⌋ * clog(⌊n/2⌋) + n ≤ n * clog(n)`. -/ +private theorem mergeSort_bound (n : ℕ) (hn : 2 ≤ n) : + ((n + 1) / 2) * Nat.clog 2 ((n + 1) / 2) + + (n / 2 * Nat.clog 2 (n / 2) + ((n + 1) / 2 + n / 2)) ≤ + n * Nat.clog 2 n := by + -- clog n = clog ⌈n/2⌉ + 1 + have hclog := Nat.clog_of_one_lt (by omega : (1 : Nat) < 2) hn + have hceil : Nat.clog 2 ((n + 1) / 2) + 1 ≤ Nat.clog 2 n := le_of_eq hclog.symm + have hfloor : Nat.clog 2 (n / 2) + 1 ≤ Nat.clog 2 n := + (Nat.add_le_add_right (Nat.clog_mono_right 2 (by omega)) 1).trans hceil + have hsum : (n + 1) / 2 + n / 2 = n := by omega + have h1 := Nat.mul_le_mul_left ((n + 1) / 2) hceil + have h2 := Nat.mul_le_mul_left (n / 2) hfloor + nlinarith [ + Nat.mul_succ ((n + 1) / 2) (Nat.clog 2 ((n + 1) / 2)), + Nat.mul_succ (n / 2) (Nat.clog 2 (n / 2))] + +/-- `mergeSort` uses at most `xs.length * Nat.clog 2 xs.length` queries. -/ +public theorem mergeSort_runsIn : + UpperBound (mergeSort (α := α)) (fun xs => xs.length * Nat.clog 2 xs.length) := by + change ∀ (query : (α × α) → TimeM Bool), (∀ a, Costs (query a) 1) → + ∀ xs, Costs (mergeSort query xs) (xs.length * Nat.clog 2 xs.length) + intro query hquery xs + fun_induction mergeSort (m := TimeM) query xs with + | case1 => exact Costs.pure _ + | case2 => exact Costs.pure _ + | case3 a b xs lr _ _ ih_left ih_right => + have hperm := mergeSort_perm query (fun p => SPred.pure_intro trivial) + apply Costs.le + · exact Costs.bind_spec ih_left (hperm _) fun left h_perm_left => + Costs.bind_spec ih_right (hperm _) fun right h_perm_right => by + have := merge_costs query hquery left right + rwa [h_perm_left.length_eq, h_perm_right.length_eq] at this + · simp only [lr.1.2, lr.2.2] + exact mergeSort_bound _ (by simp only [List.length_cons]; omega) + +-- Sorted results + +section Sorted + +variable (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + +/-- `merge` preserves sortedness and produces a permutation, for any monadic comparator + with a pure return reflecting `r`. This combined version is needed because the sortedness + proof requires knowing the result is a permutation of the input. -/ +private theorem merge_spec {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : PureReturn cmp (fun p => decide (r p.1 p.2))) + (xs ys : List α) (hxs : List.Pairwise r xs) (hys : List.Pairwise r ys) : + ⦃⌜True⌝⦄ merge cmp xs ys + ⦃⇓result => ⌜List.Pairwise r result ∧ List.Perm result (xs ++ ys)⌝⦄ := by + fun_induction merge (m := m) cmp xs ys with + | case1 => + mvcgen + | case2 xs => + mvcgen + · mpure_intro; exact ⟨hxs, by simp⟩ + | case3 x xs y ys ih_t ih_f => + have ih_t' := ih_t hxs.of_cons hys + have ih_f' := ih_f hxs hys.of_cons + have hcmp' : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓b => ⌜b = decide (r p.1 p.2)⌝⦄ := hcmp + mvcgen [ih_t', ih_f', hcmp'] + · mpure_intro + rename_i _ _ hrest + obtain ⟨hrest_pw, hrest_perm⟩ := hrest + have hlt : r x y := by simp_all [decide_eq_true_eq] + exact ⟨List.pairwise_cons.mpr ⟨fun z hz => + match List.mem_append.mp (hrest_perm.mem_iff.mp hz) with + | Or.inl hmem_xs => List.rel_of_pairwise_cons hxs hmem_xs + | Or.inr hmem_yys => + match List.mem_cons.mp hmem_yys with + | Or.inl h => h ▸ hlt + | Or.inr hmem_ys => _root_.trans hlt (List.rel_of_pairwise_cons hys hmem_ys), + hrest_pw⟩, List.Perm.cons _ hrest_perm⟩ + · mpure_intro + rename_i _ _ hrest + obtain ⟨hrest_pw, hrest_perm⟩ := hrest + have hlt : ¬ r x y := by simp_all [decide_eq_true_eq] + have hyx : r y x := (Std.Total.total y x).resolve_right hlt + exact ⟨List.pairwise_cons.mpr ⟨fun z hz => + match List.mem_cons.mp (hrest_perm.mem_iff.mp hz) with + | Or.inl h => h ▸ hyx + | Or.inr hmem => + match List.mem_append.mp hmem with + | Or.inl hmem_xs => _root_.trans hyx (List.rel_of_pairwise_cons hxs hmem_xs) + | Or.inr hmem_ys => List.rel_of_pairwise_cons hys hmem_ys, + hrest_pw⟩, + (List.Perm.cons _ hrest_perm).trans + ((List.Perm.swap x y _).trans (List.Perm.cons x (List.perm_middle.symm)))⟩ + +/-- `merge` preserves sortedness, for any monadic comparator with a pure return + reflecting `r`. -/ +public theorem merge_sorted {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : PureReturn cmp (fun p => decide (r p.1 p.2))) + (xs ys : List α) : + ⦃⌜List.Pairwise r xs ∧ List.Pairwise r ys⌝⦄ merge cmp xs ys + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := by + simp only [Triple] + apply SPred.pure_elim' + intro ⟨hxs, hys⟩ + exact Triple.entails_wp_of_post (merge_spec r cmp hcmp xs ys hxs hys) (by + simp only [PostCond.entails_noThrow]; intro _; exact SPred.pure_mono And.left) + +/-- `mergeSort` produces a sorted list, for any monadic comparator with a pure return + reflecting `r`. -/ +public theorem mergeSort_sorted {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (hcmp : PureReturn cmp (fun p => decide (r p.1 p.2))) + (xs : List α) : + ⦃⌜True⌝⦄ mergeSort cmp xs + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := by + fun_induction mergeSort (m := m) cmp xs with + | case1 => + mvcgen + · mpure_intro; exact List.Pairwise.nil + | case2 => + mvcgen + · mpure_intro; exact List.pairwise_singleton r _ + | case3 a b xs lr _ _ ih_left ih_right => + apply Triple.bind _ _ ih_left + intro left + simp only [Triple]; apply SPred.pure_elim'; intro hleft + have hmerge : ∀ right, ⦃⌜List.Pairwise r right⌝⦄ merge cmp left right + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := by + intro right; simp only [Triple]; apply SPred.pure_elim'; intro hright + exact Triple.entails_wp_of_post (merge_spec r cmp hcmp left right hleft hright) + (by simp only [PostCond.entails_noThrow]; intro _; exact SPred.pure_mono And.left) + mvcgen [ih_right, hmerge] + +/-- At `m := TimeT n`, `merge` preserves sortedness (with a pure comparator). -/ +public theorem merge_sorted_timeT {ps : PostShape} [Monad n] [WPMonad n ps] + (xs ys : List α) : + ⦃⌜List.Pairwise r xs ∧ List.Pairwise r ys⌝⦄ + merge (m := TimeT n) (fun p => pure (decide (r p.1 p.2))) xs ys + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := + merge_sorted r _ (PureReturn.pure _) xs ys + +/-- At `m := TimeT n`, `mergeSort` produces a sorted list (with a pure comparator). -/ +public theorem mergeSort_sorted_timeT {ps : PostShape} [Monad n] [WPMonad n ps] + (xs : List α) : + ⦃⌜True⌝⦄ + mergeSort (m := TimeT n) (fun p => pure (decide (r p.1 p.2))) xs + ⦃⇓result => ⌜List.Pairwise r result⌝⦄ := + mergeSort_sorted r _ (PureReturn.pure _) xs + +/-- At `m := Id`, `merge` preserves sortedness. -/ +public theorem merge_sorted_id + (xs ys : List α) (hxs : List.Pairwise r xs) (hys : List.Pairwise r ys) : + List.Pairwise r (Id.run (merge (fun p => pure (decide (r p.1 p.2))) xs ys)) := by + have := merge_sorted r (m := Id) _ (PureReturn.pure _) xs ys + simp only [Triple] at this + exact this ⟨hxs, hys⟩ + +/-- At `m := Id`, `mergeSort` produces a sorted list. -/ +public theorem mergeSort_sorted_id (xs : List α) : + List.Pairwise r (Id.run (mergeSort (fun p => pure (decide (r p.1 p.2))) xs)) := by + have := mergeSort_sorted r (m := Id) _ (PureReturn.pure _) xs + simp only [Triple] at this + exact this trivial + +end Sorted + +end Cslib.Query diff --git a/Cslib/Algorithms/Lean/Query/Sort/MonadicSort.lean b/Cslib/Algorithms/Lean/Query/Sort/MonadicSort.lean new file mode 100644 index 000000000..fb3e04449 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/MonadicSort.lean @@ -0,0 +1,99 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Lemmas +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Lemmas + +/-! # Monadic Comparison Sort Specification + +`IsMonadicSort sort` asserts that a monad-generic function `sort` is a correct comparison sort: +for any non-failing comparator it produces a permutation, and for any comparator reflecting a +total order it produces a sorted list. + +Because the function is polymorphic in the monad `m`, a computable `def` cannot observe any +instrumentation layered into `m` (see `UpperBound` for details). This allows us to combine +`IsMonadicSort` with `UpperBound` to state complexity claims about comparison sorts. + +## Polynomial-time sorting example + +`PolyNatSort k` packages an algorithm with proofs that it is a correct comparison sort using +at most `n ^ k` comparisons. We exhibit `insertionSort` as a `PolyNatSort 2`, +and assert that the framework is non-trivial in the sense that no adversary +can computably inhabit `PolyNatSort 1`. +-/ + +open Std.Do Cslib.Query + +public section + +namespace Cslib.Query + +-- ## IsMonadicSort: correctness specification for comparison sorts + +/-- A monad-generic function is a monadic comparison sort if it always produces a permutation + of its input (for any non-failing comparator), and always produces a sorted list (for any + comparator reflecting a total order `r`). + + The universal quantification over `r` in the `sorted` field is essential: it prevents the + algorithm from using any built-in ordering on the element type (e.g., `Nat.ble`), forcing + it to learn the ordering exclusively through comparator queries. -/ +structure IsMonadicSort + (sort : ∀ {m : Type → Type} [Monad m], (α × α → m Bool) → List α → m (List α)) : Prop where + /-- The sort produces a permutation of its input, for any non-failing monadic comparator. -/ + perm : ∀ {m : Type → Type} {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (_ : ∀ p, ⦃⌜True⌝⦄ cmp p ⦃⇓_ => ⌜True⌝⦄) + (xs : List α), + ⦃⌜True⌝⦄ sort cmp xs ⦃⇓result => ⌜result.Perm xs⌝⦄ + /-- The sort produces a sorted list, for any comparator with a pure return reflecting a + total order `r`. -/ + sorted : ∀ (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + {m : Type → Type} {ps : PostShape} [Monad m] [WPMonad m ps] + (cmp : α × α → m Bool) (_ : PureReturn cmp (fun p => decide (r p.1 p.2))) + (xs : List α), + ⦃⌜True⌝⦄ sort cmp xs ⦃⇓result => ⌜result.Pairwise r⌝⦄ + +/-- `insertionSort` is a monadic comparison sort. -/ +public theorem insertionSort_isMonadicSort : IsMonadicSort (insertionSort (α := α)) where + perm := insertionSort_perm + sorted r := insertionSort_sorted r + +-- ## Example: polynomial-time comparison sorting predicates + +/-- A computable inhabitant of this type would demonstrate that lists of natural numbers + can be sorted using at most `n^k` comparison queries. + + The `IsMonadicSort` component ensures the algorithm is a genuine comparison sort: + it must produce a sorted permutation for *every* total order on `Nat` (not just `≤`), + so it cannot bypass the comparator by using `Nat`'s built-in ordering. + + The `UpperBound` component, combined with monad parametricity and computability, + ensures the algorithm makes at most `xs.length ^ k` comparator queries. -/ +structure PolyNatSort (k : Nat) where + /-- The sorting algorithm, generic over the monad. -/ + sort : ∀ {m : Type → Type} [Monad m], (Nat × Nat → m Bool) → List Nat → m (List Nat) + /-- The algorithm is a correct comparison sort. -/ + isSort : IsMonadicSort sort + /-- The algorithm uses at most `xs.length ^ k` comparisons. -/ + runsIn : UpperBound sort (fun xs => xs.length ^ k) + +/-- `insertionSort` is a correct quadratic comparison sort. -/ +public def insertionSort_quadraticNatSort : PolyNatSort 2 where + sort := insertionSort + isSort := insertionSort_isMonadicSort + runsIn := insertionSort_runsIn + +/-! +This is a non-trivial claim: `PolyNatSort 1` can not be computably inhabited! +Any approach to query complexity upper bounds should allow us +1. to write true upper bound statements, and inhabit them, and +2. to write false upper bound statements, + without an adversary being able to (computably) inhabit them. +-/ + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean b/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean new file mode 100644 index 000000000..08b21174c --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean @@ -0,0 +1,107 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.QueryTree + +/-! # Lower-Bound Lemma for Binary Query Trees + +`QueryTree.exists_queriesOn_ge_clog`: if `n` oracles produce `n` distinct evaluation results +from a binary query tree, then one of those oracles makes at least `⌈log₂ n⌉` queries. + +The proof uses the adversarial/partition argument: at each query node, the `n` oracles split by +their answer; the larger group (size ≥ ⌈n/2⌉) still produces distinct results in the +corresponding subtree, and the induction proceeds there. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query.QueryTree + +/-- In a partition of `Fin n` by a Boolean function into two groups, the larger group has + size at least `⌈n/2⌉ = (n + 1) / 2`. -/ +private theorem exists_large_fiber (p : Fin n → Bool) : + ∃ b : Bool, (n + 1) / 2 ≤ Fintype.card {i : Fin n // p i = b} := by + by_contra h + push_neg at h + have key : Fintype.card {i : Fin n // p i = true} + + Fintype.card {i : Fin n // p i = false} = n := by + have := Fintype.card_congr (α := {i : Fin n // p i = true} ⊕ {i : Fin n // p i = false}) + (β := Fin n) + { toFun := fun s => match s with | .inl ⟨i, _⟩ => i | .inr ⟨i, _⟩ => i + invFun := fun i => if h : p i = true then .inl ⟨i, h⟩ + else .inr ⟨i, Bool.eq_false_iff.mpr h⟩ + left_inv := fun s => by rcases s with ⟨i, hi⟩ | ⟨i, hi⟩ <;> simp_all + right_inv := fun i => by simp only; split_ifs <;> rfl } + rw [Fintype.card_sum, Fintype.card_fin] at this; exact this + have ht := h true + have hf := h false + omega + +/-- If `n` oracles produce `n` distinct evaluation results from a binary query tree, + then one of those oracles makes at least `⌈log₂ n⌉` queries. + + This is the core combinatorial lemma for query complexity lower bounds. + The proof uses the adversarial/partition argument: at each query node, the `n` oracles + split by their answer to the query; the larger group (size ≥ ⌈n/2⌉) still produces + distinct results in the corresponding subtree, and the induction proceeds there. -/ +theorem exists_queriesOn_ge_clog + (t : QueryTree Q Bool α) (oracles : Fin n → (Q → Bool)) + (hn : 0 < n) + (h_inj : Function.Injective (fun i => t.eval (oracles i))) : + ∃ i : Fin n, t.queriesOn (oracles i) ≥ Nat.clog 2 n := by + induction t generalizing n with + | pure a => + -- All oracles evaluate to the same `a`, so injectivity forces n ≤ 1 + have : n ≤ 1 := by + by_contra h + push_neg at h + exact absurd (h_inj (show a = a from rfl)) + (show (⟨0, by omega⟩ : Fin n) ≠ ⟨1, by omega⟩ by simp [Fin.ext_iff]) + exact ⟨⟨0, hn⟩, by simp [queriesOn, Nat.clog_of_right_le_one this]⟩ + | query q cont ih => + -- Partition oracles by their answer to query q + obtain ⟨b, hm⟩ := exists_large_fiber (fun i => oracles i q) + set m := Fintype.card {i : Fin n // oracles i q = b} + -- Re-index the larger fiber as Fin m + let e := Fintype.equivFin {i : Fin n // oracles i q = b} + let oracles' : Fin m → (Q → Bool) := fun j => oracles (e.symm j).val + -- Injectivity transfers to the subtree + have h_inj' : Function.Injective (fun j => (cont b).eval (oracles' j)) := by + intro j₁ j₂ h + have hj₁ := (e.symm j₁).property + have hj₂ := (e.symm j₂).property + -- eval through query q cont with oracle answering b goes to cont b + have he : ∀ j, (QueryTree.query q cont).eval (oracles (e.symm j).val) = + (cont b).eval (oracles (e.symm j).val) := by + intro j; simp [eval, (e.symm j).property] + have := h_inj (show (QueryTree.query q cont).eval (oracles (e.symm j₁).val) = + (QueryTree.query q cont).eval (oracles (e.symm j₂).val) by rw [he, he]; exact h) + exact e.symm.injective (Subtype.val_injective this ▸ rfl) + -- Apply IH to the subtree + have hm_pos : 0 < m := by omega + obtain ⟨j, hj⟩ := ih b oracles' hm_pos h_inj' + -- Lift back to Fin n and add 1 for the root query + refine ⟨(e.symm j).val, ?_⟩ + have hqb : oracles (e.symm j).val q = b := (e.symm j).property + simp only [queriesOn_query, hqb] + -- 1 + queriesOn on subtree ≥ 1 + clog 2 m ≥ clog 2 n + calc Nat.clog 2 n + ≤ 1 + Nat.clog 2 m := by + by_cases h1 : n ≤ 1 + · simp [Nat.clog_of_right_le_one h1] + · push_neg at h1 + rw [Nat.clog_of_two_le (by omega) (by omega)] + have := Nat.clog_mono_right 2 (show (n + 2 - 1) / 2 ≤ m by omega) + omega + _ ≤ 1 + (cont b).queriesOn (oracles' j) := by omega + _ = 1 + (cont b).queriesOn (oracles (e.symm j).val) := rfl + +end Cslib.Query.QueryTree + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/UpperBound.lean b/Cslib/Algorithms/Lean/Query/UpperBound.lean new file mode 100644 index 000000000..9810e97af --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/UpperBound.lean @@ -0,0 +1,74 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Basic + +/-! # UpperBound: Query Complexity Bounds via Monad Parametricity + +`UpperBound f bound` (and its generalization `UpperBoundT`) assert that a monad-generic algorithm `f` +makes at most `bound x` queries on input `x`. + +## The parametricity argument + +An algorithm like +``` +def insertionSort [Monad m] (cmp : α × α → m Bool) : List α → m (List α) := ... +``` +is written generically over the monad `m`. To measure its query complexity, we specialize +`m` to `TimeM` (or `TimeT n` for algorithms with additional effects) and provide a +`cmp` implementation that calls `tick` once per invocation. + +Because `insertionSort` is parametric in `m`, it **cannot observe the tick instrumentation**. +It must call `cmp` the same number of times regardless of which monad it runs in. +Therefore any upper bound proved via `TimeM` is a true bound on query count in all monads. + +`UpperBoundT` handles algorithms that use a base monad `n` for their own effects +(e.g., `StateM` for accumulation). The function must be generic over monads extending `n` +via `MonadLiftT`, and we specialize to `TimeT n` which layers tick-counting on top of `n`. +The same parametricity argument applies: the algorithm cannot distinguish `TimeT n` from +any other monad that lifts `n`. + +## Computability caveat + +The parametricity argument is only valid for **computable** algorithms. A `noncomputable` +definition could use `Classical.choice` to inspect `m` or the query function and subvert +the instrumentation. Since Lean's type theory does not enforce parametricity, the soundness +guarantee is informal: `UpperBound` and `UpperBoundT` theorems should only be proved about computable +algorithms. This framework is designed for proving upper bounds on query complexity, not lower +bounds. +-/ + +open Std.Do Cslib.Query + +public section + +namespace Cslib.Query + +/-- `UpperBoundT n f bound` asserts that when the monad-generic function `f` + is specialized to `TimeT n`, with any query that calls `tick` at most once per invocation, + the total number of ticks is bounded by `bound x`. + + The function `f` is generic over monads that extend `n` via `MonadLift`, + ensuring it cannot observe the tick instrumentation. -/ +@[expose] def UpperBoundT {n : Type → Type} {ps : PostShape} [Monad n] [WP n ps] + (f : ∀ {m : Type → Type} [Monad m] [MonadLiftT n m], (α → m β) → γ → m δ) + (bound : γ → Nat) : Prop := + ∀ (query : α → TimeT n β), (∀ a, TimeT.Costs (query a) 1) → + ∀ x, TimeT.Costs (f query x) (bound x) + +/-- `UpperBound f bound` asserts that when the monad-generic function `f` is specialized to `TimeM`, + with any query that calls `tick` at most once per invocation, + the total number of ticks is bounded by `bound x`. -/ +@[expose] def UpperBound + (f : ∀ {m : Type → Type} [Monad m], (α → m β) → γ → m δ) + (bound : γ → Nat) : Prop := + ∀ (query : α → TimeM β), (∀ a, TimeT.Costs (query a) 1) → + ∀ x, TimeT.Costs (f query x) (bound x) + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/TimeM.lean b/Cslib/Algorithms/Lean/TimeM.lean deleted file mode 100644 index c4a60f063..000000000 --- a/Cslib/Algorithms/Lean/TimeM.lean +++ /dev/null @@ -1,142 +0,0 @@ -/- -Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Sorrachai Yingchareonthawornhcai, Eric Wieser --/ - -module - -public import Cslib.Init -public import Mathlib.Algebra.Group.Defs - - -@[expose] public section - -/-! - -# TimeM: Time Complexity Monad -`TimeM T α` represents a computation that produces a value of type `α` and tracks its time cost. - -`T` is usually instantiated as `ℕ` to count operations, but can be instantiated as `ℝ` to count -actual wall time, or as more complex types in order to model more general costs. - -## Design Principles -1. **Pure inputs, timed outputs**: Functions take plain values and return `TimeM` results -2. **Time annotations are trusted**: The `time` field is NOT verified against actual cost. - You must manually ensure annotations match the algorithm's complexity in your cost model. -3. **Separation of concerns**: Prove correctness properties on `.ret`, prove complexity on `.time` - -## Cost Model -**Document your cost model explicitly** Decide and be consistent about: -- **What costs 1 unit?** (comparison, arithmetic operation, etc.) -- **What is free?** (variable lookup, pattern matching, etc.) -- **Recursive calls:** Do you charge for the call itself? - -## Notation -- **`✓`** : A tick of time, see `tick`. -- **`⟪tm⟫`** : Extract the pure value from a `TimeM` computation (notation for `tm.ret`) - -## References - -See [Danielsson2008] for the discussion. --/ -namespace Cslib.Algorithms.Lean - -/-- A monad for tracking time complexity of computations. -`TimeM T α` represents a computation that returns a value of type `α` -and accumulates a time cost (represented as a type `T`, typically `ℕ`). -/ -@[ext] -structure TimeM (T : Type*) (α : Type*) where - /-- The return value of the computation -/ - ret : α - /-- The accumulated time cost of the computation -/ - time : T - -namespace TimeM - -/-- Lifts a pure value into a `TimeM` computation with zero time cost. - -Prefer to use `pure` instead of `TimeM.pure`. -/ -protected def pure [Zero T] {α} (a : α) : TimeM T α := - ⟨a, 0⟩ - -instance [Zero T] : Pure (TimeM T) where - pure := TimeM.pure - -/-- Sequentially composes two `TimeM` computations, summing their time costs. - -Prefer to use the `>>=` notation. -/ -protected def bind {α β} [Add T] (m : TimeM T α) (f : α → TimeM T β) : TimeM T β := - let r := f m.ret - ⟨r.ret, m.time + r.time⟩ - -instance [Add T] : Bind (TimeM T) where - bind := TimeM.bind - -instance : Functor (TimeM T) where - map f x := ⟨f x.ret, x.time⟩ - -instance [Add T] : Seq (TimeM T) where - seq f x := ⟨f.ret (x ()).ret, f.time + (x ()).time⟩ - -instance [Add T] : SeqLeft (TimeM T) where - seqLeft x y := ⟨x.ret, x.time + (y ()).time⟩ - -instance [Add T] : SeqRight (TimeM T) where - seqRight x y := ⟨(y ()).ret, x.time + (y ()).time⟩ - -instance [AddZero T] : Monad (TimeM T) where - pure := Pure.pure - bind := Bind.bind - map := Functor.map - seq := Seq.seq - seqLeft := SeqLeft.seqLeft - seqRight := SeqRight.seqRight - -@[simp, grind =] theorem ret_pure {α} [Zero T] (a : α) : (pure a : TimeM T α).ret = a := rfl -@[simp, grind =] theorem ret_bind {α β} [Add T] (m : TimeM T α) (f : α → TimeM T β) : - (m >>= f).ret = (f m.ret).ret := rfl -@[simp, grind =] theorem ret_map {α β} (f : α → β) (x : TimeM T α) : (f <$> x).ret = f x.ret := rfl -@[simp] theorem ret_seqRight {α} (x : TimeM T α) (y : Unit → TimeM T β) [Add T] : - (SeqRight.seqRight x y).ret = (y ()).ret := rfl -@[simp] theorem ret_seqLeft {α} [Add T] (x : TimeM T α) (y : Unit → TimeM T β) : - (SeqLeft.seqLeft x y).ret = x.ret := rfl -@[simp] theorem ret_seq {α β} [Add T] (f : TimeM T (α → β)) (x : Unit → TimeM T α) : - (Seq.seq f x).ret = f.ret (x ()).ret := rfl - -@[simp, grind =] theorem time_bind {α β} [Add T] (m : TimeM T α) (f : α → TimeM T β) : - (m >>= f).time = m.time + (f m.ret).time := rfl -@[simp, grind =] theorem time_pure {α} [Zero T] (a : α) : (pure a : TimeM T α).time = 0 := rfl -@[simp, grind =] theorem time_map {α β} (f : α → β) (x : TimeM T α) : (f <$> x).time = x.time := rfl -@[simp] theorem time_seqRight {α} [Add T] (x : TimeM T α) (y : Unit → TimeM T β) : - (SeqRight.seqRight x y).time = x.time + (y ()).time := rfl -@[simp] theorem time_seqLeft {α} [Add T] (x : TimeM T α) (y : Unit → TimeM T β) : - (SeqLeft.seqLeft x y).time = x.time + (y ()).time := rfl -@[simp] theorem time_seq {α β} [Add T] (f : TimeM T (α → β)) (x : Unit → TimeM T α) : - (Seq.seq f x).time = f.time + (x ()).time := rfl - -/-- `TimeM` is lawful so long as addition in the cost is associative and absorbs zero. -/ -instance [AddMonoid T] : LawfulMonad (TimeM T) := .mk' - (id_map := fun x => rfl) - (pure_bind := fun _ _ => by ext <;> simp) - (bind_assoc := fun _ _ _ => by ext <;> simp [add_assoc]) - (seqLeft_eq := fun _ _ => by ext <;> simp) - (bind_pure_comp := fun _ _ => by ext <;> simp) - -/-- Creates a `TimeM` computation with a time cost. -/ -def tick (c : T) : TimeM T PUnit := ⟨.unit, c⟩ - -@[simp, grind =] theorem ret_tick (c : T) : (tick c).ret = () := rfl -@[simp, grind =] theorem time_tick (c : T) : (tick c).time = c := rfl - -/-- `✓[c] x` adds `c` ticks, then executes `x`. -/ -macro "✓[" c:term "]" body:doElem : doElem => `(doElem| do TimeM.tick $c; $body:doElem) - -/-- `✓ x` is a shorthand for `✓[1] x`, which adds one tick and executes `x`. -/ -macro "✓" body:doElem : doElem => `(doElem| ✓[1] $body) - -/-- Notation for extracting the return value from a `TimeM` computation: `⟪tm⟫` -/ -scoped notation:max "⟪" tm "⟫" => (TimeM.ret tm) - -end TimeM -end Cslib.Algorithms.Lean diff --git a/lakefile.toml b/lakefile.toml index 81ea3e7d8..fa7d59e87 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -20,6 +20,7 @@ scope = "leanprover-community" [[lean_lib]] name = "Cslib" globs = ["Cslib.*"] +leanOptions.experimental.module = true [[lean_lib]] name = "CslibTests" From 0f3b27c3192112503022df895986aa52f714766b Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Tue, 3 Mar 2026 05:59:12 +0000 Subject: [PATCH 2/4] feat(Query): prove sorting lower bound and decision tree depth lemma MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the `Fin n` + `Fintype.equivFin` re-indexing proof of `exists_queriesOn_ge_clog` with a Finset-based auxiliary (`exists_mem_queriesOn_ge_clog`) that works over an arbitrary `Finset ι`. This eliminates the `exists_large_fiber` helper and all `.val`/`.property`/`Subtype.val_injective` bookkeeping, since oracles never change between recursive calls. Co-Authored-By: Claude Opus 4.6 --- .../Algorithms/Lean/Query/Sort/QueryTree.lean | 120 ++++++++---------- 1 file changed, 55 insertions(+), 65 deletions(-) diff --git a/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean b/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean index 08b21174c..05e69afc4 100644 --- a/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean +++ b/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean @@ -6,6 +6,7 @@ Authors: Kim Morrison module public import Cslib.Algorithms.Lean.Query.QueryTree +public import Mathlib.Data.Set.Function /-! # Lower-Bound Lemma for Binary Query Trees @@ -15,6 +16,9 @@ from a binary query tree, then one of those oracles makes at least `⌈log₂ n The proof uses the adversarial/partition argument: at each query node, the `n` oracles split by their answer; the larger group (size ≥ ⌈n/2⌉) still produces distinct results in the corresponding subtree, and the induction proceeds there. + +The proof works over an arbitrary `Finset ι` of oracle indices (avoiding re-indexing via +`Fintype.equivFin`), then derives the `Fin n` version as a corollary. -/ open Cslib.Query @@ -23,25 +27,53 @@ public section namespace Cslib.Query.QueryTree -/-- In a partition of `Fin n` by a Boolean function into two groups, the larger group has - size at least `⌈n/2⌉ = (n + 1) / 2`. -/ -private theorem exists_large_fiber (p : Fin n → Bool) : - ∃ b : Bool, (n + 1) / 2 ≤ Fintype.card {i : Fin n // p i = b} := by - by_contra h - push_neg at h - have key : Fintype.card {i : Fin n // p i = true} + - Fintype.card {i : Fin n // p i = false} = n := by - have := Fintype.card_congr (α := {i : Fin n // p i = true} ⊕ {i : Fin n // p i = false}) - (β := Fin n) - { toFun := fun s => match s with | .inl ⟨i, _⟩ => i | .inr ⟨i, _⟩ => i - invFun := fun i => if h : p i = true then .inl ⟨i, h⟩ - else .inr ⟨i, Bool.eq_false_iff.mpr h⟩ - left_inv := fun s => by rcases s with ⟨i, hi⟩ | ⟨i, hi⟩ <;> simp_all - right_inv := fun i => by simp only; split_ifs <;> rfl } - rw [Fintype.card_sum, Fintype.card_fin] at this; exact this - have ht := h true - have hf := h false - omega +/-- Finset-based version: if the oracles indexed by `S` produce `|S|`-many distinct evaluation + results, then some oracle in `S` makes at least `⌈log₂ |S|⌉` queries. -/ +private theorem exists_mem_queriesOn_ge_clog + {ι : Type} (t : QueryTree Q Bool α) (S : Finset ι) (hS : S.Nonempty) + (oracles : ι → (Q → Bool)) + (h_inj : Set.InjOn (fun i => t.eval (oracles i)) ↑S) : + ∃ i ∈ S, t.queriesOn (oracles i) ≥ Nat.clog 2 S.card := by + induction t generalizing ι S with + | pure a => + obtain ⟨i, hi⟩ := hS + exact ⟨i, hi, by simp [queriesOn, Nat.clog_of_right_le_one + (Finset.card_le_one.mpr fun _ ha _ hb => h_inj ha hb rfl)]⟩ + | query q cont ih => + by_cases hle : S.card ≤ 1 + · obtain ⟨i, hi⟩ := hS; exact ⟨i, hi, by simp [Nat.clog_of_right_le_one hle]⟩ + · push_neg at hle + -- Find b : Bool such that S.filter (oracles · q = b) has ≥ ⌈|S|/2⌉ elements + have ⟨b, hb⟩ : ∃ b : Bool, + (S.card + 1) / 2 ≤ (S.filter (fun i => oracles i q = b)).card := by + by_contra h; push_neg at h + have ht := h true; have hf := h false + have hpart : (S.filter (fun i => oracles i q = true)).card + + (S.filter (fun i => oracles i q = false)).card = S.card := by + have := Finset.card_filter_add_card_filter_not (s := S) (fun i => oracles i q = true) + rwa [show S.filter (fun i => ¬(oracles i q = true)) = + S.filter (fun i => oracles i q = false) from + Finset.filter_congr fun i _ => by cases oracles i q <;> simp] at this + omega + set S' := S.filter (fun i => oracles i q = b) + have hS' : S'.Nonempty := Finset.card_pos.mp (by omega) + -- Restricted injectivity: eval through query q cont agrees with cont b on S' + have h_inj' : Set.InjOn (fun i => (cont b).eval (oracles i)) ↑S' := by + intro i hi j hj heq + have him := Finset.mem_coe.mp hi |> Finset.mem_filter.mp + have hjm := Finset.mem_coe.mp hj |> Finset.mem_filter.mp + exact h_inj (Finset.mem_coe.mpr him.1) (Finset.mem_coe.mpr hjm.1) + (by simp [eval, him.2, hjm.2, heq]) + obtain ⟨i, hi, hiq⟩ := ih b S' hS' oracles h_inj' + have him := Finset.mem_filter.mp hi + refine ⟨i, him.1, ?_⟩ + simp only [queriesOn_query, him.2] + calc Nat.clog 2 S.card + ≤ 1 + Nat.clog 2 S'.card := by + rw [Nat.clog_of_two_le (by omega) (by omega)] + have := Nat.clog_mono_right 2 (show (S.card + 2 - 1) / 2 ≤ S'.card by omega) + omega + _ ≤ 1 + (cont b).queriesOn (oracles i) := by omega /-- If `n` oracles produce `n` distinct evaluation results from a binary query tree, then one of those oracles makes at least `⌈log₂ n⌉` queries. @@ -55,52 +87,10 @@ theorem exists_queriesOn_ge_clog (hn : 0 < n) (h_inj : Function.Injective (fun i => t.eval (oracles i))) : ∃ i : Fin n, t.queriesOn (oracles i) ≥ Nat.clog 2 n := by - induction t generalizing n with - | pure a => - -- All oracles evaluate to the same `a`, so injectivity forces n ≤ 1 - have : n ≤ 1 := by - by_contra h - push_neg at h - exact absurd (h_inj (show a = a from rfl)) - (show (⟨0, by omega⟩ : Fin n) ≠ ⟨1, by omega⟩ by simp [Fin.ext_iff]) - exact ⟨⟨0, hn⟩, by simp [queriesOn, Nat.clog_of_right_le_one this]⟩ - | query q cont ih => - -- Partition oracles by their answer to query q - obtain ⟨b, hm⟩ := exists_large_fiber (fun i => oracles i q) - set m := Fintype.card {i : Fin n // oracles i q = b} - -- Re-index the larger fiber as Fin m - let e := Fintype.equivFin {i : Fin n // oracles i q = b} - let oracles' : Fin m → (Q → Bool) := fun j => oracles (e.symm j).val - -- Injectivity transfers to the subtree - have h_inj' : Function.Injective (fun j => (cont b).eval (oracles' j)) := by - intro j₁ j₂ h - have hj₁ := (e.symm j₁).property - have hj₂ := (e.symm j₂).property - -- eval through query q cont with oracle answering b goes to cont b - have he : ∀ j, (QueryTree.query q cont).eval (oracles (e.symm j).val) = - (cont b).eval (oracles (e.symm j).val) := by - intro j; simp [eval, (e.symm j).property] - have := h_inj (show (QueryTree.query q cont).eval (oracles (e.symm j₁).val) = - (QueryTree.query q cont).eval (oracles (e.symm j₂).val) by rw [he, he]; exact h) - exact e.symm.injective (Subtype.val_injective this ▸ rfl) - -- Apply IH to the subtree - have hm_pos : 0 < m := by omega - obtain ⟨j, hj⟩ := ih b oracles' hm_pos h_inj' - -- Lift back to Fin n and add 1 for the root query - refine ⟨(e.symm j).val, ?_⟩ - have hqb : oracles (e.symm j).val q = b := (e.symm j).property - simp only [queriesOn_query, hqb] - -- 1 + queriesOn on subtree ≥ 1 + clog 2 m ≥ clog 2 n - calc Nat.clog 2 n - ≤ 1 + Nat.clog 2 m := by - by_cases h1 : n ≤ 1 - · simp [Nat.clog_of_right_le_one h1] - · push_neg at h1 - rw [Nat.clog_of_two_le (by omega) (by omega)] - have := Nat.clog_mono_right 2 (show (n + 2 - 1) / 2 ≤ m by omega) - omega - _ ≤ 1 + (cont b).queriesOn (oracles' j) := by omega - _ = 1 + (cont b).queriesOn (oracles (e.symm j).val) := rfl + have ⟨i, _, hi⟩ := exists_mem_queriesOn_ge_clog t Finset.univ + (Finset.univ_nonempty_iff.mpr ⟨⟨0, hn⟩⟩) oracles (h_inj.injOn) + rw [Finset.card_univ, Fintype.card_fin] at hi + exact ⟨i, hi⟩ end Cslib.Query.QueryTree From 871e515bd9f5f27dbbcf61e483f6d373a5fc3a6f Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Tue, 3 Mar 2026 06:04:26 +0000 Subject: [PATCH 3/4] fix(Query): suppress unusedArguments lint for OracleQueryTree The `_oracle` parameter is structurally unused in the abbrev body but is essential for the WP/WPMonad typeclass instances that depend on it. Co-Authored-By: Claude Opus 4.6 --- Cslib/Algorithms/Lean/Query/QueryTree.lean | 1 + 1 file changed, 1 insertion(+) diff --git a/Cslib/Algorithms/Lean/Query/QueryTree.lean b/Cslib/Algorithms/Lean/Query/QueryTree.lean index 431a22d4e..2c3cb39fc 100644 --- a/Cslib/Algorithms/Lean/Query/QueryTree.lean +++ b/Cslib/Algorithms/Lean/Query/QueryTree.lean @@ -168,6 +168,7 @@ end QueryTree /-- `OracleQueryTree Q R oracle` is `QueryTree Q R` with a fixed oracle baked into the type, enabling a `WPMonad` instance where `wp t = pure (t.eval oracle)`. -/ +@[nolint unusedArguments] abbrev OracleQueryTree (Q R : Type) (_oracle : Q → R) := QueryTree Q R namespace OracleQueryTree From 97b082f35e62f123968ccdffe0fa8cc7d1c26e2e Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Tue, 3 Mar 2026 06:15:07 +0000 Subject: [PATCH 4/4] fix(Query): correct author name spelling Co-Authored-By: Claude Opus 4.6 --- Cslib/Algorithms/Lean/Query/Basic.lean | 4 ++-- Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean | 4 ++-- Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Cslib/Algorithms/Lean/Query/Basic.lean b/Cslib/Algorithms/Lean/Query/Basic.lean index ef36586c8..6c22d8f63 100644 --- a/Cslib/Algorithms/Lean/Query/Basic.lean +++ b/Cslib/Algorithms/Lean/Query/Basic.lean @@ -1,7 +1,7 @@ /- -Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. +Copyright (c) 2025 Sorrachai Yingchareonthawornchai. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Sorrachai Yingchareonthawornhcai, Eric Wieser, Kim Morrison +Authors: Sorrachai Yingchareonthawornchai, Eric Wieser, Kim Morrison -/ module diff --git a/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean b/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean index b0633d4e0..25c6c158c 100644 --- a/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean +++ b/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean @@ -1,7 +1,7 @@ /- -Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. +Copyright (c) 2025 Sorrachai Yingchareonthawornchai. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Sorrachai Yingchareonthawornhcai, Kim Morrison +Authors: Sorrachai Yingchareonthawornchai, Kim Morrison -/ module diff --git a/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean index 5aa72bdd6..4ad0418ef 100644 --- a/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean +++ b/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean @@ -1,7 +1,7 @@ /- -Copyright (c) 2025 Sorrachai Yingchareonthawornhcai. All rights reserved. +Copyright (c) 2025 Sorrachai Yingchareonthawornchai. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Sorrachai Yingchareonthawornhcai, Kim Morrison +Authors: Sorrachai Yingchareonthawornchai, Kim Morrison -/ module