diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index ad07f2c82..427169044 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -5,7 +5,7 @@ //! A built-in helper for benchmarking K-nearest neighbors. -use std::{num::NonZeroUsize, sync::Arc}; +use std::sync::Arc; use diskann::{ ANNResult, @@ -29,7 +29,7 @@ use crate::{ /// the latter. Result aggregation for [`search::search_all`] is provided /// by the [`Aggregator`] type. /// -/// The provided implementation of [`Search`] accepts [`graph::SearchParams`] +/// The provided implementation of [`Search`] accepts [`graph::KnnSearch`] /// and returns [`Metrics`] as additional output. #[derive(Debug)] pub struct KNN @@ -92,7 +92,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::SearchParams; + type Parameters = graph::KnnSearch; type Output = Metrics; fn num_queries(&self) -> usize { @@ -100,7 +100,7 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE)) + search::IdCount::Fixed(parameters.k_value()) } async fn search( @@ -113,13 +113,14 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); + let mut knn_search = *parameters; let stats = self .index .search( + &mut knn_search, self.strategy.get(index)?, &context, self.queries.row(index), - parameters, buffer, ) .await?; @@ -142,7 +143,7 @@ pub struct Summary { pub setup: search::Setup, /// The [`Search::Parameters`] used for the batch of runs. - pub parameters: graph::SearchParams, + pub parameters: graph::KnnSearch, /// The end-to-end latency for each repetition in the batch. pub end_to_end_latencies: Vec, @@ -207,7 +208,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl search::Aggregate for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -215,7 +216,7 @@ where fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result { // Compute the recall using just the first result. @@ -280,13 +281,15 @@ where #[cfg(test)] mod tests { + use std::num::NonZeroUsize; + use super::*; use diskann::graph::test::provider; #[test] fn test_knn() { - let nearest_neighbors = 5; + let nearest_neighbors = NonZeroUsize::new(5).unwrap(); let index = search::graph::test_grid_provider(); @@ -310,7 +313,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( knn.clone(), - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -321,7 +324,7 @@ mod tests { assert_eq!(*rows.row(0).first().unwrap(), 0); for r in 0..rows.nrows() { - assert_eq!(rows.row(r).len(), nearest_neighbors); + assert_eq!(rows.row(r).len(), nearest_neighbors.get()); } const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); @@ -334,17 +337,17 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), setup.clone(), ), search::Run::new( - graph::SearchParams::new(nearest_neighbors, 15, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 15, None).unwrap(), setup.clone(), ), ]; - let recall_k = nearest_neighbors; - let recall_n = nearest_neighbors; + let recall_k = nearest_neighbors.get(); + let recall_n = nearest_neighbors.get(); let all = search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap(); diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 6dfb646bb..245f1717b 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{num::NonZeroUsize, sync::Arc}; +use std::sync::Arc; use diskann::{ ANNResult, @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy}; /// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same /// aggregator as [`search::graph::KNN`]). /// -/// The provided implementation of [`Search`] accepts [`graph::SearchParams`] +/// The provided implementation of [`Search`] accepts [`graph::KnnSearch`] /// and returns [`search::graph::knn::Metrics`] as additional output. #[derive(Debug)] pub struct MultiHop @@ -90,7 +90,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::SearchParams; + type Parameters = graph::KnnSearch; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -98,7 +98,7 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE)) + search::IdCount::Fixed(parameters.k_value()) } async fn search( @@ -111,15 +111,15 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); + let mut multihop_search = graph::MultihopSearch::new(*parameters, &*self.labels[index]); let stats = self .index - .multihop_search( + .search( + &mut multihop_search, self.strategy.get(index)?, &context, self.queries.row(index), - parameters, buffer, - &*self.labels[index], ) .await?; @@ -136,6 +136,8 @@ where #[cfg(test)] mod tests { + use std::num::NonZeroUsize; + use super::*; use diskann::graph::{index::QueryLabelProvider, test::provider}; @@ -152,7 +154,7 @@ mod tests { #[test] fn test_multihop() { - let nearest_neighbors = 5; + let nearest_neighbors = NonZeroUsize::new(5).unwrap(); let index = search::graph::test_grid_provider(); @@ -179,7 +181,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( multihop.clone(), - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -191,7 +193,7 @@ mod tests { // Check that only even IDs are returned. for r in 0..rows.nrows() { - assert_eq!(rows.row(r).len(), nearest_neighbors); + assert_eq!(rows.row(r).len(), nearest_neighbors.get()); for &id in rows.row(r) { assert_eq!(id % 2, 0, "Found odd ID {} in row {}", id, r); } @@ -207,17 +209,17 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), setup.clone(), ), search::Run::new( - graph::SearchParams::new(nearest_neighbors, 15, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 15, None).unwrap(), setup.clone(), ), ]; - let recall_k = nearest_neighbors; - let recall_n = nearest_neighbors; + let recall_k = nearest_neighbors.get(); + let recall_n = nearest_neighbors.get(); let all = search::search_all( multihop, diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index a5669ae25..e66fef5ec 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -27,7 +27,7 @@ use crate::{ /// by the [`Aggregator`] type. /// /// The provided implementation of [`Search`] accepts -/// [`graph::RangeSearchParams`] and returns [`Metrics`] as additional output. +/// [`graph::RangeSearch`] and returns [`Metrics`] as additional output. #[derive(Debug)] pub struct Range where @@ -83,7 +83,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::RangeSearchParams; + type Parameters = graph::RangeSearch; type Output = Metrics; fn num_queries(&self) -> usize { @@ -91,7 +91,7 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l_value)) + search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l())) } async fn search( @@ -104,16 +104,21 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let (_, ids, distances) = self + let mut range_search = *parameters; + let result = self .index - .range_search( + .search( + &mut range_search, self.strategy.get(index)?, &context, self.queries.row(index), - parameters, + &mut (), ) .await?; - buffer.extend(std::iter::zip(ids.into_iter(), distances.into_iter())); + buffer.extend(std::iter::zip( + result.ids.into_iter(), + result.distances.into_iter(), + )); Ok(Metrics {}) } @@ -129,8 +134,8 @@ pub struct Summary { /// The [`search::Setup`] used for the batch of runs. pub setup: search::Setup, - /// The [`graph::RangeSearchParams`] used for the batch of runs. - pub parameters: graph::RangeSearchParams, + /// The [`graph::RangeSearch`] used for the batch of runs. + pub parameters: graph::RangeSearch, /// The end-to-end latency for each repetition in the batch. pub end_to_end_latencies: Vec, @@ -174,7 +179,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl search::Aggregate for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -183,7 +188,7 @@ where #[inline(never)] fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result { // Compute the recall using just the first result. @@ -261,7 +266,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( range.clone(), - graph::RangeSearchParams::new(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), + graph::RangeSearch::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -280,11 +285,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::RangeSearchParams::new(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), + graph::RangeSearch::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), setup.clone(), ), search::Run::new( - graph::RangeSearchParams::new(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(), + graph::RangeSearch::with_options(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index c7e2ab75c..1d6102f9b 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -143,8 +143,8 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), - search_n: parameters.k_value, - search_l: parameters.l_value, + search_n: parameters.k_value().get(), + search_l: parameters.l_value().get(), qps, search_latencies: end_to_end_latencies, mean_latencies, @@ -284,7 +284,7 @@ impl RangeSearchResults { Self { num_tasks: setup.tasks.into(), - initial_l: parameters.starting_l_value, + initial_l: parameters.starting_l(), qps, search_latencies: end_to_end_latencies, mean_latencies, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 723d32155..2ad3a8dd5 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -49,8 +49,9 @@ pub(crate) fn run( .search_l .iter() .map(|search_l| { - let search_params = - diskann::graph::SearchParams::new(run.search_n, *search_l, None).unwrap(); + let k = run.search_n; + let l = *search_l; + let search_params = diskann::graph::KnnSearch::new(k, l, None).unwrap(); core_search::Run::new(search_params, setup.clone()) }) @@ -63,7 +64,7 @@ pub(crate) fn run( Ok(all) } -type Run = core_search::Run; +type Run = core_search::Run; pub(crate) trait Knn { fn search_all( &self, @@ -83,13 +84,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::KNN: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::SearchParams, + Parameters = diskann::graph::KnnSearch, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -109,13 +110,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::MultiHop: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::SearchParams, + Parameters = diskann::graph::KnnSearch, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, diff --git a/diskann-benchmark/src/backend/index/search/range.rs b/diskann-benchmark/src/backend/index/search/range.rs index 6ed6dc25f..d78a66bc8 100644 --- a/diskann-benchmark/src/backend/index/search/range.rs +++ b/diskann-benchmark/src/backend/index/search/range.rs @@ -30,7 +30,7 @@ impl<'a> RangeSearchSteps<'a> { } } -type Run = core_search::Run; +type Run = core_search::Run; pub(crate) trait Range { fn search_all( @@ -79,13 +79,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::Range: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::RangeSearchParams, + Parameters = diskann::graph::RangeSearch, Output = core_search::graph::range::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, ) -> anyhow::Result> { let results = core_search::search_all( diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index e12c26419..d6fe09c45 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -8,7 +8,7 @@ use std::num::{NonZeroU32, NonZeroUsize}; use anyhow::{anyhow, Context}; use diskann::{ - graph::{self, config, RangeSearchParams, RangeSearchParamsError, StartPointStrategy}, + graph::{self, config, RangeSearch, RangeSearchError, StartPointStrategy}, utils::IntoUsize, }; use diskann_benchmark_core::streaming::executors::bigann; @@ -90,13 +90,11 @@ pub(crate) struct GraphRangeSearch { } impl GraphRangeSearch { - pub(crate) fn construct_params( - &self, - ) -> Result, RangeSearchParamsError> { + pub(crate) fn construct_params(&self) -> Result, RangeSearchError> { self.initial_search_l .iter() .map(|&l| { - RangeSearchParams::new( + RangeSearch::with_options( self.max_returned, l, self.beam_width, @@ -111,7 +109,7 @@ impl GraphRangeSearch { } impl CheckDeserialization for GraphRangeSearch { - // all necessary checks are carried out when RangeSearchParams is initialized + // all necessary checks are carried out when RangeSearch is initialized fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { self.construct_params() .context("invalid range search params")?; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index ab0a4f4e7..8f32bf3a9 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -19,7 +19,7 @@ use diskann::{ graph::{ self, glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy}, - search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, SearchParams, + search_output_buffer, AdjacencyList, DiskANNIndex, KnnSearch, SearchOutputBuffer, }, neighbor::Neighbor, provider::{ @@ -983,21 +983,24 @@ where let strategy = self.search_strategy(query, vector_filter); let timer = Instant::now(); + let k = k_value; + let l = search_list_size as usize; let stats = if is_flat_search { self.runtime.block_on(self.index.flat_search( &strategy, &DefaultContext, strategy.query, vector_filter, - &SearchParams::new(k_value, search_list_size as usize, beam_width)?, + &KnnSearch::new(k, l, beam_width)?, &mut result_output_buffer, ))? } else { + let mut knn_search = KnnSearch::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( + &mut knn_search, &strategy, &DefaultContext, strategy.query, - &SearchParams::new(k_value, search_list_size as usize, beam_width)?, &mut result_output_buffer, ))? }; @@ -1040,7 +1043,7 @@ fn ensure_vertex_loaded>( #[cfg(test)] mod disk_provider_tests { use diskann::{ - graph::{search::record::VisitedSearchRecord, SearchParamsError}, + graph::{search::record::VisitedSearchRecord, KnnSearch, KnnSearchError}, utils::IntoUsize, ANNErrorKind, }; @@ -1529,17 +1532,15 @@ mod disk_provider_tests { "index_path is not correct" ); - let res = SearchParams::new_default(0, 10); + // Test error case: l < k + let res = KnnSearch::new_default(20, 10); assert!(res.is_err()); assert_eq!( - >::into(res.unwrap_err()).kind(), + >::into(res.unwrap_err()).kind(), ANNErrorKind::IndexError ); - let res = SearchParams::new_default(20, 10); - assert!(res.is_err()); - let res = SearchParams::new_default(10, 0); - assert!(res.is_err()); - let res = SearchParams::new(10, 10, Some(0)); + // Test error case: beam_width = 0 + let res = KnnSearch::new(10, 10, Some(0)); assert!(res.is_err()); let search_engine = @@ -1624,15 +1625,17 @@ mod disk_provider_tests { ); let strategy = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record = VisitedSearchRecord::new(0); + let search_params = KnnSearch::new(10, 10, Some(4)).unwrap(); + let mut recorded_search = + diskann::graph::search::RecordedKnnSearch::new(search_params, &mut search_record); search_engine .runtime - .block_on(search_engine.index.search_recorded( + .block_on(search_engine.index.search( + &mut recorded_search, &strategy, &DefaultContext, - &query_vector, - &SearchParams::new(10, 10, Some(4)).unwrap(), + query_vector.as_slice(), &mut result_output_buffer, - &mut search_record, )) .unwrap(); @@ -1743,7 +1746,6 @@ mod disk_provider_tests { &mut associated_data, ); let strategy = search_engine.search_strategy(&query_vector, &|_| true); - let mut search_record = VisitedSearchRecord::new(0); // Create diverse search parameters with attribute provider let diverse_params = DiverseSearchParams::new( @@ -1752,31 +1754,24 @@ mod disk_provider_tests { attribute_provider.clone(), ); - let search_params = SearchParams::new(10, 20, None).unwrap(); + let search_params = KnnSearch::new(10, 20, None).unwrap(); - search_engine + let mut diverse_search = diskann::graph::DiverseSearch::new(search_params, diverse_params); + let stats = search_engine .runtime - .block_on(search_engine.index.diverse_search_experimental( + .block_on(search_engine.index.search( + &mut diverse_search, &strategy, &DefaultContext, - &query_vector, - &search_params, - &diverse_params, + query_vector.as_slice(), &mut result_output_buffer, - &mut search_record, )) .unwrap(); - let ids = search_record - .visited - .iter() - .map(|n| n.id) - .collect::>(); - - // Verify that search was performed and visited some nodes + // Verify that search was performed and returned some results assert!( - !ids.is_empty(), - "Expected to visit some nodes during diversity search" + stats.result_count > 0, + "Expected to get some results during diversity search" ); let return_list_size = 10; @@ -1788,7 +1783,7 @@ mod disk_provider_tests { attribute_provider.clone(), ); - // Test diverse search using the experimental API + // Test diverse search using the search API let mut indices2 = vec![0u32; return_list_size as usize]; let mut distances2 = vec![0f32; return_list_size as usize]; let mut associated_data2 = vec![(); return_list_size as usize]; @@ -1798,20 +1793,19 @@ mod disk_provider_tests { &mut associated_data2, ); let strategy2 = search_engine.search_strategy(&query_vector, &|_| true); - let mut search_record2 = VisitedSearchRecord::new(0); let search_params2 = - SearchParams::new(return_list_size as usize, search_list_size as usize, None).unwrap(); + KnnSearch::new(return_list_size as usize, search_list_size as usize, None).unwrap(); + let mut diverse_search2 = + diskann::graph::DiverseSearch::new(search_params2, diverse_params); let stats = search_engine .runtime - .block_on(search_engine.index.diverse_search_experimental( + .block_on(search_engine.index.search( + &mut diverse_search2, &strategy2, &DefaultContext, - &query_vector, - &search_params2, - &diverse_params, + query_vector.as_slice(), &mut result_output_buffer2, - &mut search_record2, )) .unwrap(); @@ -2086,15 +2080,17 @@ mod disk_provider_tests { let strategy = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record = VisitedSearchRecord::new(0); + let search_params = KnnSearch::new(10, 10, Some(4)).unwrap(); + let mut recorded_search = + diskann::graph::search::RecordedKnnSearch::new(search_params, &mut search_record); search_engine .runtime - .block_on(search_engine.index.search_recorded( + .block_on(search_engine.index.search( + &mut recorded_search, &strategy, &DefaultContext, - &query_vector, - &SearchParams::new(10, 10, Some(4)).unwrap(), + query_vector.as_slice(), &mut result_output_buffer, - &mut search_record, )) .unwrap(); let visited_ids = search_record diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 5451af3ad..e47371ee1 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -173,8 +173,8 @@ pub(crate) mod tests { use crate::storage::VirtualStorageProvider; use diskann::{ graph::{ - self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, RangeSearchParams, - SearchParams, StartPointStrategy, + self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, KnnSearch, RangeSearch, + StartPointStrategy, config::IntraBatchCandidates, glue::{AsElement, InplaceDeleteStrategy, InsertStrategy, SearchStrategy, aliases}, index::{PartitionedNeighbors, QueryLabelProvider, QueryVisitDecision}, @@ -215,6 +215,8 @@ pub(crate) mod tests { // Callbacks for use with `simplified_builder`. fn no_modify(_: &mut diskann::graph::config::Builder) {} + ////////////////////////// + // Test helper functions // ///////////////////////////////////////// // Tests from the original async index // ///////////////////////////////////////// @@ -354,12 +356,14 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = + graph::KnnSearch::new_default(parameters.search_k, parameters.search_l).unwrap(); index .search( + &mut graph_search, &strategy, ¶meters.context, query, - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), &mut result_output_buffer, ) .await @@ -400,14 +404,16 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let search_params = + KnnSearch::new_default(parameters.search_k, parameters.search_l).unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, filter); index - .multihop_search( + .search( + &mut multihop, strategy, ¶meters.context, query, - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), &mut result_output_buffer, - filter, ) .await .unwrap(); @@ -1443,14 +1449,16 @@ pub(crate) mod tests { let filter = CallbackFilter::new(blocked, adjusted, 0.5); + let search_params = + KnnSearch::new_default(parameters.search_k, parameters.search_l).unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index - .multihop_search( + .search( + &mut multihop, &FullPrecision, ¶meters.context, query.as_slice(), - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), &mut result_output_buffer, - &filter, ) .await .unwrap(); @@ -2190,13 +2198,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -2207,13 +2216,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( + &mut graph_search, &Hybrid::new(None), ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -2272,76 +2282,60 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); { // Full Precision Search. - let (_, ids, _) = index - .range_search( - &FullPrecision, - ctx, - query, - &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), - ) + let mut range_search = RangeSearch::new(starting_l_value, radius).unwrap(); + let result = index + .search(&mut range_search, &FullPrecision, ctx, query, &mut ()) .await .unwrap(); - assert_range_results_exactly_match(q, >, &ids, radius, None); + assert_range_results_exactly_match(q, >, &result.ids, radius, None); } { // Quantized Search - let (_, ids, _) = index - .range_search( - &Hybrid::new(None), - ctx, - query, - &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), - ) + let mut range_search = RangeSearch::new(starting_l_value, radius).unwrap(); + let result = index + .search(&mut range_search, &Hybrid::new(None), ctx, query, &mut ()) .await .unwrap(); - assert_range_results_exactly_match(q, >, &ids, radius, None); + assert_range_results_exactly_match(q, >, &result.ids, radius, None); } { // Test with an inner radius assert!(inner_radius <= radius); - let (_, ids, _) = index - .range_search( - &FullPrecision, - ctx, - query, - &RangeSearchParams::new( - None, - starting_l_value, - None, - radius, - Some(inner_radius), - 1.0, - 1.0, - ) - .unwrap(), - ) + let mut range_search = RangeSearch::with_options( + None, + starting_l_value, + None, + radius, + Some(inner_radius), + 1.0, + 1.0, + ) + .unwrap(); + let result = index + .search(&mut range_search, &FullPrecision, ctx, query, &mut ()) .await .unwrap(); - assert_range_results_exactly_match(q, >, &ids, radius, Some(inner_radius)); + assert_range_results_exactly_match(q, >, &result.ids, radius, Some(inner_radius)); } { // Test with a lower initial beam to trigger more two-round searches // We don't expect results to exactly match here - let (_, ids, _) = index - .range_search( - &FullPrecision, - ctx, - query, - &RangeSearchParams::new_default(lower_l_value, radius).unwrap(), - ) + let mut range_search = RangeSearch::new(lower_l_value, radius).unwrap(); + let result = index + .search(&mut range_search, &FullPrecision, ctx, query, &mut ()) .await .unwrap(); // check that ids don't have duplicates let mut ids_set = std::collections::HashSet::new(); - for id in &ids { + for id in &result.ids { assert!(ids_set.insert(*id)); } } @@ -2456,13 +2450,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = + graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -2473,13 +2469,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = + graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( + &mut graph_search, &Quantized, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -2559,13 +2557,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, top_k).unwrap(); // Quantized Search index .search( + &mut graph_search, &Quantized, ctx, query, - &SearchParams::new_default(top_k, top_k).unwrap(), &mut result_output_buffer, ) .await @@ -2672,14 +2671,9 @@ pub(crate) mod tests { // Full Precision Search. let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); index - .search( - &FullPrecision, - ctx, - query, - &SearchParams::new_default(top_k, search_l).unwrap(), - &mut output, - ) + .search(&mut graph_search, &FullPrecision, ctx, query, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2689,15 +2683,10 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); index - .search( - &strategy, - ctx, - query, - &SearchParams::new_default(top_k, search_l).unwrap(), - &mut output, - ) + .search(&mut graph_search, &strategy, ctx, query, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2797,15 +2786,10 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); index - .search( - &strategy, - ctx, - query, - &SearchParams::new_default(top_k, search_l).unwrap(), - &mut output, - ) + .search(&mut graph_search, &strategy, ctx, query, &mut output) .await .unwrap(); @@ -2889,13 +2873,14 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &Quantized, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -3469,13 +3454,14 @@ pub(crate) mod tests { let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &Hybrid::new(max_fp_vecs_per_prune), ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -3615,13 +3601,14 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), &mut result_output_buffer, ) .await @@ -3883,25 +3870,22 @@ pub(crate) mod tests { attribute_provider.clone(), ); - let search_params = diskann::graph::SearchParams::new( + let search_params = diskann::graph::KnnSearch::new( return_list_size, search_list_size, None, // beam_width ) .unwrap(); - use diskann::graph::search::record::NoopSearchRecord; - let mut search_record = NoopSearchRecord::new(); + let mut diverse_search = diskann::graph::DiverseSearch::new(search_params, diverse_params); let result = index - .diverse_search_experimental( + .search( + &mut diverse_search, &FullPrecision, &DefaultContext, - &query, - &search_params, - &diverse_params, + query.as_slice(), &mut result_output_buffer, - &mut search_record, ) .await; @@ -4104,14 +4088,15 @@ pub(crate) mod tests { // but reject everything via on_visit let filter = RejectAllFilter::only([0_u32]); + let search_params = KnnSearch::new_default(10, 20).unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index - .multihop_search( + .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), &mut result_output_buffer, - &filter, ) .await .unwrap(); @@ -4166,14 +4151,15 @@ pub(crate) mod tests { let target = (num_points / 2) as u32; let filter = TerminatingFilter::new(target); + let search_params = KnnSearch::new_default(10, 40).unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index - .multihop_search( + .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 40).unwrap(), &mut result_output_buffer, - &filter, ) .await .unwrap(); @@ -4230,14 +4216,15 @@ pub(crate) mod tests { let mut baseline_buffer = search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); + let search_params = KnnSearch::new_default(10, 20).unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &EvenFilter); let baseline_stats = index - .multihop_search( + .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), &mut baseline_buffer, - &EvenFilter, // Just filter to even IDs ) .await .unwrap(); @@ -4251,14 +4238,15 @@ pub(crate) mod tests { let mut adjusted_buffer = search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); + let search_params = KnnSearch::new_default(10, 20).unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); let adjusted_stats = index - .multihop_search( + .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), &mut adjusted_buffer, - &filter, ) .await .unwrap(); @@ -4377,14 +4365,15 @@ pub(crate) mod tests { let max_visits = 5; let filter = TerminateAfterN::new(max_visits); + let search_params = KnnSearch::new_default(10, 100).unwrap(); // Large L to ensure we'd visit more without termination + let mut multihop = graph::MultihopSearch::new(search_params, &filter); let _stats = index - .multihop_search( + .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 100).unwrap(), // Large L to ensure we'd visit more without termination &mut result_output_buffer, - &filter, ) .await .unwrap(); diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index d3d37416d..0b4787213 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -8,7 +8,7 @@ use std::{num::NonZeroUsize, sync::Arc}; use diskann::{ ANNResult, graph::{ - self, ConsolidateKind, InplaceDeleteMethod, SearchParams, + self, ConsolidateKind, InplaceDeleteMethod, KnnSearch, glue::{ self, AsElement, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchStrategy, }, @@ -226,7 +226,7 @@ where strategy: &S, context: &DP::Context, query: &T, - search_params: &SearchParams, + search_params: &KnnSearch, output: &mut OB, ) -> ANNResult where @@ -235,9 +235,10 @@ where O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { + let mut knn_search = *search_params; self.handle.block_on( self.inner - .search(strategy, context, query, search_params, output), + .search(&mut knn_search, strategy, context, query, output), ) } diff --git a/diskann-utils/Cargo.toml b/diskann-utils/Cargo.toml index e2dc5ee12..06c161978 100644 --- a/diskann-utils/Cargo.toml +++ b/diskann-utils/Cargo.toml @@ -38,4 +38,4 @@ default = ["rayon"] # Enable Rayon-based Parallelism for tagged kernels. rayon = ["dep:rayon"] # Enable testing utilities like test_data_root() -testing = [] \ No newline at end of file +testing = [] diff --git a/diskann/src/error/ann_error.rs b/diskann/src/error/ann_error.rs index 6a87a8f6d..040e06037 100644 --- a/diskann/src/error/ann_error.rs +++ b/diskann/src/error/ann_error.rs @@ -718,17 +718,6 @@ where } } -pub(crate) fn ensure_positive(value: T, error: E) -> Result -where - T: PartialOrd + Default + Debug, -{ - if value > T::default() { - Ok(value) - } else { - Err(error) - } -} - // /// An internal macro for creating opaque, adhoc errors to help when debugging. // macro_rules! ann_error { // ($($arg:tt)+) => {{ diff --git a/diskann/src/error/mod.rs b/diskann/src/error/mod.rs index 7c052c4d1..3a9f9ab50 100644 --- a/diskann/src/error/mod.rs +++ b/diskann/src/error/mod.rs @@ -4,7 +4,6 @@ */ pub(crate) mod ann_error; -pub(crate) use ann_error::ensure_positive; pub use ann_error::{ANNError, ANNErrorKind, ANNResult, DiskANNError, ErrorContext, IntoANNResult}; pub(crate) mod ranked; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index ea48adc0b..b7521c912 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -24,11 +24,10 @@ use thiserror::Error; use tokio::task::JoinSet; use super::{ - AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, RangeSearchParams, SearchParams, + AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, KnnSearch, glue::{ - self, AsElement, ExpandBeam, FillSet, HybridPredicate, IdIterator, InplaceDeleteStrategy, - InsertStrategy, Predicate, PredicateMut, PruneStrategy, SearchExt, SearchPostProcess, - SearchStrategy, aliases, + self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -38,9 +37,6 @@ use super::{ search_output_buffer, }; -#[cfg(feature = "experimental_diversity_search")] -use super::DiverseSearchParams; - use crate::{ ANNError, ANNErrorKind, ANNResult, error::{ErrorExt, IntoANNResult}, @@ -110,6 +106,7 @@ pub struct DegreeStats { /// This struct provides detailed metrics about the search process, including /// the number of nodes visited, the number of distance computations performed, /// the number of hops taken during the search, and the total number of results returned. +#[derive(Debug, Clone, Copy)] pub struct SearchStats { /// The total number of distance computations performed during the search. pub cmps: u32, @@ -220,53 +217,6 @@ struct SetBatchElements { batch: Arc<[VectorIdBoxSlice]>, } -pub struct NotInMutWithLabelCheck<'a, K> -where - K: VectorId, -{ - visited_set: &'a mut hashbrown::HashSet, - query_label_evaluator: &'a dyn QueryLabelProvider, -} - -impl<'a, K> NotInMutWithLabelCheck<'a, K> -where - K: VectorId, -{ - /// Construct a new `NotInMutWithLabelCheck` around `visited_set`. - pub fn new( - visited_set: &'a mut hashbrown::HashSet, - query_label_evaluator: &'a dyn QueryLabelProvider, - ) -> Self { - Self { - visited_set, - query_label_evaluator, - } - } -} - -impl Predicate for NotInMutWithLabelCheck<'_, K> -where - K: VectorId, -{ - fn eval(&self, item: &K) -> bool { - !self.visited_set.contains(item) && self.query_label_evaluator.is_match(*item) - } -} - -impl PredicateMut for NotInMutWithLabelCheck<'_, K> -where - K: VectorId, -{ - fn eval_mut(&mut self, item: &K) -> bool { - if self.query_label_evaluator.is_match(*item) { - return self.visited_set.insert(*item); - } - false - } -} - -impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId {} - impl DiskANNIndex where DP: DataProvider, @@ -296,7 +246,7 @@ where /// * `l`: The default window size to use. /// * `additional`: Extra capacity, usually to allow start points to be filtered from /// the result. - fn search_scratch( + pub(crate) fn search_scratch( &self, l: usize, additional: usize, @@ -2061,7 +2011,7 @@ where } // A is the accessor type, T is the query type used for BuildQueryComputer - fn search_internal( + pub(crate) fn search_internal( &self, beam_width: Option, start_ids: &[DP::InternalId], @@ -2136,204 +2086,6 @@ where } } - // A is the accessor type, T is the query type used for BuildQueryComputer - // scratch.in_range is guaranteed to include the starting points - fn range_search_internal( - &self, - search_params: &RangeSearchParams, - accessor: &mut A, - computer: &A::QueryComputer, - scratch: &mut SearchScratch, - ) -> impl SendFuture> - where - A: ExpandBeam + SearchExt, - T: ?Sized, - { - async move { - let beam_width = search_params.beam_width.unwrap_or(1); - - for neighbor in &scratch.in_range { - scratch.range_frontier.push_back(neighbor.id); - } - - let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); - - let max_returned = search_params.max_returned.unwrap_or(usize::MAX); - - while !scratch.range_frontier.is_empty() { - scratch.beam_nodes.clear(); - - // In this loop we are going to find the beam_width number of remaining nodes within the radius - // Each of these nodes will be a frontier node. - while !scratch.range_frontier.is_empty() && scratch.beam_nodes.len() < beam_width { - let next = scratch.range_frontier.pop_front(); - if let Some(next_node) = next { - scratch.beam_nodes.push(next_node); - } - } - - neighbors.clear(); - accessor - .expand_beam( - scratch.beam_nodes.iter().copied(), - computer, - glue::NotInMut::new(&mut scratch.visited), - |distance, id| neighbors.push(Neighbor::new(id, distance)), - ) - .await?; - - // The predicate ensure that the contents of `neighbors` are unique. - for neighbor in neighbors.iter() { - if neighbor.distance <= search_params.radius * search_params.range_search_slack - && scratch.in_range.len() < max_returned - { - scratch.in_range.push(*neighbor); - scratch.range_frontier.push_back(neighbor.id); - } - } - scratch.cmps += neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; - } - - Ok(InternalSearchStats { - cmps: scratch.cmps, - hops: scratch.hops, - range_search_second_round: true, - }) - } - } - - // A is the accessor type, T is the query type used for BuildQueryComputer - fn multihop_search_internal( - &self, - search_params: &SearchParams, - accessor: &mut A, - computer: &A::QueryComputer, - scratch: &mut SearchScratch, - search_record: &mut SR, - query_label_evaluator: &dyn QueryLabelProvider, - ) -> impl SendFuture> - where - A: ExpandBeam + SearchExt, - T: ?Sized, - SR: SearchRecord + ?Sized, - { - async move { - let beam_width = search_params.beam_width.unwrap_or(1); - - // Helper to build the final stats from scratch state. - let make_stats = |scratch: &SearchScratch| InternalSearchStats { - cmps: scratch.cmps, - hops: scratch.hops, - range_search_second_round: false, - }; - - // Initialize search state if not already initialized. - // This allows paged search to call multihop_search_internal multiple times - if scratch.visited.is_empty() { - let start_ids = accessor.starting_points().await?; - - for id in start_ids { - scratch.visited.insert(id); - let element = accessor - .get_element(id) - .await - .escalate("start point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); - scratch.best.insert(Neighbor::new(id, dist)); - } - } - - // Pre-allocate with good capacity to avoid repeated allocations - let mut one_hop_neighbors = Vec::with_capacity(self.max_degree_with_slack()); - let mut two_hop_neighbors = Vec::with_capacity(self.max_degree_with_slack()); - let mut candidates_two_hop_expansion = Vec::with_capacity(self.max_degree_with_slack()); - - while scratch.best.has_notvisited_node() && !accessor.terminate_early() { - scratch.beam_nodes.clear(); - one_hop_neighbors.clear(); - candidates_two_hop_expansion.clear(); - two_hop_neighbors.clear(); - - // In this loop we are going to find the beam_width number of nodes that are closest to the query. - // Each of these nodes will be a frontier node. - while scratch.best.has_notvisited_node() && scratch.beam_nodes.len() < beam_width { - let closest_node = scratch.best.closest_notvisited(); - search_record.record(closest_node, scratch.hops, scratch.cmps); - scratch.beam_nodes.push(closest_node.id); - } - - // compute distances from query to one-hop neighbors, and mark them visited - accessor - .expand_beam( - scratch.beam_nodes.iter().copied(), - computer, - glue::NotInMut::new(&mut scratch.visited), - |distance, id| one_hop_neighbors.push(Neighbor::new(id, distance)), - ) - .await?; - - // Process one-hop neighbors based on on_visit() decision - for neighbor in one_hop_neighbors.iter().copied() { - match query_label_evaluator.on_visit(neighbor) { - QueryVisitDecision::Accept(accepted) => { - scratch.best.insert(accepted); - } - QueryVisitDecision::Reject => { - // Rejected nodes: still add to two-hop expansion so we can traverse through them - candidates_two_hop_expansion.push(neighbor); - } - QueryVisitDecision::Terminate => { - scratch.cmps += one_hop_neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; - return Ok(make_stats(scratch)); - } - } - } - - scratch.cmps += one_hop_neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; - - // sort the candidates for two-hop expansion by distance to query point - candidates_two_hop_expansion.sort_unstable_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // limit the number of two-hop candidates to avoid too many expansions - candidates_two_hop_expansion.truncate(self.max_degree_with_slack() / 2); - - // Expand each two-hop candidate: if its neighbor is a match, compute its distance - // to the query and insert into `scratch.visited` - // If it is not a match, do nothing - let two_hop_expansion_candidate_ids: Vec = - candidates_two_hop_expansion.iter().map(|n| n.id).collect(); - - accessor - .expand_beam( - two_hop_expansion_candidate_ids.iter().copied(), - computer, - NotInMutWithLabelCheck::new(&mut scratch.visited, query_label_evaluator), - |distance, id| { - two_hop_neighbors.push(Neighbor::new(id, distance)); - }, - ) - .await?; - - // Next, insert the new matches into `scratch.best` and increment stats counters - two_hop_neighbors - .iter() - .for_each(|neighbor| scratch.best.insert(*neighbor)); - - scratch.cmps += two_hop_neighbors.len() as u32; - scratch.hops += two_hop_expansion_candidate_ids.len() as u32; - } - - Ok(make_stats(scratch)) - } - } - /// Filter out start nodes from the best candidates in the scratch. fn filter_search_candidates( &self, @@ -2363,136 +2115,47 @@ where } } - /// Performs a graph-based search towards a target query vector recording the path taken. - /// - /// This method executes a search using the provided `strategy` to access and process elements. - /// It computes the similarity between the query vector and the elements in the index, moving towards the - /// nearest neighbors according to the search parameters. - /// The path taken is recorded according to the search_record object passed in. - /// - /// # Arguments - /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`) and beam width. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// * `search_record` - A mutable reference to a search record object that will record the path taken during the search. - /// - /// # Returns - /// - /// Returns a tuple containing: - /// - An optional vector of visited nodes (if requested in `search_params`). - /// - The number of distance computations performed. - /// - The number of hops (always zero for flat search, as no graph traversal occurs). - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - #[allow(clippy::too_many_arguments)] - pub fn search_recorded( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - output: &mut OB, - search_record: &mut SR, - ) -> impl SendFuture> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, - SR: SearchRecord + ?Sized, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = self.search_scratch(search_params.l_value, start_ids.len()); - - let stats = self - .search_internal( - search_params.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut scratch, - search_record, - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(search_params.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } - - /// Performs a graph-based search towards a target query vector. + /// Execute a search using the unified search interface. /// - /// This method executes a search using the provided `strategy` to access and process elements. - /// It computes the similarity between the query vector and the elements in the index, moving towards the - /// nearest neighbors according to the search parameters. + /// This method provides a single entry point for all search types. The `search_params` argument + /// implements [`search::Search`], which defines the complete search behavior including + /// algorithm selection and post-processing. /// - /// # Arguments + /// # Supported Search Types /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`) and beam width. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + /// - [`search::KnnSearch`]: Standard k-NN graph-based search + /// - [`search::MultihopSearch`]: Label-filtered search with multi-hop expansion + /// - [`search::RangeSearch`]: Range-based search within a distance radius + /// - [`search::DiverseSearch`]: Diversity-aware search (feature-gated) /// - /// # Returns + /// # Example /// - /// Returns a tuple containing: - /// - An optional vector of visited nodes (if requested in `search_params`). - /// - The number of distance computations performed. - /// - The number of hops (always zero for flat search, as no graph traversal occurs). + /// ```ignore + /// use diskann::graph::{KnnSearch, RangeSearch, Search}; /// - /// # Errors + /// // Standard k-NN search + /// let mut params = KnnSearch::new(10, 100, None)?; + /// let stats = index.search(&mut params, &strategy, &context, &query, &mut output).await?; /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - pub fn search( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - output: &mut OB, - ) -> impl SendFuture> + /// // Range search (note: uses () as output buffer, results in Output type) + /// let mut params = RangeSearch::new(100, 0.5)?; + /// let result = index.search(&mut params, &strategy, &context, &query, &mut ()).await?; + /// // result.ids and result.distances contain the matches + /// ``` + pub fn search<'a, S, T, O: 'a, OB, P>( + &'a self, + search_params: &'a mut P, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> + 'a where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, + P: super::search::Search, + T: ?Sized, + OB: ?Sized, { - async move { - self.search_recorded( - strategy, - context, - query, - search_params, - output, - &mut NoopSearchRecord::new(), - ) - .await - } + search_params.search(self, strategy, context, query, output) } /// Performs a brute-force flat search over the points matching a provided filter function. @@ -2507,15 +2170,12 @@ where /// * `context` - The context to pass through to providers. /// * `query` - The query vector for which nearest neighbors are sought. /// * `vector_filter` - A predicate function used to filter candidate vectors based on their external IDs. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`) and beam width. + /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`). /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. /// /// # Returns /// - /// Returns a tuple containing: - /// - An optional vector of visited nodes (if requested in `search_params`). - /// - The number of distance computations performed. - /// - The number of hops (always zero for flat search, as no graph traversal occurs). + /// Returns search statistics including the number of distance computations performed. /// /// # Errors /// @@ -2531,7 +2191,7 @@ where context: &'a DP::Context, query: &T, vector_filter: &(dyn Fn(&DP::ExternalId) -> bool + Send + Sync), - search_params: &SearchParams, + search_params: &KnnSearch, output: &mut OB, ) -> ANNResult where @@ -2548,7 +2208,7 @@ where let mut scratch = { let num_start_points = accessor.starting_points().await?.len(); - self.search_scratch(search_params.l_value, num_start_points) + self.search_scratch(search_params.l_value().get(), num_start_points) }; let id_iterator = accessor.id_iterator().await?; @@ -2576,7 +2236,7 @@ where &mut accessor, query, &computer, - scratch.best.iter().take(search_params.l_value.into_usize()), + scratch.best.iter().take(search_params.l_value().get()), output, ) .send() @@ -2591,229 +2251,6 @@ where }) } - /// A helper function for range search that allows an external application - /// to perform their own post-processing on the raw in-range results - #[allow(clippy::type_complexity)] - pub fn range_search_raw( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &RangeSearchParams, - ) -> impl SendFuture>)>> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send + Default + Clone, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = self.search_scratch(search_params.starting_l_value, start_ids.len()); - - let initial_stats = self - .search_internal( - search_params.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - let mut in_range = Vec::with_capacity(search_params.starting_l_value.into_usize()); - - for neighbor in scratch - .best - .iter() - .take(search_params.starting_l_value.into_usize()) - { - if neighbor.distance <= search_params.radius { - in_range.push(neighbor); - } - } - - // clear the visited set and repopulate it with just the in-range points - scratch.visited.clear(); - for neighbor in in_range.iter() { - scratch.visited.insert(neighbor.id); - } - scratch.in_range = in_range; - - let stats = if scratch.in_range.len() - >= ((search_params.starting_l_value as f32) * search_params.initial_search_slack) - as usize - { - // Move to range search - let range_stats = self - .range_search_internal(search_params, &mut accessor, &computer, &mut scratch) - .await?; - - InternalSearchStats { - cmps: initial_stats.cmps, - hops: initial_stats.hops + range_stats.hops, - range_search_second_round: true, - } - } else { - initial_stats - }; - - Ok(( - stats.finish(scratch.in_range.len() as u32), - scratch.in_range.to_vec(), - )) - } - } - - /// Given a `query` vector, search for all results within a specified radius - /// `l_value` is the search depth of the initial search phase - /// - /// Note that the radii in `search_params` are raw distances, not similarity scores; - /// the user is expected to execute any necessary transformations to their desired - /// radius before calling this function. - /// - /// We allow complicated types here to avoid needing an entirely new type definition - /// for just one function - #[allow(clippy::type_complexity)] - pub fn range_search( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &RangeSearchParams, - ) -> impl SendFuture, Vec)>> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send + Default + Clone, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let (mut stats, in_range) = self - .range_search_raw(strategy, context, query, search_params) - .await?; - // create a new output buffer for the range search - // need to initialize distance buffer to max value because of later filtering step - let mut result_ids: Vec = vec![O::default(); in_range.len()]; - let mut result_dists: Vec = vec![f32::MAX; in_range.len()]; - - let mut output_buffer = search_output_buffer::IdDistance::new( - result_ids.as_mut_slice(), - result_dists.as_mut_slice(), - ); - - let _ = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - in_range.into_iter(), - &mut output_buffer, - ) - .send() - .await - .into_ann_result()?; - - // Filter the output buffer for points with distance between inner and outer radius - // Note this takes a dependency on the output of `post_process` being sorted by distance - - let inner_cutoff = if let Some(inner_radius) = search_params.inner_radius { - result_dists - .iter() - .position(|dist| *dist > inner_radius) - .unwrap_or(result_dists.len()) - } else { - 0 - }; - - let outer_cutoff = result_dists - .iter() - .position(|dist| *dist > search_params.radius) - .unwrap_or(result_dists.len()); - - result_ids.truncate(outer_cutoff); - result_ids.drain(0..inner_cutoff); - - result_dists.truncate(outer_cutoff); - result_dists.drain(0..inner_cutoff); - - let result_count = result_ids.len(); - - stats.result_count = result_count as u32; - - Ok((stats, result_ids, result_dists)) - } - } - - /// Graph search that takes into account label filter matching by expanding - /// each non-matching neighborhood to search for matching nodes - /// Label provider must be included as a function argument - /// Note that if the Strategy is of type BetaFilter, this function assumes - /// but does not enforce that the label provider used in the strategy - /// is the same as the one in the function argument - pub fn multihop_search( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - output: &mut OB, - query_label_evaluator: &dyn QueryLabelProvider, - ) -> impl SendFuture> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - - let mut scratch = self.search_scratch(search_params.l_value, start_ids.len()); - - let stats = self - .multihop_search_internal( - search_params, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - query_label_evaluator, - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(search_params.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } - ////////////////// // Paged Search // ////////////////// @@ -3595,15 +3032,15 @@ struct InplaceDeleteWorkList { in_neighbors: Vec, } -/// Private internal struct for recording search statistics. -struct InternalSearchStats { - cmps: u32, - hops: u32, - range_search_second_round: bool, +/// Internal struct for recording search statistics. +pub(crate) struct InternalSearchStats { + pub(crate) cmps: u32, + pub(crate) hops: u32, + pub(crate) range_search_second_round: bool, } impl InternalSearchStats { - fn finish(self, result_count: u32) -> SearchStats { + pub(crate) fn finish(self, result_count: u32) -> SearchStats { SearchStats { cmps: self.cmps, hops: self.hops, @@ -3612,137 +3049,3 @@ impl InternalSearchStats { } } } - -#[cfg(feature = "experimental_diversity_search")] -impl DiskANNIndex -where - DP: DataProvider, -{ - /// Create a diverse search scratch with DiverseNeighborQueue - fn create_diverse_scratch

( - &self, - l_value: usize, - beam_width: Option, - diverse_params: &DiverseSearchParams

, - k_value: usize, - ) -> SearchScratch> - where - P: crate::neighbor::AttributeValueProvider, - { - use crate::neighbor::DiverseNeighborQueue; - - let attribute_provider = diverse_params.attribute_provider.clone(); - let diverse_queue = DiverseNeighborQueue::new( - l_value, - // SAFETY: k_value is guaranteed to be non-zero by SearchParams validation by caller - #[allow(clippy::expect_used)] - NonZeroUsize::new(k_value).expect("k_value must be non-zero"), - diverse_params.diverse_results_k, - attribute_provider, - ); - - SearchScratch { - best: diverse_queue, - visited: HashSet::with_capacity(self.estimate_visited_set_capacity(Some(l_value))), - id_scratch: Vec::with_capacity(self.max_degree_with_slack()), - beam_nodes: Vec::with_capacity(beam_width.unwrap_or(1)), - range_frontier: std::collections::VecDeque::new(), - in_range: Vec::new(), - hops: 0, - cmps: 0, - } - } - - /// Experimental diverse search implementation using DiverseNeighborQueue. - /// - /// This method performs a graph-based search with diversity constraints, using the provided - /// diverse search parameters to filter results based on attribute values. - /// - /// # Arguments - /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `search_params` - Parameters controlling the search behavior, including l_value, beam width, and k_value. - /// * `diverse_params` - Diversity parameters including attribute provider and alpha value. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// * `search_record` - A mutable reference to a search record object that will record the path taken during the search. - /// - /// # Returns - /// - /// Returns search statistics including comparisons and hops performed. - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - #[allow(clippy::too_many_arguments)] - pub fn diverse_search_experimental( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - diverse_params: &DiverseSearchParams

, - output: &mut OB, - search_record: &mut SR, - ) -> impl SendFuture> - where - T: Sync + ?Sized, - S: glue::SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send, - SR: super::search::record::SearchRecord + ?Sized, - P: crate::neighbor::AttributeValueProvider, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - // Use diverse search with DiverseNeighborQueue - // TODO: Use scratch pool in future PRs to avoid allocation. - let mut diverse_scratch = self.create_diverse_scratch( - search_params.l_value, - search_params.beam_width, - diverse_params, - search_params.k_value, - ); - - let stats = self - .search_internal( - search_params.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut diverse_scratch, - search_record, - ) - .await?; - - // Post-process diverse results to keep only diverse_results_k items - diverse_scratch.best.post_process(); - - // TODO: Post processing will change for diverse search in future PRs - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - diverse_scratch - .best - .iter() - .take(search_params.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } -} diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs index 8c58f6edb..067bd2d21 100644 --- a/diskann/src/graph/misc.rs +++ b/diskann/src/graph/misc.rs @@ -3,10 +3,6 @@ * Licensed under the MIT license. */ -use thiserror::Error; - -use crate::{ANNError, ANNErrorKind, error::ensure_positive}; - // enum used to return the status of the vector that `consolidate_vector` // was called on: Deleted if the vector was already deleted, and Complete // if the vector was not deleted (and thus is now consolidated) @@ -35,141 +31,6 @@ pub enum InplaceDeleteMethod { OneHop, } -// Parameters for the search algorithm -#[derive(Copy, Clone, Debug)] -pub struct SearchParams { - pub k_value: usize, - pub l_value: usize, - pub beam_width: Option, -} - -#[derive(Debug, Error)] -pub enum SearchParamsError { - #[error("l_value ({l_value}) cannot be less than k_value ({k_value})")] - LLessThanK { l_value: usize, k_value: usize }, - #[error("beam width cannot be zero")] - BeamWidthZero, - #[error("l_value cannot be zero")] - LZero, - #[error("k_value cannot be zero")] - KZero, -} - -impl From for ANNError { - fn from(err: SearchParamsError) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } -} - -impl SearchParams { - pub fn new( - k_value: usize, - l_value: usize, - beam_width: Option, - ) -> Result { - if k_value > l_value { - return Err(SearchParamsError::LLessThanK { l_value, k_value }); - } - if let Some(beam_width) = beam_width { - ensure_positive(beam_width, SearchParamsError::BeamWidthZero)?; - } - ensure_positive(k_value, SearchParamsError::KZero)?; - ensure_positive(l_value, SearchParamsError::LZero)?; - - Ok(Self { - k_value, - l_value, - beam_width, - }) - } - - pub fn new_default(k_value: usize, l_value: usize) -> Result { - SearchParams::new(k_value, l_value, None) - } -} - -// Parameters for the search algorithm -#[derive(Copy, Clone, Debug)] -pub struct RangeSearchParams { - pub max_returned: Option, - pub starting_l_value: usize, - pub beam_width: Option, - pub radius: f32, - pub inner_radius: Option, - pub initial_search_slack: f32, - pub range_search_slack: f32, -} - -#[derive(Debug, Error)] -pub enum RangeSearchParamsError { - #[error("beam width cannot be zero")] - BeamWidthZero, - #[error("l_value cannot be zero")] - LZero, - #[error("initial_search_slack must be between 0 and 1.0")] - StartingListSlackValueError, - #[error("range_search_slack must be greater than or equal to 1.0")] - RangeSearchSlackValueError, - #[error("inner_radius must be less than or equal to radius")] - InnerRadiusValueError, -} - -impl From for ANNError { - fn from(err: RangeSearchParamsError) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } -} - -impl RangeSearchParams { - pub fn new( - max_returned: Option, - starting_l_value: usize, - beam_width: Option, - radius: f32, - inner_radius: Option, - initial_search_slack: f32, - range_search_slack: f32, - ) -> Result { - // note that radius is allowed to be negative due to inner product metrics - if let Some(beam_width) = beam_width { - ensure_positive(beam_width, RangeSearchParamsError::BeamWidthZero)?; - } - ensure_positive(starting_l_value, RangeSearchParamsError::LZero)?; - if !(0.0..=1.0).contains(&initial_search_slack) { - return Err(RangeSearchParamsError::StartingListSlackValueError); - } - if range_search_slack < 1.0 { - return Err(RangeSearchParamsError::RangeSearchSlackValueError); - } - if let Some(inner_radius) = inner_radius - && inner_radius > radius - { - return Err(RangeSearchParamsError::InnerRadiusValueError); - } - - Ok(Self { - max_returned, - starting_l_value, - beam_width, - radius, - inner_radius, - initial_search_slack, - range_search_slack, - }) - } - - pub fn new_default( - starting_l_value: usize, - radius: f32, - ) -> Result { - RangeSearchParams::new(None, starting_l_value, None, radius, None, 1.0, 1.0) - } - - pub fn l_value(&self) -> usize { - self.starting_l_value - } -} - // Parameters for diverse search #[cfg(feature = "experimental_diversity_search")] #[derive(Clone, Debug)] @@ -224,75 +85,4 @@ mod tests { _ => panic!("Expected not deleted variant"), } } - - #[test] - fn test_range_search_params_error_cases() { - { - // test starting list slack factor error - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - None, // inner radius - 1.1, // initial search slack - 1.0, // range search slack - ); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "initial_search_slack must be between 0 and 1.0" - ); - } - { - // test range search slack factor error - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - None, // inner radius - 1.0, // initial search slack - 0.9, // range search slack - ); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "range_search_slack must be greater than or equal to 1.0" - ); - } - { - // test inner radius error - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - Some(2.0), // inner radius - 1.0, // initial search slack - 1.0, // range search slack - ); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "inner_radius must be less than or equal to radius" - ); - } - } - - #[test] - fn test_range_search_params_impl() { - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - None, // inner radius - 1.0, // initial search slack - 1.0, // range search slack - ) - .unwrap(); - - assert_eq!(res.l_value(), 10); - } } diff --git a/diskann/src/graph/mod.rs b/diskann/src/graph/mod.rs index d203a74b0..011734e28 100644 --- a/diskann/src/graph/mod.rs +++ b/diskann/src/graph/mod.rs @@ -21,10 +21,7 @@ mod start_point; pub use start_point::{SampleableForStart, StartPointStrategy}; mod misc; -pub use misc::{ - ConsolidateKind, InplaceDeleteMethod, RangeSearchParams, RangeSearchParamsError, SearchParams, - SearchParamsError, -}; +pub use misc::{ConsolidateKind, InplaceDeleteMethod}; #[cfg(feature = "experimental_diversity_search")] pub use misc::DiverseSearchParams; @@ -32,6 +29,15 @@ pub use misc::DiverseSearchParams; pub mod glue; pub mod search; +// Re-export unified search interface as the primary API. +pub use search::{ + KnnSearch, KnnSearchError, MultihopSearch, RangeSearch, RangeSearchError, RangeSearchOutput, + RecordedKnnSearch, Search, +}; + +#[cfg(feature = "experimental_diversity_search")] +pub use search::DiverseSearch; + mod internal; // Integration tests and test providers. diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs new file mode 100644 index 000000000..dc919fc08 --- /dev/null +++ b/diskann/src/graph/search/diverse_search.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Diversity-aware search. + +use diskann_utils::future::{AssertSend, SendFuture}; +use hashbrown::HashSet; + +use super::{KnnSearch, Search, record::NoopSearchRecord, scratch::SearchScratch}; +use crate::{ + ANNResult, + error::IntoANNResult, + graph::{ + DiverseSearchParams, + glue::{SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, SearchStats}, + search_output_buffer::SearchOutputBuffer, + }, + neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue}, + provider::{BuildQueryComputer, DataProvider}, +}; + +/// Parameters for diversity-aware search. +/// +/// Returns results that are diverse across a specified attribute. +#[derive(Debug)] +pub struct DiverseSearch

+where + P: AttributeValueProvider, +{ + /// Base k-NN search parameters. + inner: KnnSearch, + /// Diversity-specific parameters. + diverse_params: DiverseSearchParams

, +} + +impl

DiverseSearch

+where + P: AttributeValueProvider, +{ + /// Create new diverse search parameters. + pub fn new(inner: KnnSearch, diverse_params: DiverseSearchParams

) -> Self { + Self { + inner, + diverse_params, + } + } + + /// Returns a reference to the inner k-NN search parameters. + #[inline] + pub fn inner(&self) -> &KnnSearch { + &self.inner + } + + /// Returns a reference to the diversity-specific parameters. + #[inline] + pub fn diverse_params(&self) -> &DiverseSearchParams

{ + &self.diverse_params + } + + /// Create search scratch with DiverseNeighborQueue for this search. + fn create_scratch( + &self, + index: &DiskANNIndex, + ) -> SearchScratch> + where + DP: DataProvider, + P: AttributeValueProvider, + { + let attribute_provider = self.diverse_params.attribute_provider.clone(); + let diverse_queue = DiverseNeighborQueue::new( + self.inner.l_value().get(), + self.inner.k_value(), + self.diverse_params.diverse_results_k, + attribute_provider, + ); + + SearchScratch { + best: diverse_queue, + visited: HashSet::with_capacity( + index.estimate_visited_set_capacity(Some(self.inner.l_value().get())), + ), + id_scratch: Vec::with_capacity(index.max_degree_with_slack()), + beam_nodes: Vec::with_capacity(self.inner.beam_width().map_or(1, |nz| nz.get())), + range_frontier: std::collections::VecDeque::new(), + in_range: Vec::new(), + hops: 0, + cmps: 0, + } + } +} + +impl Search for DiverseSearch

+where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send, + P: AttributeValueProvider, +{ + type Output = SearchStats; + + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut diverse_scratch = self.create_scratch(index); + + let stats = index + .search_internal( + self.inner.beam_width().map(|nz| nz.get()), + &start_ids, + &mut accessor, + &computer, + &mut diverse_scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + // Post-process diverse results + diverse_scratch.best.post_process(); + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + diverse_scratch.best.iter().take(self.inner.l_value().get()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs new file mode 100644 index 000000000..f97fa7643 --- /dev/null +++ b/diskann/src/graph/search/knn_search.rs @@ -0,0 +1,352 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Standard k-NN (k-nearest neighbor) graph-based search. + +use std::{fmt::Debug, num::NonZeroUsize}; + +use diskann_utils::future::{AssertSend, SendFuture}; +use thiserror::Error; + +use super::Search; +use crate::{ + ANNError, ANNErrorKind, ANNResult, + error::IntoANNResult, + graph::{ + glue::{SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, SearchStats}, + search::record::NoopSearchRecord, + search_output_buffer::SearchOutputBuffer, + }, + provider::{BuildQueryComputer, DataProvider}, + utils::IntoUsize, +}; + +/// Error type for [`KnnSearch`] parameter validation. +#[derive(Debug, Error)] +pub enum KnnSearchError { + #[error("l_value ({l_value}) cannot be less than k_value ({k_value})")] + LLessThanK { l_value: usize, k_value: usize }, + #[error("beam width cannot be zero")] + BeamWidthZero, + #[error("k_value cannot be zero")] + KZero, + #[error("l_value cannot be zero")] + LZero, +} + +impl From for ANNError { + #[track_caller] + fn from(err: KnnSearchError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + +/// Parameters for standard k-NN (k-nearest neighbor) graph-based search. +/// +/// This is the primary search mode, using the Vamana graph structure for efficient +/// approximate nearest neighbor traversal. +#[derive(Debug, Clone, Copy)] +pub struct KnnSearch { + /// Number of results to return (k in k-NN). + k_value: NonZeroUsize, + /// Search list size - controls accuracy vs speed tradeoff. + l_value: NonZeroUsize, + /// Optional beam width for parallel graph exploration. + beam_width: Option, +} + +impl KnnSearch { + /// Create new k-NN search parameters. + /// + /// # Errors + /// + /// Returns an error if `k_value` is zero, `l_value` is zero, + /// `l_value < k_value`, or if `beam_width` is zero. + pub fn new( + k_value: usize, + l_value: usize, + beam_width: Option, + ) -> Result { + if k_value == 0 { + return Err(KnnSearchError::KZero); + } + if l_value == 0 { + return Err(KnnSearchError::LZero); + } + if k_value > l_value { + return Err(KnnSearchError::LLessThanK { l_value, k_value }); + } + if let Some(bw) = beam_width + && bw == 0 + { + return Err(KnnSearchError::BeamWidthZero); + } + + // SAFETY: We've validated k_value != 0 and l_value != 0 above + Ok(Self { + k_value: unsafe { NonZeroUsize::new_unchecked(k_value) }, + l_value: unsafe { NonZeroUsize::new_unchecked(l_value) }, + beam_width: beam_width.and_then(NonZeroUsize::new), + }) + } + + /// Create parameters with default beam width. + pub fn new_default(k_value: usize, l_value: usize) -> Result { + Self::new(k_value, l_value, None) + } + + /// Returns the number of results to return (k in k-NN). + #[inline] + pub fn k_value(&self) -> NonZeroUsize { + self.k_value + } + + /// Returns the search list size. + #[inline] + pub fn l_value(&self) -> NonZeroUsize { + self.l_value + } + + /// Returns the optional beam width for parallel graph exploration. + #[inline] + pub fn beam_width(&self) -> Option { + self.beam_width + } +} + +/// Standard k-NN graph-based search implementation. +/// +/// This is the primary search type for approximate nearest neighbor queries. It performs +/// a greedy beam search over the graph, maintaining a priority queue of the best candidates +/// found so far. The search explores neighbors of promising candidates until convergence. +/// +/// # Algorithm +/// +/// 1. Initialize with starting points +/// 2. Compute distances from query to starting points +/// 3. Greedily expand the most promising unexplored candidate +/// 4. Add the candidate's neighbors to the frontier +/// 5. Repeat until no unexplored candidates remain within the search list +/// 6. Return the top-k results from the best candidates found +/// +/// # Parameters +/// +/// - `k_value`: Number of nearest neighbors to return +/// - `l_value`: Search list size (larger values improve recall at cost of latency) +/// - `beam_width`: Optional parallel exploration width +/// +/// # Example +/// +/// ```ignore +/// use diskann::graph::{search::KnnSearch, Search}; +/// +/// let mut params = KnnSearch::new(10, 100, None)?; +/// let stats = index.search(&mut params, &strategy, &context, &query, &mut output).await?; +/// ``` +impl Search for KnnSearch +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, +{ + type Output = SearchStats; + + /// Execute the k-NN search on the given index. + /// + /// This method executes a search using the provided `strategy` to access and process elements. + /// It computes the similarity between the query vector and the elements in the index, traversing + /// the graph towards the nearest neighbors according to the search parameters. + /// + /// # Arguments + /// + /// * `index` - The DiskANN index to search. + /// * `strategy` - The search strategy to use for accessing and processing elements. + /// * `context` - The context to pass through to providers. + /// * `query` - The query vector for which nearest neighbors are sought. + /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + /// + /// # Returns + /// + /// Returns [`SearchStats`] containing: + /// - The number of distance computations performed. + /// - The number of hops (graph traversal steps). + /// - Timing information for the search operation. + /// + /// # Errors + /// + /// Returns an error if there is a failure accessing elements or computing distances. + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); + + let stats = index + .search_internal( + self.beam_width.map(|nz| nz.get()), + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(self.l_value.get().into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +//////////////////////// +// Recorded KnnSearch // +//////////////////////// + +/// K-NN search with traversal path recording. +/// +/// Records the path taken during search for debugging or analysis. +#[derive(Debug)] +pub struct RecordedKnnSearch<'r, SR: ?Sized> { + /// Base k-NN search parameters. + pub inner: KnnSearch, + /// The recorder to capture search path. + pub recorder: &'r mut SR, +} + +impl<'r, SR: ?Sized> RecordedKnnSearch<'r, SR> { + /// Create new recorded search parameters. + pub fn new(inner: KnnSearch, recorder: &'r mut SR) -> Self { + Self { inner, recorder } + } +} + +impl<'r, DP, S, T, O, OB, SR> Search for RecordedKnnSearch<'r, SR> +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + SR: super::record::SearchRecord + ?Sized, +{ + type Output = SearchStats; + + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.inner.l_value.get(), start_ids.len()); + + let stats = index + .search_internal( + self.inner.beam_width.map(|nz| nz.get()), + &start_ids, + &mut accessor, + &computer, + &mut scratch, + self.recorder, + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch + .best + .iter() + .take(self.inner.l_value.get().into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_knn_search_validation() { + // Valid + assert!(KnnSearch::new(10, 100, None).is_ok()); + assert!(KnnSearch::new(10, 100, Some(4)).is_ok()); + assert!(KnnSearch::new(10, 10, None).is_ok()); // k == l is valid + + // Invalid: k = 0 + assert!(matches!( + KnnSearch::new(0, 100, None), + Err(KnnSearchError::KZero) + )); + + // Invalid: l = 0 + assert!(matches!( + KnnSearch::new(10, 0, None), + Err(KnnSearchError::LZero) + )); + + // Invalid: l < k + assert!(matches!( + KnnSearch::new(100, 10, None), + Err(KnnSearchError::LLessThanK { .. }) + )); + + // Invalid: zero beam_width + assert!(matches!( + KnnSearch::new(10, 100, Some(0)), + Err(KnnSearchError::BeamWidthZero) + )); + } +} diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 2b02ac39f..49df6dbac 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -3,5 +3,90 @@ * Licensed under the MIT license. */ +//! Unified search execution framework. +//! +//! This module provides the primary search interface for DiskANN. All search types +//! are represented as parameter structs that implement [`Search`], which +//! contains the complete search logic. +//! +//! # Usage +//! +//! ```ignore +//! use diskann::graph::{KnnSearch, RangeSearch, MultihopSearch, Search}; +//! +//! // Standard k-NN search +//! let mut params = KnnSearch::new(10, 100, None)?; +//! let stats = index.search(&mut params, &strategy, &context, &query, &mut output).await?; +//! +//! // Range search +//! let mut params = RangeSearch::new(100, 0.5)?; +//! let result = index.search(&mut params, &strategy, &context, &query, &mut ()).await?; +//! println!("Found {} points within radius", result.ids.len()); +//! ``` + +use diskann_utils::future::SendFuture; + +use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; + +mod knn_search; +mod multihop_search; +mod range_search; + pub mod record; pub(crate) mod scratch; + +/// Trait for search parameter types that execute their own search logic. +/// +/// Each search type (graph search, range search, etc.) implements this trait +/// to define its complete search behavior. The [`DiskANNIndex::search`] method +/// delegates to the `search` method. +pub trait Search +where + DP: DataProvider, +{ + /// The result type returned by this search. + type Output; + + /// Execute the search operation with full search logic. + /// + /// This method executes a search using the provided `strategy` to access and process elements. + /// It computes the similarity between the query vector and the elements in the index, + /// finding nearest neighbors according to the search parameters. + /// + /// # Arguments + /// + /// * `index` - The DiskANN index to search. + /// * `strategy` - The search strategy to use for accessing and processing elements. + /// * `context` - The context to pass through to providers. + /// * `query` - The query vector for which nearest neighbors are sought. + /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + /// + /// # Returns + /// + /// Returns `Self::Output` which varies by search type (e.g., [`SearchStats`](super::index::SearchStats) + /// for k-NN, [`RangeSearchOutput`] for range search). + /// + /// # Errors + /// + /// Returns an error if there is a failure accessing elements or computing distances. + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture>; +} + +// Re-export search parameter types. +pub use knn_search::{KnnSearch, KnnSearchError, RecordedKnnSearch}; +pub use multihop_search::MultihopSearch; +pub use range_search::{RangeSearch, RangeSearchError, RangeSearchOutput}; + +// Feature-gated diverse search. +#[cfg(feature = "experimental_diversity_search")] +mod diverse_search; + +#[cfg(feature = "experimental_diversity_search")] +pub use diverse_search::DiverseSearch; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs new file mode 100644 index 000000000..240c77b86 --- /dev/null +++ b/diskann/src/graph/search/multihop_search.rs @@ -0,0 +1,300 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Label-filtered search using multi-hop expansion. + +use diskann_utils::Reborrow; +use diskann_utils::future::{AssertSend, SendFuture}; +use diskann_vector::PreprocessedDistanceFunction; +use hashbrown::HashSet; + +use super::{KnnSearch, Search, record::SearchRecord, scratch::SearchScratch}; +use crate::{ + ANNResult, + error::{ErrorExt, IntoANNResult}, + graph::{ + glue::{ + self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, + SearchPostProcess, SearchStrategy, + }, + index::{ + DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, + }, + search::record::NoopSearchRecord, + search_output_buffer::SearchOutputBuffer, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, DataProvider}, + utils::VectorId, +}; + +/// Parameters for label-filtered search using multi-hop expansion. +/// +/// This search extends standard graph search by expanding through non-matching +/// nodes to find matching neighbors. More efficient than flat search when the +/// matching subset is reasonably large. +#[derive(Debug)] +pub struct MultihopSearch<'q, InternalId> { + /// Base graph search parameters. + pub inner: KnnSearch, + /// Label evaluator for determining node matches. + pub label_evaluator: &'q dyn QueryLabelProvider, +} + +impl<'q, InternalId> MultihopSearch<'q, InternalId> { + /// Create new multihop search parameters. + pub fn new(inner: KnnSearch, label_evaluator: &'q dyn QueryLabelProvider) -> Self { + Self { + inner, + label_evaluator, + } + } +} + +impl<'q, DP, S, T, O, OB> Search for MultihopSearch<'q, DP::InternalId> +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send, +{ + type Output = SearchStats; + + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> { + let params = self.inner; + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + let computer = accessor.build_query_computer(query).into_ann_result()?; + + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(params.l_value().get(), start_ids.len()); + + let stats = multihop_search_internal( + index.max_degree_with_slack(), + ¶ms, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + self.label_evaluator, + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(params.l_value().get()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +///////////////////////////// +// Internal Implementation // +///////////////////////////// + +/// A predicate that checks if an item is not in the visited set AND matches the label filter. +/// +/// Used during two-hop expansion to filter neighbors based on both visitation +/// status and label matching criteria. +pub struct NotInMutWithLabelCheck<'a, K> +where + K: VectorId, +{ + visited_set: &'a mut HashSet, + query_label_evaluator: &'a dyn QueryLabelProvider, +} + +impl<'a, K> NotInMutWithLabelCheck<'a, K> +where + K: VectorId, +{ + /// Construct a new `NotInMutWithLabelCheck` around `visited_set`. + pub fn new( + visited_set: &'a mut HashSet, + query_label_evaluator: &'a dyn QueryLabelProvider, + ) -> Self { + Self { + visited_set, + query_label_evaluator, + } + } +} + +impl Predicate for NotInMutWithLabelCheck<'_, K> +where + K: VectorId, +{ + fn eval(&self, item: &K) -> bool { + !self.visited_set.contains(item) && self.query_label_evaluator.is_match(*item) + } +} + +impl PredicateMut for NotInMutWithLabelCheck<'_, K> +where + K: VectorId, +{ + fn eval_mut(&mut self, item: &K) -> bool { + if self.query_label_evaluator.is_match(*item) { + return self.visited_set.insert(*item); + } + false + } +} + +impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId {} + +/// Internal multihop search implementation. +/// +/// Performs label-filtered search by expanding through non-matching nodes +/// to find matching neighbors within two hops. +pub(crate) async fn multihop_search_internal( + max_degree_with_slack: usize, + search_params: &KnnSearch, + accessor: &mut A, + computer: &A::QueryComputer, + scratch: &mut SearchScratch, + search_record: &mut SR, + query_label_evaluator: &dyn QueryLabelProvider, +) -> ANNResult +where + I: VectorId, + A: ExpandBeam + SearchExt, + T: ?Sized, + SR: SearchRecord + ?Sized, +{ + let beam_width = search_params.beam_width().map_or(1, |nz| nz.get()); + + // Helper to build the final stats from scratch state. + let make_stats = |scratch: &SearchScratch| InternalSearchStats { + cmps: scratch.cmps, + hops: scratch.hops, + range_search_second_round: false, + }; + + // Initialize search state if not already initialized. + // This allows paged search to call multihop_search_internal multiple times + if scratch.visited.is_empty() { + let start_ids = accessor.starting_points().await?; + + for id in start_ids { + scratch.visited.insert(id); + let element = accessor + .get_element(id) + .await + .escalate("start point retrieval must succeed")?; + let dist = computer.evaluate_similarity(element.reborrow()); + scratch.best.insert(Neighbor::new(id, dist)); + } + } + + // Pre-allocate with good capacity to avoid repeated allocations + let mut one_hop_neighbors = Vec::with_capacity(max_degree_with_slack); + let mut two_hop_neighbors = Vec::with_capacity(max_degree_with_slack); + let mut candidates_two_hop_expansion = Vec::with_capacity(max_degree_with_slack); + + while scratch.best.has_notvisited_node() && !accessor.terminate_early() { + scratch.beam_nodes.clear(); + one_hop_neighbors.clear(); + candidates_two_hop_expansion.clear(); + two_hop_neighbors.clear(); + + // In this loop we are going to find the beam_width number of nodes that are closest to the query. + // Each of these nodes will be a frontier node. + while scratch.best.has_notvisited_node() && scratch.beam_nodes.len() < beam_width { + let closest_node = scratch.best.closest_notvisited(); + search_record.record(closest_node, scratch.hops, scratch.cmps); + scratch.beam_nodes.push(closest_node.id); + } + + // compute distances from query to one-hop neighbors, and mark them visited + accessor + .expand_beam( + scratch.beam_nodes.iter().copied(), + computer, + glue::NotInMut::new(&mut scratch.visited), + |distance, id| one_hop_neighbors.push(Neighbor::new(id, distance)), + ) + .await?; + + // Process one-hop neighbors based on on_visit() decision + for neighbor in one_hop_neighbors.iter().copied() { + match query_label_evaluator.on_visit(neighbor) { + QueryVisitDecision::Accept(accepted) => { + scratch.best.insert(accepted); + } + QueryVisitDecision::Reject => { + // Rejected nodes: still add to two-hop expansion so we can traverse through them + candidates_two_hop_expansion.push(neighbor); + } + QueryVisitDecision::Terminate => { + scratch.cmps += one_hop_neighbors.len() as u32; + scratch.hops += scratch.beam_nodes.len() as u32; + return Ok(make_stats(scratch)); + } + } + } + + scratch.cmps += one_hop_neighbors.len() as u32; + scratch.hops += scratch.beam_nodes.len() as u32; + + // sort the candidates for two-hop expansion by distance to query point + candidates_two_hop_expansion.sort_unstable_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // limit the number of two-hop candidates to avoid too many expansions + candidates_two_hop_expansion.truncate(max_degree_with_slack / 2); + + // Expand each two-hop candidate: if its neighbor is a match, compute its distance + // to the query and insert into `scratch.visited` + // If it is not a match, do nothing + let two_hop_expansion_candidate_ids: Vec = + candidates_two_hop_expansion.iter().map(|n| n.id).collect(); + + accessor + .expand_beam( + two_hop_expansion_candidate_ids.iter().copied(), + computer, + NotInMutWithLabelCheck::new(&mut scratch.visited, query_label_evaluator), + |distance, id| { + two_hop_neighbors.push(Neighbor::new(id, distance)); + }, + ) + .await?; + + // Next, insert the new matches into `scratch.best` and increment stats counters + two_hop_neighbors + .iter() + .for_each(|neighbor| scratch.best.insert(*neighbor)); + + scratch.cmps += two_hop_neighbors.len() as u32; + scratch.hops += two_hop_expansion_candidate_ids.len() as u32; + } + + Ok(make_stats(scratch)) +} diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs new file mode 100644 index 000000000..8d878eaff --- /dev/null +++ b/diskann/src/graph/search/range_search.rs @@ -0,0 +1,403 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Range-based search within a distance radius. + +use diskann_utils::future::{AssertSend, SendFuture}; +use thiserror::Error; + +use super::{Search, scratch::SearchScratch}; +use crate::{ + ANNError, ANNErrorKind, ANNResult, + error::IntoANNResult, + graph::{ + glue::{self, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, InternalSearchStats, SearchStats}, + search::record::NoopSearchRecord, + search_output_buffer, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, DataProvider}, + utils::IntoUsize, +}; + +/// Result from a range search operation. +pub struct RangeSearchOutput { + /// Search statistics. + pub stats: SearchStats, + /// IDs of points within the radius. + pub ids: Vec, + /// Distances corresponding to each ID. + pub distances: Vec, +} + +/// Error type for [`RangeSearch`] parameter validation. +#[derive(Debug, Error)] +pub enum RangeSearchError { + #[error("beam width cannot be zero")] + BeamWidthZero, + #[error("l_value cannot be zero")] + LZero, + #[error("initial_search_slack must be between 0 and 1.0")] + StartingListSlackValueError, + #[error("range_search_slack must be greater than or equal to 1.0")] + RangeSearchSlackValueError, + #[error("inner_radius must be less than or equal to radius")] + InnerRadiusValueError, +} + +impl From for ANNError { + #[track_caller] + fn from(err: RangeSearchError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + +/// Parameters for range-based search. +/// +/// Finds all points within a specified distance radius from the query. +#[derive(Debug, Clone, Copy)] +pub struct RangeSearch { + /// Maximum results to return (None = unlimited). + max_returned: Option, + /// Initial search list size. + starting_l: usize, + /// Optional beam width. + beam_width: Option, + /// Outer radius - points within this distance are candidates. + radius: f32, + /// Inner radius - points closer than this are excluded. + inner_radius: Option, + /// Slack factor for initial search phase (0.0 to 1.0). + initial_slack: f32, + /// Slack factor for range expansion (>= 1.0). + range_slack: f32, +} + +impl RangeSearch { + /// Create range search with default slack values. + pub fn new(starting_l: usize, radius: f32) -> Result { + Self::with_options(None, starting_l, None, radius, None, 1.0, 1.0) + } + + /// Create range search with full options. + #[allow(clippy::too_many_arguments)] + pub fn with_options( + max_returned: Option, + starting_l: usize, + beam_width: Option, + radius: f32, + inner_radius: Option, + initial_slack: f32, + range_slack: f32, + ) -> Result { + if let Some(bw) = beam_width + && bw == 0 + { + return Err(RangeSearchError::BeamWidthZero); + } + if starting_l == 0 { + return Err(RangeSearchError::LZero); + } + if !(0.0..=1.0).contains(&initial_slack) { + return Err(RangeSearchError::StartingListSlackValueError); + } + if range_slack < 1.0 { + return Err(RangeSearchError::RangeSearchSlackValueError); + } + if let Some(inner) = inner_radius + && inner > radius + { + return Err(RangeSearchError::InnerRadiusValueError); + } + + Ok(Self { + max_returned, + starting_l, + beam_width, + radius, + inner_radius, + initial_slack, + range_slack, + }) + } + + /// Returns the maximum number of results to return. + #[inline] + pub fn max_returned(&self) -> Option { + self.max_returned + } + + /// Returns the initial search list size. + #[inline] + pub fn starting_l(&self) -> usize { + self.starting_l + } + + /// Returns the optional beam width. + #[inline] + pub fn beam_width(&self) -> Option { + self.beam_width + } + + /// Returns the outer radius. + #[inline] + pub fn radius(&self) -> f32 { + self.radius + } + + /// Returns the inner radius (points closer are excluded). + #[inline] + pub fn inner_radius(&self) -> Option { + self.inner_radius + } + + /// Returns the initial search slack factor. + #[inline] + pub fn initial_slack(&self) -> f32 { + self.initial_slack + } + + /// Returns the range search slack factor. + #[inline] + pub fn range_slack(&self) -> f32 { + self.range_slack + } +} + +impl Search for RangeSearch +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send + Default + Clone, +{ + type Output = RangeSearchOutput; + + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + _output: &mut (), + ) -> impl SendFuture> { + let search_params = *self; + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(search_params.starting_l(), start_ids.len()); + + let initial_stats = index + .search_internal( + search_params.beam_width(), + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let mut in_range = Vec::with_capacity(search_params.starting_l().into_usize()); + + for neighbor in scratch + .best + .iter() + .take(search_params.starting_l().into_usize()) + { + if neighbor.distance <= search_params.radius() { + in_range.push(neighbor); + } + } + + // clear the visited set and repopulate it with just the in-range points + scratch.visited.clear(); + for neighbor in in_range.iter() { + scratch.visited.insert(neighbor.id); + } + scratch.in_range = in_range; + + let stats = if scratch.in_range.len() + >= ((search_params.starting_l() as f32) * search_params.initial_slack()) as usize + { + // Move to range search + let range_stats = range_search_internal( + index.max_degree_with_slack(), + &search_params, + &mut accessor, + &computer, + &mut scratch, + ) + .await?; + + InternalSearchStats { + cmps: initial_stats.cmps, + hops: initial_stats.hops + range_stats.hops, + range_search_second_round: true, + } + } else { + initial_stats + }; + + // Post-process results + let mut result_ids: Vec = vec![O::default(); scratch.in_range.len()]; + let mut result_dists: Vec = vec![f32::MAX; scratch.in_range.len()]; + + let mut output_buffer = search_output_buffer::IdDistance::new( + result_ids.as_mut_slice(), + result_dists.as_mut_slice(), + ); + + let _ = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.in_range.iter().copied(), + &mut output_buffer, + ) + .send() + .await + .into_ann_result()?; + + // Filter by inner/outer radius + let inner_cutoff = if let Some(inner_radius) = search_params.inner_radius() { + result_dists + .iter() + .position(|dist| *dist > inner_radius) + .unwrap_or(result_dists.len()) + } else { + 0 + }; + + let outer_cutoff = result_dists + .iter() + .position(|dist| *dist > search_params.radius()) + .unwrap_or(result_dists.len()); + + result_ids.truncate(outer_cutoff); + result_ids.drain(0..inner_cutoff); + + result_dists.truncate(outer_cutoff); + result_dists.drain(0..inner_cutoff); + + let result_count = result_ids.len(); + + Ok(RangeSearchOutput { + stats: SearchStats { + cmps: stats.cmps, + hops: stats.hops, + result_count: result_count as u32, + range_search_second_round: stats.range_search_second_round, + }, + ids: result_ids, + distances: result_dists, + }) + } + } +} + +///////////////////////////// +// Internal Implementation // +///////////////////////////// + +/// Internal range search implementation. +/// +/// Expands the search frontier to find all points within the specified radius. +/// Called after the initial graph search has identified starting candidates. +pub(crate) async fn range_search_internal( + max_degree_with_slack: usize, + search_params: &RangeSearch, + accessor: &mut A, + computer: &A::QueryComputer, + scratch: &mut SearchScratch, +) -> ANNResult +where + I: crate::utils::VectorId, + A: ExpandBeam + SearchExt, + T: ?Sized, +{ + let beam_width = search_params.beam_width().unwrap_or(1); + + for neighbor in &scratch.in_range { + scratch.range_frontier.push_back(neighbor.id); + } + + let mut neighbors = Vec::with_capacity(max_degree_with_slack); + + let max_returned = search_params.max_returned().unwrap_or(usize::MAX); + + while !scratch.range_frontier.is_empty() { + scratch.beam_nodes.clear(); + + // In this loop we are going to find the beam_width number of remaining nodes within the radius + // Each of these nodes will be a frontier node. + while !scratch.range_frontier.is_empty() && scratch.beam_nodes.len() < beam_width { + let next = scratch.range_frontier.pop_front(); + if let Some(next_node) = next { + scratch.beam_nodes.push(next_node); + } + } + + neighbors.clear(); + accessor + .expand_beam( + scratch.beam_nodes.iter().copied(), + computer, + glue::NotInMut::new(&mut scratch.visited), + |distance, id| neighbors.push(Neighbor::new(id, distance)), + ) + .await?; + + // The predicate ensures that the contents of `neighbors` are unique. + for neighbor in neighbors.iter() { + if neighbor.distance <= search_params.radius() * search_params.range_slack() + && scratch.in_range.len() < max_returned + { + scratch.in_range.push(*neighbor); + scratch.range_frontier.push_back(neighbor.id); + } + } + scratch.cmps += neighbors.len() as u32; + scratch.hops += scratch.beam_nodes.len() as u32; + } + + Ok(InternalSearchStats { + cmps: scratch.cmps, + hops: scratch.hops, + range_search_second_round: true, + }) +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_range_search_validation() { + // Valid + assert!(RangeSearch::new(100, 0.5).is_ok()); + + // Invalid: zero l + assert!(RangeSearch::new(0, 0.5).is_err()); + + // Invalid slack values + assert!(RangeSearch::with_options(None, 100, None, 0.5, None, 1.5, 1.0).is_err()); + assert!(RangeSearch::with_options(None, 100, None, 0.5, None, 1.0, 0.5).is_err()); + + // Invalid inner radius > radius + assert!(RangeSearch::with_options(None, 100, None, 0.5, Some(1.0), 1.0, 1.0).is_err()); + } +} diff --git a/diskann/src/graph/test/cases/grid.rs b/diskann/src/graph/test/cases/grid.rs index 2ea40b677..4c4fe047f 100644 --- a/diskann/src/graph/test/cases/grid.rs +++ b/diskann/src/graph/test/cases/grid.rs @@ -9,7 +9,7 @@ use diskann_vector::distance::Metric; use crate::{ graph::{ - self, DiskANNIndex, + self, DiskANNIndex, KnnSearch, test::{provider as test_provider, synthetic::Grid}, }, neighbor::Neighbor, @@ -126,10 +126,10 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let params = graph::SearchParams::new(10, 10, Some(beam_width)).unwrap(); + let mut params = KnnSearch::new(10, 10, Some(beam_width)).unwrap(); let context = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k_value]; + let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; let graph::index::SearchStats { cmps, hops, @@ -137,17 +137,17 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { range_search_second_round, } = rt .block_on(index.search( + &mut params, &test_provider::Strategy::new(), &context, query.as_slice(), - ¶ms, &mut crate::neighbor::BackInserter::new(neighbors.as_mut_slice()), )) .unwrap(); assert_eq!( result_count.into_usize(), - params.k_value, + params.k_value().get(), "grid search should be configured to always return the requested number of neighbors", );