From 3a28d1fb5bc0159e9e096debea8787a17df9e020 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 20:03:30 +0300 Subject: [PATCH 01/13] KS75: add EmbeddingProvider enum + config fields to shrimpk-core - Add EmbeddingProvider enum (Fastembed | OpenAI) with Display/FromStr/Serde - Add embedding_provider, embedding_model, embedding_api_url to EchoConfig - Add infer_embedding_dim() helper mapping known model names to dimensions - Wire into FileConfig, resolve_config(), env vars (SHRIMPK_EMBEDDING_*) - Auto-infer embedding_dim from model name unless explicitly overridden - Re-export EmbeddingProvider from shrimpk-core root - 10 new tests: parse roundtrip, display, infer_dim known/fallback, serde, TOML Co-Authored-By: Claude Opus 4.6 --- crates/shrimpk-core/src/config.rs | 221 ++++++++++++++++++++++++++++++ crates/shrimpk-core/src/lib.rs | 4 +- 2 files changed, 223 insertions(+), 2 deletions(-) diff --git a/crates/shrimpk-core/src/config.rs b/crates/shrimpk-core/src/config.rs index e3254de..5e7b9ec 100644 --- a/crates/shrimpk-core/src/config.rs +++ b/crates/shrimpk-core/src/config.rs @@ -45,6 +45,45 @@ impl std::str::FromStr for RerankerBackend { } } +/// Backend for embedding vector generation. +/// +/// Controls which embedding model and provider is used for memory storage +/// and echo queries. The default `Fastembed` backend uses a local ONNX model +/// (BGE-small-EN-v1.5, 384-dim) with zero external API calls. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub enum EmbeddingProvider { + /// Local fastembed ONNX model (default). Zero network calls. + /// Models: BGE-small-EN-v1.5 (384-dim), all-MiniLM-L6-v2 (384-dim), etc. + #[default] + Fastembed, + /// OpenAI-compatible embedding API (local or cloud). + /// Requires `embedding_api_url` and `embedding_model` to be set. + /// Works with: OpenAI, Ollama `/api/embeddings`, LiteLLM, vLLM, etc. + OpenAI, +} + +impl std::fmt::Display for EmbeddingProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Fastembed => write!(f, "fastembed"), + Self::OpenAI => write!(f, "openai"), + } + } +} + +impl std::str::FromStr for EmbeddingProvider { + type Err = String; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "fastembed" | "local" | "onnx" => Ok(Self::Fastembed), + "openai" | "api" | "ollama" => Ok(Self::OpenAI), + _ => Err(format!( + "invalid embedding provider '{s}': expected fastembed or openai" + )), + } + } +} + /// Quantization mode for embedding vectors in the echo index. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] pub enum QuantizationMode { @@ -277,6 +316,21 @@ pub struct EchoConfig { /// Embedding dimension for speech channel. Default: 640 (ECAPA-TDNN 256 + Whisper-tiny 384). #[serde(default = "default_speech_dim")] pub speech_embedding_dim: usize, + + // --- Embedding provider (KS75) --- + /// Embedding backend: `Fastembed` (local ONNX, default) or `OpenAI` (API). + #[serde(default)] + pub embedding_provider: EmbeddingProvider, + /// Model name for the embedding provider. + /// Fastembed: "BGE-small-EN-v1.5" (default), "all-MiniLM-L6-v2", etc. + /// OpenAI: "text-embedding-3-small", "nomic-embed-text", Ollama model name, etc. + #[serde(default = "default_embedding_model")] + pub embedding_model: String, + /// API URL for OpenAI-compatible embedding providers. + /// Only used when `embedding_provider = OpenAI`. + /// Default: "http://127.0.0.1:11434" (Ollama). + #[serde(default = "default_embedding_api_url")] + pub embedding_api_url: String, } fn default_true() -> bool { @@ -317,6 +371,13 @@ fn default_speech_dim() -> usize { 640 } +fn default_embedding_model() -> String { + "BGE-small-EN-v1.5".to_string() +} +fn default_embedding_api_url() -> String { + "http://127.0.0.1:11434".to_string() +} + fn default_proxy_target() -> String { "http://127.0.0.1:11434".to_string() } @@ -411,6 +472,9 @@ impl Default for EchoConfig { enabled_modalities: default_modalities(), vision_embedding_dim: default_vision_dim(), speech_embedding_dim: default_speech_dim(), + embedding_provider: EmbeddingProvider::default(), + embedding_model: default_embedding_model(), + embedding_api_url: default_embedding_api_url(), } } } @@ -476,6 +540,34 @@ impl EchoConfig { } } + /// Infer the embedding dimension from the configured model name. + /// + /// Returns the known dimension for well-known models, or falls back to + /// `self.embedding_dim` (the explicitly configured value) if the model + /// is not recognized. This lets users set `embedding_model` without also + /// needing to manually set `embedding_dim`. + pub fn infer_embedding_dim(&self) -> usize { + match self.embedding_model.to_lowercase().as_str() { + // fastembed ONNX models + s if s.contains("bge-small") => 384, + s if s.contains("bge-base") => 768, + s if s.contains("bge-large") => 1024, + s if s.contains("minilm-l6") => 384, + s if s.contains("minilm-l12") => 384, + // OpenAI + s if s.contains("text-embedding-3-small") => 1536, + s if s.contains("text-embedding-3-large") => 3072, + s if s.contains("text-embedding-ada") => 1536, + // Ollama common models + s if s.contains("nomic-embed-text") => 768, + s if s.contains("mxbai-embed-large") => 1024, + s if s.contains("all-minilm") => 384, + s if s.contains("snowflake-arctic-embed") => 1024, + // Fallback to explicit config + _ => self.embedding_dim, + } + } + /// Estimated index size in bytes for the current config. pub fn estimated_index_bytes(&self) -> u64 { let bytes_per_entry = self.quantization.bytes_per_vector(self.embedding_dim) + 100; @@ -537,6 +629,9 @@ pub struct FileConfig { pub enabled_modalities: Option>, pub vision_embedding_dim: Option, pub speech_embedding_dim: Option, + pub embedding_provider: Option, + pub embedding_model: Option, + pub embedding_api_url: Option, } /// Default data directory: `~/.shrimpk-kernel/` @@ -781,6 +876,15 @@ pub fn resolve_config() -> crate::Result { if let Some(v) = fc.speech_embedding_dim { config.speech_embedding_dim = v; } + if let Some(v) = fc.embedding_provider { + config.embedding_provider = v; + } + if let Some(v) = fc.embedding_model { + config.embedding_model = v; + } + if let Some(v) = fc.embedding_api_url { + config.embedding_api_url = v; + } } // Layer 3: env var overrides (highest priority) @@ -844,6 +948,24 @@ pub fn resolve_config() -> crate::Result { config.hebbian_prune_threshold = v; } + if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_PROVIDER") + && let Ok(provider) = v.parse::() + { + config.embedding_provider = provider; + } + if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_MODEL") { + config.embedding_model = v; + } + if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_API_URL") { + config.embedding_api_url = v; + } + + // Auto-infer embedding_dim from model name if not explicitly overridden + let dim_explicitly_set = std::env::var("SHRIMPK_EMBEDDING_DIM").is_ok(); + if !dim_explicitly_set { + config.embedding_dim = config.infer_embedding_dim(); + } + // Backward compatibility: if reranker_enabled=true but backend=None, default to Llm if config.reranker_enabled && config.reranker_backend == RerankerBackend::None { config.reranker_backend = RerankerBackend::Llm; @@ -1265,4 +1387,103 @@ mod tests { "Explicit backend should override legacy reranker_enabled" ); } + + // --- KS75: EmbeddingProvider --- + + #[test] + fn embedding_provider_default_is_fastembed() { + let config = EchoConfig::default(); + assert_eq!(config.embedding_provider, EmbeddingProvider::Fastembed); + assert_eq!(config.embedding_model, "BGE-small-EN-v1.5"); + } + + #[test] + fn embedding_provider_parse_roundtrip() { + for (input, expected) in [ + ("fastembed", EmbeddingProvider::Fastembed), + ("local", EmbeddingProvider::Fastembed), + ("onnx", EmbeddingProvider::Fastembed), + ("openai", EmbeddingProvider::OpenAI), + ("api", EmbeddingProvider::OpenAI), + ("ollama", EmbeddingProvider::OpenAI), + ] { + let parsed: EmbeddingProvider = input.parse().unwrap(); + assert_eq!(parsed, expected, "parsing '{input}'"); + } + } + + #[test] + fn embedding_provider_parse_invalid() { + assert!("unknown".parse::().is_err()); + } + + #[test] + fn embedding_provider_display() { + assert_eq!(EmbeddingProvider::Fastembed.to_string(), "fastembed"); + assert_eq!(EmbeddingProvider::OpenAI.to_string(), "openai"); + } + + #[test] + fn infer_embedding_dim_known_models() { + let mut config = EchoConfig::default(); + + config.embedding_model = "BGE-small-EN-v1.5".into(); + assert_eq!(config.infer_embedding_dim(), 384); + + config.embedding_model = "BGE-base-EN-v1.5".into(); + assert_eq!(config.infer_embedding_dim(), 768); + + config.embedding_model = "BGE-large-EN-v1.5".into(); + assert_eq!(config.infer_embedding_dim(), 1024); + + config.embedding_model = "text-embedding-3-small".into(); + assert_eq!(config.infer_embedding_dim(), 1536); + + config.embedding_model = "text-embedding-3-large".into(); + assert_eq!(config.infer_embedding_dim(), 3072); + + config.embedding_model = "nomic-embed-text".into(); + assert_eq!(config.infer_embedding_dim(), 768); + } + + #[test] + fn infer_embedding_dim_unknown_falls_back() { + let mut config = EchoConfig::default(); + config.embedding_model = "my-custom-model".into(); + config.embedding_dim = 512; + assert_eq!( + config.infer_embedding_dim(), + 512, + "Unknown model should fall back to embedding_dim" + ); + } + + #[test] + fn embedding_provider_serde_roundtrip() { + let config = EchoConfig { + embedding_provider: EmbeddingProvider::OpenAI, + embedding_model: "text-embedding-3-small".into(), + embedding_api_url: "https://api.openai.com".into(), + ..Default::default() + }; + let json = serde_json::to_string(&config).unwrap(); + let parsed: EchoConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.embedding_provider, EmbeddingProvider::OpenAI); + assert_eq!(parsed.embedding_model, "text-embedding-3-small"); + assert_eq!(parsed.embedding_api_url, "https://api.openai.com"); + } + + #[test] + fn file_config_embedding_fields_toml_roundtrip() { + let fc = FileConfig { + embedding_provider: Some(EmbeddingProvider::OpenAI), + embedding_model: Some("nomic-embed-text".into()), + embedding_api_url: Some("http://localhost:11434".into()), + ..Default::default() + }; + let toml_str = toml::to_string_pretty(&fc).unwrap(); + let parsed: FileConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.embedding_provider, Some(EmbeddingProvider::OpenAI)); + assert_eq!(parsed.embedding_model, Some("nomic-embed-text".into())); + } } diff --git a/crates/shrimpk-core/src/lib.rs b/crates/shrimpk-core/src/lib.rs index 1658303..745a76f 100644 --- a/crates/shrimpk-core/src/lib.rs +++ b/crates/shrimpk-core/src/lib.rs @@ -13,8 +13,8 @@ pub mod traits; // Re-export commonly used types at crate root pub use config::{ - EchoConfig, FileConfig, QuantizationMode, RerankerBackend, config_dir, config_path, disk_usage, - load_config_file, resolve_config, save_config_file, + EchoConfig, EmbeddingProvider, FileConfig, QuantizationMode, RerankerBackend, config_dir, + config_path, disk_usage, load_config_file, resolve_config, save_config_file, }; pub use entity::{EntityFrame, EntityId, EntityKind}; pub use error::{Result, ShrimPKError}; From cd1bfb24928f683fd6e2fe3324244a75367f6a6e Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 20:05:44 +0300 Subject: [PATCH 02/13] KS75: add EmbeddingProvider trait + rename config enum to EmbeddingBackend - Add EmbeddingProvider trait in traits.rs (embed, embed_batch, dimension, name) - Rename config enum EmbeddingProvider -> EmbeddingBackend (avoids trait name collision) - Re-export EmbeddingProvider trait from shrimpk-core root Co-Authored-By: Claude Opus 4.6 --- crates/shrimpk-core/src/config.rs | 46 +++++++++++++++---------------- crates/shrimpk-core/src/lib.rs | 5 ++-- crates/shrimpk-core/src/traits.rs | 31 +++++++++++++++++++++ 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/crates/shrimpk-core/src/config.rs b/crates/shrimpk-core/src/config.rs index 5e7b9ec..bcfa14f 100644 --- a/crates/shrimpk-core/src/config.rs +++ b/crates/shrimpk-core/src/config.rs @@ -51,7 +51,7 @@ impl std::str::FromStr for RerankerBackend { /// and echo queries. The default `Fastembed` backend uses a local ONNX model /// (BGE-small-EN-v1.5, 384-dim) with zero external API calls. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub enum EmbeddingProvider { +pub enum EmbeddingBackend { /// Local fastembed ONNX model (default). Zero network calls. /// Models: BGE-small-EN-v1.5 (384-dim), all-MiniLM-L6-v2 (384-dim), etc. #[default] @@ -62,7 +62,7 @@ pub enum EmbeddingProvider { OpenAI, } -impl std::fmt::Display for EmbeddingProvider { +impl std::fmt::Display for EmbeddingBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Fastembed => write!(f, "fastembed"), @@ -71,7 +71,7 @@ impl std::fmt::Display for EmbeddingProvider { } } -impl std::str::FromStr for EmbeddingProvider { +impl std::str::FromStr for EmbeddingBackend { type Err = String; fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { @@ -320,7 +320,7 @@ pub struct EchoConfig { // --- Embedding provider (KS75) --- /// Embedding backend: `Fastembed` (local ONNX, default) or `OpenAI` (API). #[serde(default)] - pub embedding_provider: EmbeddingProvider, + pub embedding_provider: EmbeddingBackend, /// Model name for the embedding provider. /// Fastembed: "BGE-small-EN-v1.5" (default), "all-MiniLM-L6-v2", etc. /// OpenAI: "text-embedding-3-small", "nomic-embed-text", Ollama model name, etc. @@ -472,7 +472,7 @@ impl Default for EchoConfig { enabled_modalities: default_modalities(), vision_embedding_dim: default_vision_dim(), speech_embedding_dim: default_speech_dim(), - embedding_provider: EmbeddingProvider::default(), + embedding_provider: EmbeddingBackend::default(), embedding_model: default_embedding_model(), embedding_api_url: default_embedding_api_url(), } @@ -629,7 +629,7 @@ pub struct FileConfig { pub enabled_modalities: Option>, pub vision_embedding_dim: Option, pub speech_embedding_dim: Option, - pub embedding_provider: Option, + pub embedding_provider: Option, pub embedding_model: Option, pub embedding_api_url: Option, } @@ -949,7 +949,7 @@ pub fn resolve_config() -> crate::Result { } if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_PROVIDER") - && let Ok(provider) = v.parse::() + && let Ok(provider) = v.parse::() { config.embedding_provider = provider; } @@ -1388,39 +1388,39 @@ mod tests { ); } - // --- KS75: EmbeddingProvider --- + // --- KS75: EmbeddingBackend --- #[test] fn embedding_provider_default_is_fastembed() { let config = EchoConfig::default(); - assert_eq!(config.embedding_provider, EmbeddingProvider::Fastembed); + assert_eq!(config.embedding_provider, EmbeddingBackend::Fastembed); assert_eq!(config.embedding_model, "BGE-small-EN-v1.5"); } #[test] fn embedding_provider_parse_roundtrip() { for (input, expected) in [ - ("fastembed", EmbeddingProvider::Fastembed), - ("local", EmbeddingProvider::Fastembed), - ("onnx", EmbeddingProvider::Fastembed), - ("openai", EmbeddingProvider::OpenAI), - ("api", EmbeddingProvider::OpenAI), - ("ollama", EmbeddingProvider::OpenAI), + ("fastembed", EmbeddingBackend::Fastembed), + ("local", EmbeddingBackend::Fastembed), + ("onnx", EmbeddingBackend::Fastembed), + ("openai", EmbeddingBackend::OpenAI), + ("api", EmbeddingBackend::OpenAI), + ("ollama", EmbeddingBackend::OpenAI), ] { - let parsed: EmbeddingProvider = input.parse().unwrap(); + let parsed: EmbeddingBackend = input.parse().unwrap(); assert_eq!(parsed, expected, "parsing '{input}'"); } } #[test] fn embedding_provider_parse_invalid() { - assert!("unknown".parse::().is_err()); + assert!("unknown".parse::().is_err()); } #[test] fn embedding_provider_display() { - assert_eq!(EmbeddingProvider::Fastembed.to_string(), "fastembed"); - assert_eq!(EmbeddingProvider::OpenAI.to_string(), "openai"); + assert_eq!(EmbeddingBackend::Fastembed.to_string(), "fastembed"); + assert_eq!(EmbeddingBackend::OpenAI.to_string(), "openai"); } #[test] @@ -1461,14 +1461,14 @@ mod tests { #[test] fn embedding_provider_serde_roundtrip() { let config = EchoConfig { - embedding_provider: EmbeddingProvider::OpenAI, + embedding_provider: EmbeddingBackend::OpenAI, embedding_model: "text-embedding-3-small".into(), embedding_api_url: "https://api.openai.com".into(), ..Default::default() }; let json = serde_json::to_string(&config).unwrap(); let parsed: EchoConfig = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed.embedding_provider, EmbeddingProvider::OpenAI); + assert_eq!(parsed.embedding_provider, EmbeddingBackend::OpenAI); assert_eq!(parsed.embedding_model, "text-embedding-3-small"); assert_eq!(parsed.embedding_api_url, "https://api.openai.com"); } @@ -1476,14 +1476,14 @@ mod tests { #[test] fn file_config_embedding_fields_toml_roundtrip() { let fc = FileConfig { - embedding_provider: Some(EmbeddingProvider::OpenAI), + embedding_provider: Some(EmbeddingBackend::OpenAI), embedding_model: Some("nomic-embed-text".into()), embedding_api_url: Some("http://localhost:11434".into()), ..Default::default() }; let toml_str = toml::to_string_pretty(&fc).unwrap(); let parsed: FileConfig = toml::from_str(&toml_str).unwrap(); - assert_eq!(parsed.embedding_provider, Some(EmbeddingProvider::OpenAI)); + assert_eq!(parsed.embedding_provider, Some(EmbeddingBackend::OpenAI)); assert_eq!(parsed.embedding_model, Some("nomic-embed-text".into())); } } diff --git a/crates/shrimpk-core/src/lib.rs b/crates/shrimpk-core/src/lib.rs index 745a76f..0dd9f8c 100644 --- a/crates/shrimpk-core/src/lib.rs +++ b/crates/shrimpk-core/src/lib.rs @@ -13,7 +13,7 @@ pub mod traits; // Re-export commonly used types at crate root pub use config::{ - EchoConfig, EmbeddingProvider, FileConfig, QuantizationMode, RerankerBackend, config_dir, + EchoConfig, EmbeddingBackend, FileConfig, QuantizationMode, RerankerBackend, config_dir, config_path, disk_usage, load_config_file, resolve_config, save_config_file, }; pub use entity::{EntityFrame, EntityId, EntityKind}; @@ -26,5 +26,6 @@ pub use memory::{ }; pub use pii::{PiiMatch, PiiType}; pub use traits::{ - ConsolidationOutput, Consolidator, ExtractedFact, FactType, LabelSet, ModelBackend, Provider, + ConsolidationOutput, Consolidator, EmbeddingProvider, ExtractedFact, FactType, LabelSet, + ModelBackend, Provider, }; diff --git a/crates/shrimpk-core/src/traits.rs b/crates/shrimpk-core/src/traits.rs index 5540640..4e119db 100644 --- a/crates/shrimpk-core/src/traits.rs +++ b/crates/shrimpk-core/src/traits.rs @@ -86,6 +86,37 @@ pub struct ModelCapabilities { pub is_local: bool, } +/// A backend that generates embedding vectors from text. +/// +/// Implementations wrap either a local ONNX model (fastembed) or an HTTP API +/// (OpenAI-compatible). The engine holds the provider behind a `Mutex` since +/// fastembed requires `&mut self` for inference. +/// +/// # Providers +/// - `FastembedProvider` — local ONNX (default, offline, zero API calls) +/// - `OpenAIProvider` — any OpenAI-compatible API (cloud or local Ollama) +/// +/// # Contract +/// - `embed` and `embed_batch` must return vectors of length `dimension()`. +/// - On error, return `Err` — never panic. +/// - `Send` but not `Sync` (fastembed's `TextEmbedding` is `!Sync`). +pub trait EmbeddingProvider: Send { + /// Embed a single text string, returning a vector of length `dimension()`. + fn embed(&mut self, text: &str) -> Result>; + + /// Embed a batch of texts. Default implementation calls `embed()` in a loop. + /// Providers with native batching (fastembed, OpenAI) should override this. + fn embed_batch(&mut self, texts: Vec) -> Result>> { + texts.iter().map(|t| self.embed(t)).collect() + } + + /// The embedding dimension this provider produces (e.g., 384, 768, 1536). + fn dimension(&self) -> usize; + + /// Human-readable name (e.g., "fastembed/bge-small-en-v1.5", "openai/text-embedding-3-small"). + fn name(&self) -> &str; +} + /// Type classification for extracted facts (KS67 — schema-driven extraction). /// /// Maps to distinct retrieval patterns and supersession thresholds. From 06062d46dcaffe1cf35881fe6f932971d07a812e Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 20:09:22 +0300 Subject: [PATCH 03/13] KS75: implement FastembedProvider + OpenAIProvider - New file: crates/shrimpk-memory/src/embedding_provider.rs - FastembedProvider: wraps fastembed TextEmbedding with runtime model selection (10 supported models: BGE, MiniLM, Nomic, MxBai, GTE) - OpenAIProvider: ureq HTTP client for /v1/embeddings (OpenAI, Ollama, LiteLLM) with audit logging, batch support, dimension validation - from_config() factory: selects provider based on EchoConfig.embedding_provider - API key from SHRIMPK_EMBEDDING_API_KEY env var only (never in config file) - 5 unit tests (4 pass, 1 ignored for model download) Co-Authored-By: Claude Opus 4.6 --- .../shrimpk-memory/src/embedding_provider.rs | 391 ++++++++++++++++++ crates/shrimpk-memory/src/lib.rs | 1 + 2 files changed, 392 insertions(+) create mode 100644 crates/shrimpk-memory/src/embedding_provider.rs diff --git a/crates/shrimpk-memory/src/embedding_provider.rs b/crates/shrimpk-memory/src/embedding_provider.rs new file mode 100644 index 0000000..4cc2254 --- /dev/null +++ b/crates/shrimpk-memory/src/embedding_provider.rs @@ -0,0 +1,391 @@ +//! Pluggable embedding provider implementations (KS75). +//! +//! Two backends: +//! - `FastembedProvider` — local ONNX via `fastembed` (default, zero API calls) +//! - `OpenAIProvider` — any OpenAI-compatible embedding API (cloud or local Ollama) +//! +//! Factory function `from_config()` selects the appropriate provider based on `EchoConfig`. + +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use shrimpk_core::{EchoConfig, EmbeddingBackend, EmbeddingProvider, Result, ShrimPKError}; + +// --------------------------------------------------------------------------- +// FastembedProvider +// --------------------------------------------------------------------------- + +/// Local ONNX embedding via fastembed. +/// +/// Wraps `fastembed::TextEmbedding` with runtime model selection. +/// Zero external API calls — all inference runs locally. +pub struct FastembedProvider { + model: TextEmbedding, + dim: usize, + model_name: String, +} + +impl FastembedProvider { + /// Create a new FastembedProvider for the given model name. + /// + /// Supported model names (case-insensitive): + /// - "bge-small-en-v1.5" (384-dim, default) + /// - "bge-base-en-v1.5" (768-dim) + /// - "bge-large-en-v1.5" (1024-dim) + /// - "bge-m3" (1024-dim) + /// - "all-minilm-l6-v2" (384-dim) + /// - "all-minilm-l12-v2" (384-dim) + /// - "nomic-embed-text-v1.5" (768-dim) + /// - "mxbai-embed-large-v1" (1024-dim) + /// - "gte-large-en-v1.5" (1024-dim) + pub fn new(model_name: &str) -> Result { + let (variant, dim) = resolve_fastembed_model(model_name)?; + let display_name = format!("fastembed/{model_name}"); + + let start = std::time::Instant::now(); + let model = TextEmbedding::try_new(InitOptions::new(variant)).map_err(|e| { + ShrimPKError::Embedding(format!( + "Failed to init fastembed model '{model_name}': {e}" + )) + })?; + + tracing::info!( + elapsed_ms = start.elapsed().as_millis(), + model = %display_name, + dim = dim, + "FastembedProvider initialized" + ); + + Ok(Self { + model, + dim, + model_name: display_name, + }) + } +} + +impl EmbeddingProvider for FastembedProvider { + fn embed(&mut self, text: &str) -> Result> { + let results = self + .model + .embed(vec![text.to_string()], None) + .map_err(|e| ShrimPKError::Embedding(format!("fastembed embed failed: {e}")))?; + + results + .into_iter() + .next() + .ok_or_else(|| ShrimPKError::Embedding("Empty fastembed result".into())) + } + + fn embed_batch(&mut self, texts: Vec) -> Result>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + self.model + .embed(texts, None) + .map_err(|e| ShrimPKError::Embedding(format!("fastembed batch embed failed: {e}"))) + } + + fn dimension(&self) -> usize { + self.dim + } + + fn name(&self) -> &str { + &self.model_name + } +} + +/// Map a config model name to a fastembed `EmbeddingModel` variant + dimension. +fn resolve_fastembed_model(name: &str) -> Result<(EmbeddingModel, usize)> { + let lower = name.to_lowercase(); + match lower.as_str() { + s if s.contains("bge-small-en") => Ok((EmbeddingModel::BGESmallENV15, 384)), + s if s.contains("bge-base-en") => Ok((EmbeddingModel::BGEBaseENV15, 768)), + s if s.contains("bge-large-en") => Ok((EmbeddingModel::BGELargeENV15, 1024)), + s if s.contains("bge-m3") => Ok((EmbeddingModel::BGEM3, 1024)), + s if s.contains("all-minilm-l6") || s.contains("minilm-l6") => { + Ok((EmbeddingModel::AllMiniLML6V2, 384)) + } + s if s.contains("all-minilm-l12") || s.contains("minilm-l12") => { + Ok((EmbeddingModel::AllMiniLML12V2, 384)) + } + s if s.contains("nomic-embed-text") => Ok((EmbeddingModel::NomicEmbedTextV15, 768)), + s if s.contains("mxbai-embed-large") => Ok((EmbeddingModel::MxbaiEmbedLargeV1, 1024)), + s if s.contains("gte-large-en") => Ok((EmbeddingModel::GTELargeENV15, 1024)), + s if s.contains("gte-base-en") => Ok((EmbeddingModel::GTEBaseENV15, 768)), + _ => Err(ShrimPKError::Embedding(format!( + "Unknown fastembed model '{name}'. Supported: bge-small-en-v1.5, bge-base-en-v1.5, \ + bge-large-en-v1.5, bge-m3, all-minilm-l6-v2, all-minilm-l12-v2, \ + nomic-embed-text-v1.5, mxbai-embed-large-v1, gte-large-en-v1.5, gte-base-en-v1.5" + ))), + } +} + +// --------------------------------------------------------------------------- +// OpenAIProvider +// --------------------------------------------------------------------------- + +/// OpenAI-compatible embedding API provider. +/// +/// Works with any endpoint that implements the `/v1/embeddings` contract: +/// OpenAI, Ollama, LiteLLM, vLLM, Azure OpenAI, etc. +/// +/// API key is read from `SHRIMPK_EMBEDDING_API_KEY` env var — never stored in config. +pub struct OpenAIProvider { + url: String, + model: String, + api_key: Option, + agent: ureq::Agent, + dim: usize, + display_name: String, +} + +impl OpenAIProvider { + /// Create a new OpenAI-compatible embedding provider. + /// + /// The `dim` parameter must match the actual dimension of the remote model. + /// Use `EchoConfig::infer_embedding_dim()` to auto-derive it from the model name. + pub fn new(url: &str, model: &str, dim: usize) -> Result { + let api_key = std::env::var("SHRIMPK_EMBEDDING_API_KEY").ok(); + let display_name = format!("openai/{model}"); + + let agent = ureq::Agent::new_with_config( + ureq::config::Config::builder() + .timeout_global(Some(std::time::Duration::from_secs(30))) + .build(), + ); + + tracing::info!( + url = %url, + model = %model, + dim = dim, + has_api_key = api_key.is_some(), + "OpenAIProvider initialized" + ); + + Ok(Self { + url: url.trim_end_matches('/').to_string(), + model: model.to_string(), + api_key, + agent, + dim, + display_name, + }) + } + + /// Call the embedding API for a batch of texts. + fn call_api(&self, texts: &[String]) -> Result>> { + let endpoint = format!("{}/v1/embeddings", self.url); + + let body = serde_json::json!({ + "model": self.model, + "input": texts, + }); + + // Audit logging (matches HttpConsolidator pattern) + let body_bytes = serde_json::to_vec(&body).unwrap_or_default(); + tracing::info!( + target: "shrimpk::audit", + endpoint = %endpoint, + data_bytes = body_bytes.len(), + batch_size = texts.len(), + direction = "outbound", + component = "embedding_provider", + "External embedding API call" + ); + + let mut req = self.agent.post(&endpoint); + if let Some(key) = &self.api_key { + req = req.header("Authorization", &format!("Bearer {key}")); + } + + let mut resp = req.send_json(&body).map_err(|e| { + ShrimPKError::Embedding(format!("OpenAI embedding API error at {endpoint}: {e}")) + })?; + + let json: serde_json::Value = resp.body_mut().read_json().map_err(|e| { + ShrimPKError::Embedding(format!("OpenAI embedding API parse error: {e}")) + })?; + + // Extract embeddings from response: {"data": [{"embedding": [...], "index": 0}, ...]} + let data = json["data"].as_array().ok_or_else(|| { + ShrimPKError::Embedding(format!( + "OpenAI embedding API: missing 'data' array in response: {}", + truncate_json(&json) + )) + })?; + + // Sort by index to maintain input order + let mut indexed: Vec<(usize, Vec)> = data + .iter() + .filter_map(|item| { + let index = item["index"].as_u64()? as usize; + let embedding: Vec = item["embedding"] + .as_array()? + .iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect(); + Some((index, embedding)) + }) + .collect(); + + indexed.sort_by_key(|(i, _)| *i); + let embeddings: Vec> = indexed.into_iter().map(|(_, e)| e).collect(); + + if embeddings.len() != texts.len() { + return Err(ShrimPKError::Embedding(format!( + "OpenAI embedding API returned {} embeddings for {} inputs", + embeddings.len(), + texts.len() + ))); + } + + // Validate dimension + if let Some(first) = embeddings.first() + && first.len() != self.dim + { + return Err(ShrimPKError::Embedding(format!( + "OpenAI embedding dimension mismatch: expected {}, got {} from model '{}'", + self.dim, + first.len(), + self.model + ))); + } + + Ok(embeddings) + } +} + +impl EmbeddingProvider for OpenAIProvider { + fn embed(&mut self, text: &str) -> Result> { + let results = self.call_api(&[text.to_string()])?; + results + .into_iter() + .next() + .ok_or_else(|| ShrimPKError::Embedding("Empty OpenAI embedding result".into())) + } + + fn embed_batch(&mut self, texts: Vec) -> Result>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + self.call_api(&texts) + } + + fn dimension(&self) -> usize { + self.dim + } + + fn name(&self) -> &str { + &self.display_name + } +} + +/// Truncate JSON for error messages. +fn truncate_json(v: &serde_json::Value) -> String { + let s = v.to_string(); + if s.len() > 200 { + format!("{}...", &s[..200]) + } else { + s + } +} + +// --------------------------------------------------------------------------- +// Factory +// --------------------------------------------------------------------------- + +/// Create an embedding provider based on config. +/// +/// Reads `config.embedding_provider` to select the backend: +/// - `Fastembed` → `FastembedProvider` with `config.embedding_model` +/// - `OpenAI` → `OpenAIProvider` with `config.embedding_api_url` + `config.embedding_model` +pub fn from_config(config: &EchoConfig) -> Result> { + match config.embedding_provider { + EmbeddingBackend::Fastembed => { + let provider = FastembedProvider::new(&config.embedding_model)?; + Ok(Box::new(provider)) + } + EmbeddingBackend::OpenAI => { + let dim = config.infer_embedding_dim(); + let provider = + OpenAIProvider::new(&config.embedding_api_url, &config.embedding_model, dim)?; + Ok(Box::new(provider)) + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_fastembed_known_models() { + let cases = [ + ("BGE-small-EN-v1.5", 384), + ("bge-small-en-v1.5", 384), + ("bge-base-en-v1.5", 768), + ("BGE-large-EN-v1.5", 1024), + ("bge-m3", 1024), + ("all-MiniLM-L6-v2", 384), + ("all-minilm-l12-v2", 384), + ("nomic-embed-text-v1.5", 768), + ("mxbai-embed-large-v1", 1024), + ("gte-large-en-v1.5", 1024), + ("gte-base-en-v1.5", 768), + ]; + for (name, expected_dim) in cases { + let (_, dim) = + resolve_fastembed_model(name).unwrap_or_else(|_| panic!("should resolve '{name}'")); + assert_eq!(dim, expected_dim, "model '{name}'"); + } + } + + #[test] + fn resolve_fastembed_unknown_errors() { + let result = resolve_fastembed_model("my-custom-model"); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("Unknown fastembed model"), + "error should mention unknown model: {err}" + ); + } + + #[test] + fn openai_provider_initializes() { + // OpenAI provider should initialize (api_key is read from env, may or may not be set) + let provider = OpenAIProvider::new("http://localhost:11434", "nomic-embed-text", 768); + assert!(provider.is_ok()); + let p = provider.unwrap(); + assert_eq!(p.dimension(), 768); + assert_eq!(p.name(), "openai/nomic-embed-text"); + } + + #[test] + fn from_config_default_selects_fastembed() { + // This test just checks that the factory selects the right backend. + // It doesn't actually initialize the model (that would download it). + let config = EchoConfig::default(); + assert_eq!(config.embedding_provider, EmbeddingBackend::Fastembed); + assert_eq!(config.embedding_model, "BGE-small-EN-v1.5"); + } + + #[test] + #[ignore = "requires fastembed model download"] + fn fastembed_provider_default_model_works() { + let mut provider = FastembedProvider::new("BGE-small-EN-v1.5").unwrap(); + assert_eq!(provider.dimension(), 384); + + let embedding = provider.embed("Hello world").unwrap(); + assert_eq!(embedding.len(), 384); + + let batch = provider + .embed_batch(vec!["Hello".into(), "World".into()]) + .unwrap(); + assert_eq!(batch.len(), 2); + assert_eq!(batch[0].len(), 384); + } +} diff --git a/crates/shrimpk-memory/src/lib.rs b/crates/shrimpk-memory/src/lib.rs index 932d915..b55a626 100644 --- a/crates/shrimpk-memory/src/lib.rs +++ b/crates/shrimpk-memory/src/lib.rs @@ -19,6 +19,7 @@ pub mod consolidation; pub mod consolidator; pub mod echo; pub mod embedder; +pub mod embedding_provider; pub mod hebbian; pub mod importance; pub mod labels; From c979759b4af9eeb84385d12f19bbc361e0a39ef7 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 20:14:48 +0300 Subject: [PATCH 04/13] KS75: integrate providers into MultiEmbedder + EchoEngine with dim mismatch detection - MultiEmbedder::new() -> MultiEmbedder::new(config: &EchoConfig) - Text channel delegates to Box from embedding_provider::from_config() - text_dimension() returns provider.dimension() (was hardcoded 384) - Added text_provider_name() for sidecar/logging - EchoEngine::load() checks dimension mismatch: hard error if stored != config dim - Model name sidecar (embedding_model.txt): written on persist(), warn on load() if changed - Updated all MultiEmbedder::new() call sites (embedder.rs tests, echo_precision_tuning.rs, echo_token_efficiency.rs) to pass &EchoConfig::default() Co-Authored-By: Claude Opus 4.6 --- crates/shrimpk-memory/src/echo.rs | 41 ++++++++- crates/shrimpk-memory/src/embedder.rs | 127 ++++++++++++-------------- tests/echo_precision_tuning.rs | 15 ++- tests/echo_token_efficiency.rs | 3 +- 4 files changed, 107 insertions(+), 79 deletions(-) diff --git a/crates/shrimpk-memory/src/echo.rs b/crates/shrimpk-memory/src/echo.rs index b362c90..ee62754 100644 --- a/crates/shrimpk-memory/src/echo.rs +++ b/crates/shrimpk-memory/src/echo.rs @@ -182,7 +182,7 @@ impl EchoEngine { /// Returns `ShrimPKError::Embedding` if the model fails to initialize. #[instrument(skip(config), fields(max_memories = config.max_memories, threshold = config.similarity_threshold))] pub fn new(config: EchoConfig) -> Result { - let mut embedder = MultiEmbedder::new()?; + let mut embedder = MultiEmbedder::new(&config)?; // Initialize label prototypes BEFORE wrapping embedder in Mutex (ADR-015 D4). // Prototype embeddings are computed once at startup. @@ -2662,6 +2662,12 @@ impl EchoEngine { // Persist entity store sidecar (KS73) crate::persistence::save_entities(&store, &self.config.data_dir)?; + // KS75: Write embedding model name sidecar for mismatch detection on next load + if let Ok(embedder) = self.embedder.lock() { + let model_path = self.config.data_dir.join("embedding_model.txt"); + let _ = std::fs::write(&model_path, embedder.text_provider_name()); + } + Ok(()) } @@ -2675,7 +2681,7 @@ impl EchoEngine { /// Returns `ShrimPKError::Persistence` if store file is corrupted. #[instrument(skip(config), fields(data_dir = %config.data_dir.display()))] pub fn load(config: EchoConfig) -> Result { - let mut embedder = MultiEmbedder::new()?; + let mut embedder = MultiEmbedder::new(&config)?; // Initialize label prototypes (ADR-015) let mut prototypes = crate::labels::LabelPrototypes::new_empty(); @@ -2689,6 +2695,37 @@ impl EchoEngine { let store_path = config.data_dir.join("echo_store.shrm"); let mut loaded_store = EchoStore::load(&store_path)?; + // KS75: Dimension mismatch detection — hard error if stored vectors don't match config + if let Some(first_emb) = loaded_store.all_embeddings().first() { + let stored_dim = first_emb.len(); + let config_dim = embedder.text_dimension(); + if stored_dim != config_dim { + return Err(ShrimPKError::Embedding(format!( + "Embedding dimension mismatch: stored data has {stored_dim}-dim vectors \ + but current model '{}' produces {config_dim}-dim. \ + Either switch back to the original model or clear the store with /api/clear.", + embedder.text_provider_name() + ))); + } + } + + // KS75: Model name sidecar — warn if model changed (same dim, different model = mixed space) + let model_sidecar = config.data_dir.join("embedding_model.txt"); + if model_sidecar.exists() + && let Ok(stored_model) = std::fs::read_to_string(&model_sidecar) + { + let stored_model = stored_model.trim(); + let current_model = embedder.text_provider_name(); + if stored_model != current_model && !loaded_store.all_entries().is_empty() { + tracing::warn!( + stored_model = %stored_model, + current_model = %current_model, + "Embedding model changed since last persist. \ + Vectors from different models in the same space may degrade similarity quality." + ); + } + } + // Load community summaries sidecar (KS64) if let Err(e) = crate::persistence::load_community_summaries(&mut loaded_store, &config.data_dir) diff --git a/crates/shrimpk-memory/src/embedder.rs b/crates/shrimpk-memory/src/embedder.rs index 402d03e..d6c7207 100644 --- a/crates/shrimpk-memory/src/embedder.rs +++ b/crates/shrimpk-memory/src/embedder.rs @@ -1,34 +1,33 @@ -//! Multi-channel embedding via fastembed. +//! Multi-channel embedding with pluggable text provider (KS75). //! -//! Wraps `fastembed::TextEmbedding` with the BGE-small-EN-v1.5 model -//! for 384-dimensional sentence embeddings. Vision (CLIP 512-dim) and -//! Speech (640-dim) channels are gated behind `vision` and `speech` -//! feature flags. -//! -//! When `vision` is enabled, loads two additional models: -//! - CLIP ViT-B-32 *vision* encoder (`ImageEmbedding`) — embeds images to 512-dim. -//! - CLIP ViT-B-32 *text* encoder (`TextEmbedding`) — embeds text to the same 512-dim -//! space, enabling cross-modal text-to-image retrieval. +//! Text channel delegates to an `EmbeddingProvider` implementation selected +//! at runtime via `EchoConfig` (default: fastembed BGE-small-EN-v1.5, 384-dim). +//! Vision (CLIP 512-dim) and Speech (640-dim) channels are gated behind +//! `vision` and `speech` feature flags. +#[cfg(feature = "vision")] use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; -use shrimpk_core::{Result, ShrimPKError}; +#[cfg(any(feature = "vision", feature = "speech"))] +use shrimpk_core::ShrimPKError; +use shrimpk_core::{EchoConfig, EmbeddingProvider, Result}; use tracing::instrument; /// Multi-channel embedder for text, vision, and speech modalities. /// -/// Text channel (always available): BGE-small-EN-v1.5, 384-dim. +/// Text channel delegates to a pluggable `EmbeddingProvider` (KS75). /// Vision channel (feature = "vision"): CLIP ViT-B-32, 512-dim. /// Speech channel (feature = "speech"): ECAPA-TDNN (256) + Whisper-tiny encoder (384) = 640-dim. /// -/// Thread-safe: `TextEmbedding` and `ImageEmbedding` are `Send` (but not `Sync`), +/// Thread-safe: providers are `Send` (but not `Sync`), /// so share via `Mutex` or create per-thread instances. pub struct MultiEmbedder { - text: TextEmbedding, + /// Pluggable text embedding provider (fastembed or OpenAI-compatible API). + text_provider: Box, /// CLIP vision encoder — embeds images into 512-dim CLIP space. #[cfg(feature = "vision")] vision: Option, /// CLIP text encoder — embeds text into the same 512-dim CLIP space. - /// Separate from `text` (MiniLM 384-dim) because the embedding spaces are incompatible. + /// Separate from text provider because the embedding spaces are incompatible. #[cfg(feature = "vision")] vision_text: Option, /// Speech embedder — 2 ONNX models producing a 640-dim paralinguistic embedding. @@ -38,32 +37,18 @@ pub struct MultiEmbedder { } impl MultiEmbedder { - /// Initialize the multi-channel embedder. + /// Initialize the multi-channel embedder from config. /// - /// Always loads the text model (BGE-small-EN-v1.5, 384-dim). - /// When the `vision` feature is enabled, also attempts to load - /// CLIP ViT-B-32 vision + text encoders (512-dim). If CLIP fails - /// to initialize, vision is disabled gracefully — text still works. + /// The text channel is delegated to an `EmbeddingProvider` selected by + /// `config.embedding_provider` (default: fastembed BGE-small-EN-v1.5). + /// Vision/speech channels are unchanged (compile-time feature flags). /// /// # Errors - /// Returns `ShrimPKError::Embedding` if the *text* model fails to initialize. + /// Returns `ShrimPKError::Embedding` if the text provider fails to initialize. /// Vision model failures are logged as warnings and result in `vision = None`. - #[instrument] - pub fn new() -> Result { - let start = std::time::Instant::now(); - - let text = TextEmbedding::try_new(InitOptions::new(EmbeddingModel::BGESmallENV15)) - .map_err(|e| { - ShrimPKError::Embedding(format!("Failed to init BGE-small-EN-v1.5: {e}")) - })?; - - let elapsed = start.elapsed(); - tracing::info!( - elapsed_ms = elapsed.as_millis(), - model = "BGE-small-EN-v1.5", - dim = 384, - "MultiEmbedder initialized (text channel)" - ); + #[instrument(skip(config))] + pub fn new(config: &EchoConfig) -> Result { + let text_provider = crate::embedding_provider::from_config(config)?; #[cfg(feature = "vision")] let (vision, vision_text) = { @@ -119,7 +104,7 @@ impl MultiEmbedder { }; Ok(Self { - text, + text_provider, #[cfg(feature = "vision")] vision, #[cfg(feature = "vision")] @@ -129,7 +114,7 @@ impl MultiEmbedder { }) } - /// Embed a single text string into a 384-dimensional vector. + /// Embed a single text string into a vector of `text_dimension()` dimensions. /// /// # Errors /// Returns `ShrimPKError::Embedding` if embedding generation fails. @@ -137,20 +122,11 @@ impl MultiEmbedder { pub fn embed_text(&mut self, text: &str) -> Result> { let start = std::time::Instant::now(); - let results = self - .text - .embed(vec![text.to_string()], None) - .map_err(|e| ShrimPKError::Embedding(format!("Embed failed: {e}")))?; - - let embedding = results - .into_iter() - .next() - .ok_or_else(|| ShrimPKError::Embedding("Empty embedding result".into()))?; + let embedding = self.text_provider.embed(text)?; - let elapsed = start.elapsed(); tracing::debug!( dim = embedding.len(), - elapsed_us = elapsed.as_micros(), + elapsed_us = start.elapsed().as_micros(), "Single text embed complete" ); @@ -159,8 +135,7 @@ impl MultiEmbedder { /// Batch-embed multiple texts. /// - /// More efficient than calling `embed_text()` in a loop because - /// fastembed batches the ONNX inference. + /// Delegates to the provider's native batch implementation for efficiency. /// /// # Errors /// Returns `ShrimPKError::Embedding` if any embedding generation fails. @@ -173,17 +148,13 @@ impl MultiEmbedder { let start = std::time::Instant::now(); let count = texts.len(); - let results = self - .text - .embed(texts, None) - .map_err(|e| ShrimPKError::Embedding(format!("Batch embed failed: {e}")))?; + let results = self.text_provider.embed_batch(texts)?; - let elapsed = start.elapsed(); tracing::debug!( count = count, - elapsed_ms = elapsed.as_millis(), + elapsed_ms = start.elapsed().as_millis(), avg_us = if count > 0 { - elapsed.as_micros() / count as u128 + start.elapsed().as_micros() / count as u128 } else { 0 }, @@ -193,9 +164,14 @@ impl MultiEmbedder { Ok(results) } - /// Get the text embedding dimension (384 for BGE-small-EN-v1.5). + /// Get the text embedding dimension from the active provider. pub fn text_dimension(&self) -> usize { - 384 + self.text_provider.dimension() + } + + /// Get the human-readable name of the active text embedding provider. + pub fn text_provider_name(&self) -> &str { + self.text_provider.name() } /// Embed an image into a 512-dimensional CLIP vector. @@ -337,14 +313,16 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn embedder_initializes() { - let embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); assert_eq!(embedder.text_dimension(), 384); } #[test] #[ignore = "requires fastembed model download"] fn embed_single_text() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let embedding = embedder.embed_text("Hello world").expect("Should embed"); assert_eq!( embedding.len(), @@ -356,7 +334,8 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn embed_batch_texts() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let texts = vec![ "The cat sat on the mat".to_string(), "Dogs are loyal companions".to_string(), @@ -372,7 +351,8 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn similar_texts_have_higher_similarity() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let cat = embedder.embed_text("The cat sat on the mat").unwrap(); let kitten = embedder.embed_text("A kitten rests on a rug").unwrap(); let code = embedder @@ -392,7 +372,8 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn embed_batch_empty_returns_empty() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let embeddings = embedder .embed_batch(Vec::new()) .expect("Should handle empty"); @@ -406,7 +387,8 @@ mod tests { #[test] #[ignore = "requires CLIP model download (~352 MB)"] fn clip_vision_initializes() { - let embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); assert!(embedder.has_vision(), "CLIP vision should be available"); assert_eq!(embedder.vision_dimension(), 512); } @@ -415,7 +397,8 @@ mod tests { #[test] #[ignore = "requires CLIP model download (~352 MB)"] fn embed_image_produces_512_dim() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); // Create a minimal 2x2 red PNG image let png_data = create_test_png(2, 2, [255, 0, 0]); @@ -443,7 +426,8 @@ mod tests { #[test] #[ignore = "requires CLIP model download (~352 MB)"] fn embed_text_for_vision_produces_512_dim() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let embedding = embedder .embed_text_for_vision("a photo of a cat") @@ -463,7 +447,8 @@ mod tests { fn clip_cross_modal_similarity() { // CLIP's key property: text and image embeddings in the same space // should have positive similarity for matching concepts. - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); // Embed a red image let red_png = create_test_png(32, 32, [255, 0, 0]); @@ -496,7 +481,7 @@ mod tests { #[ignore = "requires CLIP model download (~352 MB)"] fn clip_init_latency_under_5s() { let start = std::time::Instant::now(); - let _embedder = MultiEmbedder::new().expect("Should init"); + let _embedder = MultiEmbedder::new(&EchoConfig::default()).expect("Should init"); let elapsed = start.elapsed(); assert!( elapsed.as_secs() < 10, // generous to account for cold cache diff --git a/tests/echo_precision_tuning.rs b/tests/echo_precision_tuning.rs index e524521..f043fea 100644 --- a/tests/echo_precision_tuning.rs +++ b/tests/echo_precision_tuning.rs @@ -588,7 +588,8 @@ async fn threshold_range_sweep() { println!(); // Collect similarity scores for all pairs using raw embeddings - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); // Build memory embeddings map let mut mem_embeddings: Vec<(&str, Vec)> = Vec::new(); @@ -730,7 +731,8 @@ async fn threshold_range_sweep() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn query_formulation_analysis() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); let positives = positive_pairs(); @@ -849,7 +851,8 @@ async fn query_formulation_analysis() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn memory_formulation_analysis() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); println!(); println!("===================================================================="); @@ -1198,7 +1201,8 @@ async fn context_window_simulation() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn recommended_configuration() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); let positives = positive_pairs(); let negatives = negative_pairs(); @@ -1533,7 +1537,8 @@ async fn recommended_configuration() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn hardest_pairs_deep_dive() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); println!(); println!("===================================================================="); diff --git a/tests/echo_token_efficiency.rs b/tests/echo_token_efficiency.rs index 0ab88c4..b69d394 100644 --- a/tests/echo_token_efficiency.rs +++ b/tests/echo_token_efficiency.rs @@ -919,7 +919,8 @@ async fn vllm_throughput_projection() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn context_quality_comparison() { - let mut embedder = MultiEmbedder::new().expect("embedder should initialize"); + let mut embedder = MultiEmbedder::new(&shrimpk_core::EchoConfig::default()) + .expect("embedder should initialize"); let scenarios = scenarios(); println!("\n{}", "=".repeat(90)); From c4fefb8d758713dd70e2e251fc4099c4792fafd5 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 20:28:22 +0300 Subject: [PATCH 05/13] KS75: update daemon/MCP/CLI config endpoints for embedding settings - Daemon config_show: add embedding_provider, embedding_model, embedding_api_url - Daemon config_set: support setting all 3 embedding fields with validation - MCP config_show: add Embedding Provider section with source tracking - MCP config_set: add embedding_provider, embedding_model, embedding_api_url - CLI config show: add Embedding Provider section - CLI config set: add embedding_provider, embedding_model, embedding_api_url Co-Authored-By: Claude Opus 4.6 --- cli/src/main.rs | 28 ++++++++++++++++++++++++++++ crates/shrimpk-daemon/src/routes.rs | 14 +++++++++++++- crates/shrimpk-mcp/src/tools.rs | 26 ++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index 7f20dcc..b163793 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -850,6 +850,29 @@ fn cmd_config_show(config: &EchoConfig) { "auto" } ); + println!(); + println!(" Embedding Provider:"); + println!( + " {:25} {:>15} {}", + "embedding_provider", + config.embedding_provider.to_string(), + source( + "SHRIMPK_EMBEDDING_PROVIDER", + fc.embedding_provider.is_some() + ) + ); + println!( + " {:25} {:>15} {}", + "embedding_model", + &config.embedding_model, + source("SHRIMPK_EMBEDDING_MODEL", fc.embedding_model.is_some()) + ); + println!( + " {:25} {:>15} {}", + "embedding_api_url", + &config.embedding_api_url, + source("SHRIMPK_EMBEDDING_API_URL", fc.embedding_api_url.is_some()) + ); } fn cmd_config_set(key: &str, value: &str) -> anyhow::Result<()> { @@ -872,6 +895,11 @@ fn cmd_config_set(key: &str, value: &str) -> anyhow::Result<()> { "enrichment_model" => fc.enrichment_model = Some(value.to_string()), "consolidation_provider" => fc.consolidation_provider = Some(value.to_string()), "max_facts_per_memory" => fc.max_facts_per_memory = Some(value.parse()?), + "embedding_provider" => { + fc.embedding_provider = Some(value.parse().map_err(|e: String| anyhow::anyhow!(e))?) + } + "embedding_model" => fc.embedding_model = Some(value.to_string()), + "embedding_api_url" => fc.embedding_api_url = Some(value.to_string()), other => anyhow::bail!("Unknown config key: \"{other}\""), } diff --git a/crates/shrimpk-daemon/src/routes.rs b/crates/shrimpk-daemon/src/routes.rs index f969ef7..52e34a1 100644 --- a/crates/shrimpk-daemon/src/routes.rs +++ b/crates/shrimpk-daemon/src/routes.rs @@ -346,7 +346,10 @@ pub async fn config_show(State(state): State) -> Json { "proxy_max_echo_results": c.proxy_max_echo_results, "hebbian_half_life_secs": c.hebbian_half_life_secs, "hebbian_prune_threshold": c.hebbian_prune_threshold, - "proxy_max_conversation_turns": c.proxy_max_conversation_turns + "proxy_max_conversation_turns": c.proxy_max_conversation_turns, + "embedding_provider": c.embedding_provider.to_string(), + "embedding_model": c.embedding_model, + "embedding_api_url": c.embedding_api_url })) } @@ -440,6 +443,15 @@ pub async fn config_set( ) })?) } + "embedding_provider" => { + fc.embedding_provider = Some( + req.value + .parse() + .map_err(|e: String| (StatusCode::BAD_REQUEST, Json(json!({"error": e}))))?, + ) + } + "embedding_model" => fc.embedding_model = Some(req.value.clone()), + "embedding_api_url" => fc.embedding_api_url = Some(req.value.clone()), other => { return Err(( StatusCode::BAD_REQUEST, diff --git a/crates/shrimpk-mcp/src/tools.rs b/crates/shrimpk-mcp/src/tools.rs index 359025e..3ec2ac2 100644 --- a/crates/shrimpk-mcp/src/tools.rs +++ b/crates/shrimpk-mcp/src/tools.rs @@ -701,6 +701,29 @@ pub fn handle_config_show(config: &EchoConfig) -> Result { } ), String::new(), + " Embedding Provider:".to_string(), + format!( + " {:25} {:>15} {}", + "embedding_provider", + config.embedding_provider.to_string(), + source( + "SHRIMPK_EMBEDDING_PROVIDER", + fc.embedding_provider.is_some() + ) + ), + format!( + " {:25} {:>15} {}", + "embedding_model", + truncate(&config.embedding_model, 30), + source("SHRIMPK_EMBEDDING_MODEL", fc.embedding_model.is_some()) + ), + format!( + " {:25} {:>15} {}", + "embedding_api_url", + truncate(&config.embedding_api_url, 30), + source("SHRIMPK_EMBEDDING_API_URL", fc.embedding_api_url.is_some()) + ), + String::new(), " Intelligence Engine:".to_string(), format!( " {:25} {:>15} {}", @@ -817,6 +840,9 @@ pub fn handle_config_set(args: &Value) -> Result { "use_full_actr_history" => { fc.use_full_actr_history = Some(value.parse().map_err(|_| "Invalid boolean")?) } + "embedding_provider" => fc.embedding_provider = Some(value.parse().map_err(|e: String| e)?), + "embedding_model" => fc.embedding_model = Some(value.to_string()), + "embedding_api_url" => fc.embedding_api_url = Some(value.to_string()), other => return Err(format!("Unknown config key: \"{other}\"")), } From bda744f6845e1ce3021c06ffd917d504b1c0d99e Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 22:07:16 +0300 Subject: [PATCH 06/13] =?UTF-8?q?KS75:=20fix=20Greptile=20P1s=20=E2=80=94?= =?UTF-8?q?=20dim=20override=20+=20block=5Fin=5Fplace=20for=20API=20embedd?= =?UTF-8?q?ings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix SHRIMPK_EMBEDDING_DIM env var: apply value, not just check existence - Track file-config embedding_dim override to prevent infer_embedding_dim() from silently overwriting user-set dimensions - Add embed_blocking() helper using tokio::task::block_in_place to prevent worker-thread starvation with API-based embedding providers (30s timeout) - Convert all async embedder call sites (store, echo, vision, audio, entity) to use embed_blocking() Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-core/src/config.rs | 10 ++- crates/shrimpk-memory/src/echo.rs | 118 ++++++++++++------------------ 2 files changed, 55 insertions(+), 73 deletions(-) diff --git a/crates/shrimpk-core/src/config.rs b/crates/shrimpk-core/src/config.rs index bcfa14f..02c321d 100644 --- a/crates/shrimpk-core/src/config.rs +++ b/crates/shrimpk-core/src/config.rs @@ -737,6 +737,7 @@ pub fn resolve_config() -> crate::Result { let mut config = EchoConfig::auto_detect(); // Layer 2: file overrides + let mut dim_set_by_file = false; if let Some(fc) = load_config_file()? { if let Some(v) = fc.max_memories { config.max_memories = v; @@ -758,6 +759,7 @@ pub fn resolve_config() -> crate::Result { } if let Some(v) = fc.embedding_dim { config.embedding_dim = v; + dim_set_by_file = true; } if let Some(v) = fc.use_lsh { config.use_lsh = v; @@ -960,9 +962,11 @@ pub fn resolve_config() -> crate::Result { config.embedding_api_url = v; } - // Auto-infer embedding_dim from model name if not explicitly overridden - let dim_explicitly_set = std::env::var("SHRIMPK_EMBEDDING_DIM").is_ok(); - if !dim_explicitly_set { + // Auto-infer embedding_dim from model name unless explicitly overridden + // by either env var (Layer 3) or config file (Layer 2). + if let Some(v) = env_usize("SHRIMPK_EMBEDDING_DIM")? { + config.embedding_dim = v; + } else if !dim_set_by_file { config.embedding_dim = config.infer_embedding_dim(); } diff --git a/crates/shrimpk-memory/src/echo.rs b/crates/shrimpk-memory/src/echo.rs index ee62754..0baf605 100644 --- a/crates/shrimpk-memory/src/echo.rs +++ b/crates/shrimpk-memory/src/echo.rs @@ -346,12 +346,7 @@ impl EchoEngine { // - Reformulated text if available (structured form embeds better) // - Otherwise original text (semantic meaning preserved) let embed_text = reformulated.as_deref().unwrap_or(text); - let embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(embed_text)? - }; + let embedding = self.embed_blocking(|e| e.embed_text(embed_text))?; // 4. Build entry with auto-categorization for adaptive decay let category = self.reformulator.categorize(text); @@ -559,12 +554,7 @@ impl EchoEngine { } // 1. Embed image with CLIP - let vision_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_image(image_data)? - }; + let vision_embedding = self.embed_blocking(|e| e.embed_image(image_data))?; let vision_embedding = vision_embedding.ok_or_else(|| { ShrimPKError::Embedding("Vision model not available — cannot embed image".into()) @@ -573,10 +563,7 @@ impl EchoEngine { // 2. Build content and optional text embedding for cross-modal recall let content = description.unwrap_or("[image]").to_string(); let text_embedding = if let Some(desc) = description { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(desc)? + self.embed_blocking(|e| e.embed_text(desc))? } else { Vec::new() }; @@ -700,12 +687,7 @@ impl EchoEngine { } // 1. Embed audio with speech stack - let speech_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_audio(pcm_f32, sample_rate)? - }; + let speech_embedding = self.embed_blocking(|e| e.embed_audio(pcm_f32, sample_rate))?; let speech_embedding = speech_embedding.ok_or_else(|| { ShrimPKError::Embedding("Speech models not available — cannot embed audio".into()) @@ -714,10 +696,7 @@ impl EchoEngine { // 2. Build content and optional text embedding for cross-modal recall let content = description.unwrap_or("[audio]").to_string(); let text_embedding = if let Some(desc) = description { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(desc)? + self.embed_blocking(|e| e.embed_text(desc))? } else { Vec::new() }; @@ -840,35 +819,32 @@ impl EchoEngine { let reformulated = self.reformulator.reformulate(text_for_reformulation); let embed_text = reformulated.as_deref().unwrap_or(text); - let (text_embedding, vision_embedding, speech_embedding) = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; + let (text_embedding, vision_embedding, speech_embedding) = + self.embed_blocking(|embedder| { + let text_emb = embedder.embed_text(embed_text)?; - let text_emb = embedder.embed_text(embed_text)?; - - // 2. Optional vision embedding - #[cfg(feature = "vision")] - let vis_emb = if let Some(img) = image_data { - embedder.embed_image(img)? - } else { - None - }; - #[cfg(not(feature = "vision"))] - let vis_emb: Option> = None; + // 2. Optional vision embedding + #[cfg(feature = "vision")] + let vis_emb = if let Some(img) = image_data { + embedder.embed_image(img)? + } else { + None + }; + #[cfg(not(feature = "vision"))] + let vis_emb: Option> = None; - // 3. Optional speech embedding - #[cfg(feature = "speech")] - let speech_emb = if let Some((pcm, sr)) = audio_pcm { - embedder.embed_audio(pcm, sr)? - } else { - None - }; - #[cfg(not(feature = "speech"))] - let speech_emb: Option> = None; + // 3. Optional speech embedding + #[cfg(feature = "speech")] + let speech_emb = if let Some((pcm, sr)) = audio_pcm { + embedder.embed_audio(pcm, sr)? + } else { + None + }; + #[cfg(not(feature = "speech"))] + let speech_emb: Option> = None; - (text_emb, vis_emb, speech_emb) - }; + Ok((text_emb, vis_emb, speech_emb)) + })?; // 4. Build entry with all embeddings let category = self.reformulator.categorize(text); @@ -1070,12 +1046,7 @@ impl EchoEngine { }; // 1. Embed the (possibly expanded) query - let query_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(&effective_query)? - }; + let query_embedding = self.embed_blocking(|e| e.embed_text(&effective_query))?; // 2. Bloom filter pre-check — skip everything if no fingerprints match. // Bypass for small stores (< 50 entries) where Bloom adds risk without benefit. @@ -1817,12 +1788,7 @@ impl EchoEngine { let start = std::time::Instant::now(); // 1. Embed query with CLIP text encoder - let query_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text_for_vision(query)? - }; + let query_embedding = self.embed_blocking(|e| e.embed_text_for_vision(query))?; let query_embedding = match query_embedding { Some(emb) => emb, @@ -2300,12 +2266,7 @@ impl EchoEngine { } // Embed the entity name for ranking - let mut embedder = self - .embedder - .lock() - .map_err(|e| ShrimPKError::Memory(format!("lock: {e}")))?; - let query_emb = embedder.embed_text(entity)?; - drop(embedder); + let query_emb = self.embed_blocking(|e| e.embed_text(entity))?; let mut scored: Vec<(usize, f32)> = indices .iter() @@ -3004,6 +2965,23 @@ impl EchoEngine { } } + /// Lock the embedder and run a blocking embedding operation. + /// + /// Uses `tokio::task::block_in_place` to inform the Tokio scheduler that + /// this thread will block, preventing worker-thread starvation. Critical + /// for API-based providers (OpenAI) where network I/O can take seconds; + /// harmless for local fastembed calls (~5ms). + fn embed_blocking(&self, f: F) -> Result + where + F: FnOnce(&mut MultiEmbedder) -> Result, + { + let mut embedder = self + .embedder + .lock() + .map_err(|e| ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")))?; + tokio::task::block_in_place(|| f(&mut embedder)) + } + /// Test-only: generate an embedding for text using the engine's embedder. /// /// Provides access to the same embedding model used by `store()` so that From 32c4047692cd6d548a844c255a2513a93312c624 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 22:17:15 +0300 Subject: [PATCH 07/13] KS75: fix clippy field_reassign_with_default + add missing infer_embedding_dim models - Rewrite infer_embedding_dim_known_models and infer_embedding_dim_unknown_falls_back tests to use struct-init syntax with ..Default::default() instead of field reassignment, fixing clippy field_reassign_with_default lint - Add bge-m3 (1024), gte-large (1024), gte-base (768) branches to infer_embedding_dim() to match resolve_fastembed_model in shrimpk-memory - Expand test coverage to include the three new model entries Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-core/src/config.rs | 52 ++++++++++++++++++------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/crates/shrimpk-core/src/config.rs b/crates/shrimpk-core/src/config.rs index 02c321d..80b1521 100644 --- a/crates/shrimpk-core/src/config.rs +++ b/crates/shrimpk-core/src/config.rs @@ -552,6 +552,9 @@ impl EchoConfig { s if s.contains("bge-small") => 384, s if s.contains("bge-base") => 768, s if s.contains("bge-large") => 1024, + s if s.contains("bge-m3") => 1024, + s if s.contains("gte-large") => 1024, + s if s.contains("gte-base") => 768, s if s.contains("minilm-l6") => 384, s if s.contains("minilm-l12") => 384, // OpenAI @@ -1429,32 +1432,37 @@ mod tests { #[test] fn infer_embedding_dim_known_models() { - let mut config = EchoConfig::default(); - - config.embedding_model = "BGE-small-EN-v1.5".into(); - assert_eq!(config.infer_embedding_dim(), 384); - - config.embedding_model = "BGE-base-EN-v1.5".into(); - assert_eq!(config.infer_embedding_dim(), 768); - - config.embedding_model = "BGE-large-EN-v1.5".into(); - assert_eq!(config.infer_embedding_dim(), 1024); - - config.embedding_model = "text-embedding-3-small".into(); - assert_eq!(config.infer_embedding_dim(), 1536); - - config.embedding_model = "text-embedding-3-large".into(); - assert_eq!(config.infer_embedding_dim(), 3072); - - config.embedding_model = "nomic-embed-text".into(); - assert_eq!(config.infer_embedding_dim(), 768); + let cases: &[(&str, usize)] = &[ + ("BGE-small-EN-v1.5", 384), + ("BGE-base-EN-v1.5", 768), + ("BGE-large-EN-v1.5", 1024), + ("bge-m3", 1024), + ("gte-large-en-v1.5", 1024), + ("gte-base-en-v1.5", 768), + ("text-embedding-3-small", 1536), + ("text-embedding-3-large", 3072), + ("nomic-embed-text", 768), + ]; + for &(model, expected_dim) in cases { + let config = EchoConfig { + embedding_model: model.into(), + ..Default::default() + }; + assert_eq!( + config.infer_embedding_dim(), + expected_dim, + "model '{model}'" + ); + } } #[test] fn infer_embedding_dim_unknown_falls_back() { - let mut config = EchoConfig::default(); - config.embedding_model = "my-custom-model".into(); - config.embedding_dim = 512; + let config = EchoConfig { + embedding_model: "my-custom-model".into(), + embedding_dim: 512, + ..Default::default() + }; assert_eq!( config.infer_embedding_dim(), 512, From dbb0c667b0b1cca61213629687878fc511d49e86 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 23:15:23 +0300 Subject: [PATCH 08/13] =?UTF-8?q?KS75:=20fix=20LSH=20dimension=20=E2=80=94?= =?UTF-8?q?=20use=20embedder.text=5Fdimension()=20not=20config.embedding?= =?UTF-8?q?=5Fdim?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - EchoEngine::new(): capture text_dim, vision_dimension(), speech_dimension() from embedder before it is moved into Mutex - EchoEngine::load(): same fix for all three LSH index rebuilds - Prevents dimension mismatch when config.embedding_dim is stale or auto-inference failed but the embedder knows its real output dimension - All 6 CosineHash::new calls in echo.rs now use embedder methods Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-memory/src/echo.rs | 69 ++++++++++++++++++------------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/crates/shrimpk-memory/src/echo.rs b/crates/shrimpk-memory/src/echo.rs index 0baf605..d1fbc50 100644 --- a/crates/shrimpk-memory/src/echo.rs +++ b/crates/shrimpk-memory/src/echo.rs @@ -194,13 +194,43 @@ impl EchoEngine { let pii_filter = PiiFilter::new(); let reformulator = MemoryReformulator::new(); let store = RwLock::new(EchoStore::new()); - let text_lsh = CosineHash::new(config.embedding_dim, 16, 10); + // KS75: use embedder's actual dimension, not config's possibly-stale value + let text_dim = embedder.text_dimension(); + let text_lsh = CosineHash::new(text_dim, 16, 10); let bloom = TopicFilter::new(config.max_memories, 0.01); + // Pre-build vision/speech LSH before embedder is moved into Mutex + #[cfg(feature = "vision")] + let vision_lsh_init = if config + .enabled_modalities + .contains(&shrimpk_core::Modality::Vision) + { + Some(Mutex::new(CosineHash::new( + embedder.vision_dimension(), + 16, + 10, + ))) + } else { + None + }; + #[cfg(feature = "speech")] + let speech_lsh_init = if config + .enabled_modalities + .contains(&shrimpk_core::Modality::Speech) + { + Some(Mutex::new(CosineHash::new( + embedder.speech_dimension(), + 16, + 10, + ))) + } else { + None + }; + tracing::info!( max_memories = config.max_memories, threshold = config.similarity_threshold, - dim = config.embedding_dim, + dim = text_dim, use_lsh = config.use_lsh, use_bloom = config.use_bloom, "EchoEngine initialized (empty store)" @@ -213,31 +243,9 @@ impl EchoEngine { store, text_lsh: Mutex::new(text_lsh), #[cfg(feature = "vision")] - vision_lsh: if config - .enabled_modalities - .contains(&shrimpk_core::Modality::Vision) - { - Some(Mutex::new(CosineHash::new( - config.vision_embedding_dim, - 16, - 10, - ))) - } else { - None - }, + vision_lsh: vision_lsh_init, #[cfg(feature = "speech")] - speech_lsh: if config - .enabled_modalities - .contains(&shrimpk_core::Modality::Speech) - { - Some(Mutex::new(CosineHash::new( - config.speech_embedding_dim, - 16, - 10, - ))) - } else { - None - }, + speech_lsh: speech_lsh_init, bloom: RwLock::new(bloom), bloom_dirty: Mutex::new(false), pii_filter, @@ -2699,8 +2707,11 @@ impl EchoEngine { tracing::warn!(error = %e, "Failed to load entities, continuing without"); } + // KS75: use embedder's actual dimension, not config's possibly-stale value + let text_dim = embedder.text_dimension(); + // Rebuild text LSH index from loaded embeddings - let mut text_lsh = CosineHash::new(config.embedding_dim, 16, 10); + let mut text_lsh = CosineHash::new(text_dim, 16, 10); if config.use_lsh { for (i, embedding) in loaded_store.all_embeddings().iter().enumerate() { text_lsh.insert(i as u32, embedding); @@ -2736,7 +2747,7 @@ impl EchoEngine { .enabled_modalities .contains(&shrimpk_core::Modality::Vision) { - let mut vlsh = CosineHash::new(config.vision_embedding_dim, 16, 10); + let mut vlsh = CosineHash::new(embedder.vision_dimension(), 16, 10); let mut vision_count = 0usize; for (i, entry) in loaded_store.all_entries().iter().enumerate() { if let Some(ref ve) = entry.vision_embedding { @@ -2761,7 +2772,7 @@ impl EchoEngine { .enabled_modalities .contains(&shrimpk_core::Modality::Speech) { - let mut slsh = CosineHash::new(config.speech_embedding_dim, 16, 10); + let mut slsh = CosineHash::new(embedder.speech_dimension(), 16, 10); let mut speech_count = 0usize; for (i, entry) in loaded_store.all_entries().iter().enumerate() { if let Some(ref se) = entry.speech_embedding { From 2c4a4842fb393b9e64046f94dfc6784051da7491 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Wed, 8 Apr 2026 23:31:10 +0300 Subject: [PATCH 09/13] =?UTF-8?q?KS75:=20fix=20block=5Fin=5Fplace=20panic?= =?UTF-8?q?=20=E2=80=94=20detect=20runtime=20flavor,=20fallback=20for=20cu?= =?UTF-8?q?rrent=5Fthread?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace unconditional `tokio::task::block_in_place` in `embed_blocking()` with runtime-flavor detection via `Handle::try_current()` - Multi-thread runtime (daemon): still uses block_in_place to prevent worker starvation - Current-thread runtime (#[tokio::test]) or non-Tokio context: falls back to direct call, avoiding the panic - Updated doc comment explaining the three-branch behavior Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-memory/src/echo.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/crates/shrimpk-memory/src/echo.rs b/crates/shrimpk-memory/src/echo.rs index d1fbc50..fc48c73 100644 --- a/crates/shrimpk-memory/src/echo.rs +++ b/crates/shrimpk-memory/src/echo.rs @@ -2978,10 +2978,12 @@ impl EchoEngine { /// Lock the embedder and run a blocking embedding operation. /// - /// Uses `tokio::task::block_in_place` to inform the Tokio scheduler that - /// this thread will block, preventing worker-thread starvation. Critical - /// for API-based providers (OpenAI) where network I/O can take seconds; - /// harmless for local fastembed calls (~5ms). + /// On a **multi-thread** Tokio runtime (the daemon) this uses + /// `tokio::task::block_in_place` to inform the scheduler that the + /// current thread will block, preventing worker-thread starvation. + /// On a **current-thread** runtime (`#[tokio::test]`) or outside + /// Tokio entirely (sync tests, CLI) we call `f` directly, because + /// `block_in_place` panics on a single-threaded runtime. fn embed_blocking(&self, f: F) -> Result where F: FnOnce(&mut MultiEmbedder) -> Result, @@ -2990,7 +2992,14 @@ impl EchoEngine { .embedder .lock() .map_err(|e| ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")))?; - tokio::task::block_in_place(|| f(&mut embedder)) + // Use block_in_place on multi-thread runtime to prevent worker starvation. + // Fall back to direct call on current_thread runtime (tests) or outside Tokio. + match tokio::runtime::Handle::try_current() { + Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { + tokio::task::block_in_place(|| f(&mut embedder)) + } + _ => f(&mut embedder), + } } /// Test-only: generate an embedding for text using the engine's embedder. From 6db35dc2fbe71a66084ae771d988d6e76cdaaec1 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Thu, 9 Apr 2026 00:29:22 +0300 Subject: [PATCH 10/13] =?UTF-8?q?KS75:=20move=20Mutex=20lock=20inside=20bl?= =?UTF-8?q?ock=5Fin=5Fplace=20=E2=80=94=20prevent=20contended=20lock=20sta?= =?UTF-8?q?rvation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - embed_blocking() previously acquired the MultiEmbedder Mutex before entering block_in_place, so a second task blocking on lock() would silently starve a Tokio worker thread for up to 30s. - Move both lock acquisition and inference into the block_in_place closure so Tokio can schedule other tasks while waiting for the lock. - Non-multi-thread fallback path unchanged (tests, CLI, single-thread). Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-memory/src/echo.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/crates/shrimpk-memory/src/echo.rs b/crates/shrimpk-memory/src/echo.rs index fc48c73..9bafce4 100644 --- a/crates/shrimpk-memory/src/echo.rs +++ b/crates/shrimpk-memory/src/echo.rs @@ -2988,17 +2988,25 @@ impl EchoEngine { where F: FnOnce(&mut MultiEmbedder) -> Result, { - let mut embedder = self - .embedder - .lock() - .map_err(|e| ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")))?; + let embedder_mutex = &self.embedder; // Use block_in_place on multi-thread runtime to prevent worker starvation. - // Fall back to direct call on current_thread runtime (tests) or outside Tokio. + // Both the lock acquisition AND inference run inside block_in_place so that + // a contended lock() does not silently block a Tokio worker thread. match tokio::runtime::Handle::try_current() { Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { - tokio::task::block_in_place(|| f(&mut embedder)) + tokio::task::block_in_place(|| { + let mut embedder = embedder_mutex.lock().map_err(|e| { + ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) + })?; + f(&mut embedder) + }) + } + _ => { + let mut embedder = embedder_mutex.lock().map_err(|e| { + ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) + })?; + f(&mut embedder) } - _ => f(&mut embedder), } } From 29f2469fcae338ad196ca745e4dc7eadcfe27e13 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Thu, 9 Apr 2026 00:45:10 +0300 Subject: [PATCH 11/13] KS75: add Blocking safety docs to OpenAIProvider for Greptile P1 - Document that OpenAIProvider uses synchronous ureq HTTP on both the struct and call_api method, with explicit guidance that async callers must go through EchoEngine::embed_blocking(). - Satisfies Greptile flag on lines 200-204 by making the blocking contract visible in the provider file itself. Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-memory/src/embedding_provider.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/crates/shrimpk-memory/src/embedding_provider.rs b/crates/shrimpk-memory/src/embedding_provider.rs index 4cc2254..aa3b617 100644 --- a/crates/shrimpk-memory/src/embedding_provider.rs +++ b/crates/shrimpk-memory/src/embedding_provider.rs @@ -128,7 +128,14 @@ fn resolve_fastembed_model(name: &str) -> Result<(EmbeddingModel, usize)> { /// Works with any endpoint that implements the `/v1/embeddings` contract: /// OpenAI, Ollama, LiteLLM, vLLM, Azure OpenAI, etc. /// -/// API key is read from `SHRIMPK_EMBEDDING_API_KEY` env var — never stored in config. +/// API key is read from `SHRIMPK_EMBEDDING_API_KEY` env var -- never stored in config. +/// +/// # Blocking +/// +/// Uses [`ureq`] (synchronous HTTP). All calls block the current thread for up to 30 s. +/// Callers in async contexts **must** invoke this provider through +/// [`EchoEngine::embed_blocking()`](crate::echo::EchoEngine::embed_blocking) which uses +/// `tokio::task::block_in_place` to prevent worker-thread starvation. pub struct OpenAIProvider { url: String, model: String, @@ -172,6 +179,14 @@ impl OpenAIProvider { } /// Call the embedding API for a batch of texts. + /// + /// # Blocking + /// + /// This method performs a synchronous HTTP POST via [`ureq`] and will block the + /// calling thread for up to 30 s (the global timeout configured in [`Self::new`]). + /// In async contexts it is always reached through + /// [`EchoEngine::embed_blocking()`](crate::echo::EchoEngine::embed_blocking), + /// which wraps the call in `tokio::task::block_in_place`. fn call_api(&self, texts: &[String]) -> Result>> { let endpoint = format!("{}/v1/embeddings", self.url); From 8b2f7c8dd4f86b17c53473013f6228a2115b7d5a Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Thu, 9 Apr 2026 00:57:39 +0300 Subject: [PATCH 12/13] KS75: wrap OpenAI HTTP calls with block_in_place inside provider - Move tokio::task::block_in_place wrapping into OpenAIProvider::call_api so the provider handles its own blocking concern (30s ureq timeout) - Uses same runtime-detection pattern as EchoEngine::embed_blocking: block_in_place on multi-thread runtime, direct call otherwise - Defense-in-depth: inner wrap covers HTTP concern, outer wrap in echo.rs covers mutex-lock concern - Resolves Greptile P1 on lines 200-204 of embedding_provider.rs Co-Authored-By: Claude Opus 4.6 (1M context) --- .../shrimpk-memory/src/embedding_provider.rs | 143 ++++++++++-------- 1 file changed, 82 insertions(+), 61 deletions(-) diff --git a/crates/shrimpk-memory/src/embedding_provider.rs b/crates/shrimpk-memory/src/embedding_provider.rs index aa3b617..5a46e92 100644 --- a/crates/shrimpk-memory/src/embedding_provider.rs +++ b/crates/shrimpk-memory/src/embedding_provider.rs @@ -184,9 +184,17 @@ impl OpenAIProvider { /// /// This method performs a synchronous HTTP POST via [`ureq`] and will block the /// calling thread for up to 30 s (the global timeout configured in [`Self::new`]). - /// In async contexts it is always reached through - /// [`EchoEngine::embed_blocking()`](crate::echo::EchoEngine::embed_blocking), - /// which wraps the call in `tokio::task::block_in_place`. + /// + /// When running on a **multi-thread** Tokio runtime (the daemon) the blocking + /// HTTP call is wrapped in [`tokio::task::block_in_place`] to inform the + /// scheduler and prevent worker-thread starvation. On a **current-thread** + /// runtime (`#[tokio::test]`) or outside Tokio entirely (sync tests, CLI) the + /// request runs directly, because `block_in_place` panics on a single-threaded + /// runtime. + /// + /// This is defense-in-depth: [`EchoEngine::embed_blocking()`] also wraps the + /// outer call with `block_in_place` for the mutex-lock concern; the inner wrap + /// here covers the provider-specific HTTP concern. fn call_api(&self, texts: &[String]) -> Result>> { let endpoint = format!("{}/v1/embeddings", self.url); @@ -207,65 +215,78 @@ impl OpenAIProvider { "External embedding API call" ); - let mut req = self.agent.post(&endpoint); - if let Some(key) = &self.api_key { - req = req.header("Authorization", &format!("Bearer {key}")); + // Closure that performs the synchronous HTTP request and parses the response. + // Factored out so we can conditionally wrap it with block_in_place. + let do_request = || -> Result>> { + let mut req = self.agent.post(&endpoint); + if let Some(key) = &self.api_key { + req = req.header("Authorization", &format!("Bearer {key}")); + } + + let mut resp = req.send_json(&body).map_err(|e| { + ShrimPKError::Embedding(format!("OpenAI embedding API error at {endpoint}: {e}")) + })?; + + let json: serde_json::Value = resp.body_mut().read_json().map_err(|e| { + ShrimPKError::Embedding(format!("OpenAI embedding API parse error: {e}")) + })?; + + // Extract embeddings: {"data": [{"embedding": [...], "index": 0}, ...]} + let data = json["data"].as_array().ok_or_else(|| { + ShrimPKError::Embedding(format!( + "OpenAI embedding API: missing 'data' array in response: {}", + truncate_json(&json) + )) + })?; + + // Sort by index to maintain input order + let mut indexed: Vec<(usize, Vec)> = data + .iter() + .filter_map(|item| { + let index = item["index"].as_u64()? as usize; + let embedding: Vec = item["embedding"] + .as_array()? + .iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect(); + Some((index, embedding)) + }) + .collect(); + + indexed.sort_by_key(|(i, _)| *i); + let embeddings: Vec> = indexed.into_iter().map(|(_, e)| e).collect(); + + if embeddings.len() != texts.len() { + return Err(ShrimPKError::Embedding(format!( + "OpenAI embedding API returned {} embeddings for {} inputs", + embeddings.len(), + texts.len() + ))); + } + + // Validate dimension + if let Some(first) = embeddings.first() + && first.len() != self.dim + { + return Err(ShrimPKError::Embedding(format!( + "OpenAI embedding dimension mismatch: expected {}, got {} from model '{}'", + self.dim, + first.len(), + self.model + ))); + } + + Ok(embeddings) + }; + + // Wrap blocking HTTP in block_in_place on multi-thread Tokio runtime + // to prevent worker-thread starvation (ureq has a 30 s timeout). + match tokio::runtime::Handle::try_current() { + Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { + tokio::task::block_in_place(do_request) + } + _ => do_request(), } - - let mut resp = req.send_json(&body).map_err(|e| { - ShrimPKError::Embedding(format!("OpenAI embedding API error at {endpoint}: {e}")) - })?; - - let json: serde_json::Value = resp.body_mut().read_json().map_err(|e| { - ShrimPKError::Embedding(format!("OpenAI embedding API parse error: {e}")) - })?; - - // Extract embeddings from response: {"data": [{"embedding": [...], "index": 0}, ...]} - let data = json["data"].as_array().ok_or_else(|| { - ShrimPKError::Embedding(format!( - "OpenAI embedding API: missing 'data' array in response: {}", - truncate_json(&json) - )) - })?; - - // Sort by index to maintain input order - let mut indexed: Vec<(usize, Vec)> = data - .iter() - .filter_map(|item| { - let index = item["index"].as_u64()? as usize; - let embedding: Vec = item["embedding"] - .as_array()? - .iter() - .filter_map(|v| v.as_f64().map(|f| f as f32)) - .collect(); - Some((index, embedding)) - }) - .collect(); - - indexed.sort_by_key(|(i, _)| *i); - let embeddings: Vec> = indexed.into_iter().map(|(_, e)| e).collect(); - - if embeddings.len() != texts.len() { - return Err(ShrimPKError::Embedding(format!( - "OpenAI embedding API returned {} embeddings for {} inputs", - embeddings.len(), - texts.len() - ))); - } - - // Validate dimension - if let Some(first) = embeddings.first() - && first.len() != self.dim - { - return Err(ShrimPKError::Embedding(format!( - "OpenAI embedding dimension mismatch: expected {}, got {} from model '{}'", - self.dim, - first.len(), - self.model - ))); - } - - Ok(embeddings) } } From 5122b5befb3657c7f76b1bca88077aa175a21531 Mon Sep 17 00:00:00 2001 From: Lior Cohen Date: Thu, 9 Apr 2026 01:15:09 +0300 Subject: [PATCH 13/13] KS75: fix rustdoc private link in OpenAIProvider docs Replace intra-doc link to private method crate::echo::EchoEngine::embed_blocking with backtick-quoted plain text. Fixes CI failure from RUSTDOCFLAGS="-D warnings". Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/shrimpk-memory/src/embedding_provider.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/shrimpk-memory/src/embedding_provider.rs b/crates/shrimpk-memory/src/embedding_provider.rs index 5a46e92..263c3b4 100644 --- a/crates/shrimpk-memory/src/embedding_provider.rs +++ b/crates/shrimpk-memory/src/embedding_provider.rs @@ -134,7 +134,7 @@ fn resolve_fastembed_model(name: &str) -> Result<(EmbeddingModel, usize)> { /// /// Uses [`ureq`] (synchronous HTTP). All calls block the current thread for up to 30 s. /// Callers in async contexts **must** invoke this provider through -/// [`EchoEngine::embed_blocking()`](crate::echo::EchoEngine::embed_blocking) which uses +/// `EchoEngine::embed_blocking()` which uses /// `tokio::task::block_in_place` to prevent worker-thread starvation. pub struct OpenAIProvider { url: String,