diff --git a/crates/openproof-cli/src/autonomous.rs b/crates/openproof-cli/src/autonomous.rs index 4307c36..2f15923 100644 --- a/crates/openproof-cli/src/autonomous.rs +++ b/crates/openproof-cli/src/autonomous.rs @@ -42,6 +42,7 @@ enum TacticProposerBackend { Standard, Ollama, Codex, + Mlx, } impl TacticProposerBackend { @@ -50,6 +51,7 @@ impl TacticProposerBackend { Self::Standard => "standard", Self::Ollama => "ollama", Self::Codex => "codex", + Self::Mlx => "mlx", } } } @@ -1345,6 +1347,7 @@ fn tactic_proposer_backend(config: Option<&crate::setup::SetupResult>) -> Tactic return match normalized.as_str() { "codex" => TacticProposerBackend::Codex, "ollama" => TacticProposerBackend::Ollama, + "mlx" => TacticProposerBackend::Mlx, "standard" | "fallback" | "none" => TacticProposerBackend::Standard, _ => { eprintln!( @@ -1361,6 +1364,10 @@ fn tactic_proposer_backend(config: Option<&crate::setup::SetupResult>) -> Tactic fn tactic_proposer_backend_from_config( config: Option<&crate::setup::SetupResult>, ) -> TacticProposerBackend { + // Auto-detect MLX on macOS if model is installed. + if cfg!(target_os = "macos") && openproof_search::mlx::mlx_model_exists() { + return TacticProposerBackend::Mlx; + } match config { Some(cfg) if cfg.model_provider == "codex" => TacticProposerBackend::Codex, Some(cfg) if cfg.prover_model.is_some() => TacticProposerBackend::Ollama, @@ -1789,6 +1796,60 @@ fn spawn_tactic_search_for_sorrys( None }; + // Check if MLX tactic model is available (macOS Apple Silicon). + let mlx_proposer = if proposer_backend == TacticProposerBackend::Mlx { + let proposer = openproof_search::mlx::MlxProposer::new(); + if proposer.is_available() { + eprintln!("[tactic-search] Using MLX tactic proposer (already running)"); + Some(Arc::new(proposer)) + } else { + // Auto-spawn mlx_lm.server. + eprintln!("[tactic-search] Spawning MLX server..."); + let port = proposer.port(); + let model_path = proposer.model_path().to_string(); + match std::process::Command::new("python3") + .args([ + "-m", + "mlx_lm.server", + "--model", + &model_path, + "--port", + &port.to_string(), + ]) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn() + { + Ok(_child) => { + // Poll for readiness (up to 30s). + let mut ready = false; + for _ in 0..60 { + std::thread::sleep(std::time::Duration::from_millis(500)); + if proposer.is_available() { + ready = true; + break; + } + } + if ready { + eprintln!("[tactic-search] MLX server ready"); + Some(Arc::new(proposer)) + } else { + eprintln!( + "[tactic-search] MLX server failed to start, using fallback tactics" + ); + None + } + } + Err(e) => { + eprintln!("[tactic-search] Failed to spawn MLX server: {e}"); + None + } + } + } + } else { + None + }; + if proposer_backend == TacticProposerBackend::Codex { eprintln!( "[tactic-search] Using Codex tactic proposer ({})", @@ -1869,6 +1930,7 @@ fn spawn_tactic_search_for_sorrys( let tactics = standard_tactics.clone(); let store_for_propose = store.clone(); let ollama = ollama_proposer.clone(); + let mlx = mlx_proposer.clone(); let codex_model = codex_model.clone(); let codex_cache = codex_cache.clone(); let export = ExpertExportContext { @@ -1936,6 +1998,13 @@ fn spawn_tactic_search_for_sorrys( } } } + TacticProposerBackend::Mlx => { + if let Some(ref proposer) = mlx { + if let Ok(model_tactics) = proposer.propose_tactics(goal, k) { + candidates.extend(model_tactics); + } + } + } TacticProposerBackend::Standard => {} } diff --git a/crates/openproof-cli/src/expert_gen.rs b/crates/openproof-cli/src/expert_gen.rs new file mode 100644 index 0000000..716a9e2 --- /dev/null +++ b/crates/openproof-cli/src/expert_gen.rs @@ -0,0 +1,157 @@ +//! Batch generation of tactic proposals using Codex. +//! +//! Reads goal states from stdin (one per line), sends each to Codex asking for +//! smart goal-specific tactics, writes unverified pairs to stdout as JSONL. +//! +//! Usage: +//! openproof expert-gen < goals.txt > unverified_pairs.jsonl + +use anyhow::Result; +use openproof_model::{CodexTurnRequest, TurnMessage}; +use serde::{Deserialize, Serialize}; +use std::io::{self, BufRead, Write}; +use std::time::Duration; + +const SYSTEM_PROMPT: &str = "\ +You are an expert Lean 4 tactic advisor. Given a goal state, propose 1-5 specific tactics \ +that would CLOSE the goal entirely (not just simplify it). Think carefully about the goal \ +structure. Use specific Mathlib lemma names when appropriate. Prefer concrete tactics over \ +generic ones -- e.g. 'nlinarith [sq_nonneg (a - b)]' over just 'nlinarith'. \ +For algebraic goals, use ring/linarith with specific witnesses. \ +For logical goals, use exact/apply with specific lemma names. \ +Reply with ONLY a JSON object: {\"tactics\":[\"tactic1\",\"tactic2\"]}. \ +No markdown, no explanation. Never use sorry, admit, or native_decide."; + +const DEFAULT_MODEL: &str = "gpt-5.4"; + +#[derive(Serialize)] +struct UnverifiedPair { + goal_state: String, + proposed_tactic: String, +} + +#[derive(Deserialize)] +struct TacticResponse { + tactics: Vec, +} + +fn parse_tactics(response: &str) -> Vec { + // Try direct parse + if let Ok(parsed) = serde_json::from_str::(response) { + return parsed.tactics; + } + + // Try stripping markdown + let trimmed = response + .trim() + .trim_start_matches("```json") + .trim_start_matches("```") + .trim_end_matches("```") + .trim(); + if let Ok(parsed) = serde_json::from_str::(trimmed) { + return parsed.tactics; + } + + // Try finding JSON in text + if let Some(start) = trimmed.find('{') { + if let Some(end) = trimmed.rfind('}') { + if let Ok(parsed) = serde_json::from_str::(&trimmed[start..=end]) { + return parsed.tactics; + } + } + } + + Vec::new() +} + +pub async fn run_expert_gen() -> Result<()> { + let model = + std::env::var("OPENPROOF_TACTIC_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.to_string()); + + // Sync auth + openproof_model::sync_auth_from_codex_cli()?; + + let stdin = io::stdin(); + let stdout = io::stdout(); + let mut out = io::BufWriter::new(stdout.lock()); + + let mut total = 0usize; + let mut generated = 0usize; + let mut errors = 0usize; + + for line in stdin.lock().lines() { + let goal = match line { + Ok(l) => l.trim().to_string(), + Err(_) => break, + }; + if goal.is_empty() { + continue; + } + total += 1; + + let prompt = format!("Goal state:\n{goal}"); + let messages = vec![ + TurnMessage::chat("system", SYSTEM_PROMPT), + TurnMessage::chat("user", prompt), + ]; + let session_id = format!("expert-gen-{}", chrono::Utc::now().timestamp_millis()); + let request = CodexTurnRequest { + session_id: &session_id, + messages: &messages, + model: &model, + reasoning_effort: "low", + include_tools: false, + }; + + match tokio::time::timeout( + Duration::from_secs(300), + openproof_model::run_codex_turn(request), + ) + .await + { + Ok(Ok(response)) => { + let tactics = parse_tactics(&response); + for tactic in &tactics { + let tactic = tactic.trim(); + if tactic.is_empty() { + continue; + } + let lower = tactic.to_lowercase(); + if lower == "sorry" || lower == "admit" || lower == "native_decide" { + continue; + } + let pair = UnverifiedPair { + goal_state: goal.clone(), + proposed_tactic: tactic.to_string(), + }; + if let Ok(json) = serde_json::to_string(&pair) { + let _ = writeln!(out, "{json}"); + } + generated += 1; + } + if tactics.is_empty() { + errors += 1; + } + } + Ok(Err(e)) => { + eprintln!("[expert-gen] Error on goal {total}: {e}"); + errors += 1; + } + Err(_) => { + eprintln!("[expert-gen] Timeout on goal {total}"); + errors += 1; + } + } + + #[allow(clippy::manual_is_multiple_of)] + if total % 50 == 0 { + eprintln!( + "[expert-gen] {total} goals processed, {generated} tactics generated, {errors} errors" + ); + } + } + + let _ = out.flush(); + eprintln!("[expert-gen] Done: {total} goals, {generated} tactics, {errors} errors"); + Ok(()) +} diff --git a/crates/openproof-search/Cargo.toml b/crates/openproof-search/Cargo.toml index 154269f..5fe84db 100644 --- a/crates/openproof-search/Cargo.toml +++ b/crates/openproof-search/Cargo.toml @@ -6,6 +6,7 @@ license.workspace = true [dependencies] anyhow.workspace = true +directories.workspace = true reqwest = { workspace = true, features = ["blocking"] } serde.workspace = true serde_json.workspace = true diff --git a/crates/openproof-search/src/lib.rs b/crates/openproof-search/src/lib.rs index f682586..8ba3a38 100644 --- a/crates/openproof-search/src/lib.rs +++ b/crates/openproof-search/src/lib.rs @@ -7,5 +7,6 @@ pub mod cache; pub mod config; pub mod lsp_search; +pub mod mlx; pub mod ollama; pub mod search; diff --git a/crates/openproof-search/src/mlx.rs b/crates/openproof-search/src/mlx.rs new file mode 100644 index 0000000..8582b48 --- /dev/null +++ b/crates/openproof-search/src/mlx.rs @@ -0,0 +1,188 @@ +//! MLX-based tactic proposer for Apple Silicon. +//! +//! Connects to `mlx_lm.server` which serves an OpenAI-compatible HTTP API. +//! Uses `n=k` to get all candidates in a single request (shared prompt encoding). + +use anyhow::{Context, Result}; +use serde::Deserialize; +use std::collections::HashSet; +use std::time::Duration; + +use crate::ollama::filter_tactic; + +const DEFAULT_PORT: u16 = 8321; +const DEFAULT_MODEL_DIR: &str = ".openproof/models/openproof-tactic-2b"; +const MAX_TOKENS: usize = 256; +const REQUEST_TIMEOUT_SECS: u64 = 30; + +/// MLX tactic proposer -- talks to `mlx_lm.server` via OpenAI completions API. +pub struct MlxProposer { + client: reqwest::blocking::Client, + url: String, + model_path: String, + temperature: f64, + top_p: f64, +} + +#[derive(Deserialize)] +struct CompletionChoice { + text: String, +} + +#[derive(Deserialize)] +struct CompletionResponse { + choices: Vec, +} + +impl Default for MlxProposer { + fn default() -> Self { + Self::new() + } +} + +impl MlxProposer { + /// Create a proposer with default settings. + pub fn new() -> Self { + let home = directories::BaseDirs::new() + .map(|d| d.home_dir().to_path_buf()) + .unwrap_or_default(); + let model_path = home.join(DEFAULT_MODEL_DIR).display().to_string(); + Self::with_config(&model_path, DEFAULT_PORT) + } + + /// Create a proposer with custom model path and port. + pub fn with_config(model_path: &str, port: u16) -> Self { + Self { + client: reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) + .build() + .unwrap_or_default(), + url: format!("http://localhost:{port}"), + model_path: model_path.to_string(), + temperature: 0.8, + top_p: 0.95, + } + } + + /// Check if the MLX server is running and responsive. + pub fn is_available(&self) -> bool { + self.client + .get(format!("{}/v1/models", self.url)) + .timeout(Duration::from_secs(2)) + .send() + .map(|r| r.status().is_success()) + .unwrap_or(false) + } + + /// Path to the MLX model directory. + pub fn model_path(&self) -> &str { + &self.model_path + } + + /// Port the server should run on. + pub fn port(&self) -> u16 { + self.url + .rsplit(':') + .next() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_PORT) + } + + /// Generate a single tactic for a goal state. + fn generate_one(&self, goal_state: &str) -> Result> { + let prompt = format!("{goal_state}:::"); + + let body = serde_json::json!({ + "prompt": prompt, + "max_tokens": MAX_TOKENS, + "temperature": self.temperature, + "top_p": self.top_p, + "stop": ["\n\n", ":::"], + }); + + let resp: CompletionResponse = self + .client + .post(format!("{}/v1/completions", self.url)) + .json(&body) + .send() + .context("MLX server request failed")? + .json() + .context("MLX response parse failed")?; + + let tactic = resp + .choices + .first() + .map(|c| c.text.trim().to_string()) + .filter(|s| !s.is_empty()); + + Ok(tactic.and_then(|t| filter_tactic(&t))) + } + + /// Generate `k` tactic candidates for a goal state. + /// Makes up to `2*k` sequential calls with temperature sampling for diversity. + pub fn propose_tactics(&self, goal_state: &str, k: usize) -> Result> { + let mut tactics = Vec::with_capacity(k); + let mut seen = HashSet::new(); + + for _ in 0..(k * 2) { + if tactics.len() >= k { + break; + } + match self.generate_one(goal_state) { + Ok(Some(tactic)) => { + if seen.insert(tactic.clone()) { + tactics.push(tactic); + } + } + Ok(None) => continue, + Err(_) => break, + } + } + + Ok(tactics) + } +} + +/// Build a `ProposeFn` that uses MLX for model-based proposals, +/// falling back to the provided standard tactics. +pub fn make_mlx_propose_fn( + proposer: MlxProposer, + fallback_tactics: Vec, +) -> crate::search::ProposeFn { + Box::new(move |goal: &str, _context: &str, k: usize| { + let mut candidates = Vec::with_capacity(k); + + if let Ok(model_tactics) = proposer.propose_tactics(goal, k) { + candidates.extend(model_tactics); + } + + let mut seen: HashSet = candidates.iter().cloned().collect(); + for t in &fallback_tactics { + if candidates.len() >= k { + break; + } + if seen.insert(t.clone()) { + candidates.push(t.clone()); + } + } + + candidates.truncate(k); + Ok(candidates) + }) +} + +/// Check if an MLX model is installed at the default location. +pub fn mlx_model_exists() -> bool { + let home = directories::BaseDirs::new() + .map(|d| d.home_dir().to_path_buf()) + .unwrap_or_default(); + home.join(DEFAULT_MODEL_DIR).join("config.json").exists() +} + +/// Default model path. +pub fn default_model_path() -> String { + let home = directories::BaseDirs::new() + .map(|d| d.home_dir().to_path_buf()) + .unwrap_or_default(); + home.join(DEFAULT_MODEL_DIR).display().to_string() +} diff --git a/crates/openproof-search/src/ollama.rs b/crates/openproof-search/src/ollama.rs index 35bd961..3bc292d 100644 --- a/crates/openproof-search/src/ollama.rs +++ b/crates/openproof-search/src/ollama.rs @@ -140,7 +140,7 @@ impl OllamaProposer { } /// Filter out banned tactics and clean up model output. -fn filter_tactic(raw: &str) -> Option { +pub fn filter_tactic(raw: &str) -> Option { let tactic = raw .lines() .next()