diff --git a/Cargo.lock b/Cargo.lock index b65475315..8e947cccf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -682,6 +682,7 @@ dependencies = [ "rayon", "rstest", "serde", + "serde_json", "tracing", "tracing-subscriber", "vfs", diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json new file mode 100644 index 000000000..d6e9e13b2 --- /dev/null +++ b/diskann-benchmark/example/document-filter.json @@ -0,0 +1,34 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "document-index-build", + "content": { + "build": { + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_labels": "data.256.label.jsonl", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2 + }, + "search": { + "queries": "disk_index_sample_query_10pts.fbin", + "query_predicates": "query.10.label.jsonl", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "beta": 0.5, + "runs": [ + { + "search_n": 20, + "search_l": [20, 30, 40], + "recall_k": 10 + } + ] + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs new file mode 100644 index 000000000..dffe669ff --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -0,0 +1,1038 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Benchmark for DocumentInsertStrategy which allows inserting Documents +//! (vector + attributes) into a DiskANN index built with DocumentProvider. +//! Also benchmarks filtered search using InlineBetaStrategy. + +use std::io::Write; +use std::num::NonZeroUsize; +use std::path::Path; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use anyhow::Result; +use diskann::{ + graph::{ + config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, + search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, + }, + provider::DefaultContext, + utils::{async_tools, IntoUsize}, +}; +use diskann_benchmark_runner::{ + dispatcher::{DispatchRule, FailureScore, MatchScore}, + output::Output, + registry::Benchmarks, + utils::{datatype::DataType, percentiles, MicroSeconds}, + Any, Checkpoint, +}; +use diskann_label_filter::{ + attribute::{Attribute, AttributeValue}, + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::DocumentInsertStrategy, document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::inline_beta_filter::InlineBetaStrategy, + query::FilteredQuery, + read_and_parse_queries, read_baselabels, ASTExpr, +}; +use diskann_providers::model::graph::provider::async_::{ + common::{self, NoStore, TableBasedDeletes}, + inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, +}; +use diskann_utils::views::Matrix; +use indicatif::{ProgressBar, ProgressStyle}; +use serde::Serialize; + +use crate::{ + inputs::document_index::DocumentIndexBuild, + utils::{ + self, + datafiles::{self, BinFile}, + recall, + }, +}; + +/// Register the document index benchmarks. +pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { + benchmarks.register::>( + "document-index-build", + |job, checkpoint, out| { + let stats = job.run(checkpoint, out)?; + Ok(serde_json::to_value(stats)?) + }, + ); +} + +/// Document index benchmark job. +pub(super) struct DocumentIndexJob<'a> { + input: &'a DocumentIndexBuild, +} + +impl<'a> DocumentIndexJob<'a> { + fn new(input: &'a DocumentIndexBuild) -> Self { + Self { input } + } +} + +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static> { + type Type<'a> = DocumentIndexJob<'a>; +} + +// Dispatch from the concrete input type +impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { + type Error = std::convert::Infallible; + + fn try_match(_from: &&'a DocumentIndexBuild) -> Result { + Ok(MatchScore(1)) + } + + fn convert(from: &'a DocumentIndexBuild) -> Result { + Ok(DocumentIndexJob::new(from)) + } + + fn description( + f: &mut std::fmt::Formatter<'_>, + _from: Option<&&'a DocumentIndexBuild>, + ) -> std::fmt::Result { + writeln!(f, "tag: \"{}\"", DocumentIndexBuild::tag()) + } +} + +// Central dispatch mapping from Any +impl<'a> DispatchRule<&'a Any> for DocumentIndexJob<'a> { + type Error = anyhow::Error; + + fn try_match(from: &&'a Any) -> Result { + from.try_match::() + } + + fn convert(from: &'a Any) -> Result { + from.convert::() + } + + fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { + Any::description::(f, from, DocumentIndexBuild::tag()) + } +} +/// Convert a HashMap to Vec +fn hashmap_to_attributes(map: std::collections::HashMap) -> Vec { + map.into_iter() + .map(|(k, v)| Attribute::from_value(k, v)) + .collect() +} + +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> usize +where + T: bytemuck::Pod + Copy + 'static, +{ + use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; + + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return 0; + } + + // Compute the centroid (mean of all rows) as f64 for precision + let mut sum = vec![0.0f64; dim]; + for i in 0..data.nrows() { + let row = data.row(i); + for (j, &v) in row.iter().enumerate() { + // Convert T to f64 for summation using bytemuck + let f64_val: f64 = if std::any::TypeId::of::() == std::any::TypeId::of::() { + let f32_val: f32 = bytemuck::cast(v); + f32_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f64 + } else { + 0.0 + }; + sum[j] += f64_val; + } + } + + // Convert centroid to f32 and compute distances + let centroid_f32: Vec = sum + .iter() + .map(|s| (s / data.nrows() as f64) as f32) + .collect(); + + // Find the row closest to the centroid + let mut min_dist = f32::MAX; + let mut medoid_idx = 0; + for i in 0..data.nrows() { + let row = data.row(i); + let row_f32: Vec = row + .iter() + .map(|&v| { + if std::any::TypeId::of::() == std::any::TypeId::of::() { + bytemuck::cast(v) + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f32 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f32 + } else { + 0.0 + } + }) + .collect(); + let d = SquaredL2::evaluate(centroid_f32.as_slice(), row_f32.as_slice()); + if d < min_dist { + min_dist = d; + medoid_idx = i; + } + } + + medoid_idx +} + +impl<'a> DocumentIndexJob<'a> { + fn run( + self, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + // Print the input description + writeln!(output, "{}", self.input)?; + + let build = &self.input.build; + + // Dispatch based on data type - retain original type without conversion + match build.data_type { + DataType::Float32 => self.run_typed::(output), + DataType::UInt8 => self.run_typed::(output), + DataType::Int8 => self.run_typed::(output), + _ => Err(anyhow::anyhow!( + "Unsupported data type: {:?}. Supported types: float32, uint8, int8.", + build.data_type + )), + } + } + + fn run_typed(self, mut output: &mut dyn Output) -> Result + where + T: bytemuck::Pod + Copy + Send + Sync + 'static + std::fmt::Debug, + T: diskann::graph::SampleableForStart + diskann_utils::future::AsyncFriendly, + T: diskann::utils::VectorRepr + diskann_utils::sampling::WithApproximateNorm, + { + let build = &self.input.build; + + // 1. Load vectors from data file in the original data type + writeln!(output, "Loading vectors ({})...", build.data_type)?; + let timer = std::time::Instant::now(); + let data_path: &Path = build.data.as_ref(); + writeln!(output, "Data path is: {}", data_path.to_string_lossy())?; + let data: Matrix = datafiles::load_dataset(BinFile(data_path))?; + let data_load_time: MicroSeconds = timer.elapsed().into(); + let num_vectors = data.nrows(); + let dim = data.ncols(); + writeln!( + output, + " Loaded {} vectors of dimension {}", + num_vectors, dim + )?; + + // 2. Load and parse labels from the data_labels file + writeln!(output, "Loading labels...")?; + let timer = std::time::Instant::now(); + let label_path: &Path = build.data_labels.as_ref(); + let labels = read_baselabels(label_path)?; + let label_load_time: MicroSeconds = timer.elapsed().into(); + let label_count = labels.len(); + writeln!(output, " Loaded {} label documents", label_count)?; + + if num_vectors != label_count { + return Err(anyhow::anyhow!( + "Mismatch: {} vectors but {} label documents", + num_vectors, + label_count + )); + } + + // Convert labels to attribute vectors + let attributes: Vec> = labels + .into_iter() + .map(|doc| hashmap_to_attributes(doc.flatten_metadata_with_separator(""))) + .collect(); + + // 3. Create the index configuration + let metric = build.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = ConfigBuilder::new( + build.max_degree, // pruned_degree + MaxDegree::Same, // max_degree + build.l_build, // l_build + prune_kind, // prune_kind + ); + config_builder.alpha(build.alpha); + let config = config_builder.build()?; + + // 4. Create the data provider directly + writeln!(output, "Creating index...")?; + let params = DefaultProviderParameters { + max_points: num_vectors, + frozen_points: diskann::utils::ONE, + metric, + dim, + prefetch_lookahead: None, + prefetch_cache_line_level: None, + max_degree: build.max_degree as u32, + }; + + // Create the underlying provider + let fp_precursor = CreateFullPrecision::::new(dim, None); + let inner_provider = + DefaultProvider::new_empty(params, fp_precursor, NoStore, TableBasedDeletes)?; + + // Set start points using medoid strategy + let start_points = StartPointStrategy::Medoid + .compute(data.as_view()) + .map_err(|e| anyhow::anyhow!("Failed to compute start points: {}", e))?; + inner_provider.set_start_points(start_points.row_iter())?; + + // 5. Create DocumentProvider wrapping the inner provider + let attribute_store = RoaringAttributeStore::::new(); + + // Store attributes for the start point (medoid) + // Start points are stored at indices num_vectors..num_vectors+frozen_points + let medoid_idx = compute_medoid_index(&data); + let start_point_id = num_vectors as u32; // Start points begin at max_points + let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); + use diskann_label_filter::traits::attribute_store::AttributeStore; + attribute_store.set_element(&start_point_id, &medoid_attrs)?; + + let doc_provider = DocumentProvider::new(inner_provider, attribute_store); + + // Create a new DiskANNIndex with DocumentProvider + let doc_index = Arc::new(DiskANNIndex::new(config, doc_provider, None)); + + // 6. Build index by inserting vectors and attributes (parallel) + writeln!( + output, + "Building index with {} vectors using {} threads...", + num_vectors, build.num_threads + )?; + let timer = std::time::Instant::now(); + + let insert_strategy: DocumentInsertStrategy<_, [T]> = + DocumentInsertStrategy::new(common::FullPrecision); + let rt = utils::tokio::runtime(build.num_threads)?; + + // Create control block for parallel work distribution + let data_arc = Arc::new(data); + let attributes_arc = Arc::new(attributes); + let control_block = DocumentControlBlock::new( + data_arc.clone(), + attributes_arc.clone(), + output.draw_target(), + )?; + + let num_tasks = build.num_threads; + let insert_latencies = rt.block_on(async { + let tasks: Vec<_> = (0..num_tasks) + .map(|_| { + let block = control_block.clone(); + let index = doc_index.clone(); + let strategy = insert_strategy; + tokio::spawn(async move { + let mut latencies = Vec::::new(); + let ctx = DefaultContext; + loop { + match block.next() { + Some((id, vector, attrs)) => { + let doc = Document::new(vector, attrs); + let start = std::time::Instant::now(); + let result = + index.insert(strategy, &ctx, &(id as u32), &doc).await; + latencies.push(MicroSeconds::from(start.elapsed())); + + if let Err(e) = result { + block.cancel(); + return Err(e); + } + } + None => return Ok(latencies), + } + } + }) + }) + .collect(); + + // Collect results from all tasks + let mut all_latencies = Vec::with_capacity(num_vectors); + for task in tasks { + let task_latencies = task.await??; + all_latencies.extend(task_latencies); + } + Ok::<_, anyhow::Error>(all_latencies) + })?; + + let build_time: MicroSeconds = timer.elapsed().into(); + writeln!(output, " Index built in {} s", build_time.as_seconds())?; + + let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies.clone())?; + // ===================== + // Search Phase + // ===================== + let search_input = &self.input.search; + + // Load query vectors (same type as data for compatible distance computation) + writeln!(output, "\nLoading query vectors...")?; + let query_path: &Path = search_input.queries.as_ref(); + let queries: Matrix = datafiles::load_dataset(BinFile(query_path))?; + let num_queries = queries.nrows(); + writeln!(output, " Loaded {} queries", num_queries)?; + + // Load and parse query predicates + writeln!(output, "Loading query predicates...")?; + let predicate_path: &Path = search_input.query_predicates.as_ref(); + let parsed_predicates = read_and_parse_queries(predicate_path)?; + writeln!(output, " Loaded {} predicates", parsed_predicates.len())?; + + if num_queries != parsed_predicates.len() { + return Err(anyhow::anyhow!( + "Mismatch: {} queries but {} predicates", + num_queries, + parsed_predicates.len() + )); + } + + // Load groundtruth + writeln!(output, "Loading groundtruth...")?; + let gt_path: &Path = search_input.groundtruth.as_ref(); + let groundtruth: Vec> = datafiles::load_range_groundtruth(BinFile(gt_path))?; + writeln!( + output, + " Loaded groundtruth with {} rows", + groundtruth.len() + )?; + + // Run filtered searches + writeln!( + output, + "\nRunning filtered searches (beta={})...", + search_input.beta + )?; + let mut search_results = Vec::new(); + + for num_threads in &search_input.num_threads { + for run in &search_input.runs { + for &search_l in &run.search_l { + writeln!( + output, + " threads={}, search_n={}, search_l={}...", + num_threads, run.search_n, search_l + )?; + + let search_run_result = run_filtered_search( + &doc_index, + &queries, + &parsed_predicates, + &groundtruth, + search_input.beta, + *num_threads, + run.search_n, + search_l, + run.recall_k, + search_input.reps, + )?; + + writeln!( + output, + " recall={:.4}, mean_qps={:.1}", + search_run_result.recall.average, + if search_run_result.qps.is_empty() { + 0.0 + } else { + search_run_result.qps.iter().sum::() + / search_run_result.qps.len() as f64 + } + )?; + + search_results.push(search_run_result); + } + } + } + + let stats = DocumentIndexStats { + num_vectors, + dim, + label_count, + data_load_time, + label_load_time, + build_time, + insert_latencies: insert_percentiles, + build_params: BuildParamsStats { + max_degree: build.max_degree, + l_build: build.l_build, + alpha: build.alpha, + }, + search: search_results, + }; + + writeln!(output, "\n{}", stats)?; + Ok(stats) + } +} +/// Local results from a partition of queries. +struct SearchLocalResults { + ids: Matrix, + distances: Vec>, + latencies: Vec, + comparisons: Vec, + hops: Vec, +} + +impl SearchLocalResults { + fn merge(all: &[SearchLocalResults]) -> anyhow::Result { + let first = all + .first() + .ok_or_else(|| anyhow::anyhow!("empty results"))?; + let num_ids = first.ids.ncols(); + let total_rows: usize = all.iter().map(|r| r.ids.nrows()).sum(); + + let mut ids = Matrix::new(0, total_rows, num_ids); + let mut output_row = 0; + for r in all { + for input_row in r.ids.row_iter() { + ids.row_mut(output_row).copy_from_slice(input_row); + output_row += 1; + } + } + + let mut distances = Vec::new(); + let mut latencies = Vec::new(); + let mut comparisons = Vec::new(); + let mut hops = Vec::new(); + for r in all { + distances.extend_from_slice(&r.distances); + latencies.extend_from_slice(&r.latencies); + comparisons.extend_from_slice(&r.comparisons); + hops.extend_from_slice(&r.hops); + } + + Ok(Self { + ids, + distances, + latencies, + comparisons, + hops, + }) + } +} + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, + beta: f32, + num_threads: NonZeroUsize, + search_n: usize, + search_l: usize, + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let rt = utils::tokio::runtime(num_threads.get())?; + let num_queries = queries.nrows(); + + let mut all_rep_results = Vec::with_capacity(reps.get()); + let mut rep_latencies = Vec::with_capacity(reps.get()); + + for _ in 0..reps.get() { + let start = std::time::Instant::now(); + let results = rt.block_on(run_search_parallel( + index.clone(), + queries, + predicates, + beta, + num_threads, + search_n, + search_l, + ))?; + rep_latencies.push(MicroSeconds::from(start.elapsed())); + all_rep_results.push(results); + } + + // Merge results from first rep for recall calculation + let merged = SearchLocalResults::merge(&all_rep_results[0])?; + + // Compute recall + let recall_metrics: recall::RecallMetrics = + (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); + + // Compute per-query details (only for queries with recall < 1) + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = merged + .ids + .row(query_idx) + .iter() + .copied() + .filter(|&id| id != u32::MAX) + .collect(); + let result_distances: Vec = merged + .distances + .get(query_idx) + .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) + .unwrap_or_default(); + // Only keep top 20 from ground truth + let gt_ids: Vec = groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + // Compute per-query recall: intersection of result_ids with gt_ids / recall_k + let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; + + // Only include queries with imperfect recall + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = predicates[query_idx]; + let filter_str = format!("{:?}", ast_expr); + + Some(PerQueryDetails { + query_id: query_idx, + filter: filter_str, + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) + }) + .collect(); + + // Compute QPS from rep latencies + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); + + // Aggregate per-query latencies across all reps + let (all_latencies, all_cmps, all_hops): (Vec<_>, Vec<_>, Vec<_>) = all_rep_results + .iter() + .map(|results| { + let mut lat = Vec::new(); + let mut cmp = Vec::new(); + let mut hop = Vec::new(); + for r in results { + lat.extend_from_slice(&r.latencies); + cmp.extend_from_slice(&r.comparisons); + hop.extend_from_slice(&r.hops); + } + (lat, cmp, hop) + }) + .fold( + (Vec::new(), Vec::new(), Vec::new()), + |(mut a, mut b, mut c): (Vec, Vec, Vec), (x, y, z)| { + a.extend(x); + b.extend(y); + c.extend(z); + (a, b, c) + }, + ); + + let mut query_latencies = all_latencies; + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut query_latencies)?; + + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; + + Ok(SearchRunStats { + num_threads: num_threads.get(), + num_queries, + search_n, + search_l, + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) +} +async fn run_search_parallel( + index: Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + beta: f32, + num_tasks: NonZeroUsize, + search_n: usize, + search_l: usize, +) -> anyhow::Result> +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let num_queries = queries.nrows(); + + // Plan query partitions + let partitions: Result, _> = (0..num_tasks.get()) + .map(|task_id| async_tools::partition(num_queries, num_tasks, task_id)) + .collect(); + let partitions = partitions?; + + // We need to clone data for each task + let queries_arc = Arc::new(queries.clone()); + let predicates_arc = Arc::new(predicates.to_vec()); + + let handles: Vec<_> = partitions + .into_iter() + .map(|range| { + let index = index.clone(); + let queries = queries_arc.clone(); + let predicates = predicates_arc.clone(); + tokio::spawn(async move { + run_search_local(index, queries, predicates, beta, range, search_n, search_l).await + }) + }) + .collect(); + + let mut results = Vec::new(); + for h in handles { + results.push(h.await??); + } + + Ok(results) +} + +async fn run_search_local( + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, + range: std::ops::Range, + search_n: usize, + search_l: usize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let mut ids = Matrix::new(0, range.len(), search_n); + let mut all_distances: Vec> = Vec::with_capacity(range.len()); + let mut latencies = Vec::with_capacity(range.len()); + let mut comparisons = Vec::with_capacity(range.len()); + let mut hops = Vec::with_capacity(range.len()); + + let ctx = DefaultContext; + let search_params = SearchParams::new_default(search_n, search_l)?; + + for (output_idx, query_idx) in range.enumerate() { + let query_vec = queries.row(query_idx); + let (_, ref ast_expr) = predicates[query_idx]; + + let strategy = InlineBetaStrategy::new(beta, common::FullPrecision); + let query_vec_owned = query_vec.to_vec(); + let filtered_query: FilteredQuery> = + FilteredQuery::new(query_vec_owned, ast_expr.clone()); + + let start = std::time::Instant::now(); + + let mut distances = vec![0.0f32; search_n]; + let result_ids = ids.row_mut(output_idx); + let mut result_buffer = search_output_buffer::IdDistance::new(result_ids, &mut distances); + + let stats = index + .search( + &strategy, + &ctx, + &filtered_query, + &search_params, + &mut result_buffer, + ) + .await?; + + let result_count = stats.result_count.into_usize(); + result_ids[result_count..].fill(u32::MAX); + distances[result_count..].fill(f32::MAX); + + latencies.push(MicroSeconds::from(start.elapsed())); + comparisons.push(stats.cmps); + hops.push(stats.hops); + all_distances.push(distances); + } + + Ok(SearchLocalResults { + ids, + distances: all_distances, + latencies, + comparisons, + hops, + }) +} +#[derive(Debug, Serialize)] +pub struct BuildParamsStats { + pub max_degree: usize, + pub l_build: usize, + pub alpha: f32, +} + +/// Helper module for serializing arrays as compact single-line JSON strings +mod compact_array { + use serde::Serializer; + + pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } + + pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub result_ids: Vec, + #[serde(serialize_with = "compact_array::serialize_f32_vec")] + pub result_distances: Vec, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub groundtruth_ids: Vec, +} + +/// Results from a single search configuration (one search_l value). +#[derive(Debug, Serialize)] +pub struct SearchRunStats { + pub num_threads: usize, + pub num_queries: usize, + pub search_n: usize, + pub search_l: usize, + pub recall: recall::RecallMetrics, + pub qps: Vec, + pub wall_clock_time: Vec, + pub mean_latency: f64, + pub p90_latency: MicroSeconds, + pub p99_latency: MicroSeconds, + pub mean_cmps: f32, + pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, +} + +#[derive(Debug, Serialize)] +pub struct DocumentIndexStats { + pub num_vectors: usize, + pub dim: usize, + pub label_count: usize, + pub data_load_time: MicroSeconds, + pub label_load_time: MicroSeconds, + pub build_time: MicroSeconds, + pub insert_latencies: percentiles::Percentiles, + pub build_params: BuildParamsStats, + pub search: Vec, +} + +impl std::fmt::Display for DocumentIndexStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build Stats:")?; + writeln!(f, " Vectors: {} x {}", self.num_vectors, self.dim)?; + writeln!(f, " Label Count: {}", self.label_count)?; + writeln!( + f, + " Data Load Time: {} s", + self.data_load_time.as_seconds() + )?; + writeln!( + f, + " Label Load Time: {} s", + self.label_load_time.as_seconds() + )?; + writeln!(f, " Total Build Time: {} s", self.build_time.as_seconds())?; + writeln!(f, " Insert Latencies:")?; + writeln!(f, " Mean: {} us", self.insert_latencies.mean)?; + writeln!(f, " P50: {} us", self.insert_latencies.median)?; + writeln!(f, " P90: {} us", self.insert_latencies.p90)?; + writeln!(f, " P99: {} us", self.insert_latencies.p99)?; + writeln!(f, " Build Parameters:")?; + writeln!(f, " max_degree (R): {}", self.build_params.max_degree)?; + writeln!(f, " l_build (L): {}", self.build_params.l_build)?; + writeln!(f, " alpha: {}", self.build_params.alpha)?; + + if !self.search.is_empty() { + writeln!(f, "\nFiltered Search Results:")?; + writeln!( + f, + " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", + "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + )?; + for s in &self.search { + let mean_qps = if s.qps.is_empty() { + 0.0 + } else { + s.qps.iter().sum::() / s.qps.len() as f64 + }; + let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); + let mean_wall_clock = if s.wall_clock_time.is_empty() { + 0.0 + } else { + s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 + }; + writeln!( + f, + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", + s.search_l, + s.search_n, + s.mean_cmps, + s.mean_hops, + mean_qps, + max_qps, + s.mean_latency, + s.p99_latency, + s.recall.average, + s.num_threads, + s.num_queries, + mean_wall_clock + )?; + } + } + Ok(()) + } +} + +// ================================ +// Parallel Build Support +// ================================ + +fn make_progress_bar( + nrows: usize, + draw_target: indicatif::ProgressDrawTarget, +) -> anyhow::Result { + let progress = ProgressBar::with_draw_target(Some(nrows as u64), draw_target); + progress.set_style(ProgressStyle::with_template( + "Building [{elapsed_precise}] {wide_bar} {percent}", + )?); + Ok(progress) +} + +/// Control block for parallel document insertion. +/// Manages work distribution and progress tracking across multiple tasks. +struct DocumentControlBlock { + data: Arc>, + attributes: Arc>>, + position: AtomicUsize, + cancel: AtomicBool, + progress: ProgressBar, +} + +impl DocumentControlBlock { + fn new( + data: Arc>, + attributes: Arc>>, + draw_target: indicatif::ProgressDrawTarget, + ) -> anyhow::Result> { + let nrows = data.nrows(); + Ok(Arc::new(Self { + data, + attributes, + position: AtomicUsize::new(0), + cancel: AtomicBool::new(false), + progress: make_progress_bar(nrows, draw_target)?, + })) + } + + /// Return the next document data to insert: (id, vector_slice, attributes). + fn next(&self) -> Option<(usize, &[T], Vec)> { + let cancel = self.cancel.load(Ordering::Relaxed); + if cancel { + None + } else { + let i = self.position.fetch_add(1, Ordering::Relaxed); + match self.data.get_row(i) { + Some(row) => { + let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + self.progress.inc(1); + Some((i, row, attrs)) + } + None => None, + } + } + } + + /// Tell all users of the control block to cancel and return early. + fn cancel(&self) { + self.cancel.store(true, Ordering::Relaxed); + } +} + +impl Drop for DocumentControlBlock { + fn drop(&mut self) { + self.progress.finish(); + } +} diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs new file mode 100644 index 000000000..9937590cc --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -0,0 +1,13 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Backend benchmark implementation for document index with label filters. +//! +//! This benchmark tests the DocumentInsertStrategy which enables inserting +//! Document objects (vector + attributes) into a DiskANN index. + +mod benchmark; + +pub(crate) use benchmark::register_benchmarks; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index c7e2ab75c..21d74f915 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -109,6 +109,7 @@ impl std::fmt::Display for AggregatedSearchResults { #[derive(Debug, Serialize)] pub(super) struct SearchResults { pub(super) num_tasks: usize, + pub(super) num_queries: usize, pub(super) search_n: usize, pub(super) search_l: usize, pub(super) qps: Vec, @@ -143,6 +144,7 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), + num_queries: recall.num_queries, search_n: parameters.k_value, search_l: parameters.l_value, qps, @@ -182,6 +184,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] } else { &[ @@ -194,6 +198,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] }; @@ -237,6 +243,13 @@ where ); row.insert(format!("{:3}", r.recall.average), col_idx + 7); row.insert(r.num_tasks, col_idx + 8); + row.insert(r.num_queries, col_idx + 9); + let mean_wall_clock = if r.search_latencies.is_empty() { + 0.0 + } else { + r.search_latencies.iter().map(|t| t.as_seconds()).sum::() / r.search_latencies.len() as f64 + }; + row.insert(format!("{:.3}", mean_wall_clock), col_idx + 10); }); write!(f, "{}", table) diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..5dc1967de 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -4,6 +4,7 @@ */ mod disk_index; +mod document_index; mod exhaustive; mod filters; mod index; @@ -13,4 +14,5 @@ pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::regis disk_index::register_benchmarks(registry); index::register_benchmarks(registry); filters::register_benchmarks(registry); + document_index::register_benchmarks(registry); } diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs new file mode 100644 index 000000000..b1a36e48a --- /dev/null +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -0,0 +1,177 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Input types for document index benchmarks using DocumentInsertStrategy. + +use std::num::NonZeroUsize; + +use anyhow::Context; +use diskann_benchmark_runner::{ + files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, +}; +use serde::{Deserialize, Serialize}; + +use super::async_::GraphSearch; +use crate::inputs::{as_input, Example, Input}; + +////////////// +// Registry // +////////////// + +as_input!(DocumentIndexBuild); + +pub(super) fn register_inputs( + registry: &mut diskann_benchmark_runner::registry::Inputs, +) -> anyhow::Result<()> { + registry.register(Input::::new())?; + Ok(()) +} + +/// Build parameters for document index construction. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentBuildParams { + pub(crate) data_type: DataType, + pub(crate) data: InputFile, + pub(crate) data_labels: InputFile, + pub(crate) distance: crate::utils::SimilarityMeasure, + pub(crate) max_degree: usize, + pub(crate) l_build: usize, + pub(crate) alpha: f32, + #[serde(default = "default_num_threads")] + pub(crate) num_threads: usize, +} + +fn default_num_threads() -> usize { + 1 +} + +impl CheckDeserialization for DocumentBuildParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.check_deserialization(checker)?; + self.data_labels.check_deserialization(checker)?; + if self.max_degree == 0 { + return Err(anyhow::anyhow!("max_degree must be > 0")); + } + if self.l_build == 0 { + return Err(anyhow::anyhow!("l_build must be > 0")); + } + if self.alpha <= 0.0 { + return Err(anyhow::anyhow!("alpha must be > 0")); + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentSearchParams { + pub(crate) queries: InputFile, + pub(crate) query_predicates: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) beta: f32, + #[serde(default = "default_reps")] + pub(crate) reps: NonZeroUsize, + #[serde(default = "default_thread_counts")] + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, +} + +fn default_reps() -> NonZeroUsize { + NonZeroUsize::new(5).unwrap() +} +fn default_thread_counts() -> Vec { + vec![NonZeroUsize::new(1).unwrap()] +} + +impl CheckDeserialization for DocumentSearchParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.check_deserialization(checker)?; + self.query_predicates.check_deserialization(checker)?; + self.groundtruth.check_deserialization(checker)?; + if self.beta <= 0.0 || self.beta > 1.0 { + return Err(anyhow::anyhow!( + "beta must be in range (0, 1], got: {}", + self.beta + )); + } + for (i, run) in self.runs.iter_mut().enumerate() { + run.check_deserialization(checker) + .with_context(|| format!("search run {}", i))?; + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentIndexBuild { + pub(crate) build: DocumentBuildParams, + pub(crate) search: DocumentSearchParams, +} + +impl DocumentIndexBuild { + pub(crate) const fn tag() -> &'static str { + "document-index-build" + } +} + +impl CheckDeserialization for DocumentIndexBuild { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.build.check_deserialization(checker)?; + self.search.check_deserialization(checker)?; + Ok(()) + } +} + +impl Example for DocumentIndexBuild { + fn example() -> Self { + Self { + build: DocumentBuildParams { + data_type: DataType::Float32, + data: InputFile::new("data.fbin"), + data_labels: InputFile::new("data.label.jsonl"), + distance: crate::utils::SimilarityMeasure::SquaredL2, + max_degree: 32, + l_build: 50, + alpha: 1.2, + num_threads: 1, + }, + search: DocumentSearchParams { + queries: InputFile::new("queries.fbin"), + query_predicates: InputFile::new("query.label.jsonl"), + groundtruth: InputFile::new("groundtruth.bin"), + beta: 0.5, + reps: default_reps(), + num_threads: default_thread_counts(), + runs: vec![GraphSearch { + search_n: 10, + search_l: vec![20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} + +impl std::fmt::Display for DocumentIndexBuild { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build with Label Filters\n")?; + writeln!(f, "tag: \"{}\"", Self::tag())?; + writeln!( + f, + "\nBuild: data={}, labels={}, R={}, L={}, alpha={}", + self.build.data.display(), + self.build.data_labels.display(), + self.build.max_degree, + self.build.l_build, + self.build.alpha + )?; + writeln!( + f, + "Search: queries={}, beta={}", + self.search.queries.display(), + self.search.beta + )?; + Ok(()) + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index a0ae1a982..65de65a41 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod async_; pub(crate) mod disk; +pub(crate) mod document_index; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod save_and_load; @@ -16,6 +17,7 @@ pub(crate) fn register_inputs( exhaustive::register_inputs(registry)?; disk::register_inputs(registry)?; filters::register_inputs(registry)?; + document_index::register_inputs(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..50ef7e430 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub(crate) use benchmark_core::recall::knn; use diskann_benchmark_core as benchmark_core; use serde::Serialize; diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 72dbeb918..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,6 +3,13 @@ * Licensed under the MIT license. */ +/// Create a generic multi-threaded runtime with `num_threads`. +pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { + Ok(tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build()?) +} + /// Create a current-thread runtime and block on the given future. /// Only for functions that don't need multi-threading pub(crate) fn block_on(future: F) -> F::Output { diff --git a/diskann-label-filter/src/document.rs b/diskann-label-filter/src/document.rs index 31cad4772..5c817525c 100644 --- a/diskann-label-filter/src/document.rs +++ b/diskann-label-filter/src/document.rs @@ -8,12 +8,12 @@ use diskann_utils::reborrow::Reborrow; ///Simple container class that clients can use to /// supply diskann with a vector and its attributes -pub struct Document<'a, V> { +pub struct Document<'a, V: ?Sized> { vector: &'a V, attributes: Vec, } -impl<'a, V> Document<'a, V> { +impl<'a, V: ?Sized> Document<'a, V> { pub fn new(vector: &'a V, attributes: Vec) -> Self { Self { vector, attributes } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs index 0fa21cc02..8b39d8731 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs @@ -31,19 +31,14 @@ impl ASTLabelIdMapper { Self { attribute_map } } - fn _lookup( - encoder: &AttributeEncoder, - attribute: &Attribute, - field: &str, - op: &CompareOp, - ) -> ANNResult> { + fn _lookup(encoder: &AttributeEncoder, attribute: &Attribute) -> ANNResult> { match encoder.get(attribute) { Some(attribute_id) => Ok(ASTIdExpr::Terminal(attribute_id)), None => Err(ANNError::message( ANNErrorKind::Opaque, format!( - "{}+{} present in the query does not exist in the dataset.", - field, op + "{} present in the query does not exist in the dataset.", + attribute ), )), } @@ -120,10 +115,10 @@ impl ASTVisitor for ASTLabelIdMapper { if let Some(attribute) = label_or_none { match self.attribute_map.read() { - Ok(guard) => Self::_lookup(&guard, &attribute, field, op), + Ok(guard) => Self::_lookup(&guard, &attribute), Err(poison_error) => { let attr_map = poison_error.into_inner(); - Self::_lookup(&attr_map, &attribute, field, op) + Self::_lookup(&attr_map, &attribute) } } } else { diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs new file mode 100644 index 000000000..850976a32 --- /dev/null +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -0,0 +1,274 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +//! A strategy wrapper that enables insertion of [Document] objects into a +//! [DiskANNIndex] using a [DocumentProvider]. + +use std::marker::PhantomData; + +use diskann::{ + graph::{ + glue::{ + ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + }, + SearchOutputBuffer, + }, + neighbor::Neighbor, + provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + ANNResult, +}; + +use super::document_provider::DocumentProvider; +use crate::document::Document; +use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; + +/// A strategy wrapper that enables insertion of [Document] objects. +pub struct DocumentInsertStrategy { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl Clone for DocumentInsertStrategy { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + +impl Copy for DocumentInsertStrategy {} + +impl DocumentInsertStrategy { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } + + pub fn inner(&self) -> &Inner { + &self.inner + } +} + +/// Wrapper accessor for Document queries +pub struct DocumentSearchAccessor { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl DocumentSearchAccessor { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } +} + +impl HasId for DocumentSearchAccessor +where + Inner: HasId, + VT: ?Sized, +{ + type Id = Inner::Id; +} + +impl Accessor for DocumentSearchAccessor +where + Inner: Accessor, + VT: ?Sized, +{ + type ElementRef<'a> = Inner::ElementRef<'a>; + type Element<'a> + = Inner::Element<'a> + where + Self: 'a; + type Extended = Inner::Extended; + type GetError = Inner::GetError; + + fn get_element( + &mut self, + id: Self::Id, + ) -> impl std::future::Future, Self::GetError>> + Send { + self.inner.get_element(id) + } + + fn on_elements_unordered( + &mut self, + itr: Itr, + f: F, + ) -> impl std::future::Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + self.inner.on_elements_unordered(itr, f) + } +} + +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +where + Inner: BuildQueryComputer, + VT: ?Sized, +{ + type QueryComputerError = Inner::QueryComputerError; + type QueryComputer = Inner::QueryComputer; + + fn build_query_computer( + &self, + from: &Document<'doc, VT>, + ) -> Result { + self.inner.build_query_computer(from.vector()) + } +} + +impl<'this, Inner, VT> DelegateNeighbor<'this> for DocumentSearchAccessor +where + Inner: DelegateNeighbor<'this>, + VT: ?Sized, +{ + type Delegate = Inner::Delegate; + fn delegate_neighbor(&'this mut self) -> Self::Delegate { + self.inner.delegate_neighbor() + } +} + +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +where + Inner: ExpandBeam, + VT: ?Sized, +{ +} + +impl SearchExt for DocumentSearchAccessor +where + Inner: SearchExt, + VT: ?Sized, +{ + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + self.inner.starting_points() + } + fn terminate_early(&mut self) -> bool { + self.inner.terminate_early() + } +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct CopyIdsForDocument; + +impl<'doc, A, VT> SearchPostProcess> for CopyIdsForDocument +where + A: BuildQueryComputer>, + VT: ?Sized, +{ + type Error = std::convert::Infallible; + + fn post_process( + &self, + _accessor: &mut A, + _query: &Document<'doc, VT>, + _computer: &>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let count = output.extend(candidates.map(|n| (n.id, n.distance))); + std::future::ready(Ok(count)) + } +} + +impl<'doc, Inner, DP, VT> + SearchStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type QueryComputer = Inner::QueryComputer; + type PostProcessor = CopyIdsForDocument; + type SearchAccessorError = Inner::SearchAccessorError; + type SearchAccessor<'a> = DocumentSearchAccessor, VT>; + + fn search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } + + fn post_processor(&self) -> Self::PostProcessor { + CopyIdsForDocument + } +} + +impl<'doc, Inner, DP, VT> + InsertStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type PruneStrategy = DocumentPruneStrategy; + + fn prune_strategy(&self) -> Self::PruneStrategy { + DocumentPruneStrategy::new(self.inner.prune_strategy()) + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .insert_search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } +} + +#[derive(Clone, Copy)] +pub struct DocumentPruneStrategy { + inner: Inner, +} + +impl DocumentPruneStrategy { + pub fn new(inner: Inner) -> Self { + Self { inner } + } +} + +impl PruneStrategy>> + for DocumentPruneStrategy +where + DP: DataProvider, + Inner: PruneStrategy, +{ + type DistanceComputer = Inner::DistanceComputer; + type PruneAccessor<'a> = Inner::PruneAccessor<'a>; + type PruneAccessorError = Inner::PruneAccessorError; + + fn prune_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::PruneAccessorError> { + self.inner + .prune_accessor(provider.inner_provider(), context) + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index 6b496271b..1fabf5f54 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -77,7 +77,7 @@ impl<'a, VT, DP, AS> SetElement> for DocumentProvider where DP: DataProvider + Delete + SetElement, AS: AttributeStore + AsyncFriendly, - VT: Sync + Send, + VT: Sync + Send + ?Sized, { type SetError = ANNError; type Guard = >::Guard; diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index d56cb13c1..370ef25ae 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -5,8 +5,6 @@ use std::sync::{Arc, RwLock}; -use diskann::ANNResult; - use crate::{ encoded_attribute_provider::{ ast_id_expr::ASTIdExpr, ast_label_id_mapper::ASTLabelIdMapper, @@ -16,20 +14,21 @@ use crate::{ }; pub(crate) struct EncodedFilterExpr { - ast_id_expr: ASTIdExpr, + ast_id_expr: Option>, } impl EncodedFilterExpr { - pub fn new( - ast_expr: &ASTExpr, - attribute_map: Arc>, - ) -> ANNResult { + pub fn new(ast_expr: &ASTExpr, attribute_map: Arc>) -> Self { let mut mapper = ASTLabelIdMapper::new(attribute_map); - let ast_id_expr = ast_expr.accept(&mut mapper)?; - Ok(Self { ast_id_expr }) + match ast_expr.accept(&mut mapper) { + Ok(ast_id_expr) => Self { + ast_id_expr: Some(ast_id_expr), + }, + Err(_e) => Self { ast_id_expr: None }, + } } - pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { + pub(crate) fn encoded_filter_expr(&self) -> &Option> { &self.ast_id_expr } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index 6b82a68b1..c69589ba0 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -15,7 +15,7 @@ use diskann::{utils::VectorId, ANNError, ANNErrorKind, ANNResult}; use diskann_utils::future::AsyncFriendly; use std::sync::{Arc, RwLock}; -pub(crate) struct RoaringAttributeStore +pub struct RoaringAttributeStore where IT: VectorId + AsyncFriendly, { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 962d361d7..1def9a406 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -28,7 +28,7 @@ use crate::{ type AttrAccessor = EncodedAttributeAccessor::Id>>; -pub(crate) struct EncodedDocumentAccessor +pub struct EncodedDocumentAccessor where IA: HasId, { @@ -136,7 +136,7 @@ where Some(set) => Ok(set.into_owned()), None => Err(ANNError::message( ANNErrorKind::IndexError, - "No labels were found for vector", + format!("No labels were found for vector:{:?}", id), )), } })?; @@ -220,12 +220,20 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone())?; + let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone()); + let is_valid_filter = id_query.encoded_filter_expr().is_some(); + if !is_valid_filter { + tracing::warn!( + "Failed to convert {} into an id expr. This will now be an unfiltered search.", + from.filter_expr() + ); + } Ok(InlineBetaComputer::new( inner_computer, self.beta_value, id_query, + is_valid_filter, )) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b25b1746f..f03f36c12 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -28,6 +28,13 @@ pub struct InlineBetaStrategy { inner: Strategy, } +impl InlineBetaStrategy { + /// Create a new InlineBetaStrategy with the given beta value and inner strategy. + pub fn new(beta: f32, inner: Strategy) -> Self { + Self { beta, inner } + } +} + impl SearchStrategy>, FilteredQuery> for InlineBetaStrategy @@ -72,6 +79,7 @@ pub struct InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, //optimization to avoid evaluating empty predicates. } impl InlineBetaComputer { @@ -79,17 +87,23 @@ impl InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, ) -> Self { Self { inner_computer, beta_value, filter_expr, + is_valid_filter, } } pub(crate) fn filter_expr(&self) -> &EncodedFilterExpr { &self.filter_expr } + + pub(crate) fn is_valid_filter(&self) -> bool { + self.is_valid_filter + } } impl PreprocessedDistanceFunction, f32> @@ -101,22 +115,35 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim + if self.is_valid_filter { + match self + .filter_expr + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pred_eval) + { + Ok(matched) => { + if matched { + return sim * self.beta_value; + } else { + return sim; + } + } + Err(_) => { + //If predicate evaluation fails for any reason, we simply revert + //to unfiltered search. + tracing::warn!("Predicate evaluation failed"); + return sim; } } - Err(_) => { - //TODO: If predicate evaluation fails, we are taking the approach that we will simply - //return the score returned by the inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim - } + } else { + //If predicate evaluation fails, we will return the score returned by the + //inner computer, as though no predicate was specified. + tracing::warn!( + "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" + ); + sim } } } @@ -155,8 +182,16 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.filter_expr().encoded_filter_expr().accept(&pe)? { - filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + if computer.is_valid_filter() { + if computer + .filter_expr() + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pe)? + { + filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + } } } diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 106845f98..273475b15 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -40,6 +40,7 @@ pub mod encoded_attribute_provider { pub(crate) mod ast_id_expr; pub(crate) mod ast_label_id_mapper; pub(crate) mod attribute_encoder; + pub mod document_insert_strategy; pub mod document_provider; pub mod encoded_attribute_accessor; pub(crate) mod encoded_filter_expr; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e74419a46..9a48488fe 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -1,580 +1,638 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{collections::HashMap, fmt::Debug, future::Future}; - -use diskann::{ - ANNError, ANNResult, - graph::{ - SearchOutputBuffer, - glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, - }, - utils::{IntoUsize, VectorRepr}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; - -use crate::model::graph::{ - provider::async_::{ - FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, - common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, - }, - inmem::DefaultProvider, - postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, - }, - traits::AdHoc, -}; - -/// A type alias for the DefaultProvider with full-precision as the primary vector store. -pub type FullPrecisionProvider = - DefaultProvider, Q, D, Ctx>; - -/// The default full-precision vector store. -pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; - -/// A default full-precision vector store provider. -#[derive(Clone)] -pub struct CreateFullPrecision { - dim: usize, - prefetch_cache_line_level: Option, - _phantom: std::marker::PhantomData, -} - -impl CreateFullPrecision -where - T: VectorRepr, -{ - /// Create a new full-precision vector store provider. - pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { - Self { - dim, - prefetch_cache_line_level, - _phantom: std::marker::PhantomData, - } - } -} - -impl CreateVectorStore for CreateFullPrecision -where - T: VectorRepr, -{ - type Target = FullPrecisionStore; - fn create( - self, - max_points: usize, - metric: Metric, - prefetch_lookahead: Option, - ) -> Self::Target { - FullPrecisionStore::new( - max_points, - self.dim, - metric, - self.prefetch_cache_line_level, - prefetch_lookahead, - ) - } -} - -//////////////// -// SetElement // -//////////////// - -impl SetElementHelper for FullPrecisionStore -where - T: VectorRepr, -{ - /// Set the element at the given index. - fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { - unsafe { self.set_vector_sync(id.into_usize(), element) } - } -} - -////////////////// -// FullAccessor // -////////////////// - -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, -{ - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, -} - -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } -} - -impl HasId for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Id = u32; -} - -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } - } -} - -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The extended element inherets the lifetime of the Accessor. - type Extended = &'a [T]; - - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - /// - /// NOTE: We intentionally don't use `'b` here since our implementation borrows - /// the inner `Opaque` from the underlying provider. - type Element<'b> - = &'a [T] - where - Self: 'b; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'b> = &'b [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} - -impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} - -impl FillSet for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - for i in itr { - set.entry(i).or_insert_with(|| unsafe { - self.provider.base_vectors.get_vector_sync(i.into_usize()) - }); - } - Ok(()) - } -} - -//-------------------// -// In-mem Extensions // -//-------------------// - -impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted - } -} - -////////////////// -// Post Process // -////////////////// - -pub trait GetFullPrecision { - type Repr: VectorRepr; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; -} - -/// A [`SearchPostProcess`]or that: -/// -/// 1. Filters out deleted ids from being returned. -/// 2. Reranks a candidate stream using full-precision distances. -/// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; - -impl glue::SearchPostProcess for Rerank -where - T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, -{ - type Error = Panics; - - fn post_process( - &self, - accessor: &mut A, - query: &[T], - _computer: &A::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, - { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) - } -} - -//////////////// -// Strategies // -//////////////// - -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -// Pruning -impl PruneStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputer = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -/// Implementing this trait allows `FullPrecision` to be used for multi-insert. -impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = diskann::error::Infallible; - fn as_element( - &mut self, - vector: &'a [T], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl InsertStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type PruneStrategy = Self; - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } -} - -// Inplace Delete // -impl InplaceDeleteStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type DeleteElementError = Panics; - type DeleteElement<'a> = [T]; - type DeleteElementGuard = Box<[T]>; - type PruneStrategy = Self; - type SearchStrategy = Internal; - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) - } - - fn prune_strategy(&self) -> Self::PruneStrategy { - Self - } - - async fn get_delete_element<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - id: u32, - ) -> Result { - Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashMap, fmt::Debug, future::Future}; + +use diskann::{ + ANNError, ANNResult, + graph::{ + SearchOutputBuffer, + glue::{ + self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, + }, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasId, + }, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::{DistanceFunction, distance::Metric}; + +use crate::model::graph::{ + provider::async_::{ + FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, + common::{ + CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, + PrefetchCacheLineLevel, SetElementHelper, + }, + inmem::DefaultProvider, + postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, + }, + traits::AdHoc, +}; + +/// A type alias for the DefaultProvider with full-precision as the primary vector store. +pub type FullPrecisionProvider = + DefaultProvider, Q, D, Ctx>; + +/// The default full-precision vector store. +pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; + +/// A default full-precision vector store provider. +#[derive(Clone)] +pub struct CreateFullPrecision { + dim: usize, + prefetch_cache_line_level: Option, + _phantom: std::marker::PhantomData, +} + +impl CreateFullPrecision +where + T: VectorRepr, +{ + /// Create a new full-precision vector store provider. + pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { + Self { + dim, + prefetch_cache_line_level, + _phantom: std::marker::PhantomData, + } + } +} + +impl CreateVectorStore for CreateFullPrecision +where + T: VectorRepr, +{ + type Target = FullPrecisionStore; + fn create( + self, + max_points: usize, + metric: Metric, + prefetch_lookahead: Option, + ) -> Self::Target { + FullPrecisionStore::new( + max_points, + self.dim, + metric, + self.prefetch_cache_line_level, + prefetch_lookahead, + ) + } +} + +//////////////// +// SetElement // +//////////////// + +impl SetElementHelper for FullPrecisionStore +where + T: VectorRepr, +{ + /// Set the element at the given index. + fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { + unsafe { self.set_vector_sync(id.into_usize(), element) } + } +} + +////////////////// +// FullAccessor // +////////////////// + +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +/// * [`BuildQueryComputer`]. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, + + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: Vec, +} + +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors + } +} + +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } +} + +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + pub fn new(provider: &'a FullPrecisionProvider) -> Self { + Self { + provider, + id_buffer: Vec::new(), + } + } +} + +impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.provider.neighbors() + } +} + +impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + /// The extended element inherets the lifetime of the Accessor. + type Extended = &'a [T]; + + /// This accessor returns raw slices. There *is* a chance of racing when the fast + /// providers are used. We just have to live with it. + /// + /// NOTE: We intentionally don't use `'b` here since our implementation borrows + /// the inner `Opaque` from the underlying provider. + type Element<'b> + = &'a [T] + where + Self: 'b; + + /// `ElementRef` has an arbitrarily short lifetime. + type ElementRef<'b> = &'b [T]; + + /// Choose to panic on an out-of-bounds access rather than propagate an error. + type GetError = Panics; + + /// Return the full-precision vector stored at index `i`. + /// + /// This function always completes synchronously. + #[inline(always)] + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // potentially mixing unsynchronized reads and writes on the underlying memory. + std::future::ready(Ok(unsafe { + self.provider.base_vectors.get_vector_sync(id.into_usize()) + })) + } + + /// Perform a bulk operation. + /// + /// This implementation uses prefetching. + fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> impl Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(itr); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + f( + unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + *id, + ) + } + + std::future::ready(Ok(())) + } +} + +impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.provider.metric, + Some(self.provider.base_vectors.dim()), + )) + } +} + +impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &[T], + ) -> Result { + Ok(T::query_distance(from, self.provider.metric)) + } +} + +impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +/// Support for Vec queries that delegates to the [T] impl via deref. +/// This allows InlineBetaStrategy to use Vec queries with FullAccessor. +impl BuildQueryComputer> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &Vec, + ) -> Result { + // Delegate to [T] impl via deref + Ok(T::query_distance(from.as_slice(), self.provider.metric)) + } +} + +/// Support for Vec queries that delegates to the [T] impl. +impl ExpandBeam> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +impl FillSet for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + async fn fill_set( + &mut self, + set: &mut HashMap, + itr: Itr, + ) -> Result<(), Self::GetError> + where + Itr: Iterator + Send + Sync, + { + for i in itr { + set.entry(i).or_insert_with(|| unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + } + Ok(()) + } +} + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted + } +} + +////////////////// +// Post Process // +////////////////// + +pub trait GetFullPrecision { + type Repr: VectorRepr; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; +} + +/// A [`SearchPostProcess`]or that: +/// +/// 1. Filters out deleted ids from being returned. +/// 2. Reranks a candidate stream using full-precision distances. +/// 3. Copies back the results to the output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl glue::SearchPostProcess for Rerank +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = Panics; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator>, + B: SearchOutputBuffer + ?Sized, + { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let f = full.distance(); + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. + std::future::ready(Ok(output.extend(reranked))) + } +} + +//////////////// +// Strategies // +//////////////// + +// A layered approach is used for search strategies. The `Internal` version does the heavy +// lifting in terms of establishing accessors and post processing. +// +// However, during post-processing, the `Internal` versions of strategies will not filter +// out the start points. The publicly exposed types *will* filter out the start points. +// +// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust +// the adjacency list for the start point to reuse the `Internal` strategies. + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = RemoveDeletedIdsAndCopy; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Support for Vec queries that delegates to the [T] impl. +/// This allows InlineBetaStrategy to use Vec queries with FullPrecision. +impl SearchStrategy, Vec> for FullPrecision +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// Pruning +impl PruneStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputer = T::Distance; + type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::PruneAccessorError> { + Ok(FullAccessor::new(provider)) + } +} + +/// Implementing this trait allows `FullPrecision` to be used for multi-insert. +impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [T], + _id: Self::Id, + ) -> impl Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +impl InsertStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +// Inplace Delete // +impl InplaceDeleteStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type DeleteElementError = Panics; + type DeleteElement<'a> = [T]; + type DeleteElementGuard = Box<[T]>; + type PruneStrategy = Self; + type SearchStrategy = Internal; + fn search_strategy(&self) -> Self::SearchStrategy { + Internal(Self) + } + + fn prune_strategy(&self) -> Self::PruneStrategy { + Self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + id: u32, + ) -> Result { + Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) + } +} diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index ae987dca9..1b4b3408e 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,14 +5,13 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true -license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] byteorder.workspace = true clap = { workspace = true, features = ["derive"] } -diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` +diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } diskann-disk = { workspace = true } diskann-utils = { workspace = true } @@ -24,6 +23,7 @@ ordered-float = "4.2.0" rand_distr.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true bincode.workspace = true opentelemetry.workspace = true diskann-quantization = { workspace = true } diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e96f7ae8f..8c2fa29f6 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -25,18 +25,99 @@ use diskann_utils::views::Matrix; use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; +use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; +/// Expands a JSON object with array-valued fields into multiple objects with scalar values. +/// For example: {"country": ["AU", "NZ"], "year": 2007} +/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// +/// If multiple fields have arrays, all combinations are generated. +fn expand_array_fields(value: &Value) -> Vec { + match value { + Value::Object(map) => { + // Start with a single empty object + let mut results: Vec> = vec![Map::new()]; + + for (key, val) in map.iter() { + if let Value::Array(arr) = val { + // Expand: for each existing result, create copies for each array element + let mut new_results: Vec> = Vec::new(); + for existing in results.iter() { + for item in arr.iter() { + let mut new_map: Map = existing.clone(); + new_map.insert(key.clone(), item.clone()); + new_results.push(new_map); + } + } + // If array is empty, keep existing results without this key + if !arr.is_empty() { + results = new_results; + } + } else { + // Non-array field: add to all existing results + for existing in results.iter_mut() { + existing.insert(key.clone(), val.clone()); + } + } + } + + results.into_iter().map(Value::Object).collect() + } + // If not an object, return as-is + _ => vec![value.clone()], + } +} + +/// Evaluates a query expression against a label, expanding array fields first. +/// Returns true if any expanded variant matches the query. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + let expanded = expand_array_fields(label); + expanded + .iter() + .any(|item| eval_query_expr(query_expr, item)) +} + pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; + tracing::info!( + "Loaded {} base labels from {}", + base_labels.len(), + base_label_filename + ); + + // Print first few base labels for debugging + for (i, label) in base_labels.iter().take(3).enumerate() { + tracing::debug!( + "Base label sample [{}]: doc_id={}, label={}", + i, + label.doc_id, + label.label + ); + } // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; + tracing::info!( + "Loaded {} queries from {}", + parsed_queries.len(), + query_label_filename + ); + + // Print first few queries for debugging + for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { + tracing::debug!( + "Query sample [{}]: query_id={}, expr={:?}", + i, + query_id, + query_expr + ); + } // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -45,7 +126,17 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - if eval_query_expr(query_expr, &base_label.label) { + // Handle case where base_label.label is an array - check if any element matches + // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) + let matches = if let Some(array) = base_label.label.as_array() { + array + .iter() + .any(|item| eval_query_with_array_expansion(query_expr, item)) + } else { + eval_query_with_array_expansion(query_expr, &base_label.label) + }; + + if matches { bitmap.insert(base_label.doc_id); } } @@ -53,6 +144,44 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); + // Debug: Print match statistics for each query + let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); + let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); + tracing::info!( + "Filter matching summary: {} total matches across {} queries ({} queries have matches)", + total_matches, + query_bitmaps.len(), + queries_with_matches + ); + + // Print per-query match counts + for (i, bitmap) in query_bitmaps.iter().enumerate() { + if i < 10 || bitmap.is_empty() { + tracing::debug!( + "Query {}: {} base vectors matched the filter", + i, + bitmap.len() + ); + } + } + + // If no matches, print more diagnostic info + if total_matches == 0 { + tracing::warn!("WARNING: No base vectors matched any query filters!"); + tracing::warn!( + "This could indicate a format mismatch between base labels and query filters." + ); + + // Try to identify what keys exist in base labels vs queries + if let Some(first_label) = base_labels.first() { + tracing::warn!( + "First base label (full): doc_id={}, label={}", + first_label.doc_id, + first_label.label + ); + } + } + Ok(query_bitmaps) } @@ -195,6 +324,47 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); + // Debug: Print top K matches for each query + tracing::info!( + "Ground truth computed for {} queries with recall_at={}", + ground_truth.len(), + recall_at + ); + for (query_idx, npq) in ground_truth.iter().enumerate() { + let neighbors: Vec<_> = npq.iter().collect(); + let neighbor_count = neighbors.len(); + + if query_idx < 10 { + // Print top K IDs and distances for first 10 queries + let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); + let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); + tracing::debug!( + "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", + query_idx, + neighbor_count, + top_ids, + top_dists + ); + } + + if neighbor_count == 0 { + tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); + } + } + + // Summary stats + let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); + let queries_with_neighbors = ground_truth + .iter() + .filter(|npq| npq.iter().count() > 0) + .count(); + tracing::info!( + "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", + total_neighbors, + queries_with_neighbors, + ground_truth.len() - queries_with_neighbors + ); + if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter() diff --git a/test_data/disk_index_search/data.256.label.jsonl b/test_data/disk_index_search/data.256.label.jsonl index 83254af7b..a99cde8e2 100644 --- a/test_data/disk_index_search/data.256.label.jsonl +++ b/test_data/disk_index_search/data.256.label.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f8b6b99ca32173557689712d3fb5da30c5e4111130fd2accbccf32f5ce3e47e -size 17702 +oid sha256:92576896b10780a2cd80a16030f8384610498b76453f57fadeacb854379e0acf +size 17701