From e3e1651dfa36faf1d5969a43c48b579e29698b0c Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 10:10:30 +0530 Subject: [PATCH 1/3] Sync changes from CDB_DiskANN repo - Refactored recall utilities in diskann-benchmark - Updated tokio utilities - Added attribute and format parser improvements in label-filter - Updated ground_truth utilities in diskann-tools --- diskann-benchmark/src/utils/recall.rs | 703 +----------------- diskann-benchmark/src/utils/tokio.rs | 20 +- diskann-label-filter/src/attribute.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 + .../src/utils/flatten_utils.rs | 2 +- diskann-tools/Cargo.toml | 18 +- diskann-tools/src/utils/ground_truth.rs | 161 +++- 7 files changed, 196 insertions(+), 711 deletions(-) diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 5b7fd1594..bfaf46772 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,15 +3,13 @@ * Licensed under the MIT license. */ -use std::{collections::HashSet, hash::Hash}; - -use diskann_utils::strided::StridedView; -use diskann_utils::views::MatrixView; +use diskann_benchmark_core as benchmark_core; +pub(crate) use benchmark_core::recall::knn; use serde::Serialize; -use thiserror::Error; -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, @@ -25,278 +23,19 @@ pub(crate) struct RecallMetrics { pub(crate) minimum: usize, /// The maximum observed recall (max possible value: `recall_k`). pub(crate) maximum: usize, - /// Recall results by query - pub(crate) by_query: Option>, -} - -// impl RecallMetrics { -// pub(crate) fn num_queries(&self) -> usize { -// self.num_queries -// } - -// pub(crate) fn average(&self) -> f64 { -// self.average -// } -// } - -#[derive(Debug, Error)] -pub(crate) enum ComputeRecallError { - #[error("results matrix has {0} rows but ground truth has {1}")] - RowsMismatch(usize, usize), - #[error("distances matrix has {0} rows but ground truth has {1}")] - DistanceRowsMismatch(usize, usize), - #[error("recall k value {0} must be less than or equal to recall n {1}")] - RecallKAndNError(usize, usize), - #[error("number of results per query {0} must be at least the specified recall k {1}")] - NotEnoughResults(usize, usize), - #[error( - "number of groundtruth values per query {0} must be at least the specified recall n {1}" - )] - NotEnoughGroundTruth(usize, usize), - #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")] - NotEnoughGroundTruthDistances(usize, usize), -} - -pub(crate) trait ComputeKnnRecall { - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result; -} - -impl ComputeKnnRecall for MatrixView<'_, T> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -impl ComputeKnnRecall for Vec> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -pub(crate) trait KnnRecall { - type Item; - - fn nrows(&self) -> usize; - fn ncols(&self) -> Option; - fn row(&self, i: usize) -> &[Self::Item]; -} - -impl KnnRecall for MatrixView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - MatrixView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(MatrixView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - MatrixView::<'_, T>::row(self, i) - } -} - -impl KnnRecall for Vec> { - type Item = T; - - fn nrows(&self) -> usize { - self.len() - } - fn ncols(&self) -> Option { - None - } - fn row(&self, i: usize) -> &[Self::Item] { - &self[i] - } } -impl KnnRecall for StridedView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - StridedView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(StridedView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - StridedView::<'_, T>::row(self, i) - } -} - -fn compute_knn_recall( - groundtruth: &K, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, - K: KnnRecall, -{ - if recall_k > recall_n { - return Err(ComputeRecallError::RecallKAndNError(recall_k, recall_n)); - } - - let nrows = results.nrows(); - if nrows != groundtruth.nrows() { - return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows())); - } - - if results.ncols() < recall_n && !allow_insufficient_results { - return Err(ComputeRecallError::NotEnoughResults( - results.ncols(), - recall_n, - )); - } - - // Validate groundtruth size for fixed-size sources - match groundtruth.ncols() { - Some(ncols) if ncols < recall_k => { - return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k)); - } - _ => {} - } - - if let Some(distances) = groundtruth_distances { - if nrows != distances.nrows() { - return Err(ComputeRecallError::DistanceRowsMismatch( - distances.nrows(), - nrows, - )); - } - - match groundtruth.ncols() { - Some(ncols) if distances.ncols() != ncols => { - return Err(ComputeRecallError::NotEnoughGroundTruthDistances( - distances.ncols(), - ncols, - )); - } - _ => {} +impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { + fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { + Self { + recall_k: m.recall_k, + recall_n: m.recall_n, + num_queries: m.num_queries, + average: m.average, + minimum: m.minimum, + maximum: m.maximum, } } - - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); - let mut this_groundtruth = HashSet::new(); - let mut this_results = HashSet::new(); - - for (i, result) in results.row_iter().enumerate() { - let gt_row = groundtruth.row(i); - - // Populate the groundtruth using the top-k - this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().copied().take(recall_k)); - - // If we have distances, then continue to append distances as long as the distance - // value is constant - if let Some(distances) = groundtruth_distances { - if recall_k > 0 { - let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { - if *d == last_distance { - this_groundtruth.insert(*g); - } else { - break; - } - } - } - } - } - - this_results.clear(); - this_results.extend(result.iter().copied().take(recall_n)); - - // Count the overlap - let r = this_groundtruth - .iter() - .filter(|i| this_results.contains(i)) - .count() - .min(recall_k); - - recall_values.push(r); - } - - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); - - let div = if groundtruth.ncols().is_some() { - recall_k * nrows - } else { - (0..groundtruth.nrows()) - .map(|i| groundtruth.row(i).len()) - .sum::() - .max(1) - }; - - let average = (total as f64) / (div as f64); - - Ok(RecallMetrics { - recall_k, - recall_n, - num_queries: nrows, - average, - minimum: *minimum, - maximum: *maximum, - by_query: if enhanced_metrics { - Some(recall_values) - } else { - None - }, - }) } /// Compute `k-recall-at-n` for all valid combinations of values in `recall_k` and @@ -309,14 +48,13 @@ where feature = "product-quantization" ))] pub(crate) fn compute_multiple_recalls( - results: StridedView<'_, T>, - groundtruth: StridedView<'_, T>, + results: &dyn benchmark_core::recall::Rows, + groundtruth: &dyn benchmark_core::recall::Rows, recall_k: &[usize], recall_n: &[usize], - enhanced_metrics: bool, -) -> Result, ComputeRecallError> +) -> Result, benchmark_core::recall::ComputeRecallError> where - T: Eq + Hash + Copy + std::fmt::Debug, + T: benchmark_core::recall::RecallCompatible, { let mut result = Vec::new(); for k in recall_k { @@ -325,414 +63,27 @@ where continue; } - result.push(compute_knn_recall( - &groundtruth, - None, - results, - *k, - *n, - false, - enhanced_metrics, - )?); + let recall = benchmark_core::recall::knn(groundtruth, None, results, *k, *n, false)?; + result.push((&recall).into()); } } Ok(result) } -#[derive(Debug, Serialize)] -pub(crate) struct APMetrics { +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] +pub(crate) struct AveragePrecisionMetrics { /// The number of queries. pub(crate) num_queries: usize, /// The average precision pub(crate) average_precision: f64, } -#[derive(Debug, Error)] -pub(crate) enum ComputeAPError { - #[error("results has {0} elements but ground truth has {1}")] - EntriesMismatch(usize, usize), -} - -/// Compute average precision of a range search result -pub(crate) fn compute_average_precision( - results: Vec>, - groundtruth: &[Vec], -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - if results.len() != groundtruth.len() { - return Err(ComputeAPError::EntriesMismatch( - results.len(), - groundtruth.len(), - )); - } - - // The actual recall computation. - let mut num_gt_results = 0; - let mut num_reported_results = 0; - - let mut scratch = HashSet::new(); - - std::iter::zip(results.iter(), groundtruth.iter()).for_each(|(result, gt)| { - scratch.clear(); - scratch.extend(result.iter().copied()); - num_reported_results += gt.iter().filter(|i| scratch.contains(i)).count(); - num_gt_results += gt.len(); - }); - - // Perform post-processing. - let average_precision = (num_reported_results as f64) / (num_gt_results as f64); - - Ok(APMetrics { - average_precision, - num_queries: results.len(), - }) -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use diskann_utils::views::Matrix; - - use super::*; - - pub(crate) fn compute_knn_recall( - results: StridedView<'_, u32>, - groundtruth: G, // StridedView - groundtruth_distances: Option>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result - where - G: ComputeKnnRecall + KnnRecall + Clone, - { - groundtruth.compute_knn_recall( - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } - - struct ExpectedRecall { - recall_k: usize, - recall_n: usize, - // Recall for each component. - components: Vec, - } - - impl ExpectedRecall { - fn new(recall_k: usize, recall_n: usize, components: Vec) -> Self { - assert!(recall_k <= recall_n); - components.iter().for_each(|x| { - assert!(*x <= recall_k); - }); - Self { - recall_k, - recall_n, - components, - } - } - - fn compute_recall(&self) -> f64 { - (self.components.iter().sum::() as f64) - / ((self.components.len() * self.recall_k) as f64) - } - } - - #[test] - fn test_happy_path() { - let groundtruth = Matrix::try_from( - vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0 - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - let distances = Matrix::try_from( - vec![ - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0 - 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1 - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2 - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - // Shift row 0 by one and row 1 by two. - let our_results = Matrix::try_from( - vec![ - 100, 0, 1, 2, 5, 6, // row 0 - 100, 101, 7, 8, 9, 10, // row 1 - 0, 1, 2, 3, 4, 5, // row 2 - 0, 1, 2, 3, 4, 5, // row 3 - ] - .into(), - 4, - 6, - ) - .unwrap(); - - //---------// - // No Ties // - //---------// - let expected_no_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]), - ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]), - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]), - ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]), - ]; - let epsilon = 1e-6; // Define a small tolerance - - for (i, expected) in expected_no_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - None, - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - } - - //-----------// - // With Ties // - //-----------// - let expected_with_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in - ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]), - ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]), - ]; - - for (i, expected) in expected_with_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - assert_eq!(recall.by_query, Some(expected.components.clone())); - } - } - - #[test] - fn test_errors() { - // k greater than n - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 11, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RecallKAndNError(..))); - } - - // Unequal rows - { - let groundtruth = Matrix::::new(0, 11, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RowsMismatch(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::RowsMismatch(..) - )); - } - - // Not enough results - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 5); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - false, - false, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughResults(..))); - let _ = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - true, - false, - ); - } - - // Not enough groundtruth - { - let groundtruth = Matrix::::new(0, 10, 5); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); - } - - // Distance Row Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 9, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..))); - } - - // Distance Cols Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 10, 9); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!( - err, - ComputeRecallError::NotEnoughGroundTruthDistances(..) - )); +impl From<&benchmark_core::recall::AveragePrecisionMetrics> for AveragePrecisionMetrics { + fn from(m: &benchmark_core::recall::AveragePrecisionMetrics) -> Self { + Self { + num_queries: m.num_queries, + average_precision: m.average_precision, } } } diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index a21d3f520..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -/// Create a multi-threaded runtime with `num_threads`. +/// Create a generic multi-threaded runtime with `num_threads`. pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { Ok(tokio::runtime::Builder::new_multi_thread() .worker_threads(num_threads) @@ -18,21 +18,3 @@ pub(crate) fn block_on(future: F) -> F::Output { .expect("current thread runtime initialization failed") .block_on(future) } - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_runtimes() { - for num_threads in [1, 2, 4, 8] { - let rt = runtime(num_threads).unwrap(); - let metrics = rt.metrics(); - assert_eq!(metrics.num_workers(), num_threads); - } - } -} diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index 9eb7ff500..f0d99bfd9 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,6 +5,7 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index c042d8338..5e9e3a9c1 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,8 +15,10 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, + } + /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 16404af4b..83c9f80f9 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&i, separator), out, separator); + flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); } } _ => { diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 7f0cb203a..1b4b3408e 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,14 +5,13 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true -license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] byteorder.workspace = true clap = { workspace = true, features = ["derive"] } -diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` +diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } diskann-disk = { workspace = true } diskann-utils = { workspace = true } @@ -24,31 +23,24 @@ ordered-float = "4.2.0" rand_distr.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive"] } -toml = "0.8.13" +serde_json.workspace = true bincode.workspace = true opentelemetry.workspace = true -opentelemetry_sdk.workspace = true -csv.workspace = true -tokio = { workspace = true, features = ["full"] } -arc-swap.workspace = true diskann-quantization = { workspace = true } diskann = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } tracing.workspace = true bit-set.workspace = true anyhow.workspace = true -serde_json.workspace = true itertools.workspace = true diskann-label-filter.workspace = true [dev-dependencies] rstest.workspace = true -assert_ok = "1.0.2" -# Use virtual-storage for integration tests -diskann-disk = { path = "../diskann-disk", features = ["virtual_storage"] } vfs = { workspace = true } -ureq = { version = "3.0.11", default-features = false, features = ["native-tls", "gzip"] } -diskann-providers = { path = "../diskann-providers", default-features = false, features = ["testing", "virtual_storage"] } +diskann-providers = { workspace = true, default-features = false, features = [ + "virtual_storage", +] } diskann-utils = { workspace = true, features = ["testing"] } [features] diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e96f7ae8f..31e69b2b2 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -25,18 +25,97 @@ use diskann_utils::views::Matrix; use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; +use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; +/// Expands a JSON object with array-valued fields into multiple objects with scalar values. +/// For example: {"country": ["AU", "NZ"], "year": 2007} +/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// +/// If multiple fields have arrays, all combinations are generated. +fn expand_array_fields(value: &Value) -> Vec { + match value { + Value::Object(map) => { + // Start with a single empty object + let mut results: Vec> = vec![Map::new()]; + + for (key, val) in map.iter() { + if let Value::Array(arr) = val { + // Expand: for each existing result, create copies for each array element + let mut new_results: Vec> = Vec::new(); + for existing in results.iter() { + for item in arr.iter() { + let mut new_map: Map = existing.clone(); + new_map.insert(key.clone(), item.clone()); + new_results.push(new_map); + } + } + // If array is empty, keep existing results without this key + if !arr.is_empty() { + results = new_results; + } + } else { + // Non-array field: add to all existing results + for existing in results.iter_mut() { + existing.insert(key.clone(), val.clone()); + } + } + } + + results.into_iter().map(Value::Object).collect() + } + // If not an object, return as-is + _ => vec![value.clone()], + } +} + +/// Evaluates a query expression against a label, expanding array fields first. +/// Returns true if any expanded variant matches the query. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + let expanded = expand_array_fields(label); + expanded.iter().any(|item| eval_query_expr(query_expr, item)) +} + pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; + tracing::info!( + "Loaded {} base labels from {}", + base_labels.len(), + base_label_filename + ); + + // Print first few base labels for debugging + for (i, label) in base_labels.iter().take(3).enumerate() { + tracing::debug!( + "Base label sample [{}]: doc_id={}, label={}", + i, + label.doc_id, + label.label + ); + } // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; + tracing::info!( + "Loaded {} queries from {}", + parsed_queries.len(), + query_label_filename + ); + + // Print first few queries for debugging + for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { + tracing::debug!( + "Query sample [{}]: query_id={}, expr={:?}", + i, + query_id, + query_expr + ); + } // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -45,7 +124,15 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - if eval_query_expr(query_expr, &base_label.label) { + // Handle case where base_label.label is an array - check if any element matches + // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) + let matches = if let Some(array) = base_label.label.as_array() { + array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + } else { + eval_query_with_array_expansion(query_expr, &base_label.label) + }; + + if matches { bitmap.insert(base_label.doc_id); } } @@ -53,6 +140,38 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); + // Debug: Print match statistics for each query + let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); + let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); + tracing::info!( + "Filter matching summary: {} total matches across {} queries ({} queries have matches)", + total_matches, + query_bitmaps.len(), + queries_with_matches + ); + + // Print per-query match counts + for (i, bitmap) in query_bitmaps.iter().enumerate() { + if i < 10 || bitmap.is_empty() { + tracing::debug!( + "Query {}: {} base vectors matched the filter", + i, + bitmap.len() + ); + } + } + + // If no matches, print more diagnostic info + if total_matches == 0 { + tracing::warn!("WARNING: No base vectors matched any query filters!"); + tracing::warn!("This could indicate a format mismatch between base labels and query filters."); + + // Try to identify what keys exist in base labels vs queries + if let Some(first_label) = base_labels.first() { + tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + } + } + Ok(query_bitmaps) } @@ -195,6 +314,44 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); + // Debug: Print top K matches for each query + tracing::info!( + "Ground truth computed for {} queries with recall_at={}", + ground_truth.len(), + recall_at + ); + for (query_idx, npq) in ground_truth.iter().enumerate() { + let neighbors: Vec<_> = npq.iter().collect(); + let neighbor_count = neighbors.len(); + + if query_idx < 10 { + // Print top K IDs and distances for first 10 queries + let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); + let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); + tracing::debug!( + "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", + query_idx, + neighbor_count, + top_ids, + top_dists + ); + } + + if neighbor_count == 0 { + tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); + } + } + + // Summary stats + let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); + let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + tracing::info!( + "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", + total_neighbors, + queries_with_neighbors, + ground_truth.len() - queries_with_neighbors + ); + if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter() From ec2091ffb510a245970e0fdec83bb46955cbefe7 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 11:08:49 +0530 Subject: [PATCH 2/3] Before merging with main --- Cargo.lock | 340 ------------------ .../src/utils/flatten_utils.rs | 2 +- 2 files changed, 1 insertion(+), 341 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e80330d7d..665b6c6df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "adler2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" - [[package]] name = "aho-corasick" version = "1.1.4" @@ -103,30 +97,12 @@ dependencies = [ "rustversion", ] -[[package]] -name = "assert_ok" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c770ef7624541db11cce57929f00e737fef89157d7c1cd1977b20ee74fefd84" - [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64ct" -version = "1.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" - [[package]] name = "bf-tree" version = "0.4.5" @@ -225,16 +201,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "cc" -version = "1.2.52" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" -dependencies = [ - "find-msvc-tools", - "shlex", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -327,31 +293,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - [[package]] name = "criterion" version = "0.5.1" @@ -419,43 +360,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" -[[package]] -name = "csv" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde_core", -] - -[[package]] -name = "csv-core" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" -dependencies = [ - "memchr", -] - [[package]] name = "defer" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "930c7171c8df9fb1782bdf9b918ed9ed2d33d1d22300abb754f9085bc48bf8e8" -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "pem-rfc7468", - "zeroize", -] - [[package]] name = "derive_more" version = "2.1.1" @@ -718,14 +628,11 @@ name = "diskann-tools" version = "0.41.0" dependencies = [ "anyhow", - "arc-swap", - "assert_ok", "bincode", "bit-set", "bytemuck", "byteorder", "clap", - "csv", "diskann", "diskann-disk", "diskann-label-filter", @@ -737,7 +644,6 @@ dependencies = [ "itertools 0.13.0", "num_cpus", "opentelemetry", - "opentelemetry_sdk", "ordered-float", "rand 0.9.2", "rand_distr", @@ -745,11 +651,8 @@ dependencies = [ "rstest", "serde", "serde_json", - "tokio", - "toml 0.8.23", "tracing", "tracing-subscriber", - "ureq", "vfs", ] @@ -956,12 +859,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "find-msvc-tools" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" - [[package]] name = "flatbuffers" version = "25.12.19" @@ -972,16 +869,6 @@ dependencies = [ "rustc_version", ] -[[package]] -name = "flate2" -version = "1.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1000,21 +887,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "futures" version = "0.3.31" @@ -1322,22 +1194,6 @@ dependencies = [ "paste", ] -[[package]] -name = "http" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" -dependencies = [ - "bytes", - "itoa", -] - -[[package]] -name = "httparse" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" - [[package]] name = "iai-callgrind" version = "0.14.2" @@ -1559,16 +1415,6 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" -[[package]] -name = "miniz_oxide" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" -dependencies = [ - "adler2", - "simd-adler32", -] - [[package]] name = "mio" version = "1.1.1" @@ -1650,23 +1496,6 @@ dependencies = [ "nano-gemm-core", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "never-say-never" version = "6.6.666" @@ -1730,50 +1559,6 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" -[[package]] -name = "openssl" -version = "0.10.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" -dependencies = [ - "bitflags 2.10.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.113", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - -[[package]] -name = "openssl-sys" -version = "0.9.111" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "opentelemetry" version = "0.30.0" @@ -1842,15 +1627,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.3.2" @@ -1889,12 +1665,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - [[package]] name = "plotters" version = "0.3.7" @@ -2326,15 +2096,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "rustls-pki-types" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" -dependencies = [ - "zeroize", -] - [[package]] name = "rustversion" version = "1.0.22" @@ -2384,15 +2145,6 @@ dependencies = [ "sdd", ] -[[package]] -name = "schannel" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -2405,29 +2157,6 @@ version = "4.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63d45f3526312c9c90d717aac28d37010e623fbd7ca6f21503e69784e86f40" -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags 2.10.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.27" @@ -2523,12 +2252,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -2539,12 +2262,6 @@ dependencies = [ "libc", ] -[[package]] -name = "simd-adler32" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" - [[package]] name = "slab" version = "0.4.11" @@ -2922,42 +2639,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" -[[package]] -name = "ureq" -version = "3.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" -dependencies = [ - "base64", - "der", - "flate2", - "log", - "native-tls", - "percent-encoding", - "rustls-pki-types", - "ureq-proto", - "utf-8", - "webpki-root-certs", -] - -[[package]] -name = "ureq-proto" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" -dependencies = [ - "base64", - "http", - "httparse", - "log", -] - -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "utf8parse" version = "0.2.2" @@ -2970,12 +2651,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.5" @@ -3090,15 +2765,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-root-certs" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "winapi-util" version = "0.1.11" @@ -3305,12 +2971,6 @@ dependencies = [ "syn 2.0.113", ] -[[package]] -name = "zeroize" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" - [[package]] name = "zmij" version = "1.0.11" diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 83c9f80f9..16404af4b 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); + flatten_recursive(item, stack.push(&i, separator), out, separator); } } _ => { From a949024b8283390d49b43061973abbb74653d17d Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Mon, 16 Feb 2026 14:40:55 +0530 Subject: [PATCH 3/3] Working version of inline beta search --- .../example/document-filter.json | 34 + .../src/backend/document_index/benchmark.rs | 1038 ++++++++++++++ .../src/backend/document_index/mod.rs | 13 + diskann-benchmark/src/backend/index/result.rs | 13 + diskann-benchmark/src/backend/mod.rs | 2 + .../src/inputs/document_index.rs | 177 +++ diskann-benchmark/src/inputs/mod.rs | 2 + diskann-benchmark/src/utils/recall.rs | 1 + diskann-benchmark/src/utils/tokio.rs | 7 + diskann-label-filter/src/attribute.rs | 1 - diskann-label-filter/src/document.rs | 4 +- .../ast_label_id_mapper.rs | 15 +- .../document_insert_strategy.rs | 274 ++++ .../document_provider.rs | 2 +- .../encoded_filter_expr.rs | 19 +- .../roaring_attribute_store.rs | 2 +- .../encoded_document_accessor.rs | 14 +- .../inline_beta_search/inline_beta_filter.rs | 67 +- diskann-label-filter/src/lib.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 - .../provider/async_/inmem/full_precision.rs | 1218 +++++++++-------- diskann-tools/src/utils/ground_truth.rs | 37 +- .../disk_index_search/data.256.label.jsonl | 4 +- 23 files changed, 2307 insertions(+), 640 deletions(-) create mode 100644 diskann-benchmark/example/document-filter.json create mode 100644 diskann-benchmark/src/backend/document_index/benchmark.rs create mode 100644 diskann-benchmark/src/backend/document_index/mod.rs create mode 100644 diskann-benchmark/src/inputs/document_index.rs create mode 100644 diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json new file mode 100644 index 000000000..d6e9e13b2 --- /dev/null +++ b/diskann-benchmark/example/document-filter.json @@ -0,0 +1,34 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "document-index-build", + "content": { + "build": { + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_labels": "data.256.label.jsonl", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2 + }, + "search": { + "queries": "disk_index_sample_query_10pts.fbin", + "query_predicates": "query.10.label.jsonl", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "beta": 0.5, + "runs": [ + { + "search_n": 20, + "search_l": [20, 30, 40], + "recall_k": 10 + } + ] + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs new file mode 100644 index 000000000..dffe669ff --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -0,0 +1,1038 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Benchmark for DocumentInsertStrategy which allows inserting Documents +//! (vector + attributes) into a DiskANN index built with DocumentProvider. +//! Also benchmarks filtered search using InlineBetaStrategy. + +use std::io::Write; +use std::num::NonZeroUsize; +use std::path::Path; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use anyhow::Result; +use diskann::{ + graph::{ + config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, + search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, + }, + provider::DefaultContext, + utils::{async_tools, IntoUsize}, +}; +use diskann_benchmark_runner::{ + dispatcher::{DispatchRule, FailureScore, MatchScore}, + output::Output, + registry::Benchmarks, + utils::{datatype::DataType, percentiles, MicroSeconds}, + Any, Checkpoint, +}; +use diskann_label_filter::{ + attribute::{Attribute, AttributeValue}, + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::DocumentInsertStrategy, document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::inline_beta_filter::InlineBetaStrategy, + query::FilteredQuery, + read_and_parse_queries, read_baselabels, ASTExpr, +}; +use diskann_providers::model::graph::provider::async_::{ + common::{self, NoStore, TableBasedDeletes}, + inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, +}; +use diskann_utils::views::Matrix; +use indicatif::{ProgressBar, ProgressStyle}; +use serde::Serialize; + +use crate::{ + inputs::document_index::DocumentIndexBuild, + utils::{ + self, + datafiles::{self, BinFile}, + recall, + }, +}; + +/// Register the document index benchmarks. +pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { + benchmarks.register::>( + "document-index-build", + |job, checkpoint, out| { + let stats = job.run(checkpoint, out)?; + Ok(serde_json::to_value(stats)?) + }, + ); +} + +/// Document index benchmark job. +pub(super) struct DocumentIndexJob<'a> { + input: &'a DocumentIndexBuild, +} + +impl<'a> DocumentIndexJob<'a> { + fn new(input: &'a DocumentIndexBuild) -> Self { + Self { input } + } +} + +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static> { + type Type<'a> = DocumentIndexJob<'a>; +} + +// Dispatch from the concrete input type +impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { + type Error = std::convert::Infallible; + + fn try_match(_from: &&'a DocumentIndexBuild) -> Result { + Ok(MatchScore(1)) + } + + fn convert(from: &'a DocumentIndexBuild) -> Result { + Ok(DocumentIndexJob::new(from)) + } + + fn description( + f: &mut std::fmt::Formatter<'_>, + _from: Option<&&'a DocumentIndexBuild>, + ) -> std::fmt::Result { + writeln!(f, "tag: \"{}\"", DocumentIndexBuild::tag()) + } +} + +// Central dispatch mapping from Any +impl<'a> DispatchRule<&'a Any> for DocumentIndexJob<'a> { + type Error = anyhow::Error; + + fn try_match(from: &&'a Any) -> Result { + from.try_match::() + } + + fn convert(from: &'a Any) -> Result { + from.convert::() + } + + fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { + Any::description::(f, from, DocumentIndexBuild::tag()) + } +} +/// Convert a HashMap to Vec +fn hashmap_to_attributes(map: std::collections::HashMap) -> Vec { + map.into_iter() + .map(|(k, v)| Attribute::from_value(k, v)) + .collect() +} + +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> usize +where + T: bytemuck::Pod + Copy + 'static, +{ + use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; + + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return 0; + } + + // Compute the centroid (mean of all rows) as f64 for precision + let mut sum = vec![0.0f64; dim]; + for i in 0..data.nrows() { + let row = data.row(i); + for (j, &v) in row.iter().enumerate() { + // Convert T to f64 for summation using bytemuck + let f64_val: f64 = if std::any::TypeId::of::() == std::any::TypeId::of::() { + let f32_val: f32 = bytemuck::cast(v); + f32_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f64 + } else { + 0.0 + }; + sum[j] += f64_val; + } + } + + // Convert centroid to f32 and compute distances + let centroid_f32: Vec = sum + .iter() + .map(|s| (s / data.nrows() as f64) as f32) + .collect(); + + // Find the row closest to the centroid + let mut min_dist = f32::MAX; + let mut medoid_idx = 0; + for i in 0..data.nrows() { + let row = data.row(i); + let row_f32: Vec = row + .iter() + .map(|&v| { + if std::any::TypeId::of::() == std::any::TypeId::of::() { + bytemuck::cast(v) + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f32 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f32 + } else { + 0.0 + } + }) + .collect(); + let d = SquaredL2::evaluate(centroid_f32.as_slice(), row_f32.as_slice()); + if d < min_dist { + min_dist = d; + medoid_idx = i; + } + } + + medoid_idx +} + +impl<'a> DocumentIndexJob<'a> { + fn run( + self, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + // Print the input description + writeln!(output, "{}", self.input)?; + + let build = &self.input.build; + + // Dispatch based on data type - retain original type without conversion + match build.data_type { + DataType::Float32 => self.run_typed::(output), + DataType::UInt8 => self.run_typed::(output), + DataType::Int8 => self.run_typed::(output), + _ => Err(anyhow::anyhow!( + "Unsupported data type: {:?}. Supported types: float32, uint8, int8.", + build.data_type + )), + } + } + + fn run_typed(self, mut output: &mut dyn Output) -> Result + where + T: bytemuck::Pod + Copy + Send + Sync + 'static + std::fmt::Debug, + T: diskann::graph::SampleableForStart + diskann_utils::future::AsyncFriendly, + T: diskann::utils::VectorRepr + diskann_utils::sampling::WithApproximateNorm, + { + let build = &self.input.build; + + // 1. Load vectors from data file in the original data type + writeln!(output, "Loading vectors ({})...", build.data_type)?; + let timer = std::time::Instant::now(); + let data_path: &Path = build.data.as_ref(); + writeln!(output, "Data path is: {}", data_path.to_string_lossy())?; + let data: Matrix = datafiles::load_dataset(BinFile(data_path))?; + let data_load_time: MicroSeconds = timer.elapsed().into(); + let num_vectors = data.nrows(); + let dim = data.ncols(); + writeln!( + output, + " Loaded {} vectors of dimension {}", + num_vectors, dim + )?; + + // 2. Load and parse labels from the data_labels file + writeln!(output, "Loading labels...")?; + let timer = std::time::Instant::now(); + let label_path: &Path = build.data_labels.as_ref(); + let labels = read_baselabels(label_path)?; + let label_load_time: MicroSeconds = timer.elapsed().into(); + let label_count = labels.len(); + writeln!(output, " Loaded {} label documents", label_count)?; + + if num_vectors != label_count { + return Err(anyhow::anyhow!( + "Mismatch: {} vectors but {} label documents", + num_vectors, + label_count + )); + } + + // Convert labels to attribute vectors + let attributes: Vec> = labels + .into_iter() + .map(|doc| hashmap_to_attributes(doc.flatten_metadata_with_separator(""))) + .collect(); + + // 3. Create the index configuration + let metric = build.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = ConfigBuilder::new( + build.max_degree, // pruned_degree + MaxDegree::Same, // max_degree + build.l_build, // l_build + prune_kind, // prune_kind + ); + config_builder.alpha(build.alpha); + let config = config_builder.build()?; + + // 4. Create the data provider directly + writeln!(output, "Creating index...")?; + let params = DefaultProviderParameters { + max_points: num_vectors, + frozen_points: diskann::utils::ONE, + metric, + dim, + prefetch_lookahead: None, + prefetch_cache_line_level: None, + max_degree: build.max_degree as u32, + }; + + // Create the underlying provider + let fp_precursor = CreateFullPrecision::::new(dim, None); + let inner_provider = + DefaultProvider::new_empty(params, fp_precursor, NoStore, TableBasedDeletes)?; + + // Set start points using medoid strategy + let start_points = StartPointStrategy::Medoid + .compute(data.as_view()) + .map_err(|e| anyhow::anyhow!("Failed to compute start points: {}", e))?; + inner_provider.set_start_points(start_points.row_iter())?; + + // 5. Create DocumentProvider wrapping the inner provider + let attribute_store = RoaringAttributeStore::::new(); + + // Store attributes for the start point (medoid) + // Start points are stored at indices num_vectors..num_vectors+frozen_points + let medoid_idx = compute_medoid_index(&data); + let start_point_id = num_vectors as u32; // Start points begin at max_points + let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); + use diskann_label_filter::traits::attribute_store::AttributeStore; + attribute_store.set_element(&start_point_id, &medoid_attrs)?; + + let doc_provider = DocumentProvider::new(inner_provider, attribute_store); + + // Create a new DiskANNIndex with DocumentProvider + let doc_index = Arc::new(DiskANNIndex::new(config, doc_provider, None)); + + // 6. Build index by inserting vectors and attributes (parallel) + writeln!( + output, + "Building index with {} vectors using {} threads...", + num_vectors, build.num_threads + )?; + let timer = std::time::Instant::now(); + + let insert_strategy: DocumentInsertStrategy<_, [T]> = + DocumentInsertStrategy::new(common::FullPrecision); + let rt = utils::tokio::runtime(build.num_threads)?; + + // Create control block for parallel work distribution + let data_arc = Arc::new(data); + let attributes_arc = Arc::new(attributes); + let control_block = DocumentControlBlock::new( + data_arc.clone(), + attributes_arc.clone(), + output.draw_target(), + )?; + + let num_tasks = build.num_threads; + let insert_latencies = rt.block_on(async { + let tasks: Vec<_> = (0..num_tasks) + .map(|_| { + let block = control_block.clone(); + let index = doc_index.clone(); + let strategy = insert_strategy; + tokio::spawn(async move { + let mut latencies = Vec::::new(); + let ctx = DefaultContext; + loop { + match block.next() { + Some((id, vector, attrs)) => { + let doc = Document::new(vector, attrs); + let start = std::time::Instant::now(); + let result = + index.insert(strategy, &ctx, &(id as u32), &doc).await; + latencies.push(MicroSeconds::from(start.elapsed())); + + if let Err(e) = result { + block.cancel(); + return Err(e); + } + } + None => return Ok(latencies), + } + } + }) + }) + .collect(); + + // Collect results from all tasks + let mut all_latencies = Vec::with_capacity(num_vectors); + for task in tasks { + let task_latencies = task.await??; + all_latencies.extend(task_latencies); + } + Ok::<_, anyhow::Error>(all_latencies) + })?; + + let build_time: MicroSeconds = timer.elapsed().into(); + writeln!(output, " Index built in {} s", build_time.as_seconds())?; + + let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies.clone())?; + // ===================== + // Search Phase + // ===================== + let search_input = &self.input.search; + + // Load query vectors (same type as data for compatible distance computation) + writeln!(output, "\nLoading query vectors...")?; + let query_path: &Path = search_input.queries.as_ref(); + let queries: Matrix = datafiles::load_dataset(BinFile(query_path))?; + let num_queries = queries.nrows(); + writeln!(output, " Loaded {} queries", num_queries)?; + + // Load and parse query predicates + writeln!(output, "Loading query predicates...")?; + let predicate_path: &Path = search_input.query_predicates.as_ref(); + let parsed_predicates = read_and_parse_queries(predicate_path)?; + writeln!(output, " Loaded {} predicates", parsed_predicates.len())?; + + if num_queries != parsed_predicates.len() { + return Err(anyhow::anyhow!( + "Mismatch: {} queries but {} predicates", + num_queries, + parsed_predicates.len() + )); + } + + // Load groundtruth + writeln!(output, "Loading groundtruth...")?; + let gt_path: &Path = search_input.groundtruth.as_ref(); + let groundtruth: Vec> = datafiles::load_range_groundtruth(BinFile(gt_path))?; + writeln!( + output, + " Loaded groundtruth with {} rows", + groundtruth.len() + )?; + + // Run filtered searches + writeln!( + output, + "\nRunning filtered searches (beta={})...", + search_input.beta + )?; + let mut search_results = Vec::new(); + + for num_threads in &search_input.num_threads { + for run in &search_input.runs { + for &search_l in &run.search_l { + writeln!( + output, + " threads={}, search_n={}, search_l={}...", + num_threads, run.search_n, search_l + )?; + + let search_run_result = run_filtered_search( + &doc_index, + &queries, + &parsed_predicates, + &groundtruth, + search_input.beta, + *num_threads, + run.search_n, + search_l, + run.recall_k, + search_input.reps, + )?; + + writeln!( + output, + " recall={:.4}, mean_qps={:.1}", + search_run_result.recall.average, + if search_run_result.qps.is_empty() { + 0.0 + } else { + search_run_result.qps.iter().sum::() + / search_run_result.qps.len() as f64 + } + )?; + + search_results.push(search_run_result); + } + } + } + + let stats = DocumentIndexStats { + num_vectors, + dim, + label_count, + data_load_time, + label_load_time, + build_time, + insert_latencies: insert_percentiles, + build_params: BuildParamsStats { + max_degree: build.max_degree, + l_build: build.l_build, + alpha: build.alpha, + }, + search: search_results, + }; + + writeln!(output, "\n{}", stats)?; + Ok(stats) + } +} +/// Local results from a partition of queries. +struct SearchLocalResults { + ids: Matrix, + distances: Vec>, + latencies: Vec, + comparisons: Vec, + hops: Vec, +} + +impl SearchLocalResults { + fn merge(all: &[SearchLocalResults]) -> anyhow::Result { + let first = all + .first() + .ok_or_else(|| anyhow::anyhow!("empty results"))?; + let num_ids = first.ids.ncols(); + let total_rows: usize = all.iter().map(|r| r.ids.nrows()).sum(); + + let mut ids = Matrix::new(0, total_rows, num_ids); + let mut output_row = 0; + for r in all { + for input_row in r.ids.row_iter() { + ids.row_mut(output_row).copy_from_slice(input_row); + output_row += 1; + } + } + + let mut distances = Vec::new(); + let mut latencies = Vec::new(); + let mut comparisons = Vec::new(); + let mut hops = Vec::new(); + for r in all { + distances.extend_from_slice(&r.distances); + latencies.extend_from_slice(&r.latencies); + comparisons.extend_from_slice(&r.comparisons); + hops.extend_from_slice(&r.hops); + } + + Ok(Self { + ids, + distances, + latencies, + comparisons, + hops, + }) + } +} + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, + beta: f32, + num_threads: NonZeroUsize, + search_n: usize, + search_l: usize, + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let rt = utils::tokio::runtime(num_threads.get())?; + let num_queries = queries.nrows(); + + let mut all_rep_results = Vec::with_capacity(reps.get()); + let mut rep_latencies = Vec::with_capacity(reps.get()); + + for _ in 0..reps.get() { + let start = std::time::Instant::now(); + let results = rt.block_on(run_search_parallel( + index.clone(), + queries, + predicates, + beta, + num_threads, + search_n, + search_l, + ))?; + rep_latencies.push(MicroSeconds::from(start.elapsed())); + all_rep_results.push(results); + } + + // Merge results from first rep for recall calculation + let merged = SearchLocalResults::merge(&all_rep_results[0])?; + + // Compute recall + let recall_metrics: recall::RecallMetrics = + (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); + + // Compute per-query details (only for queries with recall < 1) + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = merged + .ids + .row(query_idx) + .iter() + .copied() + .filter(|&id| id != u32::MAX) + .collect(); + let result_distances: Vec = merged + .distances + .get(query_idx) + .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) + .unwrap_or_default(); + // Only keep top 20 from ground truth + let gt_ids: Vec = groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + // Compute per-query recall: intersection of result_ids with gt_ids / recall_k + let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; + + // Only include queries with imperfect recall + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = predicates[query_idx]; + let filter_str = format!("{:?}", ast_expr); + + Some(PerQueryDetails { + query_id: query_idx, + filter: filter_str, + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) + }) + .collect(); + + // Compute QPS from rep latencies + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); + + // Aggregate per-query latencies across all reps + let (all_latencies, all_cmps, all_hops): (Vec<_>, Vec<_>, Vec<_>) = all_rep_results + .iter() + .map(|results| { + let mut lat = Vec::new(); + let mut cmp = Vec::new(); + let mut hop = Vec::new(); + for r in results { + lat.extend_from_slice(&r.latencies); + cmp.extend_from_slice(&r.comparisons); + hop.extend_from_slice(&r.hops); + } + (lat, cmp, hop) + }) + .fold( + (Vec::new(), Vec::new(), Vec::new()), + |(mut a, mut b, mut c): (Vec, Vec, Vec), (x, y, z)| { + a.extend(x); + b.extend(y); + c.extend(z); + (a, b, c) + }, + ); + + let mut query_latencies = all_latencies; + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut query_latencies)?; + + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; + + Ok(SearchRunStats { + num_threads: num_threads.get(), + num_queries, + search_n, + search_l, + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) +} +async fn run_search_parallel( + index: Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + beta: f32, + num_tasks: NonZeroUsize, + search_n: usize, + search_l: usize, +) -> anyhow::Result> +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let num_queries = queries.nrows(); + + // Plan query partitions + let partitions: Result, _> = (0..num_tasks.get()) + .map(|task_id| async_tools::partition(num_queries, num_tasks, task_id)) + .collect(); + let partitions = partitions?; + + // We need to clone data for each task + let queries_arc = Arc::new(queries.clone()); + let predicates_arc = Arc::new(predicates.to_vec()); + + let handles: Vec<_> = partitions + .into_iter() + .map(|range| { + let index = index.clone(); + let queries = queries_arc.clone(); + let predicates = predicates_arc.clone(); + tokio::spawn(async move { + run_search_local(index, queries, predicates, beta, range, search_n, search_l).await + }) + }) + .collect(); + + let mut results = Vec::new(); + for h in handles { + results.push(h.await??); + } + + Ok(results) +} + +async fn run_search_local( + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, + range: std::ops::Range, + search_n: usize, + search_l: usize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let mut ids = Matrix::new(0, range.len(), search_n); + let mut all_distances: Vec> = Vec::with_capacity(range.len()); + let mut latencies = Vec::with_capacity(range.len()); + let mut comparisons = Vec::with_capacity(range.len()); + let mut hops = Vec::with_capacity(range.len()); + + let ctx = DefaultContext; + let search_params = SearchParams::new_default(search_n, search_l)?; + + for (output_idx, query_idx) in range.enumerate() { + let query_vec = queries.row(query_idx); + let (_, ref ast_expr) = predicates[query_idx]; + + let strategy = InlineBetaStrategy::new(beta, common::FullPrecision); + let query_vec_owned = query_vec.to_vec(); + let filtered_query: FilteredQuery> = + FilteredQuery::new(query_vec_owned, ast_expr.clone()); + + let start = std::time::Instant::now(); + + let mut distances = vec![0.0f32; search_n]; + let result_ids = ids.row_mut(output_idx); + let mut result_buffer = search_output_buffer::IdDistance::new(result_ids, &mut distances); + + let stats = index + .search( + &strategy, + &ctx, + &filtered_query, + &search_params, + &mut result_buffer, + ) + .await?; + + let result_count = stats.result_count.into_usize(); + result_ids[result_count..].fill(u32::MAX); + distances[result_count..].fill(f32::MAX); + + latencies.push(MicroSeconds::from(start.elapsed())); + comparisons.push(stats.cmps); + hops.push(stats.hops); + all_distances.push(distances); + } + + Ok(SearchLocalResults { + ids, + distances: all_distances, + latencies, + comparisons, + hops, + }) +} +#[derive(Debug, Serialize)] +pub struct BuildParamsStats { + pub max_degree: usize, + pub l_build: usize, + pub alpha: f32, +} + +/// Helper module for serializing arrays as compact single-line JSON strings +mod compact_array { + use serde::Serializer; + + pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } + + pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub result_ids: Vec, + #[serde(serialize_with = "compact_array::serialize_f32_vec")] + pub result_distances: Vec, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub groundtruth_ids: Vec, +} + +/// Results from a single search configuration (one search_l value). +#[derive(Debug, Serialize)] +pub struct SearchRunStats { + pub num_threads: usize, + pub num_queries: usize, + pub search_n: usize, + pub search_l: usize, + pub recall: recall::RecallMetrics, + pub qps: Vec, + pub wall_clock_time: Vec, + pub mean_latency: f64, + pub p90_latency: MicroSeconds, + pub p99_latency: MicroSeconds, + pub mean_cmps: f32, + pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, +} + +#[derive(Debug, Serialize)] +pub struct DocumentIndexStats { + pub num_vectors: usize, + pub dim: usize, + pub label_count: usize, + pub data_load_time: MicroSeconds, + pub label_load_time: MicroSeconds, + pub build_time: MicroSeconds, + pub insert_latencies: percentiles::Percentiles, + pub build_params: BuildParamsStats, + pub search: Vec, +} + +impl std::fmt::Display for DocumentIndexStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build Stats:")?; + writeln!(f, " Vectors: {} x {}", self.num_vectors, self.dim)?; + writeln!(f, " Label Count: {}", self.label_count)?; + writeln!( + f, + " Data Load Time: {} s", + self.data_load_time.as_seconds() + )?; + writeln!( + f, + " Label Load Time: {} s", + self.label_load_time.as_seconds() + )?; + writeln!(f, " Total Build Time: {} s", self.build_time.as_seconds())?; + writeln!(f, " Insert Latencies:")?; + writeln!(f, " Mean: {} us", self.insert_latencies.mean)?; + writeln!(f, " P50: {} us", self.insert_latencies.median)?; + writeln!(f, " P90: {} us", self.insert_latencies.p90)?; + writeln!(f, " P99: {} us", self.insert_latencies.p99)?; + writeln!(f, " Build Parameters:")?; + writeln!(f, " max_degree (R): {}", self.build_params.max_degree)?; + writeln!(f, " l_build (L): {}", self.build_params.l_build)?; + writeln!(f, " alpha: {}", self.build_params.alpha)?; + + if !self.search.is_empty() { + writeln!(f, "\nFiltered Search Results:")?; + writeln!( + f, + " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", + "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + )?; + for s in &self.search { + let mean_qps = if s.qps.is_empty() { + 0.0 + } else { + s.qps.iter().sum::() / s.qps.len() as f64 + }; + let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); + let mean_wall_clock = if s.wall_clock_time.is_empty() { + 0.0 + } else { + s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 + }; + writeln!( + f, + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", + s.search_l, + s.search_n, + s.mean_cmps, + s.mean_hops, + mean_qps, + max_qps, + s.mean_latency, + s.p99_latency, + s.recall.average, + s.num_threads, + s.num_queries, + mean_wall_clock + )?; + } + } + Ok(()) + } +} + +// ================================ +// Parallel Build Support +// ================================ + +fn make_progress_bar( + nrows: usize, + draw_target: indicatif::ProgressDrawTarget, +) -> anyhow::Result { + let progress = ProgressBar::with_draw_target(Some(nrows as u64), draw_target); + progress.set_style(ProgressStyle::with_template( + "Building [{elapsed_precise}] {wide_bar} {percent}", + )?); + Ok(progress) +} + +/// Control block for parallel document insertion. +/// Manages work distribution and progress tracking across multiple tasks. +struct DocumentControlBlock { + data: Arc>, + attributes: Arc>>, + position: AtomicUsize, + cancel: AtomicBool, + progress: ProgressBar, +} + +impl DocumentControlBlock { + fn new( + data: Arc>, + attributes: Arc>>, + draw_target: indicatif::ProgressDrawTarget, + ) -> anyhow::Result> { + let nrows = data.nrows(); + Ok(Arc::new(Self { + data, + attributes, + position: AtomicUsize::new(0), + cancel: AtomicBool::new(false), + progress: make_progress_bar(nrows, draw_target)?, + })) + } + + /// Return the next document data to insert: (id, vector_slice, attributes). + fn next(&self) -> Option<(usize, &[T], Vec)> { + let cancel = self.cancel.load(Ordering::Relaxed); + if cancel { + None + } else { + let i = self.position.fetch_add(1, Ordering::Relaxed); + match self.data.get_row(i) { + Some(row) => { + let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + self.progress.inc(1); + Some((i, row, attrs)) + } + None => None, + } + } + } + + /// Tell all users of the control block to cancel and return early. + fn cancel(&self) { + self.cancel.store(true, Ordering::Relaxed); + } +} + +impl Drop for DocumentControlBlock { + fn drop(&mut self) { + self.progress.finish(); + } +} diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs new file mode 100644 index 000000000..9937590cc --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -0,0 +1,13 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Backend benchmark implementation for document index with label filters. +//! +//! This benchmark tests the DocumentInsertStrategy which enables inserting +//! Document objects (vector + attributes) into a DiskANN index. + +mod benchmark; + +pub(crate) use benchmark::register_benchmarks; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index c7e2ab75c..21d74f915 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -109,6 +109,7 @@ impl std::fmt::Display for AggregatedSearchResults { #[derive(Debug, Serialize)] pub(super) struct SearchResults { pub(super) num_tasks: usize, + pub(super) num_queries: usize, pub(super) search_n: usize, pub(super) search_l: usize, pub(super) qps: Vec, @@ -143,6 +144,7 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), + num_queries: recall.num_queries, search_n: parameters.k_value, search_l: parameters.l_value, qps, @@ -182,6 +184,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] } else { &[ @@ -194,6 +198,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] }; @@ -237,6 +243,13 @@ where ); row.insert(format!("{:3}", r.recall.average), col_idx + 7); row.insert(r.num_tasks, col_idx + 8); + row.insert(r.num_queries, col_idx + 9); + let mean_wall_clock = if r.search_latencies.is_empty() { + 0.0 + } else { + r.search_latencies.iter().map(|t| t.as_seconds()).sum::() / r.search_latencies.len() as f64 + }; + row.insert(format!("{:.3}", mean_wall_clock), col_idx + 10); }); write!(f, "{}", table) diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..5dc1967de 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -4,6 +4,7 @@ */ mod disk_index; +mod document_index; mod exhaustive; mod filters; mod index; @@ -13,4 +14,5 @@ pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::regis disk_index::register_benchmarks(registry); index::register_benchmarks(registry); filters::register_benchmarks(registry); + document_index::register_benchmarks(registry); } diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs new file mode 100644 index 000000000..b1a36e48a --- /dev/null +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -0,0 +1,177 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Input types for document index benchmarks using DocumentInsertStrategy. + +use std::num::NonZeroUsize; + +use anyhow::Context; +use diskann_benchmark_runner::{ + files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, +}; +use serde::{Deserialize, Serialize}; + +use super::async_::GraphSearch; +use crate::inputs::{as_input, Example, Input}; + +////////////// +// Registry // +////////////// + +as_input!(DocumentIndexBuild); + +pub(super) fn register_inputs( + registry: &mut diskann_benchmark_runner::registry::Inputs, +) -> anyhow::Result<()> { + registry.register(Input::::new())?; + Ok(()) +} + +/// Build parameters for document index construction. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentBuildParams { + pub(crate) data_type: DataType, + pub(crate) data: InputFile, + pub(crate) data_labels: InputFile, + pub(crate) distance: crate::utils::SimilarityMeasure, + pub(crate) max_degree: usize, + pub(crate) l_build: usize, + pub(crate) alpha: f32, + #[serde(default = "default_num_threads")] + pub(crate) num_threads: usize, +} + +fn default_num_threads() -> usize { + 1 +} + +impl CheckDeserialization for DocumentBuildParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.check_deserialization(checker)?; + self.data_labels.check_deserialization(checker)?; + if self.max_degree == 0 { + return Err(anyhow::anyhow!("max_degree must be > 0")); + } + if self.l_build == 0 { + return Err(anyhow::anyhow!("l_build must be > 0")); + } + if self.alpha <= 0.0 { + return Err(anyhow::anyhow!("alpha must be > 0")); + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentSearchParams { + pub(crate) queries: InputFile, + pub(crate) query_predicates: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) beta: f32, + #[serde(default = "default_reps")] + pub(crate) reps: NonZeroUsize, + #[serde(default = "default_thread_counts")] + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, +} + +fn default_reps() -> NonZeroUsize { + NonZeroUsize::new(5).unwrap() +} +fn default_thread_counts() -> Vec { + vec![NonZeroUsize::new(1).unwrap()] +} + +impl CheckDeserialization for DocumentSearchParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.check_deserialization(checker)?; + self.query_predicates.check_deserialization(checker)?; + self.groundtruth.check_deserialization(checker)?; + if self.beta <= 0.0 || self.beta > 1.0 { + return Err(anyhow::anyhow!( + "beta must be in range (0, 1], got: {}", + self.beta + )); + } + for (i, run) in self.runs.iter_mut().enumerate() { + run.check_deserialization(checker) + .with_context(|| format!("search run {}", i))?; + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentIndexBuild { + pub(crate) build: DocumentBuildParams, + pub(crate) search: DocumentSearchParams, +} + +impl DocumentIndexBuild { + pub(crate) const fn tag() -> &'static str { + "document-index-build" + } +} + +impl CheckDeserialization for DocumentIndexBuild { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.build.check_deserialization(checker)?; + self.search.check_deserialization(checker)?; + Ok(()) + } +} + +impl Example for DocumentIndexBuild { + fn example() -> Self { + Self { + build: DocumentBuildParams { + data_type: DataType::Float32, + data: InputFile::new("data.fbin"), + data_labels: InputFile::new("data.label.jsonl"), + distance: crate::utils::SimilarityMeasure::SquaredL2, + max_degree: 32, + l_build: 50, + alpha: 1.2, + num_threads: 1, + }, + search: DocumentSearchParams { + queries: InputFile::new("queries.fbin"), + query_predicates: InputFile::new("query.label.jsonl"), + groundtruth: InputFile::new("groundtruth.bin"), + beta: 0.5, + reps: default_reps(), + num_threads: default_thread_counts(), + runs: vec![GraphSearch { + search_n: 10, + search_l: vec![20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} + +impl std::fmt::Display for DocumentIndexBuild { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build with Label Filters\n")?; + writeln!(f, "tag: \"{}\"", Self::tag())?; + writeln!( + f, + "\nBuild: data={}, labels={}, R={}, L={}, alpha={}", + self.build.data.display(), + self.build.data_labels.display(), + self.build.max_degree, + self.build.l_build, + self.build.alpha + )?; + writeln!( + f, + "Search: queries={}, beta={}", + self.search.queries.display(), + self.search.beta + )?; + Ok(()) + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index a0ae1a982..65de65a41 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod async_; pub(crate) mod disk; +pub(crate) mod document_index; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod save_and_load; @@ -16,6 +17,7 @@ pub(crate) fn register_inputs( exhaustive::register_inputs(registry)?; disk::register_inputs(registry)?; filters::register_inputs(registry)?; + document_index::register_inputs(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..50ef7e430 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub(crate) use benchmark_core::recall::knn; use diskann_benchmark_core as benchmark_core; use serde::Serialize; diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 72dbeb918..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,6 +3,13 @@ * Licensed under the MIT license. */ +/// Create a generic multi-threaded runtime with `num_threads`. +pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { + Ok(tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build()?) +} + /// Create a current-thread runtime and block on the given future. /// Only for functions that don't need multi-threading pub(crate) fn block_on(future: F) -> F::Output { diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index f0d99bfd9..9eb7ff500 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,7 +5,6 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; -use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/document.rs b/diskann-label-filter/src/document.rs index 31cad4772..5c817525c 100644 --- a/diskann-label-filter/src/document.rs +++ b/diskann-label-filter/src/document.rs @@ -8,12 +8,12 @@ use diskann_utils::reborrow::Reborrow; ///Simple container class that clients can use to /// supply diskann with a vector and its attributes -pub struct Document<'a, V> { +pub struct Document<'a, V: ?Sized> { vector: &'a V, attributes: Vec, } -impl<'a, V> Document<'a, V> { +impl<'a, V: ?Sized> Document<'a, V> { pub fn new(vector: &'a V, attributes: Vec) -> Self { Self { vector, attributes } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs index 0fa21cc02..8b39d8731 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs @@ -31,19 +31,14 @@ impl ASTLabelIdMapper { Self { attribute_map } } - fn _lookup( - encoder: &AttributeEncoder, - attribute: &Attribute, - field: &str, - op: &CompareOp, - ) -> ANNResult> { + fn _lookup(encoder: &AttributeEncoder, attribute: &Attribute) -> ANNResult> { match encoder.get(attribute) { Some(attribute_id) => Ok(ASTIdExpr::Terminal(attribute_id)), None => Err(ANNError::message( ANNErrorKind::Opaque, format!( - "{}+{} present in the query does not exist in the dataset.", - field, op + "{} present in the query does not exist in the dataset.", + attribute ), )), } @@ -120,10 +115,10 @@ impl ASTVisitor for ASTLabelIdMapper { if let Some(attribute) = label_or_none { match self.attribute_map.read() { - Ok(guard) => Self::_lookup(&guard, &attribute, field, op), + Ok(guard) => Self::_lookup(&guard, &attribute), Err(poison_error) => { let attr_map = poison_error.into_inner(); - Self::_lookup(&attr_map, &attribute, field, op) + Self::_lookup(&attr_map, &attribute) } } } else { diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs new file mode 100644 index 000000000..850976a32 --- /dev/null +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -0,0 +1,274 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +//! A strategy wrapper that enables insertion of [Document] objects into a +//! [DiskANNIndex] using a [DocumentProvider]. + +use std::marker::PhantomData; + +use diskann::{ + graph::{ + glue::{ + ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + }, + SearchOutputBuffer, + }, + neighbor::Neighbor, + provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + ANNResult, +}; + +use super::document_provider::DocumentProvider; +use crate::document::Document; +use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; + +/// A strategy wrapper that enables insertion of [Document] objects. +pub struct DocumentInsertStrategy { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl Clone for DocumentInsertStrategy { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + +impl Copy for DocumentInsertStrategy {} + +impl DocumentInsertStrategy { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } + + pub fn inner(&self) -> &Inner { + &self.inner + } +} + +/// Wrapper accessor for Document queries +pub struct DocumentSearchAccessor { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl DocumentSearchAccessor { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } +} + +impl HasId for DocumentSearchAccessor +where + Inner: HasId, + VT: ?Sized, +{ + type Id = Inner::Id; +} + +impl Accessor for DocumentSearchAccessor +where + Inner: Accessor, + VT: ?Sized, +{ + type ElementRef<'a> = Inner::ElementRef<'a>; + type Element<'a> + = Inner::Element<'a> + where + Self: 'a; + type Extended = Inner::Extended; + type GetError = Inner::GetError; + + fn get_element( + &mut self, + id: Self::Id, + ) -> impl std::future::Future, Self::GetError>> + Send { + self.inner.get_element(id) + } + + fn on_elements_unordered( + &mut self, + itr: Itr, + f: F, + ) -> impl std::future::Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + self.inner.on_elements_unordered(itr, f) + } +} + +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +where + Inner: BuildQueryComputer, + VT: ?Sized, +{ + type QueryComputerError = Inner::QueryComputerError; + type QueryComputer = Inner::QueryComputer; + + fn build_query_computer( + &self, + from: &Document<'doc, VT>, + ) -> Result { + self.inner.build_query_computer(from.vector()) + } +} + +impl<'this, Inner, VT> DelegateNeighbor<'this> for DocumentSearchAccessor +where + Inner: DelegateNeighbor<'this>, + VT: ?Sized, +{ + type Delegate = Inner::Delegate; + fn delegate_neighbor(&'this mut self) -> Self::Delegate { + self.inner.delegate_neighbor() + } +} + +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +where + Inner: ExpandBeam, + VT: ?Sized, +{ +} + +impl SearchExt for DocumentSearchAccessor +where + Inner: SearchExt, + VT: ?Sized, +{ + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + self.inner.starting_points() + } + fn terminate_early(&mut self) -> bool { + self.inner.terminate_early() + } +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct CopyIdsForDocument; + +impl<'doc, A, VT> SearchPostProcess> for CopyIdsForDocument +where + A: BuildQueryComputer>, + VT: ?Sized, +{ + type Error = std::convert::Infallible; + + fn post_process( + &self, + _accessor: &mut A, + _query: &Document<'doc, VT>, + _computer: &>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let count = output.extend(candidates.map(|n| (n.id, n.distance))); + std::future::ready(Ok(count)) + } +} + +impl<'doc, Inner, DP, VT> + SearchStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type QueryComputer = Inner::QueryComputer; + type PostProcessor = CopyIdsForDocument; + type SearchAccessorError = Inner::SearchAccessorError; + type SearchAccessor<'a> = DocumentSearchAccessor, VT>; + + fn search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } + + fn post_processor(&self) -> Self::PostProcessor { + CopyIdsForDocument + } +} + +impl<'doc, Inner, DP, VT> + InsertStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type PruneStrategy = DocumentPruneStrategy; + + fn prune_strategy(&self) -> Self::PruneStrategy { + DocumentPruneStrategy::new(self.inner.prune_strategy()) + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .insert_search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } +} + +#[derive(Clone, Copy)] +pub struct DocumentPruneStrategy { + inner: Inner, +} + +impl DocumentPruneStrategy { + pub fn new(inner: Inner) -> Self { + Self { inner } + } +} + +impl PruneStrategy>> + for DocumentPruneStrategy +where + DP: DataProvider, + Inner: PruneStrategy, +{ + type DistanceComputer = Inner::DistanceComputer; + type PruneAccessor<'a> = Inner::PruneAccessor<'a>; + type PruneAccessorError = Inner::PruneAccessorError; + + fn prune_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::PruneAccessorError> { + self.inner + .prune_accessor(provider.inner_provider(), context) + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index 6b496271b..1fabf5f54 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -77,7 +77,7 @@ impl<'a, VT, DP, AS> SetElement> for DocumentProvider where DP: DataProvider + Delete + SetElement, AS: AttributeStore + AsyncFriendly, - VT: Sync + Send, + VT: Sync + Send + ?Sized, { type SetError = ANNError; type Guard = >::Guard; diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index d56cb13c1..370ef25ae 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -5,8 +5,6 @@ use std::sync::{Arc, RwLock}; -use diskann::ANNResult; - use crate::{ encoded_attribute_provider::{ ast_id_expr::ASTIdExpr, ast_label_id_mapper::ASTLabelIdMapper, @@ -16,20 +14,21 @@ use crate::{ }; pub(crate) struct EncodedFilterExpr { - ast_id_expr: ASTIdExpr, + ast_id_expr: Option>, } impl EncodedFilterExpr { - pub fn new( - ast_expr: &ASTExpr, - attribute_map: Arc>, - ) -> ANNResult { + pub fn new(ast_expr: &ASTExpr, attribute_map: Arc>) -> Self { let mut mapper = ASTLabelIdMapper::new(attribute_map); - let ast_id_expr = ast_expr.accept(&mut mapper)?; - Ok(Self { ast_id_expr }) + match ast_expr.accept(&mut mapper) { + Ok(ast_id_expr) => Self { + ast_id_expr: Some(ast_id_expr), + }, + Err(_e) => Self { ast_id_expr: None }, + } } - pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { + pub(crate) fn encoded_filter_expr(&self) -> &Option> { &self.ast_id_expr } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index 6b82a68b1..c69589ba0 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -15,7 +15,7 @@ use diskann::{utils::VectorId, ANNError, ANNErrorKind, ANNResult}; use diskann_utils::future::AsyncFriendly; use std::sync::{Arc, RwLock}; -pub(crate) struct RoaringAttributeStore +pub struct RoaringAttributeStore where IT: VectorId + AsyncFriendly, { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 962d361d7..1def9a406 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -28,7 +28,7 @@ use crate::{ type AttrAccessor = EncodedAttributeAccessor::Id>>; -pub(crate) struct EncodedDocumentAccessor +pub struct EncodedDocumentAccessor where IA: HasId, { @@ -136,7 +136,7 @@ where Some(set) => Ok(set.into_owned()), None => Err(ANNError::message( ANNErrorKind::IndexError, - "No labels were found for vector", + format!("No labels were found for vector:{:?}", id), )), } })?; @@ -220,12 +220,20 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone())?; + let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone()); + let is_valid_filter = id_query.encoded_filter_expr().is_some(); + if !is_valid_filter { + tracing::warn!( + "Failed to convert {} into an id expr. This will now be an unfiltered search.", + from.filter_expr() + ); + } Ok(InlineBetaComputer::new( inner_computer, self.beta_value, id_query, + is_valid_filter, )) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b25b1746f..f03f36c12 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -28,6 +28,13 @@ pub struct InlineBetaStrategy { inner: Strategy, } +impl InlineBetaStrategy { + /// Create a new InlineBetaStrategy with the given beta value and inner strategy. + pub fn new(beta: f32, inner: Strategy) -> Self { + Self { beta, inner } + } +} + impl SearchStrategy>, FilteredQuery> for InlineBetaStrategy @@ -72,6 +79,7 @@ pub struct InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, //optimization to avoid evaluating empty predicates. } impl InlineBetaComputer { @@ -79,17 +87,23 @@ impl InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, ) -> Self { Self { inner_computer, beta_value, filter_expr, + is_valid_filter, } } pub(crate) fn filter_expr(&self) -> &EncodedFilterExpr { &self.filter_expr } + + pub(crate) fn is_valid_filter(&self) -> bool { + self.is_valid_filter + } } impl PreprocessedDistanceFunction, f32> @@ -101,22 +115,35 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim + if self.is_valid_filter { + match self + .filter_expr + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pred_eval) + { + Ok(matched) => { + if matched { + return sim * self.beta_value; + } else { + return sim; + } + } + Err(_) => { + //If predicate evaluation fails for any reason, we simply revert + //to unfiltered search. + tracing::warn!("Predicate evaluation failed"); + return sim; } } - Err(_) => { - //TODO: If predicate evaluation fails, we are taking the approach that we will simply - //return the score returned by the inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim - } + } else { + //If predicate evaluation fails, we will return the score returned by the + //inner computer, as though no predicate was specified. + tracing::warn!( + "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" + ); + sim } } } @@ -155,8 +182,16 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.filter_expr().encoded_filter_expr().accept(&pe)? { - filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + if computer.is_valid_filter() { + if computer + .filter_expr() + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pe)? + { + filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + } } } diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 106845f98..273475b15 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -40,6 +40,7 @@ pub mod encoded_attribute_provider { pub(crate) mod ast_id_expr; pub(crate) mod ast_label_id_mapper; pub(crate) mod attribute_encoder; + pub mod document_insert_strategy; pub mod document_provider; pub mod encoded_attribute_accessor; pub(crate) mod encoded_filter_expr; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index 5e9e3a9c1..c042d8338 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,10 +15,8 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, - } - /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e74419a46..9a48488fe 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -1,580 +1,638 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{collections::HashMap, fmt::Debug, future::Future}; - -use diskann::{ - ANNError, ANNResult, - graph::{ - SearchOutputBuffer, - glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, - }, - utils::{IntoUsize, VectorRepr}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; - -use crate::model::graph::{ - provider::async_::{ - FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, - common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, - }, - inmem::DefaultProvider, - postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, - }, - traits::AdHoc, -}; - -/// A type alias for the DefaultProvider with full-precision as the primary vector store. -pub type FullPrecisionProvider = - DefaultProvider, Q, D, Ctx>; - -/// The default full-precision vector store. -pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; - -/// A default full-precision vector store provider. -#[derive(Clone)] -pub struct CreateFullPrecision { - dim: usize, - prefetch_cache_line_level: Option, - _phantom: std::marker::PhantomData, -} - -impl CreateFullPrecision -where - T: VectorRepr, -{ - /// Create a new full-precision vector store provider. - pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { - Self { - dim, - prefetch_cache_line_level, - _phantom: std::marker::PhantomData, - } - } -} - -impl CreateVectorStore for CreateFullPrecision -where - T: VectorRepr, -{ - type Target = FullPrecisionStore; - fn create( - self, - max_points: usize, - metric: Metric, - prefetch_lookahead: Option, - ) -> Self::Target { - FullPrecisionStore::new( - max_points, - self.dim, - metric, - self.prefetch_cache_line_level, - prefetch_lookahead, - ) - } -} - -//////////////// -// SetElement // -//////////////// - -impl SetElementHelper for FullPrecisionStore -where - T: VectorRepr, -{ - /// Set the element at the given index. - fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { - unsafe { self.set_vector_sync(id.into_usize(), element) } - } -} - -////////////////// -// FullAccessor // -////////////////// - -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, -{ - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, -} - -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } -} - -impl HasId for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Id = u32; -} - -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } - } -} - -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The extended element inherets the lifetime of the Accessor. - type Extended = &'a [T]; - - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - /// - /// NOTE: We intentionally don't use `'b` here since our implementation borrows - /// the inner `Opaque` from the underlying provider. - type Element<'b> - = &'a [T] - where - Self: 'b; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'b> = &'b [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} - -impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} - -impl FillSet for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - for i in itr { - set.entry(i).or_insert_with(|| unsafe { - self.provider.base_vectors.get_vector_sync(i.into_usize()) - }); - } - Ok(()) - } -} - -//-------------------// -// In-mem Extensions // -//-------------------// - -impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted - } -} - -////////////////// -// Post Process // -////////////////// - -pub trait GetFullPrecision { - type Repr: VectorRepr; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; -} - -/// A [`SearchPostProcess`]or that: -/// -/// 1. Filters out deleted ids from being returned. -/// 2. Reranks a candidate stream using full-precision distances. -/// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; - -impl glue::SearchPostProcess for Rerank -where - T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, -{ - type Error = Panics; - - fn post_process( - &self, - accessor: &mut A, - query: &[T], - _computer: &A::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, - { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) - } -} - -//////////////// -// Strategies // -//////////////// - -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -// Pruning -impl PruneStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputer = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -/// Implementing this trait allows `FullPrecision` to be used for multi-insert. -impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = diskann::error::Infallible; - fn as_element( - &mut self, - vector: &'a [T], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl InsertStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type PruneStrategy = Self; - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } -} - -// Inplace Delete // -impl InplaceDeleteStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type DeleteElementError = Panics; - type DeleteElement<'a> = [T]; - type DeleteElementGuard = Box<[T]>; - type PruneStrategy = Self; - type SearchStrategy = Internal; - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) - } - - fn prune_strategy(&self) -> Self::PruneStrategy { - Self - } - - async fn get_delete_element<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - id: u32, - ) -> Result { - Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashMap, fmt::Debug, future::Future}; + +use diskann::{ + ANNError, ANNResult, + graph::{ + SearchOutputBuffer, + glue::{ + self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, + }, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasId, + }, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::{DistanceFunction, distance::Metric}; + +use crate::model::graph::{ + provider::async_::{ + FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, + common::{ + CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, + PrefetchCacheLineLevel, SetElementHelper, + }, + inmem::DefaultProvider, + postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, + }, + traits::AdHoc, +}; + +/// A type alias for the DefaultProvider with full-precision as the primary vector store. +pub type FullPrecisionProvider = + DefaultProvider, Q, D, Ctx>; + +/// The default full-precision vector store. +pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; + +/// A default full-precision vector store provider. +#[derive(Clone)] +pub struct CreateFullPrecision { + dim: usize, + prefetch_cache_line_level: Option, + _phantom: std::marker::PhantomData, +} + +impl CreateFullPrecision +where + T: VectorRepr, +{ + /// Create a new full-precision vector store provider. + pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { + Self { + dim, + prefetch_cache_line_level, + _phantom: std::marker::PhantomData, + } + } +} + +impl CreateVectorStore for CreateFullPrecision +where + T: VectorRepr, +{ + type Target = FullPrecisionStore; + fn create( + self, + max_points: usize, + metric: Metric, + prefetch_lookahead: Option, + ) -> Self::Target { + FullPrecisionStore::new( + max_points, + self.dim, + metric, + self.prefetch_cache_line_level, + prefetch_lookahead, + ) + } +} + +//////////////// +// SetElement // +//////////////// + +impl SetElementHelper for FullPrecisionStore +where + T: VectorRepr, +{ + /// Set the element at the given index. + fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { + unsafe { self.set_vector_sync(id.into_usize(), element) } + } +} + +////////////////// +// FullAccessor // +////////////////// + +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +/// * [`BuildQueryComputer`]. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, + + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: Vec, +} + +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors + } +} + +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } +} + +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + pub fn new(provider: &'a FullPrecisionProvider) -> Self { + Self { + provider, + id_buffer: Vec::new(), + } + } +} + +impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.provider.neighbors() + } +} + +impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + /// The extended element inherets the lifetime of the Accessor. + type Extended = &'a [T]; + + /// This accessor returns raw slices. There *is* a chance of racing when the fast + /// providers are used. We just have to live with it. + /// + /// NOTE: We intentionally don't use `'b` here since our implementation borrows + /// the inner `Opaque` from the underlying provider. + type Element<'b> + = &'a [T] + where + Self: 'b; + + /// `ElementRef` has an arbitrarily short lifetime. + type ElementRef<'b> = &'b [T]; + + /// Choose to panic on an out-of-bounds access rather than propagate an error. + type GetError = Panics; + + /// Return the full-precision vector stored at index `i`. + /// + /// This function always completes synchronously. + #[inline(always)] + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // potentially mixing unsynchronized reads and writes on the underlying memory. + std::future::ready(Ok(unsafe { + self.provider.base_vectors.get_vector_sync(id.into_usize()) + })) + } + + /// Perform a bulk operation. + /// + /// This implementation uses prefetching. + fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> impl Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(itr); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + f( + unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + *id, + ) + } + + std::future::ready(Ok(())) + } +} + +impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.provider.metric, + Some(self.provider.base_vectors.dim()), + )) + } +} + +impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &[T], + ) -> Result { + Ok(T::query_distance(from, self.provider.metric)) + } +} + +impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +/// Support for Vec queries that delegates to the [T] impl via deref. +/// This allows InlineBetaStrategy to use Vec queries with FullAccessor. +impl BuildQueryComputer> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &Vec, + ) -> Result { + // Delegate to [T] impl via deref + Ok(T::query_distance(from.as_slice(), self.provider.metric)) + } +} + +/// Support for Vec queries that delegates to the [T] impl. +impl ExpandBeam> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +impl FillSet for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + async fn fill_set( + &mut self, + set: &mut HashMap, + itr: Itr, + ) -> Result<(), Self::GetError> + where + Itr: Iterator + Send + Sync, + { + for i in itr { + set.entry(i).or_insert_with(|| unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + } + Ok(()) + } +} + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted + } +} + +////////////////// +// Post Process // +////////////////// + +pub trait GetFullPrecision { + type Repr: VectorRepr; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; +} + +/// A [`SearchPostProcess`]or that: +/// +/// 1. Filters out deleted ids from being returned. +/// 2. Reranks a candidate stream using full-precision distances. +/// 3. Copies back the results to the output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl glue::SearchPostProcess for Rerank +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = Panics; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator>, + B: SearchOutputBuffer + ?Sized, + { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let f = full.distance(); + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. + std::future::ready(Ok(output.extend(reranked))) + } +} + +//////////////// +// Strategies // +//////////////// + +// A layered approach is used for search strategies. The `Internal` version does the heavy +// lifting in terms of establishing accessors and post processing. +// +// However, during post-processing, the `Internal` versions of strategies will not filter +// out the start points. The publicly exposed types *will* filter out the start points. +// +// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust +// the adjacency list for the start point to reuse the `Internal` strategies. + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = RemoveDeletedIdsAndCopy; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Support for Vec queries that delegates to the [T] impl. +/// This allows InlineBetaStrategy to use Vec queries with FullPrecision. +impl SearchStrategy, Vec> for FullPrecision +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// Pruning +impl PruneStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputer = T::Distance; + type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::PruneAccessorError> { + Ok(FullAccessor::new(provider)) + } +} + +/// Implementing this trait allows `FullPrecision` to be used for multi-insert. +impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [T], + _id: Self::Id, + ) -> impl Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +impl InsertStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +// Inplace Delete // +impl InplaceDeleteStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type DeleteElementError = Panics; + type DeleteElement<'a> = [T]; + type DeleteElementGuard = Box<[T]>; + type PruneStrategy = Self; + type SearchStrategy = Internal; + fn search_strategy(&self) -> Self::SearchStrategy { + Internal(Self) + } + + fn prune_strategy(&self) -> Self::PruneStrategy { + Self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + id: u32, + ) -> Result { + Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) + } +} diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 31e69b2b2..8c2fa29f6 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -32,14 +32,14 @@ use crate::utils::{search_index_utils, CMDResult, CMDToolError}; /// Expands a JSON object with array-valued fields into multiple objects with scalar values. /// For example: {"country": ["AU", "NZ"], "year": 2007} /// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] -/// +/// /// If multiple fields have arrays, all combinations are generated. fn expand_array_fields(value: &Value) -> Vec { match value { Value::Object(map) => { // Start with a single empty object let mut results: Vec> = vec![Map::new()]; - + for (key, val) in map.iter() { if let Value::Array(arr) = val { // Expand: for each existing result, create copies for each array element @@ -62,7 +62,7 @@ fn expand_array_fields(value: &Value) -> Vec { } } } - + results.into_iter().map(Value::Object).collect() } // If not an object, return as-is @@ -74,7 +74,9 @@ fn expand_array_fields(value: &Value) -> Vec { /// Returns true if any expanded variant matches the query. fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { let expanded = expand_array_fields(label); - expanded.iter().any(|item| eval_query_expr(query_expr, item)) + expanded + .iter() + .any(|item| eval_query_expr(query_expr, item)) } pub fn read_labels_and_compute_bitmap( @@ -127,11 +129,13 @@ pub fn read_labels_and_compute_bitmap( // Handle case where base_label.label is an array - check if any element matches // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) let matches = if let Some(array) = base_label.label.as_array() { - array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + array + .iter() + .any(|item| eval_query_with_array_expansion(query_expr, item)) } else { eval_query_with_array_expansion(query_expr, &base_label.label) }; - + if matches { bitmap.insert(base_label.doc_id); } @@ -164,11 +168,17 @@ pub fn read_labels_and_compute_bitmap( // If no matches, print more diagnostic info if total_matches == 0 { tracing::warn!("WARNING: No base vectors matched any query filters!"); - tracing::warn!("This could indicate a format mismatch between base labels and query filters."); - + tracing::warn!( + "This could indicate a format mismatch between base labels and query filters." + ); + // Try to identify what keys exist in base labels vs queries if let Some(first_label) = base_labels.first() { - tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + tracing::warn!( + "First base label (full): doc_id={}, label={}", + first_label.doc_id, + first_label.label + ); } } @@ -323,7 +333,7 @@ pub fn compute_ground_truth_from_datafiles< for (query_idx, npq) in ground_truth.iter().enumerate() { let neighbors: Vec<_> = npq.iter().collect(); let neighbor_count = neighbors.len(); - + if query_idx < 10 { // Print top K IDs and distances for first 10 queries let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); @@ -336,7 +346,7 @@ pub fn compute_ground_truth_from_datafiles< top_dists ); } - + if neighbor_count == 0 { tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); } @@ -344,7 +354,10 @@ pub fn compute_ground_truth_from_datafiles< // Summary stats let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); - let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + let queries_with_neighbors = ground_truth + .iter() + .filter(|npq| npq.iter().count() > 0) + .count(); tracing::info!( "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", total_neighbors, diff --git a/test_data/disk_index_search/data.256.label.jsonl b/test_data/disk_index_search/data.256.label.jsonl index 83254af7b..a99cde8e2 100644 --- a/test_data/disk_index_search/data.256.label.jsonl +++ b/test_data/disk_index_search/data.256.label.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f8b6b99ca32173557689712d3fb5da30c5e4111130fd2accbccf32f5ce3e47e -size 17702 +oid sha256:92576896b10780a2cd80a16030f8384610498b76453f57fadeacb854379e0acf +size 17701