Skip to content

Commit 0a79d3e

Browse files
committed
feat: add semantic chunking to Ollama embedding provider
Implemented full semantic chunking support for Ollama provider matching Jina's chunking capabilities for consistent behavior across providers. Changes to ollama_embedding_provider.rs: - Added tokenizer: Arc<Tokenizer> field to OllamaEmbeddingProvider - Added max_tokens_per_text field to OllamaEmbeddingConfig - Load Qwen2.5-Coder tokenizer in new() for accurate token counting - Added chunker_config() method returning ChunkerConfig - Added build_plan_for_nodes() using build_chunk_plan() - Added prepare_text() that chunks long nodes semantically - Updated generate_embeddings() to: * Chunk each node using prepare_text() * Flatten chunks while tracking node ownership * Generate embeddings for all chunks * Aggregate chunk embeddings by averaging (same as Jina) - Updated generate_embeddings_with_config() to use chunking Benefits: - Long functions/classes now chunked at semantic boundaries - Respects CODEGRAPH_MAX_CHUNK_TOKENS (default: 512 tokens) - Multiple chunks per node averaged into single embedding - Consistent with Jina provider behavior - Uses Qwen2.5-Coder tokenizer for accurate token counts Before: 10K char function → single 10K char embedding (error/truncation) After: 10K char function → 5 chunks → 5 embeddings → averaged to 1 Example usage: export CODEGRAPH_MAX_CHUNK_TOKENS=512 # Max tokens per chunk codegraph index -l rust . # Uses chunking automatically
1 parent a41ce7c commit 0a79d3e

File tree

1 file changed

+123
-14
lines changed

1 file changed

+123
-14
lines changed

crates/codegraph-vector/src/ollama_embedding_provider.rs

Lines changed: 123 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ use async_trait::async_trait;
66
use codegraph_core::{CodeGraphError, CodeNode, Result};
77
use reqwest::Client;
88
use serde::{Deserialize, Serialize};
9+
use std::path::PathBuf;
10+
use std::sync::Arc;
911
use std::time::{Duration, Instant};
1012
use tokio::time::timeout;
11-
use tracing::{debug, info};
13+
use tokenizers::Tokenizer;
14+
use tracing::{debug, info, warn};
1215

16+
use crate::prep::chunker::{build_chunk_plan, ChunkPlan, ChunkerConfig, SanitizeMode};
1317
use crate::providers::{
1418
BatchConfig, EmbeddingMetrics, EmbeddingProvider, MemoryUsage, ProviderCharacteristics,
1519
};
@@ -22,6 +26,7 @@ pub struct OllamaEmbeddingConfig {
2226
pub timeout: Duration,
2327
pub batch_size: usize,
2428
pub max_retries: usize,
29+
pub max_tokens_per_text: usize,
2530
}
2631

2732
impl Default for OllamaEmbeddingConfig {
@@ -32,6 +37,11 @@ impl Default for OllamaEmbeddingConfig {
3237
.map(|value| value.clamp(1, 4096))
3338
.unwrap_or(32);
3439

40+
let max_tokens_per_text = std::env::var("CODEGRAPH_MAX_CHUNK_TOKENS")
41+
.ok()
42+
.and_then(|v| v.parse::<usize>().ok())
43+
.unwrap_or(512);
44+
3545
Self {
3646
model_name: std::env::var("CODEGRAPH_EMBEDDING_MODEL")
3747
.unwrap_or_else(|_| "nomic-embed-code".to_string()),
@@ -40,6 +50,7 @@ impl Default for OllamaEmbeddingConfig {
4050
timeout: Duration::from_secs(60),
4151
batch_size,
4252
max_retries: 3,
53+
max_tokens_per_text,
4354
}
4455
}
4556
}
@@ -56,12 +67,18 @@ impl From<&codegraph_core::EmbeddingConfig> for OllamaEmbeddingConfig {
5667
// Use batch_size from config (already has env var fallback in config loading)
5768
let batch_size = config.batch_size.clamp(1, 4096);
5869

70+
let max_tokens_per_text = std::env::var("CODEGRAPH_MAX_CHUNK_TOKENS")
71+
.ok()
72+
.and_then(|v| v.parse::<usize>().ok())
73+
.unwrap_or(512);
74+
5975
Self {
6076
model_name,
6177
base_url: config.ollama_url.clone(),
6278
timeout: Duration::from_secs(60),
6379
batch_size,
6480
max_retries: 3,
81+
max_tokens_per_text,
6582
}
6683
}
6784
}
@@ -86,6 +103,7 @@ pub struct OllamaEmbeddingProvider {
86103
client: Client,
87104
config: OllamaEmbeddingConfig,
88105
characteristics: ProviderCharacteristics,
106+
tokenizer: Arc<Tokenizer>,
89107
}
90108

91109
impl OllamaEmbeddingProvider {
@@ -99,10 +117,25 @@ impl OllamaEmbeddingProvider {
99117
memory_usage: MemoryUsage::Medium, // ~500MB-1GB for embedding model
100118
};
101119

120+
// Load Qwen2.5-Coder tokenizer for accurate token counting
121+
let tokenizer_path = PathBuf::from(concat!(
122+
env!("CARGO_MANIFEST_DIR"),
123+
"/tokenizers/qwen2.5-coder.json"
124+
));
125+
let tokenizer = Tokenizer::from_file(&tokenizer_path).unwrap_or_else(|e| {
126+
warn!(
127+
"Failed to load Qwen2.5-Coder tokenizer from {:?}: {}. Using fallback character approximation.",
128+
tokenizer_path, e
129+
);
130+
// Create a minimal fallback tokenizer (shouldn't happen in practice)
131+
panic!("Tokenizer required for Ollama chunking");
132+
});
133+
102134
Self {
103135
client: Client::new(),
104136
config,
105137
characteristics,
138+
tokenizer: Arc::new(tokenizer),
106139
}
107140
}
108141

@@ -154,6 +187,35 @@ impl OllamaEmbeddingProvider {
154187
Ok(has_model)
155188
}
156189

190+
fn chunker_config(&self) -> ChunkerConfig {
191+
ChunkerConfig::new(self.config.max_tokens_per_text)
192+
.max_texts_per_request(self.config.batch_size)
193+
.cache_capacity(2048)
194+
.sanitize_mode(SanitizeMode::AsciiFastPath)
195+
}
196+
197+
fn build_plan_for_nodes(&self, nodes: &[CodeNode]) -> ChunkPlan {
198+
build_chunk_plan(nodes, Arc::clone(&self.tokenizer), self.chunker_config())
199+
}
200+
201+
fn prepare_text(&self, node: &CodeNode) -> Vec<String> {
202+
let plan = self.build_plan_for_nodes(std::slice::from_ref(node));
203+
if plan.chunks.is_empty() {
204+
return vec![Self::format_node_text(node)];
205+
}
206+
207+
let texts: Vec<String> = plan.chunks.into_iter().map(|chunk| chunk.text).collect();
208+
209+
if texts.len() > 1 {
210+
debug!(
211+
"Chunked node '{}' into {} chunks (max {} tokens)",
212+
node.name, texts.len(), self.config.max_tokens_per_text
213+
);
214+
}
215+
216+
texts
217+
}
218+
157219
fn format_node_text(node: &CodeNode) -> String {
158220
let mut header = format!(
159221
"{} {} {}",
@@ -287,7 +349,7 @@ impl EmbeddingProvider for OllamaEmbeddingProvider {
287349
self.generate_single_embedding(&formatted).await
288350
}
289351

290-
/// Generate embeddings for multiple code nodes with batch optimization
352+
/// Generate embeddings for multiple code nodes with batch optimization and chunking
291353
async fn generate_embeddings(&self, nodes: &[CodeNode]) -> Result<Vec<Vec<f32>>> {
292354
if nodes.is_empty() {
293355
return Ok(Vec::new());
@@ -300,23 +362,72 @@ impl EmbeddingProvider for OllamaEmbeddingProvider {
300362
);
301363
let start_time = Instant::now();
302364

303-
// Prepare texts from nodes
304-
let texts: Vec<String> = nodes.iter().map(Self::format_node_text).collect();
305-
let embeddings = self
306-
.generate_embeddings_for_texts(&texts, self.config.batch_size)
365+
// Prepare texts from nodes with semantic chunking
366+
let node_chunks: Vec<(usize, Vec<String>)> = nodes
367+
.iter()
368+
.enumerate()
369+
.map(|(idx, node)| (idx, self.prepare_text(node)))
370+
.collect();
371+
372+
// Flatten all chunks and track which node they belong to
373+
let mut all_texts = Vec::new();
374+
let mut chunk_to_node: Vec<usize> = Vec::new();
375+
376+
for (node_idx, chunks) in &node_chunks {
377+
for chunk in chunks {
378+
all_texts.push(chunk.clone());
379+
chunk_to_node.push(*node_idx);
380+
}
381+
}
382+
383+
debug!(
384+
"Processing {} nodes with {} total chunks (avg {:.2} chunks/node)",
385+
nodes.len(),
386+
all_texts.len(),
387+
all_texts.len() as f64 / nodes.len() as f64
388+
);
389+
390+
// Generate embeddings for all chunks
391+
let chunk_embeddings = self
392+
.generate_embeddings_for_texts(&all_texts, self.config.batch_size)
307393
.await?;
308394

395+
// Aggregate chunk embeddings back into node embeddings
396+
let dimension = self.embedding_dimension();
397+
let mut node_embeddings: Vec<Vec<f32>> = vec![vec![0.0f32; dimension]; nodes.len()];
398+
let mut node_chunk_counts = vec![0usize; nodes.len()];
399+
400+
// Accumulate chunk embeddings for each node
401+
for (chunk_idx, chunk_embedding) in chunk_embeddings.into_iter().enumerate() {
402+
let node_idx = chunk_to_node[chunk_idx];
403+
for (i, &val) in chunk_embedding.iter().enumerate() {
404+
node_embeddings[node_idx][i] += val;
405+
}
406+
node_chunk_counts[node_idx] += 1;
407+
}
408+
409+
// Average the accumulated embeddings
410+
for (node_idx, count) in node_chunk_counts.iter().enumerate() {
411+
if *count > 0 {
412+
let divisor = *count as f32;
413+
for val in &mut node_embeddings[node_idx] {
414+
*val /= divisor;
415+
}
416+
}
417+
}
418+
309419
let total_time = start_time.elapsed();
310420
let embeddings_per_second = nodes.len() as f64 / total_time.as_secs_f64().max(0.001);
311421

312422
info!(
313-
"Ollama embeddings complete: {} embeddings in {:.2}s ({:.1} emb/s)",
423+
"Ollama embeddings complete: {} nodes ({} chunks) in {:.2}s ({:.1} emb/s)",
314424
nodes.len(),
425+
all_texts.len(),
315426
total_time.as_secs_f64(),
316427
embeddings_per_second
317428
);
318429

319-
Ok(embeddings)
430+
Ok(node_embeddings)
320431
}
321432

322433
/// Generate embeddings with batch configuration and metrics
@@ -326,13 +437,11 @@ impl EmbeddingProvider for OllamaEmbeddingProvider {
326437
config: &BatchConfig,
327438
) -> Result<(Vec<Vec<f32>>, EmbeddingMetrics)> {
328439
let start_time = Instant::now();
329-
let texts: Vec<String> = nodes.iter().map(Self::format_node_text).collect();
330-
let batch_size = self.effective_batch_size(config.batch_size);
331-
let embeddings = self
332-
.generate_embeddings_for_texts(&texts, batch_size)
333-
.await?;
334-
let duration = start_time.elapsed();
335440

441+
// Use chunking-aware generate_embeddings instead of direct text formatting
442+
let embeddings = self.generate_embeddings(nodes).await?;
443+
444+
let duration = start_time.elapsed();
336445
let metrics = EmbeddingMetrics::new(
337446
format!("ollama-{}", self.config.model_name),
338447
nodes.len(),

0 commit comments

Comments
 (0)