diff --git a/.gitignore b/.gitignore index 948e2e563..946937707 100644 --- a/.gitignore +++ b/.gitignore @@ -333,4 +333,4 @@ target/ *.info # ignore VS Code local history -.history/ \ No newline at end of file +.history/ diff --git a/Cargo.lock b/Cargo.lock index 426fe8795..d830c9aed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,7 +110,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8cce2075c711f351f0aa52c05e645cc41f1b3cc0cdba1ad12c5f67c121c1bb7d" dependencies = [ "cfg-if", - "io-uring", + "io-uring 0.6.4", "libc", "rand 0.8.5", "rand 0.9.2", @@ -531,7 +531,7 @@ dependencies = [ "half", "hashbrown 0.16.1", "iai-callgrind", - "io-uring", + "io-uring 0.7.11", "libc", "opentelemetry", "rand 0.9.2", @@ -586,7 +586,7 @@ dependencies = [ name = "diskann-platform" version = "0.45.0" dependencies = [ - "io-uring", + "io-uring 0.7.11", "libc", "tracing", "windows-sys 0.59.0", @@ -1294,6 +1294,17 @@ dependencies = [ "libc", ] +[[package]] +name = "io-uring" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd7bddefd0a8833b88a4b68f90dae22c7450d11b354198baee3874fd811b344" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "libc", +] + [[package]] name = "is-terminal" version = "0.4.17" diff --git a/diskann-benchmark/example/pipe-search.json b/diskann-benchmark/example/pipe-search.json new file mode 100644 index 000000000..6b0a97871 --- /dev/null +++ b/diskann-benchmark/example/pipe-search.json @@ -0,0 +1,79 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "test_data/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "search_list": [10, 20, 40, 80], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "search_mode": { + "mode": "BeamSearch" + } + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "test_data/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "search_list": [10, 20, 40, 80], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "search_mode": { + "mode": "PipeSearch" + } + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "test_data/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "search_list": [10, 20, 40, 80], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "search_mode": { + "mode": "PipeSearch" + } + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 65e5804a7..7f8f7b9e2 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -4,18 +4,24 @@ */ use rayon::prelude::*; -use std::{collections::HashSet, fmt, sync::atomic::AtomicBool, time::Instant}; +use std::{collections::HashSet, fmt, sync::atomic::AtomicBool, sync::Arc, time::Instant}; use opentelemetry::{global, trace::Span, trace::Tracer}; use opentelemetry_sdk::trace::SdkTracerProvider; -use diskann::utils::VectorRepr; +use diskann::{utils::VectorRepr, ANNResult}; use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; +#[cfg(target_os = "linux")] +use diskann_disk::search::provider::pipelined_accessor::PipelinedConfig; +#[cfg(target_os = "linux")] +use diskann_disk::storage::PipelinedReaderConfig; use diskann_disk::{ data_model::CachingStrategy, search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + disk_provider::{DiskIndexSearcher, SearchResult}, + disk_vertex_provider_factory::DiskVertexProviderFactory, }, + search::traits::VertexProviderFactory, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, }; @@ -32,7 +38,7 @@ use serde::Serialize; use crate::{ backend::disk_index::{graph_data_type::GraphData, json_spancollector::JsonSpanCollector}, - inputs::disk::{DiskIndexLoad, DiskSearchPhase}, + inputs::disk::{DiskIndexLoad, DiskSearchPhase, SearchMode}, utils::{datafiles, SimilarityMeasure}, }; @@ -44,6 +50,7 @@ pub(super) struct DiskSearchStats { pub(crate) is_flat_search: bool, pub(crate) distance: SimilarityMeasure, pub(crate) uses_vector_filters: bool, + pub(super) search_mode: String, pub(super) num_nodes_to_cache: Option, pub(super) search_results_per_l: Vec, span_metrics: serde_json::Value, @@ -154,6 +161,101 @@ impl DiskSearchResult { } } +/// Write a single query's search result into pre-allocated buffers. +fn write_query_result( + result: ANNResult>, + recall_at: usize, + stats: &mut QueryStatistics, + rc: &mut u32, + id_chunk: &mut [u32], + dist_chunk: &mut [f32], + has_any_search_failed: &AtomicBool, + error_label: &str, +) { + match result { + Ok(search_result) => { + *stats = search_result.stats.query_statistics; + *rc = search_result.results.len() as u32; + let actual_results = search_result.results.len().min(recall_at); + for (i, result_item) in search_result + .results + .iter() + .take(actual_results) + .enumerate() + { + id_chunk[i] = result_item.vertex_id; + dist_chunk[i] = result_item.distance; + } + } + Err(e) => { + eprintln!("{} failed for query: {:?}", error_label, e); + *rc = 0; + id_chunk.fill(0); + dist_chunk.fill(0.0); + has_any_search_failed.store(true, std::sync::atomic::Ordering::Release); + } + } +} + +/// Execute the per-L search iteration loop, handling buffer allocation, timing, +/// span management, error checking, and result aggregation. +fn run_search_loop( + search_list: &[u32], + recall_at: u32, + beam_width: usize, + num_queries: usize, + span_prefix: &str, + has_any_search_failed: &AtomicBool, + gt_context: &GroundTruthContext, + mut iteration_body: impl FnMut(u32, &mut [QueryStatistics], &mut [u32], &mut [u32], &mut [f32]), +) -> anyhow::Result> { + let mut results = Vec::with_capacity(search_list.len()); + + for &l in search_list.iter() { + let mut statistics_vec = vec![QueryStatistics::default(); num_queries]; + let mut result_counts = vec![0u32; num_queries]; + let mut result_ids = vec![0u32; (recall_at as usize) * num_queries]; + let mut result_dists = vec![0.0f32; (recall_at as usize) * num_queries]; + + let start = Instant::now(); + + let mut l_span = { + let tracer = global::tracer(""); + let span_name = format!("{}-with-L={}-bw={}", span_prefix, l, beam_width); + tracer.start(span_name) + }; + + iteration_body( + l, + &mut statistics_vec, + &mut result_counts, + &mut result_ids, + &mut result_dists, + ); + + let total_time = start.elapsed(); + + if has_any_search_failed.load(std::sync::atomic::Ordering::Acquire) { + anyhow::bail!("One or more searches failed. See logs for details."); + } + + let search_result = DiskSearchResult::new( + &statistics_vec, + &result_ids, + &result_counts, + l, + total_time.as_secs_f32(), + num_queries, + gt_context, + )?; + + l_span.end(); + results.push(search_result); + } + + Ok(results) +} + pub(super) fn search_disk_index( index_load: &DiskIndexLoad, search_params: &DiskSearchPhase, @@ -214,113 +316,189 @@ where CachingStrategy::None }; - let reader_factory = AlignedFileReaderFactory::new(disk_index_path); + let reader_factory = AlignedFileReaderFactory::new(disk_index_path.clone()); let vertex_provider_factory = DiskVertexProviderFactory::new(reader_factory, caching_strategy)?; - let searcher = &DiskIndexSearcher::, _>::new( - search_params.num_threads, - if let Some(lim) = search_params.search_io_limit { - lim - } else { - usize::MAX - }, - &index_reader, - vertex_provider_factory, - search_params.distance.into(), - None, - )?; - - logger.log_checkpoint("index_loaded"); - let pool = create_thread_pool(search_params.num_threads)?; - let mut search_results_per_l = Vec::with_capacity(search_params.search_list.len()); + let search_results_per_l; let has_any_search_failed = AtomicBool::new(false); - // Execute search iterations - for &l in search_params.search_list.iter() { - let mut statistics_vec: Vec = - vec![QueryStatistics::default(); num_queries]; - let mut result_counts: Vec = vec![0; num_queries]; - let mut result_ids: Vec = vec![0; (search_params.recall_at as usize) * num_queries]; - let mut result_dists: Vec = - vec![0.0; (search_params.recall_at as usize) * num_queries]; - - let start = Instant::now(); - - let mut l_span = { - let tracer = global::tracer(""); - let span_name = format!("search-with-L={}-bw={}", l, search_params.beam_width); - tracer.start(span_name) - }; + match &search_params.search_mode { + SearchMode::BeamSearch => { + let searcher = &DiskIndexSearcher::, _>::new( + search_params.num_threads, + search_params.search_io_limit.unwrap_or(usize::MAX), + &index_reader, + vertex_provider_factory, + search_params.distance.into(), + None, + )?; - let zipped = queries - .par_row_iter() - .zip(vector_filters.par_iter()) - .zip(result_ids.par_chunks_mut(search_params.recall_at as usize)) - .zip(result_dists.par_chunks_mut(search_params.recall_at as usize)) - .zip(statistics_vec.par_iter_mut()) - .zip(result_counts.par_iter_mut()); - - zipped.for_each_in_pool(&pool, |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { - let vector_filter = if search_params.vector_filters_file.is_none() { - None - } else { - Some(Box::new(move |vid: &u32| vf.contains(vid)) - as Box bool + Send + Sync>) - }; + logger.log_checkpoint("index_loaded"); - match searcher.search( - q, + search_results_per_l = run_search_loop( + &search_params.search_list, search_params.recall_at, - l, - Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, - ) { - Ok(search_result) => { - *stats = search_result.stats.query_statistics; - *rc = search_result.results.len() as u32; - let actual_results = search_result - .results - .len() - .min(search_params.recall_at as usize); - for (i, result_item) in search_result - .results - .iter() - .take(actual_results) - .enumerate() - { - id_chunk[i] = result_item.vertex_id; - dist_chunk[i] = result_item.distance; - } - } - Err(e) => { - eprintln!("Search failed for query: {:?}", e); - *rc = 0; - id_chunk.fill(0); - dist_chunk.fill(0.0); - has_any_search_failed.store(true, std::sync::atomic::Ordering::Release); - } + search_params.beam_width, + num_queries, + "search", + &has_any_search_failed, + >_context, + |l, statistics_vec, result_counts, result_ids, result_dists| { + let zipped = queries + .par_row_iter() + .zip(vector_filters.par_iter()) + .zip(result_ids.par_chunks_mut(search_params.recall_at as usize)) + .zip(result_dists.par_chunks_mut(search_params.recall_at as usize)) + .zip(statistics_vec.par_iter_mut()) + .zip(result_counts.par_iter_mut()); + + zipped.for_each_in_pool( + &pool, + |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { + let vector_filter = if search_params.vector_filters_file.is_none() { + None + } else { + Some(Box::new(move |vid: &u32| vf.contains(vid)) + as Box bool + Send + Sync>) + }; + + write_query_result( + searcher.search( + q, + search_params.recall_at, + l, + Some(search_params.beam_width), + vector_filter, + search_params.is_flat_search, + ), + search_params.recall_at as usize, + stats, + rc, + id_chunk, + dist_chunk, + &has_any_search_failed, + "Search", + ); + }, + ); + }, + )?; + } + // Pipelined search — for read-only search on completed (static) indices only. + // Uses io_uring for IO/compute overlap through the generic search loop. + SearchMode::PipeSearch { sqpoll_idle_ms } => { + #[cfg(target_os = "linux")] + { + use diskann::utils::object_pool::ObjectPool; + use diskann_disk::data_model::Cache; + use diskann_disk::search::provider::pipelined_accessor::{ + PipelinedScratch, PipelinedScratchArgs, + }; + + let reader_config = PipelinedReaderConfig { + sqpoll_idle_ms: *sqpoll_idle_ms, + }; + + // Extract the node cache before moving vertex_provider_factory into the searcher + let node_cache: Arc>> = vertex_provider_factory + .cache + .clone() + .unwrap_or_else(|| Arc::new(Cache::new(0, 0).expect("empty cache"))); + + // Derive pool args from the graph header before moving factory into searcher + let graph_header = vertex_provider_factory.get_header()?; + let pq_data = index_reader.get_pq_data(); + let metadata = graph_header.metadata(); + let block_size = graph_header.effective_block_size(); + let num_sectors_per_node = graph_header.num_sectors_per_node(); + let slot_size = num_sectors_per_node * block_size; + let bw = search_params.beam_width; + let max_slots = (bw * 2).clamp(16, 128); + + let scratch_args = PipelinedScratchArgs { + disk_index_path: disk_index_path.clone(), + max_slots, + slot_size, + alignment: block_size, + graph_degree: graph_header.max_degree::()?, + dims: metadata.dims, + num_pq_chunks: pq_data.get_num_chunks(), + num_pq_centers: pq_data.get_num_centers(), + reader_config, + }; + let scratch_pool = Arc::new(ObjectPool::::try_new( + scratch_args.clone(), + 0, + None, + )?); + + let mut searcher = DiskIndexSearcher::, _>::new( + search_params.num_threads, + search_params.search_io_limit.unwrap_or(usize::MAX), + &index_reader, + vertex_provider_factory, + search_params.distance.into(), + None, + )?; + + searcher.with_pipelined_config(PipelinedConfig { + beam_width: search_params.beam_width, + node_cache, + scratch_pool, + scratch_args, + }); + + let searcher = &searcher; + + logger.log_checkpoint("index_loaded"); + + search_results_per_l = run_search_loop( + &search_params.search_list, + search_params.recall_at, + search_params.beam_width, + num_queries, + "pipesearch", + &has_any_search_failed, + >_context, + |l, statistics_vec, result_counts, result_ids, result_dists| { + let zipped = queries + .par_row_iter() + .zip(result_ids.par_chunks_mut(search_params.recall_at as usize)) + .zip(result_dists.par_chunks_mut(search_params.recall_at as usize)) + .zip(statistics_vec.par_iter_mut()) + .zip(result_counts.par_iter_mut()); + + zipped.for_each_in_pool( + &pool, + |((((q, id_chunk), dist_chunk), stats), rc)| { + write_query_result( + searcher.search_pipelined( + q, + search_params.recall_at, + l, + search_params.beam_width, + None, + ), + search_params.recall_at as usize, + stats, + rc, + id_chunk, + dist_chunk, + &has_any_search_failed, + "PipeSearch", + ); + }, + ); + }, + )?; + } + #[cfg(not(target_os = "linux"))] + { + let _ = sqpoll_idle_ms; + anyhow::bail!("PipeSearch is only supported on Linux"); } - }); - let total_time = start.elapsed(); - - if has_any_search_failed.load(std::sync::atomic::Ordering::Acquire) { - anyhow::bail!("One or more searches failed. See logs for details."); } - - let search_result = DiskSearchResult::new( - &statistics_vec, - &result_ids, - &result_counts, - l, - total_time.as_secs_f32(), - num_queries, - >_context, - )?; - - l_span.end(); - search_results_per_l.push(search_result); } // Log search completed checkpoint @@ -343,6 +521,7 @@ where is_flat_search: search_params.is_flat_search, distance: search_params.distance, uses_vector_filters: search_params.vector_filters_file.is_some(), + search_mode: format!("{}", search_params.search_mode), num_nodes_to_cache: search_params.num_nodes_to_cache, search_results_per_l, span_metrics, @@ -427,6 +606,7 @@ impl fmt::Display for DiskSearchStats { writeln!(f, "Flat search, : {}", self.is_flat_search)?; writeln!(f, "Distance, : {}", self.distance)?; writeln!(f, "Vector filters, : {}", self.uses_vector_filters)?; + writeln!(f, "Search mode, : {}", self.search_mode)?; writeln!( f, "Nodes to cache, : {}", diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index bf843d72f..cf995cd09 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -71,6 +71,39 @@ pub(crate) struct DiskIndexBuild { pub(crate) save_path: String, } +/// Search algorithm to use for disk index search. +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "mode")] +#[derive(Default)] +pub(crate) enum SearchMode { + /// Standard beam search (default, current behavior). + #[default] + BeamSearch, + /// Pipelined search through the generic search loop (queue-based ExpandBeam). + /// Overlaps IO and compute using io_uring on Linux. + #[serde(alias = "UnifiedPipeSearch")] + PipeSearch { + /// Enable kernel-side SQ polling (ms idle timeout). None = disabled. + #[serde(default)] + sqpoll_idle_ms: Option, + }, +} + +impl fmt::Display for SearchMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SearchMode::BeamSearch => write!(f, "BeamSearch"), + SearchMode::PipeSearch { sqpoll_idle_ms } => { + write!(f, "PipeSearch")?; + if let Some(sq) = sqpoll_idle_ms { + write!(f, "(sqpoll={}ms)", sq)?; + } + Ok(()) + } + } + } +} + /// Search phase configuration #[derive(Debug, Deserialize, Serialize)] pub(crate) struct DiskSearchPhase { @@ -85,6 +118,9 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, + /// Search algorithm to use (defaults to BeamSearch). + #[serde(default)] + pub(crate) search_mode: SearchMode, } ///////// @@ -234,6 +270,10 @@ impl CheckDeserialization for DiskSearchPhase { anyhow::bail!("search_io_limit must be positive if specified"); } } + match &self.search_mode { + SearchMode::BeamSearch => {} + SearchMode::PipeSearch { .. } => {} + } Ok(()) } } @@ -272,6 +312,7 @@ impl Example for DiskIndexOperation { vector_filters_file: None, num_nodes_to_cache: None, search_io_limit: None, + search_mode: SearchMode::default(), }; Self { @@ -397,6 +438,7 @@ impl DiskSearchPhase { Some(lim) => write_field!(f, "Search IO Limit", format!("{lim}"))?, None => write_field!(f, "Search IO Limit", "none (defaults to `usize::MAX`)")?, } + write_field!(f, "Search Mode", format!("{:?}", self.search_mode))?; Ok(()) } } diff --git a/diskann-disk/Cargo.toml b/diskann-disk/Cargo.toml index c68d65769..204cf36e0 100644 --- a/diskann-disk/Cargo.toml +++ b/diskann-disk/Cargo.toml @@ -47,7 +47,7 @@ vfs = { workspace = true } opentelemetry = { workspace = true, optional = true } [target.'cfg(target_os = "linux")'.dependencies] -io-uring = "0.6.4" +io-uring = "0.7" libc = "0.2.148" [dev-dependencies] diff --git a/diskann-disk/src/data_model/graph_header.rs b/diskann-disk/src/data_model/graph_header.rs index f04803e4a..67999ad36 100644 --- a/diskann-disk/src/data_model/graph_header.rs +++ b/diskann-disk/src/data_model/graph_header.rs @@ -12,6 +12,7 @@ use thiserror::Error; use super::{GraphLayoutVersion, GraphMetadata}; /// GraphHeader. The header is stored in the first sector of the disk index file, or the first segment of the JET stream. +#[derive(Clone)] pub struct GraphHeader { // Graph metadata. metadata: GraphMetadata, @@ -85,6 +86,28 @@ impl GraphHeader { &self.layout_version } + /// Returns the effective block size, falling back to the default (4096) for + /// legacy (v0.0) layouts or when the stored value is zero. + pub fn effective_block_size(&self) -> usize { + let bs = self.block_size as usize; + if (self.layout_version.major_version() == 0 && self.layout_version.minor_version() == 0) + || bs == 0 + { + 4096 + } else { + bs + } + } + + /// Returns the number of disk sectors required to store a single graph node. + pub fn num_sectors_per_node(&self) -> usize { + if self.metadata.num_nodes_per_block > 0 { + 1 + } else { + (self.metadata.node_len as usize).div_ceil(self.effective_block_size()) + } + } + /// Returns the maximum degree of the graph /// /// # Type Parameters diff --git a/diskann-disk/src/search/mod.rs b/diskann-disk/src/search/mod.rs index 915956ad4..1f0d8f148 100644 --- a/diskann-disk/src/search/mod.rs +++ b/diskann-disk/src/search/mod.rs @@ -7,3 +7,5 @@ pub mod provider; pub mod traits; + +pub(crate) mod sector_math; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index ab0a4f4e7..a49abe22d 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -71,19 +71,19 @@ where Data: GraphDataType, { /// Holds the graph header information that contains metadata about disk-index file. - graph_header: GraphHeader, + pub(crate) graph_header: GraphHeader, // Full precision distance comparer used in post_process to reorder results. - distance_comparer: ::Distance, + pub(crate) distance_comparer: ::Distance, /// The PQ data used for quantization. - pq_data: Arc, + pub(crate) pq_data: Arc, /// The number of points in the graph. - num_points: usize, + pub(crate) num_points: usize, /// Metric used for distance computation. - metric: Metric, + pub(crate) metric: Metric, /// The number of IO operations that can be done in parallel. search_io_limit: usize, @@ -373,8 +373,8 @@ where /// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates. pub struct DiskQueryComputer { - num_pq_chunks: usize, - query_centroid_l2_distance: Vec, + pub(crate) num_pq_chunks: usize, + pub(crate) query_centroid_l2_distance: Vec, } impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer { @@ -783,14 +783,18 @@ pub struct DiskIndexSearcher< Data: GraphDataType, ProviderFactory: VertexProviderFactory, { - index: DiskANNIndex>, - runtime: Runtime, + pub(crate) index: DiskANNIndex>, + pub(crate) runtime: Runtime, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: ProviderFactory, /// Scratch pool for disk search operations that need allocations. scratch_pool: Arc>>, + + /// Optional pipelined search configuration (Linux only, io_uring-based). + #[cfg(target_os = "linux")] + pub(crate) pipelined_config: Option>, } #[derive(Debug)] @@ -891,6 +895,8 @@ where runtime, vertex_provider_factory, scratch_pool, + #[cfg(target_os = "linux")] + pipelined_config: None, }) } @@ -959,6 +965,84 @@ where Ok(search_result) } + /// Perform a search with explicit [`SearchParams`] for full control over + /// adaptive beam width, relaxed monotonicity, etc. + pub fn search_with_params( + &self, + query: &[Data::VectorDataType], + search_params: &SearchParams, + vector_filter: Option>, + is_flat_search: bool, + ) -> ANNResult> { + let k_value = search_params.k_value; + let mut query_stats = QueryStatistics::default(); + let mut indices = vec![0u32; k_value]; + let mut distances = vec![0f32; k_value]; + let mut associated_data = vec![Data::AssociatedDataType::default(); k_value]; + + let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( + &mut indices[..k_value], + &mut distances[..k_value], + &mut associated_data[..k_value], + ); + + let filter = vector_filter.unwrap_or(default_vector_filter::()); + let strategy = self.search_strategy(query, &filter); + let timer = Instant::now(); + let stats = if is_flat_search { + self.runtime.block_on(self.index.flat_search( + &strategy, + &DefaultContext, + strategy.query, + &filter, + search_params, + &mut result_output_buffer, + ))? + } else { + self.runtime.block_on(self.index.search( + &strategy, + &DefaultContext, + strategy.query, + search_params, + &mut result_output_buffer, + ))? + }; + query_stats.total_comparisons = stats.cmps; + query_stats.search_hops = stats.hops; + query_stats.total_execution_time_us = timer.elapsed().as_micros(); + query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128; + query_stats.total_io_operations = strategy.io_tracker.io_count() as u32; + query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32; + query_stats.query_pq_preprocess_time_us = + IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128; + query_stats.cpu_time_us = query_stats.total_execution_time_us + - query_stats.io_time_us + - query_stats.query_pq_preprocess_time_us; + + let mut search_result = SearchResult { + results: Vec::with_capacity(k_value), + stats: SearchResultStats { + cmps: query_stats.total_comparisons, + result_count: stats.result_count, + query_statistics: query_stats, + }, + }; + + for ((vertex_id, distance), associated_data) in indices + .into_iter() + .zip(distances.into_iter()) + .zip(associated_data.into_iter()) + { + search_result.results.push(SearchResultItem { + vertex_id, + distance, + data: associated_data, + }); + } + + Ok(search_result) + } + /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. #[allow(clippy::too_many_arguments)] diff --git a/diskann-disk/src/search/provider/disk_sector_graph.rs b/diskann-disk/src/search/provider/disk_sector_graph.rs index 1f00ad6db..700cf84df 100644 --- a/diskann-disk/src/search/provider/disk_sector_graph.rs +++ b/diskann-disk/src/search/provider/disk_sector_graph.rs @@ -15,9 +15,7 @@ use crate::{ utils::aligned_file_reader::{traits::AlignedFileReader, AlignedRead}, }; -const DEFAULT_DISK_SECTOR_LEN: usize = 4096; - -/// Sector graph read from disk index +/// Sector graph read from diskindex pub struct DiskSectorGraph { /// Ensure `sector_reader` is dropped before `sectors_data` by placing it before `sectors_data`. /// Graph storage to read sectors @@ -57,19 +55,11 @@ impl DiskSectorGraph { header: &GraphHeader, max_n_batch_sector_read: usize, ) -> ANNResult { - let mut block_size = header.block_size() as usize; - let version = header.layout_version(); - if (version.major_version() == 0 && version.minor_version() == 0) || block_size == 0 { - block_size = DEFAULT_DISK_SECTOR_LEN; - } + let block_size = header.effective_block_size(); let num_nodes_per_sector = header.metadata().num_nodes_per_block; let node_len = header.metadata().node_len; - let num_sectors_per_node = if num_nodes_per_sector > 0 { - 1 - } else { - (node_len as usize).div_ceil(block_size) - }; + let num_sectors_per_node = header.num_sectors_per_node(); Ok(Self { sector_reader, @@ -152,23 +142,21 @@ impl DiskSectorGraph { /// Get offset of node in sectors_data #[inline] fn get_node_offset(&self, vertex_id: u32) -> usize { - if self.num_nodes_per_sector == 0 { - // multi-sector node - 0 - } else { - // multi node in a sector - (vertex_id as u64 % self.num_nodes_per_sector * self.node_len) as usize - } + crate::search::sector_math::node_offset_in_sector( + vertex_id, + self.num_nodes_per_sector, + self.node_len, + ) } #[inline] /// Gets the index for the sector that contains the node with the given vertex_id pub fn node_sector_index(&self, vertex_id: u32) -> u64 { - 1 + if self.num_nodes_per_sector > 0 { - vertex_id as u64 / self.num_nodes_per_sector - } else { - vertex_id as u64 * self.num_sectors_per_node as u64 - } + crate::search::sector_math::node_sector_index( + vertex_id, + self.num_nodes_per_sector, + self.num_sectors_per_node, + ) } } @@ -190,6 +178,8 @@ mod disk_sector_graph_test { use super::*; use crate::data_model::{GraphLayoutVersion, GraphMetadata}; + const DEFAULT_DISK_SECTOR_LEN: usize = 4096; + fn test_index_path() -> String { test_data_root() .join("disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_aligned_reader_test.index") diff --git a/diskann-disk/src/search/provider/mod.rs b/diskann-disk/src/search/provider/mod.rs index 69f697d84..2a168522a 100644 --- a/diskann-disk/src/search/provider/mod.rs +++ b/diskann-disk/src/search/provider/mod.rs @@ -13,3 +13,6 @@ pub mod disk_provider; pub mod disk_sector_graph; pub mod disk_vertex_provider; pub mod disk_vertex_provider_factory; + +#[cfg(target_os = "linux")] +pub mod pipelined_accessor; diff --git a/diskann-disk/src/search/provider/pipelined_accessor.rs b/diskann-disk/src/search/provider/pipelined_accessor.rs new file mode 100644 index 000000000..d875f5b3f --- /dev/null +++ b/diskann-disk/src/search/provider/pipelined_accessor.rs @@ -0,0 +1,1011 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Queue-based pipelined disk accessor that integrates with the generic search loop +//! via the `ExpandBeam` trait's `submit_expand` / `expand_available` / `has_pending` methods. +//! +//! Plugs into `DiskANNIndex::search_internal()` and overlaps IO with computation +//! using io_uring under the hood. + +use std::collections::{HashMap, VecDeque}; +use std::future::Future; +use std::ops::Range; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Arc; +use std::time::Instant; + +use byteorder::{ByteOrder, LittleEndian}; +use diskann::{ + graph::{ + glue::{ + ExpandBeam, HybridPredicate, IdIterator, SearchExt, SearchPostProcess, SearchStrategy, + }, + search_output_buffer, AdjacencyList, SearchOutputBuffer, SearchParams, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildQueryComputer, DefaultContext, DelegateNeighbor, HasId, NeighborAccessor, + }, + utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}, + ANNError, ANNResult, +}; +use diskann_providers::model::{ + compute_pq_distance, graph::traits::GraphDataType, pq::quantizer_preprocess, PQScratch, +}; +use diskann_vector::DistanceFunction; + +use crate::data_model::Cache; +use crate::storage::{PipelinedReader, PipelinedReaderConfig}; + +use crate::search::sector_math::{node_offset_in_sector, node_sector_index}; +use crate::search::traits::VertexProviderFactory; +use crate::utils::QueryStatistics; + +use super::disk_provider::{ + DiskIndexSearcher, DiskProvider, DiskQueryComputer, SearchResult, SearchResultItem, + SearchResultStats, +}; + +/// A loaded node parsed from sector data. +struct LoadedNode { + fp_vector: Vec, + adjacency_list: Vec, + /// Submission rank (lower = higher priority / submitted earlier). + /// Nodes submitted first via closest_notvisited() have better PQ distance, + /// so expanding them first (like PipeSearch) improves search quality. + rank: u64, +} + +impl LoadedNode { + /// Reset and fill from sector buffer, reusing existing Vec capacity. + fn parse_from( + &mut self, + sector_buf: &[u8], + vertex_id: u32, + num_nodes_per_sector: u64, + node_len: u64, + fp_vector_len: u64, + rank: u64, + ) -> ANNResult<()> { + let offset = node_offset_in_sector(vertex_id, num_nodes_per_sector, node_len); + let end = offset + node_len as usize; + let node_data = sector_buf.get(offset..end).ok_or_else(|| { + ANNError::log_index_error(format_args!( + "Node data out of bounds: vertex {} offset {}..{} in buffer of len {}", + vertex_id, + offset, + end, + sector_buf.len() + )) + })?; + + let fp_len = fp_vector_len as usize; + if fp_len > node_data.len() { + return Err(ANNError::log_index_error(format_args!( + "fp_vector_len {} exceeds node_data len {}", + fp_len, + node_data.len() + ))); + } + + self.fp_vector.clear(); + self.fp_vector.extend_from_slice(&node_data[..fp_len]); + + let neighbor_data = &node_data[fp_len..]; + let num_neighbors = LittleEndian::read_u32(&neighbor_data[..4]) as usize; + let max_neighbors = (neighbor_data.len().saturating_sub(4)) / 4; + let num_neighbors = num_neighbors.min(max_neighbors); + + self.adjacency_list.clear(); + for i in 0..num_neighbors { + let start = 4 + i * 4; + self.adjacency_list + .push(LittleEndian::read_u32(&neighbor_data[start..start + 4])); + } + + self.rank = rank; + Ok(()) + } +} + +/// Tracks an in-flight IO request. +struct InFlightIo { + vertex_id: u32, + slot_id: usize, + rank: u64, +} + +// --------------------------------------------------------------------------- +// Poolable scratch: PipelinedReader + PQScratch, reused across queries +// --------------------------------------------------------------------------- + +/// Reusable scratch state for pipelined search, pooled to avoid per-query +/// allocation of io_uring rings, file descriptors, and PQ scratch buffers. +pub struct PipelinedScratch { + pub reader: PipelinedReader, + pub pq_scratch: PQScratch, + // Per-query scratch collections, cleared between queries but retain capacity + in_flight_ios: VecDeque, + loaded_nodes: HashMap, + distance_cache: HashMap, + /// Reusable buffer for neighbor IDs during expand_available + neighbor_buf: Vec, + /// Freelist of LoadedNode instances to avoid per-node allocation + node_pool: Vec, + /// Reusable buffer for completed slot IDs from poll/wait. + completed_buf: Vec, +} + +/// Arguments for creating or resetting a [`PipelinedScratch`]. +#[derive(Clone)] +pub struct PipelinedScratchArgs { + pub disk_index_path: String, + pub max_slots: usize, + pub slot_size: usize, + pub alignment: usize, + pub graph_degree: usize, + pub dims: usize, + pub num_pq_chunks: usize, + pub num_pq_centers: usize, + pub reader_config: PipelinedReaderConfig, +} + +impl TryAsPooled for PipelinedScratch { + type Error = ANNError; + + fn try_create(args: PipelinedScratchArgs) -> Result { + let reader = PipelinedReader::new( + &args.disk_index_path, + args.max_slots, + args.slot_size, + args.alignment, + &args.reader_config, + )?; + let pq_scratch = PQScratch::new( + args.graph_degree, + args.dims, + args.num_pq_chunks, + args.num_pq_centers, + )?; + Ok(Self { + reader, + pq_scratch, + in_flight_ios: VecDeque::new(), + loaded_nodes: HashMap::new(), + distance_cache: HashMap::new(), + neighbor_buf: Vec::new(), + node_pool: Vec::new(), + completed_buf: Vec::new(), + }) + } + + fn try_modify(&mut self, _args: PipelinedScratchArgs) -> Result<(), Self::Error> { + self.reader.reset(); + // Return all loaded_nodes back to the pool before clearing + self.node_pool + .extend(self.loaded_nodes.drain().map(|(_, node)| node)); + self.in_flight_ios.clear(); + self.distance_cache.clear(); + self.neighbor_buf.clear(); + self.completed_buf.clear(); + Ok(()) + } +} + +impl PipelinedScratch { + /// Get a LoadedNode from the pool, or create a new empty one. + fn acquire_node(&mut self) -> LoadedNode { + self.node_pool.pop().unwrap_or_else(|| LoadedNode { + fp_vector: Vec::new(), + adjacency_list: Vec::new(), + rank: 0, + }) + } + + /// Return a LoadedNode to the pool for reuse. + fn release_node(&mut self, node: LoadedNode) { + self.node_pool.push(node); + } +} + +// --------------------------------------------------------------------------- +// PipelinedDiskAccessor +// --------------------------------------------------------------------------- + +/// Pipelined disk accessor that overlaps IO and compute via io_uring. +/// +/// Implements the `ExpandBeam` trait's queue-based methods: +/// - `submit_expand`: submits non-blocking io_uring reads for the given node IDs +/// - `expand_available`: polls for completed reads and expands those nodes +/// - `has_pending`: returns true when IO operations are in-flight +pub struct PipelinedDiskAccessor<'a, Data: GraphDataType> { + provider: &'a DiskProvider, + scratch: PoolOption, + query: &'a [Data::VectorDataType], + + // Graph geometry (cached from GraphHeader) + num_nodes_per_sector: u64, + num_sectors_per_node: usize, + block_size: usize, + node_len: u64, + fp_vector_len: u64, + num_points: usize, + + // Node cache (shared, read-only) for avoiding disk IO on hot nodes + node_cache: Arc>, + + // IO state (now lives in scratch for reuse, accessed via self.scratch) + /// Monotonically increasing submission rank for priority-ordered expansion. + next_rank: u64, + + // IO statistics + io_count: u32, + cache_hits: u32, + /// Accumulated IO time (submission + polling + waiting) + io_time: std::time::Duration, + /// Accumulated CPU time (fp distance + PQ distance + node parsing) + cpu_time: std::time::Duration, + /// PQ preprocess time (distance table construction) + preprocess_time: std::time::Duration, + // Shared stats written on drop so caller can read them after search + shared_io_stats: Arc, +} + +impl<'a, Data> PipelinedDiskAccessor<'a, Data> +where + Data: GraphDataType, +{ + /// Create a new pipelined disk accessor using a pooled scratch. + pub fn new( + provider: &'a DiskProvider, + query: &'a [Data::VectorDataType], + scratch: PoolOption, + node_cache: Arc>, + shared_io_stats: Arc, + ) -> ANNResult { + let metadata = provider.graph_header.metadata(); + let dims = metadata.dims; + let num_nodes_per_sector = metadata.num_nodes_per_block; + let node_len = metadata.node_len; + let fp_vector_len = (dims * std::mem::size_of::()) as u64; + + let block_size = provider.graph_header.effective_block_size(); + let num_sectors_per_node = provider.graph_header.num_sectors_per_node(); + + Ok(Self { + provider, + scratch, + query, + num_nodes_per_sector, + num_sectors_per_node, + block_size, + node_len, + fp_vector_len, + num_points: provider.num_points, + node_cache, + next_rank: 0, + io_count: 0, + cache_hits: 0, + io_time: std::time::Duration::ZERO, + cpu_time: std::time::Duration::ZERO, + preprocess_time: std::time::Duration::ZERO, + shared_io_stats, + }) + } + + /// Preprocess PQ distance tables for this query. Must be called before search. + pub fn preprocess_query(&mut self) -> ANNResult<()> { + let timer = std::time::Instant::now(); + let metadata = self.provider.graph_header.metadata(); + let dims = metadata.dims; + let medoid = metadata.medoid as u32; + self.scratch.pq_scratch.set(dims, self.query, 1.0)?; + quantizer_preprocess( + &mut self.scratch.pq_scratch, + &self.provider.pq_data, + self.provider.metric, + &[medoid], + )?; + self.preprocess_time = timer.elapsed(); + Ok(()) + } + + fn pq_distances_inner( + pq: &mut PQScratch, + provider: &DiskProvider, + ids: &[u32], + f: &mut F, + ) -> ANNResult<()> + where + F: FnMut(f32, u32), + { + compute_pq_distance( + ids, + provider.pq_data.get_num_chunks(), + &pq.aligned_pqtable_dist_scratch, + provider.pq_data.pq_compressed_data().get_data(), + &mut pq.aligned_pq_coord_scratch, + &mut pq.aligned_dist_scratch, + )?; + for (i, id) in ids.iter().enumerate() { + f(pq.aligned_dist_scratch[i], *id); + } + Ok(()) + } + + /// Returns the number of disk IO operations performed. + pub fn io_count(&self) -> u32 { + self.io_count + } + + /// Returns the number of cache hits (nodes served from cache without IO). + pub fn cache_hits(&self) -> u32 { + self.cache_hits + } + + /// Poll completed IOs and move data from reader buffers into loaded_nodes. + fn drain_completions(&mut self) -> ANNResult<()> { + if self.scratch.in_flight_ios.is_empty() { + return Ok(()); + } + + let io_start = Instant::now(); + // Split borrows: reader and completed_buf are separate fields. + let PipelinedScratch { + reader, + completed_buf, + .. + } = &mut *self.scratch; + reader.poll_completions(completed_buf)?; + self.io_time += io_start.elapsed(); + + if completed_buf.is_empty() { + return Ok(()); + } + + Self::process_completed_ios_inner( + &mut self.scratch, + self.num_nodes_per_sector, + self.node_len, + self.fp_vector_len, + ) + } + /// Block until at least one IO completes, then eagerly drain all available. + fn wait_and_drain(&mut self) -> ANNResult<()> { + let io_start = Instant::now(); + let PipelinedScratch { + reader, + completed_buf, + .. + } = &mut *self.scratch; + reader.wait_completions(completed_buf)?; + self.io_time += io_start.elapsed(); + + if completed_buf.is_empty() { + return Ok(()); + } + + Self::process_completed_ios_inner( + &mut self.scratch, + self.num_nodes_per_sector, + self.node_len, + self.fp_vector_len, + ) + } + + /// Shared logic: process completed slot IDs, parse nodes, retain in-flight. + /// Uses linear scan on completed_buf (small, bounded by max_slots) to + /// avoid per-poll HashSet allocation. Reuses LoadedNode instances from the + /// node pool to avoid per-IO Vec allocations. + fn process_completed_ios_inner( + scratch: &mut PipelinedScratch, + num_nodes_per_sector: u64, + node_len: u64, + fp_vector_len: u64, + ) -> ANNResult<()> { + let mut i = 0; + while i < scratch.in_flight_ios.len() { + let io = &scratch.in_flight_ios[i]; + if scratch.completed_buf.contains(&io.slot_id) { + let io = scratch.in_flight_ios.swap_remove_back(i).unwrap(); + // Acquire node first (mutably borrows node_pool), + // then get sector buf (immutably borrows reader) — no conflict. + let mut node = scratch.node_pool.pop().unwrap_or_else(|| LoadedNode { + fp_vector: Vec::new(), + adjacency_list: Vec::new(), + rank: 0, + }); + let sector_buf = scratch.reader.get_slot_buf(io.slot_id); + node.parse_from( + sector_buf, + io.vertex_id, + num_nodes_per_sector, + node_len, + fp_vector_len, + io.rank, + )?; + // Release the slot back to the reader's free-list now that + // we've copied the data out. + scratch.reader.release_slot(io.slot_id); + scratch.loaded_nodes.insert(io.vertex_id, node); + } else { + i += 1; + } + } + Ok(()) + } +} + +impl HasId for PipelinedDiskAccessor<'_, Data> +where + Data: GraphDataType, +{ + type Id = u32; +} + +impl<'a, Data> Accessor for PipelinedDiskAccessor<'a, Data> +where + Data: GraphDataType, +{ + type Extended = &'a [u8]; + type Element<'b> + = &'a [u8] + where + Self: 'b; + type ElementRef<'b> = &'b [u8]; + type GetError = ANNError; + + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize)) + } +} + +impl IdIterator> for PipelinedDiskAccessor<'_, Data> +where + Data: GraphDataType, +{ + async fn id_iterator(&mut self) -> Result, ANNError> { + Ok(0..self.num_points as u32) + } +} + +/// Delegate for neighbor access (required by AsNeighbor). +pub struct PipelinedNeighborDelegate<'a, 'b, Data: GraphDataType>( + #[allow(dead_code)] &'a mut PipelinedDiskAccessor<'b, Data>, +); + +impl HasId for PipelinedNeighborDelegate<'_, '_, Data> +where + Data: GraphDataType, +{ + type Id = u32; +} + +impl NeighborAccessor for PipelinedNeighborDelegate<'_, '_, Data> +where + Data: GraphDataType, +{ + async fn get_neighbors( + self, + _id: Self::Id, + _neighbors: &mut AdjacencyList, + ) -> ANNResult { + // Neighbor expansion is handled by expand_available, not get_neighbors + Ok(self) + } +} + +impl<'a, 'b, Data> DelegateNeighbor<'a> for PipelinedDiskAccessor<'b, Data> +where + Data: GraphDataType, +{ + type Delegate = PipelinedNeighborDelegate<'a, 'b, Data>; + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + PipelinedNeighborDelegate(self) + } +} + +impl BuildQueryComputer<[Data::VectorDataType]> for PipelinedDiskAccessor<'_, Data> +where + Data: GraphDataType, +{ + type QueryComputerError = ANNError; + type QueryComputer = DiskQueryComputer; + + fn build_query_computer( + &self, + _from: &[Data::VectorDataType], + ) -> Result { + Ok(DiskQueryComputer { + num_pq_chunks: self.provider.pq_data.get_num_chunks(), + query_centroid_l2_distance: self + .scratch + .pq_scratch + .aligned_pqtable_dist_scratch + .as_slice() + .to_vec(), + }) + } + + async fn distances_unordered( + &mut self, + vec_id_itr: Itr, + _computer: &Self::QueryComputer, + f: F, + ) -> Result<(), Self::GetError> + where + F: Send + FnMut(f32, Self::Id), + Itr: Iterator, + { + self.scratch.neighbor_buf.clear(); + self.scratch.neighbor_buf.extend(vec_id_itr); + let mut f = f; + let PipelinedScratch { + ref mut pq_scratch, + ref neighbor_buf, + .. + } = *self.scratch; + Self::pq_distances_inner(pq_scratch, self.provider, neighbor_buf, &mut f) + } +} + +impl ExpandBeam<[Data::VectorDataType]> for PipelinedDiskAccessor<'_, Data> +where + Data: GraphDataType, +{ + /// Submit non-blocking io_uring reads for the given node IDs. + /// Nodes found in the node cache are placed directly into `loaded_nodes`, + /// skipping disk IO entirely. Returns IDs that could not be submitted. + fn submit_expand(&mut self, ids: impl Iterator + Send) -> Vec { + let io_start = Instant::now(); + let mut rejected = Vec::new(); + let mut hit_slot_limit = false; + let mut enqueued = 0u32; + for id in ids { + if self.scratch.loaded_nodes.contains_key(&id) { + continue; // Already loaded from a previous IO + } + + // Check node cache first — if the node is cached, build a LoadedNode + // from the cache and skip IO entirely. + if let (Some(vec_data), Some(adj_list)) = ( + self.node_cache.get_vector(&id), + self.node_cache.get_adjacency_list(&id), + ) { + let mut node = self.scratch.acquire_node(); + node.fp_vector.clear(); + node.fp_vector + .extend_from_slice(bytemuck::cast_slice(vec_data)); + node.adjacency_list.clear(); + node.adjacency_list.extend(adj_list.iter().copied()); + node.rank = self.next_rank; + self.next_rank += 1; + self.scratch.loaded_nodes.insert(id, node); + self.cache_hits += 1; + continue; + } + + // Don't submit if no free io_uring slots are available. + if hit_slot_limit || !self.scratch.reader.has_free_slot() { + hit_slot_limit = true; + rejected.push(id); + continue; + } + + let sector_idx = + node_sector_index(id, self.num_nodes_per_sector, self.num_sectors_per_node); + let sector_offset = sector_idx * self.block_size as u64; + let rank = self.next_rank; + self.next_rank += 1; + // enqueue_read allocates a slot internally and pushes the SQE. + // On failure the slot stays free inside the reader. + match self.scratch.reader.enqueue_read(sector_offset) { + Ok(slot_id) => { + self.scratch.in_flight_ios.push_back(InFlightIo { + vertex_id: id, + slot_id, + rank, + }); + self.io_count += 1; + enqueued += 1; + } + Err(_) => { + rejected.push(id); + } + } + } + // Flush all enqueued SQEs in a single syscall. + if enqueued > 0 { + if let Err(e) = self.scratch.reader.flush() { + // Slots remain InFlight; they'll be drained on drop/reset. + self.io_time += io_start.elapsed(); + tracing::warn!("PipelinedReader::flush failed: {e}"); + return rejected; + } + } + self.io_time += io_start.elapsed(); + rejected + } + + /// Poll for completed reads and expand the best loaded node. + /// + /// Uses two selection strategies: + /// 1. If `ids` provides candidates, pick the first loaded match (queue order) + /// 2. Otherwise, pick the loaded node with the lowest submission rank + /// (earliest submitted = best PQ distance at submission time) + async fn expand_available( + &mut self, + ids: impl Iterator + Send, + _computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + expanded_ids: &mut Vec, + ) -> ANNResult<()> + where + P: HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + expanded_ids.clear(); + + // Non-blocking poll for completions + self.drain_completions()?; + + if self.scratch.loaded_nodes.is_empty() { + return Ok(()); + } + + // Try caller's priority order first + let mut best_vid: Option = None; + for id in ids { + if self.scratch.loaded_nodes.contains_key(&id) { + best_vid = Some(id); + break; + } + } + + // Fallback: pick loaded node with lowest rank (best PQ at submission) + if best_vid.is_none() { + best_vid = self + .scratch + .loaded_nodes + .iter() + .min_by_key(|(_, node)| node.rank) + .map(|(&id, _)| id); + } + + let vid = match best_vid { + Some(id) => id, + None => return Ok(()), + }; + let node = self.scratch.loaded_nodes.remove(&vid).unwrap(); + expanded_ids.push(vid); + + // Compute full-precision distance and cache it for post-processing + let cpu_start = Instant::now(); + let fp_vec: &[Data::VectorDataType] = bytemuck::cast_slice(&node.fp_vector); + let fp_dist = self + .provider + .distance_comparer + .evaluate_similarity(self.query, fp_vec); + self.scratch.distance_cache.insert(vid, fp_dist); + + // Get unvisited neighbors into reusable buffer + self.scratch.neighbor_buf.clear(); + self.scratch.neighbor_buf.extend( + node.adjacency_list + .iter() + .copied() + .filter(|&nbr| (nbr as usize) < self.num_points && pred.eval_mut(&nbr)), + ); + + if !self.scratch.neighbor_buf.is_empty() { + let PipelinedScratch { + ref mut pq_scratch, + ref neighbor_buf, + .. + } = *self.scratch; + Self::pq_distances_inner(pq_scratch, self.provider, neighbor_buf, &mut on_neighbors)?; + } + self.cpu_time += cpu_start.elapsed(); + + // Return node to pool for reuse + self.scratch.release_node(node); + + Ok(()) + } + + /// Returns true when there are in-flight IO operations. + fn has_pending(&self) -> bool { + !self.scratch.in_flight_ios.is_empty() + } + + fn inflight_count(&self) -> usize { + self.scratch.in_flight_ios.len() + } + + fn wait_for_io(&mut self) -> ANNResult<()> { + // Only block if there are actually in-flight IOs to wait for + if !self.scratch.in_flight_ios.is_empty() { + self.wait_and_drain()?; + } + Ok(()) + } +} + +impl SearchExt for PipelinedDiskAccessor<'_, Data> +where + Data: GraphDataType, +{ + async fn starting_points(&self) -> ANNResult> { + let start_vertex_id = self.provider.graph_header.metadata().medoid as u32; + Ok(vec![start_vertex_id]) + } + + fn terminate_early(&mut self) -> bool { + false + } +} + +impl Drop for PipelinedDiskAccessor<'_, Data> +where + Data: GraphDataType, +{ + fn drop(&mut self) { + self.shared_io_stats + .io_count + .fetch_add(self.io_count, Ordering::Relaxed); + self.shared_io_stats + .cache_hits + .fetch_add(self.cache_hits, Ordering::Relaxed); + self.shared_io_stats + .io_us + .fetch_add(self.io_time.as_micros() as u64, Ordering::Relaxed); + self.shared_io_stats + .cpu_us + .fetch_add(self.cpu_time.as_micros() as u64, Ordering::Relaxed); + self.shared_io_stats + .preprocess_us + .fetch_add(self.preprocess_time.as_micros() as u64, Ordering::Relaxed); + } +} + +// --------------------------------------------------------------------------- +// SearchStrategy + PostProcessor for pipelined search +// --------------------------------------------------------------------------- + +/// Configuration for creating a pipelined search through DiskIndexSearcher. +pub struct PipelinedConfig> { + pub beam_width: usize, + /// Shared node cache. Nodes found here skip disk IO entirely. + pub node_cache: Arc>, + /// Pooled scratch (io_uring reader + PQ buffers), created once and reused. + pub scratch_pool: Arc>, + /// Args for retrieving/creating pooled scratch instances. + pub scratch_args: PipelinedScratchArgs, +} + +/// Shared IO statistics written by the accessor and read by the caller after search. +/// Uses atomics so the accessor (which lives inside search_internal) can write stats +/// that the caller can read after the search completes. +pub struct PipelinedIoStats { + pub io_count: AtomicU32, + pub cache_hits: AtomicU32, + pub io_us: std::sync::atomic::AtomicU64, + pub cpu_us: std::sync::atomic::AtomicU64, + pub preprocess_us: std::sync::atomic::AtomicU64, +} + +impl Default for PipelinedIoStats { + fn default() -> Self { + Self { + io_count: AtomicU32::new(0), + cache_hits: AtomicU32::new(0), + io_us: std::sync::atomic::AtomicU64::new(0), + cpu_us: std::sync::atomic::AtomicU64::new(0), + preprocess_us: std::sync::atomic::AtomicU64::new(0), + } + } +} + +/// Search strategy that creates PipelinedDiskAccessor instances. +pub struct PipelinedSearchStrategy<'a, Data: GraphDataType> { + query: &'a [Data::VectorDataType], + config: &'a PipelinedConfig, + vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + io_stats: Arc, +} + +/// Post-processor for pipelined search that reranks using cached full-precision distances. +#[derive(Clone, Copy)] +pub struct PipelinedPostProcessor<'a> { + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), +} + +impl + SearchPostProcess< + PipelinedDiskAccessor<'_, Data>, + [Data::VectorDataType], + (u32, Data::AssociatedDataType), + > for PipelinedPostProcessor<'_> +where + Data: GraphDataType, +{ + type Error = ANNError; + + async fn post_process( + &self, + accessor: &mut PipelinedDiskAccessor<'_, Data>, + _query: &[Data::VectorDataType], + _computer: &DiskQueryComputer, + _candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized, + { + // Rerank using ALL expanded nodes' cached fp-distances, not just + // candidates from the priority queue. This matches PipeANN's + // full_retset approach: every expanded node contributes to results + // regardless of its PQ distance ranking. + let mut reranked: Vec<((u32, Data::AssociatedDataType), f32)> = accessor + .scratch + .distance_cache + .iter() + .filter(|(id, _)| (self.filter)(id)) + .map(|(&id, &dist)| ((id, Data::AssociatedDataType::default()), dist)) + .collect(); + + reranked.sort_unstable_by(|a, b| a.1.total_cmp(&b.1)); + Ok(output.extend(reranked)) + } +} + +impl<'this, Data> + SearchStrategy, [Data::VectorDataType], (u32, Data::AssociatedDataType)> + for PipelinedSearchStrategy<'this, Data> +where + Data: GraphDataType, +{ + type QueryComputer = DiskQueryComputer; + type SearchAccessor<'a> = PipelinedDiskAccessor<'a, Data>; + type SearchAccessorError = ANNError; + type PostProcessor = PipelinedPostProcessor<'this>; + + fn search_accessor<'a>( + &'a self, + provider: &'a DiskProvider, + _context: &DefaultContext, + ) -> Result, Self::SearchAccessorError> { + let scratch = + PoolOption::try_pooled(&self.config.scratch_pool, self.config.scratch_args.clone())?; + let mut accessor = PipelinedDiskAccessor::new( + provider, + self.query, + scratch, + self.config.node_cache.clone(), + self.io_stats.clone(), + )?; + accessor.preprocess_query()?; + Ok(accessor) + } + + fn post_processor(&self) -> Self::PostProcessor { + PipelinedPostProcessor { + filter: self.vector_filter, + } + } +} + +// --------------------------------------------------------------------------- +// DiskIndexSearcher integration (search_pipelined method) +// --------------------------------------------------------------------------- + +impl DiskIndexSearcher +where + Data: GraphDataType, + ProviderFactory: VertexProviderFactory, +{ + /// Attach a pipelined configuration to this searcher. + pub fn with_pipelined_config(&mut self, config: PipelinedConfig) { + self.pipelined_config = Some(config); + } + + /// Perform a pipelined search through the unified search loop. + /// + /// Requires that `with_pipelined_config()` was called first. + pub fn search_pipelined( + &self, + query: &[Data::VectorDataType], + return_list_size: u32, + search_list_size: u32, + beam_width: usize, + vector_filter: Option<&(dyn Fn(&u32) -> bool + Send + Sync)>, + ) -> ANNResult> { + let config = self + .pipelined_config + .as_ref() + .ok_or_else(|| ANNError::log_index_error("pipelined_config not set"))?; + + let default_filter: Box bool + Send + Sync> = Box::new(|_| true); + let filter: &(dyn Fn(&u32) -> bool + Send + Sync) = + vector_filter.unwrap_or(default_filter.as_ref()); + + let io_stats = Arc::new(PipelinedIoStats::default()); + + let strategy = PipelinedSearchStrategy { + query, + config, + vector_filter: filter, + io_stats: io_stats.clone(), + }; + + let search_params = SearchParams::new( + return_list_size as usize, + search_list_size as usize, + Some(beam_width), + )?; + + let mut indices = vec![0u32; return_list_size as usize]; + let mut distances = vec![0f32; return_list_size as usize]; + let mut associated_data = + vec![Data::AssociatedDataType::default(); return_list_size as usize]; + let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( + &mut indices[..], + &mut distances[..], + &mut associated_data[..], + ); + + let mut query_stats = QueryStatistics::default(); + let timer = std::time::Instant::now(); + + // Preprocess PQ distance table: the accessor's build_query_computer relies + // on the pq_scratch having been preprocessed for this query. + let stats = self.runtime.block_on(self.index.search( + &strategy, + &DefaultContext, + query, + &search_params, + &mut result_output_buffer, + ))?; + + query_stats.total_comparisons = stats.cmps; + query_stats.search_hops = stats.hops; + query_stats.total_execution_time_us = timer.elapsed().as_micros(); + query_stats.total_io_operations = io_stats.io_count.load(Ordering::Relaxed); + query_stats.total_vertices_loaded = + io_stats.io_count.load(Ordering::Relaxed) + io_stats.cache_hits.load(Ordering::Relaxed); + query_stats.io_time_us = io_stats.io_us.load(Ordering::Relaxed) as u128; + query_stats.cpu_time_us = io_stats.cpu_us.load(Ordering::Relaxed) as u128; + query_stats.query_pq_preprocess_time_us = + io_stats.preprocess_us.load(Ordering::Relaxed) as u128; + + let mut search_result = SearchResult { + results: Vec::with_capacity(return_list_size as usize), + stats: SearchResultStats { + cmps: stats.cmps, + result_count: stats.result_count, + query_statistics: query_stats, + }, + }; + + for ((vertex_id, distance), data) in indices + .into_iter() + .zip(distances.into_iter()) + .zip(associated_data.into_iter()) + { + search_result.results.push(SearchResultItem { + vertex_id, + distance, + data, + }); + } + + Ok(search_result) + } +} diff --git a/diskann-disk/src/search/sector_math.rs b/diskann-disk/src/search/sector_math.rs new file mode 100644 index 000000000..0313f33ab --- /dev/null +++ b/diskann-disk/src/search/sector_math.rs @@ -0,0 +1,33 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Shared sector-layout arithmetic used by both beam search and pipelined search. + +/// Compute the sector index that contains the given vertex. +/// +/// The first sector (index 0) is reserved for the graph header, so data sectors +/// start at index 1. +#[inline] +pub fn node_sector_index( + vertex_id: u32, + num_nodes_per_sector: u64, + num_sectors_per_node: usize, +) -> u64 { + 1 + if num_nodes_per_sector > 0 { + vertex_id as u64 / num_nodes_per_sector + } else { + vertex_id as u64 * num_sectors_per_node as u64 + } +} + +/// Compute the byte offset of a node within its sector. +#[inline] +pub fn node_offset_in_sector(vertex_id: u32, num_nodes_per_sector: u64, node_len: u64) -> usize { + if num_nodes_per_sector == 0 { + 0 + } else { + (vertex_id as u64 % num_nodes_per_sector * node_len) as usize + } +} diff --git a/diskann-disk/src/storage/mod.rs b/diskann-disk/src/storage/mod.rs index 410e39a0a..0e03d6875 100644 --- a/diskann-disk/src/storage/mod.rs +++ b/diskann-disk/src/storage/mod.rs @@ -21,4 +21,9 @@ pub use cached_writer::CachedWriter; pub mod quant; +#[cfg(target_os = "linux")] +pub(crate) mod pipelined_reader; +#[cfg(target_os = "linux")] +pub use pipelined_reader::{PipelinedReader, PipelinedReaderConfig, MAX_IO_CONCURRENCY}; + pub mod api; diff --git a/diskann-disk/src/storage/pipelined_reader.rs b/diskann-disk/src/storage/pipelined_reader.rs new file mode 100644 index 000000000..b27c127e0 --- /dev/null +++ b/diskann-disk/src/storage/pipelined_reader.rs @@ -0,0 +1,787 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Pipelined IO reader using io_uring with non-blocking submit/poll semantics. +//! +//! # Safety model +//! +//! The kernel writes to slot buffers via DMA, which is invisible to the Rust +//! compiler. To avoid aliasing UB we **never** form `&[u8]` or `&mut [u8]` +//! references to the backing allocation while any IO is in-flight. Instead we: +//! +//! 1. Obtain the base raw pointer (`*mut u8`) **once** at construction — before +//! any IO is submitted — and store it for later use. +//! 2. Pass raw pointers to io_uring for kernel DMA targets. +//! 3. Only materialise `&[u8]` slices via [`std::slice::from_raw_parts`] for +//! slots whose state is [`SlotState::Completed`] (kernel has finished writing). +//! +//! Slot lifecycle: `Free → InFlight → Completed → Free`. +//! +//! [`PipelinedReader`] owns the free-list and state machine so callers never +//! need `unsafe` for normal operation. + +use std::{ + collections::VecDeque, + fs::OpenOptions, + os::{fd::AsRawFd, unix::fs::OpenOptionsExt}, +}; + +use diskann::{ANNError, ANNResult}; +use diskann_providers::common::AlignedBoxWithSlice; +use io_uring::IoUring; + +/// Maximum number of concurrent IO operations supported by the ring. +pub const MAX_IO_CONCURRENCY: usize = 128; + +/// Configuration for io_uring-based pipelined reader. +#[derive(Debug, Clone, Default)] +pub struct PipelinedReaderConfig { + /// Enable kernel-side SQ polling. If `Some(idle_ms)`, a kernel thread polls + /// the submission queue, eliminating the syscall per submit. After `idle_ms` + /// milliseconds of inactivity the kernel thread sleeps (resumed automatically + /// on next `submit()`). Requires Linux kernel >= 5.11 (>= 5.13 unprivileged). + pub sqpoll_idle_ms: Option, +} + +/// State of each buffer slot in the pool. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SlotState { + /// Slot is available for a new IO submission. + Free, + /// SQE has been pushed (and possibly submitted). Kernel may be DMA-ing. + InFlight, + /// CQE has been reaped — data is ready. Safe to create `&[u8]`. + Completed, +} + +/// A pipelined IO reader that wraps `io_uring` for non-blocking submit/poll. +/// +/// Unlike `LinuxAlignedFileReader` which uses `submit_and_wait` (blocking), +/// this reader submits reads and polls completions independently, enabling +/// IO/compute overlap within a single search query. +/// +/// The reader owns both the ring buffer allocation and the slot state machine. +/// Callers interact through a safe API: +/// +/// 1. [`enqueue_read`](Self::enqueue_read) — push an SQE, returns `slot_id`. +/// 2. [`flush`](Self::flush) — submit all enqueued SQEs to the kernel (one syscall). +/// 3. [`poll_completions`](Self::poll_completions) / +/// [`wait_completions`](Self::wait_completions) — drain CQEs. +/// 4. [`get_slot_buf`](Self::get_slot_buf) — borrow data for a `Completed` slot. +/// 5. [`release_slot`](Self::release_slot) — return a `Completed` slot to `Free`. +pub struct PipelinedReader { + ring: IoUring, + /// Owns the aligned allocation. **Must not be dereferenced** while any IO is + /// in-flight — see the module-level safety discussion. + _slot_bufs: AlignedBoxWithSlice, + /// Raw pointer to the start of the buffer, obtained once at construction. + /// All subsequent slot access goes through pointer arithmetic on this base. + buf_base: *mut u8, + /// Size of each slot buffer in bytes. + slot_size: usize, + /// Maximum number of slots available. + max_slots: usize, + /// Per-slot state. + slot_states: Vec, + /// FIFO free-list for O(1) slot allocation. + free_slots: VecDeque, + /// Number of slots whose SQEs have been submitted to the kernel (InFlight). + in_flight: usize, + /// Keep the file handle alive for the lifetime of the reader. + _file: std::fs::File, +} + +// SAFETY: The raw pointer `buf_base` is derived from an owned allocation +// (`_slot_bufs`) and is never shared — all mutable access requires `&mut self`. +// The io_uring ring and file descriptor are kernel-side resources with no +// thread-affinity. Moving the reader between threads is safe. +unsafe impl Send for PipelinedReader {} +// SAFETY: `&self` methods only access completed slot data (kernel has finished +// writing). All mutation requires `&mut self`. +unsafe impl Sync for PipelinedReader {} + +impl PipelinedReader { + /// Create a new pipelined reader. + /// + /// # Arguments + /// * `file_path` - Path to the disk index file. + /// * `max_slots` - Number of buffer slots (clamped to [`MAX_IO_CONCURRENCY`]). + /// * `slot_size` - Size of each buffer slot in bytes (should be sector-aligned). + /// * `alignment` - Memory alignment for the buffer (typically 4096 for O_DIRECT). + /// * `config` - Optional io_uring tuning (e.g. SQPOLL). + pub fn new( + file_path: &str, + max_slots: usize, + slot_size: usize, + alignment: usize, + config: &PipelinedReaderConfig, + ) -> ANNResult { + let file = OpenOptions::new() + .read(true) + .custom_flags(libc::O_DIRECT) + .open(file_path) + .map_err(ANNError::log_io_error)?; + + let max_slots = max_slots.min(MAX_IO_CONCURRENCY); + let entries = max_slots as u32; + let ring = if let Some(idle_ms) = config.sqpoll_idle_ms { + let mut builder = IoUring::builder(); + builder.setup_sqpoll(idle_ms); + builder.build(entries)? + } else { + IoUring::new(entries)? + }; + let fd = file.as_raw_fd(); + ring.submitter().register_files(std::slice::from_ref(&fd))?; + + let mut slot_bufs = AlignedBoxWithSlice::new(max_slots * slot_size, alignment)?; + + // SAFETY: No IOs are in-flight yet, so creating a `&mut [u8]` is sound. + // We extract the raw pointer here and never form a reference again. + let buf_base: *mut u8 = slot_bufs.as_mut_slice().as_mut_ptr(); + + Ok(Self { + ring, + _slot_bufs: slot_bufs, + buf_base, + slot_size, + max_slots, + slot_states: vec![SlotState::Free; max_slots], + free_slots: (0..max_slots).collect(), + in_flight: 0, + _file: file, + }) + } + + // ------------------------------------------------------------------ + // Submission + // ------------------------------------------------------------------ + + /// Enqueue an asynchronous read for `sector_offset` into a newly-allocated + /// buffer slot. Returns the `slot_id` on success. + /// + /// The SQE is pushed to the submission queue but **not submitted** to the + /// kernel. Call [`flush`](Self::flush) after enqueuing a batch to submit + /// them all in a single syscall. + /// + /// Returns an error if no free slots are available. + pub fn enqueue_read(&mut self, sector_offset: u64) -> ANNResult { + let slot_id = self.free_slots.pop_front().ok_or_else(|| { + ANNError::log_index_error(format_args!( + "PipelinedReader: no free slots (max_slots={})", + self.max_slots + )) + })?; + debug_assert_eq!(self.slot_states[slot_id], SlotState::Free); + + // Raw pointer arithmetic — no reference to the backing buffer. + let buf_ptr = unsafe { self.buf_base.add(slot_id * self.slot_size) }; + + let read_op = + io_uring::opcode::Read::new(io_uring::types::Fixed(0), buf_ptr, self.slot_size as u32) + .offset(sector_offset) + .build() + .user_data(slot_id as u64); + + // SAFETY: `buf_ptr` points into a pre-allocated, aligned region that + // outlives the reader. The slot is being transitioned to InFlight so no + // other code will access this memory region. + let push_result = unsafe { self.ring.submission().push(&read_op) }; + if let Err(e) = push_result { + // SQE queue full — return slot to free-list. + self.free_slots.push_back(slot_id); + return Err(ANNError::log_push_error(e)); + } + + self.slot_states[slot_id] = SlotState::InFlight; + self.in_flight += 1; + Ok(slot_id) + } + + /// Submit all enqueued SQEs to the kernel in a single syscall. + /// + /// Retries automatically on `EINTR`. On fatal errors the enqueued slots + /// remain `InFlight` and will be drained on [`Drop`]. + pub fn flush(&mut self) -> ANNResult<()> { + loop { + match self.ring.submit() { + Ok(_) => return Ok(()), + Err(ref e) if e.raw_os_error() == Some(libc::EINTR) => continue, + Err(e) => return Err(ANNError::log_io_error(e)), + } + } + } + + // ------------------------------------------------------------------ + // Completion + // ------------------------------------------------------------------ + + /// Poll for completed IO operations (non-blocking). + /// + /// Appends completed `slot_id`s to `completed`. Slots transition from + /// `InFlight` → `Completed`. The caller must eventually call + /// [`release_slot`](Self::release_slot) for each returned slot. + /// + /// On IO errors or short reads the affected slot is freed automatically and + /// an error is returned. Successfully completed slots in `completed` are + /// still valid and should be processed first. + pub fn poll_completions(&mut self, completed: &mut Vec) -> ANNResult<()> { + self.drain_cqes(completed) + } + + /// Block until at least one IO completes, then drain all available CQEs. + /// + /// Same contract as [`poll_completions`](Self::poll_completions). + pub fn wait_completions(&mut self, completed: &mut Vec) -> ANNResult<()> { + if self.in_flight == 0 { + completed.clear(); + return Ok(()); + } + // submit_and_wait also flushes any un-submitted SQEs. + loop { + match self.ring.submit_and_wait(1) { + Ok(_) => break, + Err(ref e) if e.raw_os_error() == Some(libc::EINTR) => continue, + Err(e) => return Err(ANNError::log_io_error(e)), + } + } + self.drain_cqes(completed) + } + + /// Drain all available CQEs from the completion queue. + /// + /// Processes every available CQE. On error or short-read the affected slot + /// is returned to `Free` and the first error is propagated after all CQEs + /// have been consumed (so no CQEs are left unprocessed). + fn drain_cqes(&mut self, completed: &mut Vec) -> ANNResult<()> { + completed.clear(); + let mut first_error: Option = None; + + for cqe in self.ring.completion() { + let slot_id = cqe.user_data() as usize; + debug_assert!(slot_id < self.max_slots); + debug_assert_eq!(self.slot_states[slot_id], SlotState::InFlight); + self.in_flight -= 1; + + if cqe.result() < 0 { + self.slot_states[slot_id] = SlotState::Free; + self.free_slots.push_back(slot_id); + if first_error.is_none() { + first_error = Some(ANNError::log_io_error(std::io::Error::from_raw_os_error( + -cqe.result(), + ))); + } + continue; + } + + let bytes_read = cqe.result() as usize; + if bytes_read < self.slot_size { + self.slot_states[slot_id] = SlotState::Free; + self.free_slots.push_back(slot_id); + if first_error.is_none() { + first_error = Some(ANNError::log_io_error(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!( + "short read: expected {} bytes, got {}", + self.slot_size, bytes_read + ), + ))); + } + continue; + } + + self.slot_states[slot_id] = SlotState::Completed; + completed.push(slot_id); + } + + match first_error { + Some(e) => Err(e), + None => Ok(()), + } + } + + // ------------------------------------------------------------------ + // Slot access + // ------------------------------------------------------------------ + + /// Returns the read buffer for a completed slot. + /// + /// # Panics + /// Panics if `slot_id` is out of range or the slot is not in `Completed` + /// state (i.e. data is not yet ready or has already been released). + pub fn get_slot_buf(&self, slot_id: usize) -> &[u8] { + assert!(slot_id < self.max_slots, "slot_id out of range"); + assert_eq!( + self.slot_states[slot_id], + SlotState::Completed, + "slot {slot_id} is not Completed (state: {:?})", + self.slot_states[slot_id], + ); + // SAFETY: The slot is Completed — the kernel has finished writing. + // `buf_base` was derived from a valid, aligned allocation that outlives + // `self`. The slice covers exactly `slot_size` bytes within bounds. + unsafe { + std::slice::from_raw_parts(self.buf_base.add(slot_id * self.slot_size), self.slot_size) + } + } + + /// Release a completed slot back to the free-list for reuse. + /// + /// # Panics + /// Panics if the slot is not in `Completed` state. + pub fn release_slot(&mut self, slot_id: usize) { + assert!(slot_id < self.max_slots, "slot_id out of range"); + assert_eq!( + self.slot_states[slot_id], + SlotState::Completed, + "cannot release slot {slot_id}: not Completed (state: {:?})", + self.slot_states[slot_id], + ); + self.slot_states[slot_id] = SlotState::Free; + self.free_slots.push_back(slot_id); + } + + // ------------------------------------------------------------------ + // Lifecycle helpers + // ------------------------------------------------------------------ + + /// Returns `true` if a free slot is available for [`enqueue_read`](Self::enqueue_read). + pub fn has_free_slot(&self) -> bool { + !self.free_slots.is_empty() + } + + /// Returns the number of submitted but not yet completed reads. + pub fn in_flight_count(&self) -> usize { + self.in_flight + } + + /// Returns the slot size in bytes. + pub fn slot_size(&self) -> usize { + self.slot_size + } + + /// Returns the maximum number of buffer slots. + pub fn max_slots(&self) -> usize { + self.max_slots + } + + /// Reset the reader for reuse: drain all in-flight IOs, release all + /// completed slots, then restore every slot to `Free`. + pub fn reset(&mut self) { + self.drain_all(); + } + + /// Drain all in-flight IOs, blocking until they complete, then reset all + /// slot states to `Free`. + /// + /// On transient errors (`EINTR`) retries automatically. On unrecoverable + /// errors aborts the process — deallocating the buffer while the kernel + /// still holds DMA references would cause memory corruption. + fn drain_all(&mut self) { + let mut remaining = self.in_flight; + while remaining > 0 { + match self.ring.submit_and_wait(remaining) { + Ok(_) => {} + Err(ref e) if e.raw_os_error() == Some(libc::EINTR) => continue, + Err(_) => { + // Cannot safely deallocate while kernel may have DMA refs. + std::process::abort(); + } + } + for cqe in self.ring.completion() { + let _ = cqe; + remaining = remaining.saturating_sub(1); + } + } + self.in_flight = 0; + for state in &mut self.slot_states { + *state = SlotState::Free; + } + self.free_slots.clear(); + self.free_slots.extend(0..self.max_slots); + } +} + +impl Drop for PipelinedReader { + fn drop(&mut self) { + // Must wait for all in-flight kernel IOs to complete before the + // allocation backing `_slot_bufs` is freed. + self.drain_all(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::io::Write; + + const SECTOR: usize = 4096; + + /// Create a temp file with `n_sectors` sectors of known data. + /// Each sector is filled with the byte `(sector_index & 0xFF) as u8`. + fn make_test_file(n_sectors: usize) -> tempfile::NamedTempFile { + let mut f = tempfile::NamedTempFile::new().expect("create tempfile"); + for i in 0..n_sectors { + let byte = (i & 0xFF) as u8; + f.write_all(&vec![byte; SECTOR]).expect("write sector"); + } + f.flush().expect("flush"); + f + } + + /// Create a reader backed by a temp file. Returns both so the file + /// outlives the reader. + fn make_reader( + n_sectors: usize, + max_slots: usize, + ) -> (tempfile::NamedTempFile, PipelinedReader) { + let file = make_test_file(n_sectors); + let reader = PipelinedReader::new( + file.path().to_str().unwrap(), + max_slots, + SECTOR, + SECTOR, + &PipelinedReaderConfig::default(), + ) + .unwrap(); + (file, reader) + } + + /// Enqueue reads for `sectors`, flush, wait for all completions. + /// Returns the slot IDs in enqueue order. + fn enqueue_flush_wait( + reader: &mut PipelinedReader, + sectors: impl IntoIterator, + ) -> Vec { + let mut slots = Vec::new(); + for s in sectors { + slots.push(reader.enqueue_read((s * SECTOR) as u64).unwrap()); + } + reader.flush().unwrap(); + drain_all_completions(reader); + slots + } + + /// Wait until all in-flight IOs complete. + fn drain_all_completions(reader: &mut PipelinedReader) { + let mut buf = Vec::new(); + while reader.in_flight_count() > 0 { + reader.wait_completions(&mut buf).unwrap(); + } + } + + /// Assert that a completed slot contains the expected fill byte for a + /// given sector index (test files fill sector N with byte N & 0xFF). + fn assert_sector_data(reader: &PipelinedReader, slot: usize, sector: usize) { + let buf = reader.get_slot_buf(slot); + let expected = (sector & 0xFF) as u8; + assert!( + buf.iter().all(|&b| b == expected), + "slot {slot} (sector {sector}): expected 0x{expected:02x}, got 0x{:02x}", + buf[0], + ); + } + + // =================================================================== + // Unit tests — each tests a single API behavior + // =================================================================== + + #[test] + fn slot_lifecycle_round_trip() { + let (_f, mut reader) = make_reader(4, 4); + + // Enqueue → flush → wait → get_buf → release + let slot = reader.enqueue_read(0).unwrap(); + assert_eq!(reader.slot_states[slot], SlotState::InFlight); + + reader.flush().unwrap(); + drain_all_completions(&mut reader); + assert_eq!(reader.slot_states[slot], SlotState::Completed); + + assert_sector_data(&reader, slot, 0); + reader.release_slot(slot); + assert_eq!(reader.slot_states[slot], SlotState::Free); + + // Reuse the slot for a different sector + let slots = enqueue_flush_wait(&mut reader, [1]); + assert_sector_data(&reader, slots[0], 1); + reader.release_slot(slots[0]); + } + + #[test] + fn slot_exhaustion_returns_error() { + let (_f, mut reader) = make_reader(8, 4); + for i in 0..4 { + reader.enqueue_read((i * SECTOR) as u64).unwrap(); + } + assert!(reader.enqueue_read(0).is_err()); + } + + #[test] + #[should_panic(expected = "not Completed")] + fn double_release_panics() { + let (_f, mut reader) = make_reader(1, 2); + let slots = enqueue_flush_wait(&mut reader, [0]); + reader.release_slot(slots[0]); + reader.release_slot(slots[0]); // should panic + } + + #[test] + #[should_panic(expected = "not Completed")] + fn get_buf_on_free_slot_panics() { + let (_f, reader) = make_reader(1, 2); + reader.get_slot_buf(0); + } + + #[test] + #[should_panic(expected = "not Completed")] + fn get_buf_on_inflight_slot_panics() { + let (_f, mut reader) = make_reader(1, 2); + let slot = reader.enqueue_read(0).unwrap(); + reader.flush().unwrap(); + reader.get_slot_buf(slot); // still InFlight + } + + #[test] + fn drop_drains_in_flight() { + let (_f, mut reader) = make_reader(4, 4); + for i in 0..4 { + reader.enqueue_read((i * SECTOR) as u64).unwrap(); + } + reader.flush().unwrap(); + drop(reader); // must not panic or leak + } + + #[test] + fn data_integrity_multi_slot() { + let (_f, mut reader) = make_reader(8, 4); + let slots = enqueue_flush_wait(&mut reader, 0..4); + for (slot, sector) in slots.iter().zip(0..4) { + assert_sector_data(&reader, *slot, sector); + reader.release_slot(*slot); + } + } + + #[test] + fn reset_clears_all_state() { + let (_f, mut reader) = make_reader(4, 4); + enqueue_flush_wait(&mut reader, [0, 1]); + reader.enqueue_read(2 * SECTOR as u64).unwrap(); + reader.flush().unwrap(); + + reader.reset(); + assert_eq!(reader.in_flight, 0); + assert_eq!(reader.free_slots.len(), 4); + assert!(reader.slot_states.iter().all(|&s| s == SlotState::Free)); + } + + #[test] + fn poll_and_wait_return_empty_when_idle() { + let (_f, mut reader) = make_reader(1, 2); + let mut buf = Vec::new(); + reader.poll_completions(&mut buf).unwrap(); + assert!(buf.is_empty()); + reader.wait_completions(&mut buf).unwrap(); + assert!(buf.is_empty()); + } + + #[test] + fn short_read_detected_as_error() { + let mut f = tempfile::NamedTempFile::new().unwrap(); + f.write_all(&vec![0xABu8; 512]).unwrap(); // < SECTOR + f.flush().unwrap(); + + let mut reader = PipelinedReader::new( + f.path().to_str().unwrap(), + 1, + SECTOR, + SECTOR, + &PipelinedReaderConfig::default(), + ) + .unwrap(); + reader.enqueue_read(0).unwrap(); + reader.flush().unwrap(); + + let mut completed = Vec::new(); + let result = reader.wait_completions(&mut completed); + assert!(result.is_err(), "short read should be detected"); + assert!(completed.is_empty()); + } + + #[test] + fn drop_with_unflushed_sqes() { + let (_f, mut reader) = make_reader(8, 8); + for i in 0..8 { + reader.enqueue_read((i * SECTOR) as u64).unwrap(); + } + // Enqueued but never flushed — drain_all's submit_and_wait handles it + drop(reader); + } + + // =================================================================== + // Stress tests — exercise the state machine at scale + // =================================================================== + + /// Randomized state-machine fuzzer using seeded RNG for reproducibility. + /// Exercises random interleavings of enqueue, flush, poll, wait, release, + /// and reset with data verification. + #[test] + fn stress_random_slot_lifecycle() { + let (_f, mut reader) = make_reader(256, 16); + let mut rng = StdRng::seed_from_u64(0xDEAD_BEEF); + let mut pending_completed: Vec = Vec::new(); + let mut total_verified = 0u64; + + for _ in 0..2000 { + match rng.random_range(0u32..100) { + 0..40 => { + if reader.has_free_slot() { + let sector = rng.random_range(0usize..256); + reader.enqueue_read((sector * SECTOR) as u64).unwrap(); + } + } + 40..55 => { + reader.flush().unwrap(); + } + 55..70 => { + let mut buf = Vec::new(); + reader.poll_completions(&mut buf).unwrap(); + pending_completed.extend_from_slice(&buf); + } + 70..80 => { + if reader.in_flight_count() > 0 { + reader.flush().unwrap(); + let mut buf = Vec::new(); + reader.wait_completions(&mut buf).unwrap(); + pending_completed.extend_from_slice(&buf); + } + } + 80..95 => { + if let Some(slot) = pending_completed.pop() { + let buf = reader.get_slot_buf(slot); + let first = buf[0]; + assert!( + buf.iter().all(|&b| b == first), + "data corruption in slot {slot}" + ); + reader.release_slot(slot); + total_verified += 1; + } + } + _ => { + pending_completed.clear(); + reader.reset(); + } + } + } + + // Cleanup: flush + drain remaining + reader.flush().unwrap(); + let mut buf = Vec::new(); + while reader.in_flight_count() > 0 { + reader.wait_completions(&mut buf).unwrap(); + for &slot in &buf { + let data = reader.get_slot_buf(slot); + assert!(data.iter().all(|&b| b == data[0])); + reader.release_slot(slot); + total_verified += 1; + } + } + for &slot in &pending_completed { + let data = reader.get_slot_buf(slot); + assert!(data.iter().all(|&b| b == data[0])); + reader.release_slot(slot); + total_verified += 1; + } + assert!(total_verified > 0, "stress test verified zero reads"); + } + + /// Saturate all slots, drain, repeat — catches off-by-one in free-list. + #[test] + fn stress_saturate_and_drain_cycles() { + let max_slots = 32; + let (_f, mut reader) = make_reader(max_slots, max_slots); + + for cycle in 0..100 { + let sectors: Vec = (0..max_slots) + .map(|i| (cycle * max_slots + i) % max_slots) + .collect(); + let slots = enqueue_flush_wait(&mut reader, sectors.iter().copied()); + assert!(reader.enqueue_read(0).is_err()); + + for (slot, §or) in slots.iter().zip(sectors.iter()) { + assert_sector_data(&reader, *slot, sector); + reader.release_slot(*slot); + } + } + } + + /// 1-slot reader: max state transitions per slot. + #[test] + fn stress_single_slot_rapid_reuse() { + let n_sectors = 64; + let (_f, mut reader) = make_reader(n_sectors, 1); + + for i in 0..500 { + let sector = i % n_sectors; + let slots = enqueue_flush_wait(&mut reader, [sector]); + assert_sector_data(&reader, slots[0], sector); + reader.release_slot(slots[0]); + } + } + + /// Drop with 0, 1, 2, … max_slots in-flight IOs. + #[test] + fn stress_drop_at_various_inflight_counts() { + let max_slots = 16; + for inflight in 0..=max_slots { + let (_f, mut reader) = make_reader(max_slots, max_slots); + for i in 0..inflight { + reader.enqueue_read((i * SECTOR) as u64).unwrap(); + } + if inflight > 0 { + reader.flush().unwrap(); + } + drop(reader); + } + } + + /// Read every sector in a 256-sector file through 8 slots, verify all. + #[test] + fn stress_full_file_sequential_scan() { + let n_sectors = 256; + let max_slots = 8; + let (_f, mut reader) = make_reader(n_sectors, max_slots); + + let mut sectors_verified = vec![false; n_sectors]; + let mut slot_to_sector = [0usize; 128]; + let mut next_sector = 0usize; + let mut buf = Vec::new(); + + while next_sector < n_sectors || reader.in_flight_count() > 0 { + while next_sector < n_sectors && reader.has_free_slot() { + let slot = reader.enqueue_read((next_sector * SECTOR) as u64).unwrap(); + slot_to_sector[slot] = next_sector; + next_sector += 1; + } + reader.flush().unwrap(); + + reader.wait_completions(&mut buf).unwrap(); + for &slot in &buf { + let sector = slot_to_sector[slot]; + assert_sector_data(&reader, slot, sector); + sectors_verified[sector] = true; + reader.release_slot(slot); + } + } + + assert!( + sectors_verified.iter().all(|&v| v), + "not all sectors verified" + ); + } +} diff --git a/diskann-platform/Cargo.toml b/diskann-platform/Cargo.toml index 07bbd3a33..5eac8fe68 100644 --- a/diskann-platform/Cargo.toml +++ b/diskann-platform/Cargo.toml @@ -15,7 +15,7 @@ documentation.workspace = true tracing.workspace = true [target.'cfg(target_os = "linux")'.dependencies] -io-uring = "0.6.4" +io-uring = "0.7" libc = "0.2.148" [target.'cfg(target_os = "windows")'.dependencies.windows-sys] diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 1a2936f5e..575b18280 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -389,7 +389,7 @@ where _span.set_attribute(KeyValue::new("latency_95", latency_95 as f64)); _span.set_attribute(KeyValue::new("mean_cpus", mean_cpus)); _span.set_attribute(KeyValue::new("mean_io_time", mean_io_time)); - _span.set_attribute(KeyValue::new("mean_ios", mean_ios as f64)); + _span.set_attribute(KeyValue::new("mean_ios", mean_ios)); _span.set_attribute(KeyValue::new("mean_comps", mean_comps)); _span.set_attribute(KeyValue::new("mean_hops", mean_hops)); _span.set_attribute(KeyValue::new("recall", recall as f64)); diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 411a97031..581cc9838 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -255,6 +255,81 @@ pub trait ExpandBeam: BuildQueryComputer + AsNeighbor + Sized where T: ?Sized, { + /// Submit IDs for expansion. + /// + /// For non-pipelined providers (default), this is a no-op — IDs are passed + /// directly to [`expand_beam`] in [`expand_available`]. For pipelined providers, + /// this submits non-blocking IO requests. Any IDs that could not be submitted + /// (e.g., no free IO slots) are returned so the caller can revert their state. + fn submit_expand(&mut self, _ids: impl Iterator + Send) -> Vec { + Vec::new() // Default: all accepted + } + + /// Expand nodes whose data is available, invoking `on_neighbors` for each discovered + /// neighbor. + /// + /// For non-pipelined providers (default), this expands all the `ids` passed in + /// synchronously via [`expand_beam`]. For pipelined providers, this polls for + /// completed IO operations and expands only the nodes whose data has arrived, + /// returning immediately without blocking. + /// + /// The IDs of nodes actually expanded are written into `expanded_ids`. + fn expand_available( + &mut self, + ids: impl Iterator + Send, + computer: &Self::QueryComputer, + pred: P, + on_neighbors: F, + expanded_ids: &mut Vec, + ) -> impl std::future::Future> + Send + where + P: HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + async move { + expanded_ids.clear(); + expanded_ids.extend(ids); + self.expand_beam(expanded_ids.iter().copied(), computer, pred, on_neighbors) + .await?; + Ok(()) + } + } + + /// Returns true if there are submitted but not-yet-expanded nodes pending. + /// + /// For non-pipelined providers (default), this always returns `false` since + /// [`expand_available`] processes everything synchronously. Pipelined providers + /// return `true` when IO operations are in-flight. + fn has_pending(&self) -> bool { + false + } + + /// Returns the number of IOs currently in-flight (submitted but not completed). + /// + /// The search loop uses this to cap submissions at `beam_width`. + /// Default: 0 (non-pipelined providers have no in-flight IO). + fn inflight_count(&self) -> usize { + 0 + } + + /// Block until at least one IO completes, then eagerly drain all available. + /// + /// Called by the search loop only when it cannot make progress: nothing was + /// submitted (no candidates or inflight cap reached) AND nothing was expanded + /// (no completions available). Blocking here yields the CPU thread instead of + /// spin-polling, while the eager drain ensures we process bursts efficiently. + /// + /// Default: no-op (non-pipelined providers never need to wait). + fn wait_for_io(&mut self) -> ANNResult<()> { + Ok(()) + } + + /// Expand all `ids` synchronously: load data, get neighbors, compute distances. + /// + /// This is the original single-shot expansion method. For non-pipelined providers, + /// the default [`expand_available`] delegates to this. Pipelined providers may + /// override [`submit_expand`] and [`expand_available`] instead and leave this as + /// the default. fn expand_beam( &mut self, ids: Itr, diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index ea48adc0b..ad99ecbce 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -397,7 +397,6 @@ where for attempt in 0..num_insert_attempts { let mut search_record = VisitedSearchRecord::new(self.estimate_visited_set_capacity(Some(search_l))); - self.search_internal( None, // beam_width &start_ids, @@ -521,7 +520,6 @@ where let mut search_record = VisitedSearchRecord::new( self.estimate_visited_set_capacity(Some(scratch.best.search_l())), ); - self.search_internal( None, // beam_width &start_ids, @@ -1329,7 +1327,6 @@ where let start_ids = search_accessor.starting_points().await?; let mut scratch = self.search_scratch(l_value, start_ids.len()); - self.search_internal( None, // beam_width &start_ids, @@ -2094,38 +2091,63 @@ where } } - let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); - while scratch.best.has_notvisited_node() && !accessor.terminate_early() { - scratch.beam_nodes.clear(); + scratch.neighbors.clear(); - // In this loop we are going to find the beam_width number of nodes that are closest to the query. - // Each of these nodes will be a frontier node. - while scratch.best.has_notvisited_node() && scratch.beam_nodes.len() < beam_width { - let closest_node = scratch.best.closest_notvisited(); - search_record.record(closest_node, scratch.hops, scratch.cmps); - scratch.beam_nodes.push(closest_node.id); - } + let mut expanded_ids = Vec::new(); - neighbors.clear(); + while (scratch.best.has_notvisited_node() + || scratch.best.peek_best_unsubmitted().is_some() + || accessor.has_pending()) + && !accessor.terminate_early() + { + // Phase 1: Expand nodes whose data is available. + // Non-pipelined: synchronously expands beam_nodes from previous submit. + // Pipelined: polls IO completions and expands one loaded node. + // On the first iteration beam_nodes is empty — a no-op for both paths. + scratch.neighbors.clear(); accessor - .expand_beam( + .expand_available( scratch.beam_nodes.iter().copied(), computer, glue::NotInMut::new(&mut scratch.visited), - |distance, id| neighbors.push(Neighbor::new(id, distance)), + |distance, id| scratch.neighbors.push(Neighbor::new(id, distance)), + &mut expanded_ids, ) .await?; - // The predicate ensures that the contents of `neighbors` are unique. - // - // We insert into the priority queue outside of the expansion for - // code-locality purposes. - neighbors + for &id in &expanded_ids { + scratch.best.mark_visited_by_id(&id); + } + + scratch + .neighbors .iter() .for_each(|neighbor| scratch.best.insert(*neighbor)); + scratch.cmps += scratch.neighbors.len() as u32; + scratch.hops += expanded_ids.len() as u32; - scratch.cmps += neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; + // Phase 2: Select and submit candidates to fill the pipeline. + // Non-pipelined: inflight is always 0, so this submits beam_width nodes. + // Pipelined: submits enough to keep beam_width IOs in flight. + scratch.beam_nodes.clear(); + let slots = beam_width.saturating_sub(accessor.inflight_count()); + while scratch.beam_nodes.len() < slots { + if let Some(closest_node) = scratch.best.pop_best_unsubmitted() { + search_record.record(closest_node, scratch.hops, scratch.cmps); + scratch.beam_nodes.push(closest_node.id); + } else { + break; + } + } + let rejected = accessor.submit_expand(scratch.beam_nodes.iter().copied()); + for id in rejected { + scratch.best.revert_submitted(&id); + } + + // Phase 3: Block only when no progress was made but IOs are pending. + if expanded_ids.is_empty() && accessor.has_pending() { + accessor.wait_for_io()?; + } } Ok(InternalSearchStats { @@ -2156,8 +2178,6 @@ where scratch.range_frontier.push_back(neighbor.id); } - let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); - let max_returned = search_params.max_returned.unwrap_or(usize::MAX); while !scratch.range_frontier.is_empty() { @@ -2172,18 +2192,18 @@ where } } - neighbors.clear(); + scratch.neighbors.clear(); accessor .expand_beam( scratch.beam_nodes.iter().copied(), computer, glue::NotInMut::new(&mut scratch.visited), - |distance, id| neighbors.push(Neighbor::new(id, distance)), + |distance, id| scratch.neighbors.push(Neighbor::new(id, distance)), ) .await?; // The predicate ensure that the contents of `neighbors` are unique. - for neighbor in neighbors.iter() { + for neighbor in scratch.neighbors.iter() { if neighbor.distance <= search_params.radius * search_params.range_search_slack && scratch.in_range.len() < max_returned { @@ -2191,7 +2211,7 @@ where scratch.range_frontier.push_back(neighbor.id); } } - scratch.cmps += neighbors.len() as u32; + scratch.cmps += scratch.neighbors.len() as u32; scratch.hops += scratch.beam_nodes.len() as u32; } @@ -2614,7 +2634,6 @@ where let start_ids = accessor.starting_points().await?; let mut scratch = self.search_scratch(search_params.starting_l_value, start_ids.len()); - let initial_stats = self .search_internal( search_params.beam_width, @@ -3644,7 +3663,7 @@ where SearchScratch { best: diverse_queue, visited: HashSet::with_capacity(self.estimate_visited_set_capacity(Some(l_value))), - id_scratch: Vec::with_capacity(self.max_degree_with_slack()), + neighbors: Vec::with_capacity(self.max_degree_with_slack()), beam_nodes: Vec::with_capacity(beam_width.unwrap_or(1)), range_frontier: std::collections::VecDeque::new(), in_range: Vec::new(), diff --git a/diskann/src/graph/search/scratch.rs b/diskann/src/graph/search/scratch.rs index 2a4706821..75ca5f54d 100644 --- a/diskann/src/graph/search/scratch.rs +++ b/diskann/src/graph/search/scratch.rs @@ -20,12 +20,9 @@ pub const GRAPH_SLACK_FACTOR: f64 = 1.3_f64; /// Scratch space used during graph search. /// -/// This struct contains three important members used by both the sync and async indexes: -/// `query`, `best`, and `visited`. -/// -/// The member `id_scratch` is only used by the sync index. -/// -/// Members `labels` and `beta` are used by the async index for beta-filtered search. +/// This struct holds reusable buffers that are cleared between searches but retain their +/// heap allocations. The key members are `best` (priority queue), `visited` (dedup set), +/// and `neighbors`/`submitted` (per-hop buffers used by the search loop). #[derive(Debug)] pub struct SearchScratch> where @@ -34,26 +31,16 @@ where /// A priority queue of the best candidates seen during search. This data structure is /// also responsible for determining the best unvisited candidate. /// - /// Used by both sync and async. - /// /// When used in a paged search context, this queue is unbounded. pub best: Q, /// A record of all ids visited during a search. /// - /// Used by both sync and async. - /// /// This is used to prevent multiple requests to the same `id` from the vector providers. pub visited: HashSet, - /// A buffer for adjacency lists. - /// - /// Only used by sync. - /// - /// Adjacency lists in the sync provider are guarded by read/write locks. The - /// `id_scratch` is used to copy out the contents of an adjacency list to minimize the - /// duration the lock is held. - pub id_scratch: Vec, + /// A reusable buffer for collecting neighbor distances during expansion. + pub neighbors: Vec>, /// A list of beam search nodes used during search. This is used when beam search is enabled /// to temporarily hold beam of nodes in each hop. @@ -123,7 +110,7 @@ where Self { best, visited, - id_scratch: Vec::new(), + neighbors: Vec::new(), beam_nodes: Vec::new(), in_range: Vec::new(), range_frontier: VecDeque::new(), @@ -147,7 +134,7 @@ where pub fn clear(&mut self) { self.best.clear(); self.visited.clear(); - self.id_scratch.clear(); + self.neighbors.clear(); self.beam_nodes.clear(); self.in_range.clear(); self.range_frontier.clear(); @@ -244,7 +231,7 @@ mod tests { assert_eq!(x.visited.capacity(), 0); assert!(x.visited.is_empty()); - assert!(x.id_scratch.is_empty()); + assert!(x.neighbors.is_empty()); assert!(x.hops == 0); assert!(x.cmps == 0); @@ -262,7 +249,7 @@ mod tests { assert_eq!(x.visited.capacity(), 0); assert!(x.visited.is_empty()); - assert!(x.id_scratch.is_empty()); + assert!(x.neighbors.is_empty()); assert!(x.hops == 0); assert!(x.cmps == 0); @@ -299,8 +286,8 @@ mod tests { x.visited.insert(1); x.visited.insert(10); - x.id_scratch.push(1); - x.id_scratch.push(10); + x.neighbors.push(Neighbor::new(1, 1.0)); + x.neighbors.push(Neighbor::new(10, 2.0)); x.best.insert(Neighbor::new(1, 1.0)); x.best.insert(Neighbor::new(10, 2.0)); @@ -309,7 +296,7 @@ mod tests { // Do the clear. x.clear(); assert!(x.visited.is_empty()); - assert!(x.id_scratch.is_empty()); + assert!(x.neighbors.is_empty()); assert_eq!(x.best.size(), 0); assert!(x.hops == 0); diff --git a/diskann/src/neighbor/diverse_priority_queue.rs b/diskann/src/neighbor/diverse_priority_queue.rs index b5373ad44..8d8609916 100644 --- a/diskann/src/neighbor/diverse_priority_queue.rs +++ b/diskann/src/neighbor/diverse_priority_queue.rs @@ -12,7 +12,7 @@ use std::{ }; use crate::neighbor::{ - Neighbor, + Neighbor, NodeState, queue::{ BestCandidatesIterator, NeighborPriorityQueue, NeighborPriorityQueueIdType, NeighborQueue, }, @@ -260,6 +260,60 @@ where let sz = self.global_queue.search_l().min(self.global_queue.size()); BestCandidatesIterator::new(sz, self) } + + fn peek_best_unsubmitted(&self) -> Option> { + self.global_queue + .peek_best_unsubmitted() + .map(|n| Neighbor::new(n.id.id, n.distance)) + } + + fn pop_best_unsubmitted(&mut self) -> Option> { + self.global_queue + .pop_best_unsubmitted() + .map(|n| Neighbor::new(n.id.id, n.distance)) + } + + fn mark_visited_by_id(&mut self, id: &P::Id) -> bool { + let limit = self.global_queue.search_l().min(self.global_queue.size()); + for i in self.global_queue.cursor..limit { + if self.global_queue.get(i).id.id == *id { + self.global_queue.set_state(i, NodeState::Visited); + // Advance cursor past consecutive Visited nodes + if i == self.global_queue.cursor { + while self.global_queue.cursor < limit + && self.global_queue.get_state(self.global_queue.cursor) + == NodeState::Visited + { + self.global_queue.cursor += 1; + } + } + return true; + } + } + false + } + + fn mark_submitted(&mut self, id: &P::Id) -> bool { + let limit = self.global_queue.search_l().min(self.global_queue.size()); + for i in self.global_queue.cursor..limit { + if self.global_queue.get(i).id.id == *id { + self.global_queue.set_state(i, NodeState::Submitted); + return true; + } + } + false + } + + fn revert_submitted(&mut self, id: &P::Id) -> bool { + let limit = self.global_queue.search_l().min(self.global_queue.size()); + for i in self.global_queue.cursor..limit { + if self.global_queue.get(i).id.id == *id { + self.global_queue.set_state(i, NodeState::Unvisited); + return true; + } + } + false + } } /// Trait for providing attribute values for vector IDs. diff --git a/diskann/src/neighbor/mod.rs b/diskann/src/neighbor/mod.rs index 29ee87981..a29b6bb4b 100644 --- a/diskann/src/neighbor/mod.rs +++ b/diskann/src/neighbor/mod.rs @@ -10,7 +10,7 @@ use crate::graph::{SearchOutputBuffer, search_output_buffer}; // Exports mod queue; -pub use queue::{NeighborPriorityQueue, NeighborPriorityQueueIdType, NeighborQueue}; +pub use queue::{NeighborPriorityQueue, NeighborPriorityQueueIdType, NeighborQueue, NodeState}; #[cfg(feature = "experimental_diversity_search")] mod diverse_priority_queue; diff --git a/diskann/src/neighbor/queue.rs b/diskann/src/neighbor/queue.rs index 48453ca2f..917300ccf 100644 --- a/diskann/src/neighbor/queue.rs +++ b/diskann/src/neighbor/queue.rs @@ -8,17 +8,39 @@ use std::marker::PhantomData; use super::Neighbor; +/// Tri-state for nodes in the priority queue. +/// +/// - `Unvisited`: candidate not yet selected for expansion. +/// - `Submitted`: selected and submitted for IO (pipelined) or expansion, but not yet expanded. +/// - `Visited`: fully expanded — neighbors have been processed. +#[repr(u8)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)] +pub enum NodeState { + #[default] + Unvisited = 0, + Submitted = 1, + Visited = 2, +} + /// Shared trait for type the generic `I` parameter used by the /// `NeighborPeriorityQueue`. pub trait NeighborPriorityQueueIdType: - Default + Eq + Clone + Copy + std::fmt::Debug + std::fmt::Display + Send + Sync + Default + Eq + Clone + Copy + std::fmt::Debug + std::fmt::Display + std::hash::Hash + Send + Sync { } /// Any type that implements all the individual requirements for /// `NeighborPriorityQueueIdType` implements the full trait. impl NeighborPriorityQueueIdType for T where - T: Default + Eq + Clone + Copy + std::fmt::Debug + std::fmt::Display + Send + Sync + T: Default + + Eq + + Clone + + Copy + + std::fmt::Debug + + std::fmt::Display + + std::hash::Hash + + Send + + Sync { } @@ -59,6 +81,35 @@ pub trait NeighborQueue: std::fmt::Debug + Send /// Return an iterator over the best candidates. fn iter(&self) -> Self::Iter<'_>; + + /// Return the first node that is `Unvisited` (not `Submitted` or `Visited`), + /// scanning from the cursor. Does not modify any state. + fn peek_best_unsubmitted(&self) -> Option> { + None + } + + /// Find the first `Unvisited` node, mark it `Submitted`, and return it — single pass. + fn pop_best_unsubmitted(&mut self) -> Option> { + None + } + + /// Find the node with matching `id`, mark it visited, and advance the cursor if needed. + /// Returns true if found and marked, false otherwise. + fn mark_visited_by_id(&mut self, _id: &I) -> bool { + false + } + + /// Transition a node from `Unvisited` to `Submitted`. + /// Returns true if found and transitioned, false otherwise. + fn mark_submitted(&mut self, _id: &I) -> bool { + false + } + + /// Transition a node from `Submitted` back to `Unvisited` (for rejected submissions). + /// Returns true if found and reverted, false otherwise. + fn revert_submitted(&mut self, _id: &I) -> bool { + false + } } /// Neighbor priority Queue based on the distance to the query node @@ -76,11 +127,11 @@ pub struct NeighborPriorityQueue { capacity: usize, /// The current notvisited neighbor whose distance is smallest among all notvisited neighbor - cursor: usize, + pub(crate) cursor: usize, - /// The neighbor (id, visited) collection. + /// The neighbor (id, state) collection. /// These are stored together to make inserts cheaper. - id_visiteds: Vec<(I, bool)>, + id_states: Vec<(I, NodeState)>, /// The neighbor distance collection distances: Vec, @@ -101,7 +152,7 @@ impl NeighborPriorityQueue { size: 0, capacity: search_param_l, cursor: 0, - id_visiteds: Vec::with_capacity(search_param_l), + id_states: Vec::with_capacity(search_param_l), distances: Vec::with_capacity(search_param_l), auto_resizable: false, search_param_l, @@ -114,7 +165,7 @@ impl NeighborPriorityQueue { size: 0, capacity: search_param_l, cursor: 0, - id_visiteds: Vec::with_capacity(search_param_l), + id_states: Vec::with_capacity(search_param_l), distances: Vec::with_capacity(search_param_l), auto_resizable: true, search_param_l, @@ -148,17 +199,18 @@ impl NeighborPriorityQueue { }; if self.size == self.capacity { - self.id_visiteds.truncate(self.size - 1); + self.id_states.truncate(self.size - 1); self.distances.truncate(self.size - 1); self.size -= 1; } - self.id_visiteds.insert(insert_idx, (nbr.id, false)); + self.id_states + .insert(insert_idx, (nbr.id, NodeState::Unvisited)); self.distances.insert(insert_idx, nbr.distance); self.size += 1; - debug_assert!(self.size == self.id_visiteds.len()); + debug_assert!(self.size == self.id_states.len()); debug_assert!(self.size == self.distances.len()); if insert_idx < self.cursor { @@ -175,11 +227,11 @@ impl NeighborPriorityQueue { // Copy the first L best candidates to the result vector for (i, res) in result.iter_mut().enumerate().take(extract_size) { - *res = Neighbor::new(self.id_visiteds[i].0, self.distances[i]); + *res = Neighbor::new(self.id_states[i].0, self.distances[i]); } // Remove the first L best candidates from the priority queue - self.id_visiteds.drain(0..extract_size); + self.id_states.drain(0..extract_size); self.distances.drain(0..extract_size); // Update the size and cursor of the priority queue @@ -192,7 +244,7 @@ impl NeighborPriorityQueue { /// Drain candidates from the front, signaling that they have been consumed. pub fn drain_best(&mut self, count: usize) { let count = count.min(self.size); - self.id_visiteds.drain(0..count); + self.id_states.drain(0..count); self.distances.drain(0..count); self.size -= count; self.cursor = 0; @@ -224,7 +276,7 @@ impl NeighborPriorityQueue { // Check if we found the exact neighbor (both id and distance must match) if index < self.size && self.get_unchecked(index).id == nbr.id { // Remove the neighbor from both collections - self.id_visiteds.remove(index); + self.id_states.remove(index); self.distances.remove(index); self.size -= 1; @@ -233,7 +285,7 @@ impl NeighborPriorityQueue { self.cursor -= 1; } - debug_assert!(self.size == self.id_visiteds.len()); + debug_assert!(self.size == self.id_states.len()); debug_assert!(self.size == self.distances.len()); return true; @@ -301,7 +353,7 @@ impl NeighborPriorityQueue { /// Get the neighbor at index - SAFETY: index must be less than size fn get_unchecked(&self, index: usize) -> Neighbor { debug_assert!(index < self.size); - let id = unsafe { self.id_visiteds.get_unchecked(index).0 }; + let id = unsafe { self.id_states.get_unchecked(index).0 }; let distance = unsafe { *self.distances.get_unchecked(index) }; Neighbor::new(id, distance) } @@ -315,11 +367,11 @@ impl NeighborPriorityQueue { /// Get the closest and notvisited neighbor pub fn closest_notvisited(&mut self) -> Neighbor { let current = self.cursor; - self.set_visited(current, true); + self.set_state(current, NodeState::Visited); - // Look for the next notvisited neighbor + // Advance cursor past Visited nodes (stop at Submitted or Unvisited) self.cursor += 1; - while self.cursor < self.size && self.get_visited(self.cursor) { + while self.cursor < self.size && self.get_state(self.cursor) == NodeState::Visited { self.cursor += 1; } self.get_unchecked(current) @@ -352,14 +404,14 @@ impl NeighborPriorityQueue { pub fn reconfigure(&mut self, search_param_l: usize) { self.search_param_l = search_param_l; if search_param_l < self.size { - self.id_visiteds.truncate(search_param_l); + self.id_states.truncate(search_param_l); self.distances.truncate(search_param_l); self.size = search_param_l; self.cursor = self.cursor.min(search_param_l); } else if search_param_l > self.capacity { // Grow the backing store. let additional = search_param_l - self.size; - self.id_visiteds.reserve(additional); + self.id_states.reserve(additional); self.distances.reserve(additional); } self.capacity = search_param_l; @@ -373,7 +425,7 @@ impl NeighborPriorityQueue { /// /// Most of the time, you want `reconfigure`. fn reserve(&mut self, additional: usize) { - self.id_visiteds.reserve(additional); + self.id_states.reserve(additional); self.distances.reserve(additional); self.capacity += additional; } @@ -381,23 +433,21 @@ impl NeighborPriorityQueue { /// Set size (and cursor) to 0. This must be called to reset the queue when reusing /// between searched. pub fn clear(&mut self) { - self.id_visiteds.clear(); + self.id_states.clear(); self.distances.clear(); self.size = 0; self.cursor = 0; } - fn set_visited(&mut self, index: usize, flag: bool) { - // SAFETY: index must be less than size - assert!(index <= self.size); + pub(crate) fn set_state(&mut self, index: usize, state: NodeState) { + assert!(index < self.size); assert!(self.size <= self.capacity); - unsafe { self.id_visiteds.get_unchecked_mut(index) }.1 = flag; + unsafe { self.id_states.get_unchecked_mut(index) }.1 = state; } - fn get_visited(&self, index: usize) -> bool { - // SAFETY: index must be less than size + pub(crate) fn get_state(&self, index: usize) -> NodeState { assert!(index < self.size); - unsafe { self.id_visiteds.get_unchecked(index).1 } + unsafe { self.id_states.get_unchecked(index).1 } } /// Return whether or not the queue is auto resizeable (for paged search). @@ -414,7 +464,7 @@ impl NeighborPriorityQueue { fn dbgassert_unique_insert(&self, id: I) { for i in 0..self.size { debug_assert!( - self.id_visiteds[i].0 != id, + self.id_states[i].0 != id, "Neighbor with ID {} already exists in the priority queue", id ); @@ -455,11 +505,11 @@ impl NeighborPriorityQueue { // If this item should be kept, move it to write position if f(&neighbor) { if write_idx != read_idx { - self.id_visiteds[write_idx] = self.id_visiteds[read_idx]; + self.id_states[write_idx] = self.id_states[read_idx]; self.distances[write_idx] = self.distances[read_idx]; } - // Reset visited state since compaction invalidates previous state - self.id_visiteds[write_idx].1 = false; + // Reset state since compaction invalidates previous state + self.id_states[write_idx].1 = NodeState::Unvisited; write_idx += 1; } } @@ -479,12 +529,84 @@ impl NeighborPriorityQueue { pub fn truncate(&mut self, len: usize) { let new_size = len; if new_size < self.size { - self.id_visiteds.truncate(new_size); + self.id_states.truncate(new_size); self.distances.truncate(new_size); self.size = new_size; self.cursor = 0; } } + + /// Return the first `Unvisited` node, scanning from cursor. + /// Does not modify any state. + pub fn peek_best_unsubmitted(&self) -> Option> { + let limit = self.search_param_l.min(self.size); + for i in self.cursor..limit { + if self.id_states[i].1 == NodeState::Unvisited { + return Some(Neighbor::new(self.id_states[i].0, self.distances[i])); + } + } + None + } + + /// Find the first `Unvisited` node, mark it `Submitted`, and return it — single pass. + pub fn pop_best_unsubmitted(&mut self) -> Option> { + let limit = self.search_param_l.min(self.size); + for i in self.cursor..limit { + if self.id_states[i].1 == NodeState::Unvisited { + self.id_states[i].1 = NodeState::Submitted; + return Some(Neighbor::new(self.id_states[i].0, self.distances[i])); + } + } + None + } + + /// Find the node with matching `id`, mark it `Visited`, and advance the cursor if needed. + /// Returns true if found and marked, false otherwise. + pub fn mark_visited_by_id(&mut self, id: &I) -> bool { + for i in self.cursor..self.size { + if self.id_states[i].0 == *id { + self.id_states[i].1 = NodeState::Visited; + // If the cursor was pointing at this node, advance past Visited nodes + if self.cursor == i { + self.cursor += 1; + while self.cursor < self.size + && self.get_state(self.cursor) == NodeState::Visited + { + self.cursor += 1; + } + } + return true; + } + } + false + } + + /// Transition a node from `Unvisited` to `Submitted`. + /// Returns true if found and transitioned, false otherwise. + pub fn mark_submitted(&mut self, id: &I) -> bool { + let limit = self.search_param_l.min(self.size); + for i in self.cursor..limit { + if self.id_states[i].0 == *id && self.id_states[i].1 == NodeState::Unvisited { + self.id_states[i].1 = NodeState::Submitted; + return true; + } + } + false + } + + /// Transition a node from `Submitted` back to `Unvisited`. + /// Used when submit_expand rejects an ID (no free IO slots). + /// Returns true if found and reverted, false otherwise. + pub fn revert_submitted(&mut self, id: &I) -> bool { + for i in self.cursor..self.size { + if self.id_states[i].0 == *id && self.id_states[i].1 == NodeState::Submitted { + debug_assert!(i >= self.cursor); + self.id_states[i].1 = NodeState::Unvisited; + return true; + } + } + false + } } impl NeighborQueue for NeighborPriorityQueue { @@ -529,6 +651,26 @@ impl NeighborQueue for NeighborPriorityQueue< fn iter(&self) -> Self::Iter<'_> { self.iter() } + + fn peek_best_unsubmitted(&self) -> Option> { + self.peek_best_unsubmitted() + } + + fn pop_best_unsubmitted(&mut self) -> Option> { + self.pop_best_unsubmitted() + } + + fn mark_visited_by_id(&mut self, id: &I) -> bool { + self.mark_visited_by_id(id) + } + + fn mark_submitted(&mut self, id: &I) -> bool { + self.mark_submitted(id) + } + + fn revert_submitted(&mut self, id: &I) -> bool { + self.revert_submitted(id) + } } /// Enable the following syntax for iteration over the valid elements in the queue. @@ -692,23 +834,23 @@ mod neighbor_priority_queue_test { let mut queue = NeighborPriorityQueue::new(3); queue.insert(Neighbor::new(1, 1.0)); queue.insert(Neighbor::new(2, 0.5)); - assert!(!queue.get_visited(0)); + assert!(queue.get_state(0) != NodeState::Visited); queue.insert(Neighbor::new(3, 1.5)); // node id in queue should be [2,1,3] assert!(queue.has_notvisited_node()); let nbr = queue.closest_notvisited(); assert_eq!(nbr.id, 2); assert_eq!(nbr.distance, 0.5); - assert!(queue.get_visited(0)); // super unfortunate test. We know based on above id 2 should be 0th index + assert!(queue.get_state(0) == NodeState::Visited); // super unfortunate test. We know based on above id 2 should be 0th index assert!(queue.has_notvisited_node()); let nbr = queue.closest_notvisited(); assert_eq!(nbr.id, 1); assert_eq!(nbr.distance, 1.0); - assert!(queue.get_visited(1)); + assert!(queue.get_state(1) == NodeState::Visited); assert!(queue.has_notvisited_node()); let nbr = queue.closest_notvisited(); assert_eq!(nbr.id, 3); assert_eq!(nbr.distance, 1.5); - assert!(queue.get_visited(2)); + assert!(queue.get_state(2) == NodeState::Visited); assert!(!queue.has_notvisited_node()); } @@ -728,7 +870,7 @@ mod neighbor_priority_queue_test { fn test_reserve() { let mut queue = NeighborPriorityQueue::::new(5); queue.reconfigure(10); - assert_eq!(queue.id_visiteds.len(), 0); + assert_eq!(queue.id_states.len(), 0); assert_eq!(queue.distances.len(), 0); assert_eq!(queue.capacity, 10); } @@ -738,7 +880,7 @@ mod neighbor_priority_queue_test { let mut queue = NeighborPriorityQueue::::new(10); queue.reconfigure(5); assert_eq!(queue.capacity, 5); - assert_eq!(queue.id_visiteds.len(), 0); + assert_eq!(queue.id_states.len(), 0); assert_eq!(queue.distances.len(), 0); queue.reconfigure(11); @@ -752,7 +894,7 @@ mod neighbor_priority_queue_test { assert_eq!(resizable_queue.capacity(), 10); assert_eq!(resizable_queue.size(), 0); assert!(resizable_queue.auto_resizable); - assert_eq!(resizable_queue.id_visiteds.len(), 0); + assert_eq!(resizable_queue.id_states.len(), 0); assert_eq!(resizable_queue.distances.len(), 0); } @@ -1427,4 +1569,230 @@ mod neighbor_priority_queue_test { assert_eq!(queue.size(), 1); assert_eq!(queue.cursor, 0); // cursor is always reset to 0 } + + #[test] + fn test_peek_best_unsubmitted_basic() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)] + + let result = queue.peek_best_unsubmitted(); + assert!(result.is_some()); + assert_eq!(result.unwrap().id, 2); // closest unvisited, unsubmitted + } + + #[test] + fn test_peek_best_unsubmitted_skips_submitted() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)] + + queue.mark_submitted(&2); + let result = queue.peek_best_unsubmitted(); + assert!(result.is_some()); + assert_eq!(result.unwrap().id, 1); // 2 is submitted, so next is 1 + } + + #[test] + fn test_peek_best_unsubmitted_skips_visited() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)] + + queue.closest_notvisited(); // visits 2 + + let result = queue.peek_best_unsubmitted(); + assert!(result.is_some()); + assert_eq!(result.unwrap().id, 1); // 2 is visited, so next is 1 + } + + #[test] + fn test_peek_best_unsubmitted_none_when_all_excluded() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + + queue.mark_submitted(&1); + queue.mark_submitted(&2); + let result = queue.peek_best_unsubmitted(); + assert!(result.is_none()); + } + + #[test] + fn test_peek_best_unsubmitted_respects_search_l() { + let mut queue = NeighborPriorityQueue::auto_resizable_with_search_param_l(2); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + queue.insert(Neighbor::new(4, 2.0)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5), 4(2.0)], search_l=2 + + queue.mark_submitted(&2); + queue.mark_submitted(&1); + // Both nodes within search_l window are submitted + let result = queue.peek_best_unsubmitted(); + assert!(result.is_none()); + } + + #[test] + fn test_peek_best_unsubmitted_does_not_modify_state() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + + let _ = queue.peek_best_unsubmitted(); + let _ = queue.peek_best_unsubmitted(); + + // Cursor should still be at 0 (no state modification) + assert_eq!(queue.cursor, 0); + assert!(queue.has_notvisited_node()); + } + + #[test] + fn test_peek_best_unsubmitted_empty_queue() { + let queue = NeighborPriorityQueue::::new(5); + assert!(queue.peek_best_unsubmitted().is_none()); + } + + #[test] + fn test_mark_visited_by_id_basic() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)] + + assert!(queue.mark_visited_by_id(&1)); + assert_eq!(queue.get_state(1), NodeState::Visited); // id=1 is at index 1 + } + + #[test] + fn test_mark_visited_by_id_not_found() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + + assert!(!queue.mark_visited_by_id(&99)); + } + + #[test] + fn test_mark_visited_by_id_advances_cursor() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)], cursor=0 + + // Mark the node at cursor (id=2 at index 0) + assert!(queue.mark_visited_by_id(&2)); + // Cursor should advance past this visited node to index 1 + assert_eq!(queue.cursor, 1); + } + + #[test] + fn test_mark_visited_by_id_cursor_skips_consecutive_visited() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)], cursor=0 + + // Visit id=1 (index 1) first - cursor stays at 0 + assert!(queue.mark_visited_by_id(&1)); + assert_eq!(queue.cursor, 0); + + // Now visit id=2 (index 0, where cursor is) - cursor should skip past both visited nodes + assert!(queue.mark_visited_by_id(&2)); + assert_eq!(queue.cursor, 2); // skips index 0 (visited) and index 1 (visited) + } + + #[test] + fn test_mark_visited_by_id_does_not_move_cursor_for_non_cursor_node() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)], cursor=0 + + // Mark id=3 (index 2) as visited - cursor should stay at 0 + assert!(queue.mark_visited_by_id(&3)); + assert_eq!(queue.cursor, 0); + } + + #[test] + fn test_peek_and_mark_workflow() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)] + + // Peek - should return id=2 + let node = queue.peek_best_unsubmitted().unwrap(); + assert_eq!(node.id, 2); + queue.mark_submitted(&node.id); + + // Peek again - should return id=1 (2 is submitted) + let node = queue.peek_best_unsubmitted().unwrap(); + assert_eq!(node.id, 1); + queue.mark_submitted(&node.id); + + // Mark id=2 as visited (IO completed) + assert!(queue.mark_visited_by_id(&2)); + + // Peek - should return id=3 (2 visited, 1 submitted) + let node = queue.peek_best_unsubmitted().unwrap(); + assert_eq!(node.id, 3); + } + + #[test] + fn test_mark_submitted_and_revert() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + // Queue sorted: [2(0.5), 1(1.0)] + + // Mark id=2 as submitted + assert!(queue.mark_submitted(&2)); + assert_eq!(queue.get_state(0), NodeState::Submitted); + + // peek should skip submitted + let node = queue.peek_best_unsubmitted().unwrap(); + assert_eq!(node.id, 1); + + // Revert id=2 back to unvisited + assert!(queue.revert_submitted(&2)); + assert_eq!(queue.get_state(0), NodeState::Unvisited); + + // Now peek should return id=2 again + let node = queue.peek_best_unsubmitted().unwrap(); + assert_eq!(node.id, 2); + } + + #[test] + fn test_cursor_stops_at_submitted() { + let mut queue = NeighborPriorityQueue::new(5); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + // Queue sorted: [2(0.5), 1(1.0), 3(1.5)], cursor=0 + + // Mark id=2 as submitted, then visited — cursor should advance past it + // but stop at id=1 (Unvisited) + queue.mark_submitted(&2); + queue.mark_visited_by_id(&2); + assert_eq!(queue.cursor, 1); + + // Mark id=1 as submitted — cursor should NOT advance (Submitted ≠ Visited) + queue.mark_submitted(&1); + assert_eq!(queue.cursor, 1); + + // has_notvisited_node still true (cursor < limit and id=1 is Submitted, not Visited) + assert!(queue.has_notvisited_node()); + } }