diff --git a/README.md b/README.md index e7bb74f..8c086aa 100644 --- a/README.md +++ b/README.md @@ -272,11 +272,11 @@ All three distance functions are **lower-is-closer**: | SQL function | Index metric | Kernel | |---|---|---| -| `l2_distance(a, b)` | `L2sq` | `sqrt(sum((a_i - b_i)^2))` (UDF) / `sum((a_i - b_i)^2)` (index) | +| `l2_distance(a, b)` | `L2sq` | `sum((a_i - b_i)^2)` | | `cosine_distance(a, b)` | `Cos` | `1 - dot(a,b) / (norm(a) * norm(b))` | | `negative_dot_product(a, b)` | `IP` | `-(a . b)` | -Note: `l2_distance` UDF returns actual L2 (with sqrt) for human-readable distances; USearch uses L2sq internally (no sqrt). The sort order is identical. +`l2_distance` returns squared L2 (no sqrt), matching USearch's `MetricKind::L2sq`. This ensures numeric consistency between the UDF, the rewritten index path, and the brute-force path. ### Running tests diff --git a/src/lib.rs b/src/lib.rs index 538752d..5aee261 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ use datafusion::prelude::SessionContext; /// Register all extension components with a DataFusion [`SessionContext`]. /// /// Registers: -/// - `l2_distance(col, query)` — Euclidean distance (L2) +/// - `l2_distance(col, query)` — squared Euclidean distance (L2sq) /// - `cosine_distance(col, query)` — cosine distance /// - `negative_dot_product(col, query)` — negated inner product /// - `vector_usearch(table, query, k)` — explicit ANN table function diff --git a/src/udf.rs b/src/udf.rs index 226c89c..233fcc1 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -22,12 +22,13 @@ use datafusion::scalar::ScalarValue; type Kernel = fn(&[f32], &[f32]) -> f32; +// Returns L2sq (no sqrt) — matches USearch MetricKind::L2sq and keeps numeric +// values consistent between the UDF path and the optimizer-rewritten index path. fn l2_kernel(a: &[f32], b: &[f32]) -> f32 { a.iter() .zip(b.iter()) .map(|(x, y)| (x - y) * (x - y)) .sum::() - .sqrt() } fn cosine_kernel(a: &[f32], b: &[f32]) -> f32 { diff --git a/tests/execution.rs b/tests/execution.rs index 8d6c5e8..542b5fe 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use arrow_array::builder::{FixedSizeListBuilder, Float32Builder}; -use arrow_array::{FixedSizeListArray, RecordBatch, StringArray, UInt64Array}; +use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; @@ -453,3 +453,60 @@ async fn exec_parquet_native_where_no_matches() { let ids = collect_ids(&ctx, &sql).await; assert!(ids.is_empty(), "no rows should match; got {ids:?}"); } + +// ═══════════════════════════════════════════════════════════════════════════════ +// Numeric regression — l2_distance must return L2sq (no sqrt) +// ═══════════════════════════════════════════════════════════════════════════════ + +/// l2_distance must return squared L2, not actual L2. +/// Row 1 = [1,0,0,0], query = [1,0,0,0] → L2sq = 0.0 +/// Row 2 = [0,1,0,0], query = [1,0,0,0] → L2sq = 2.0 (L2 would be ~1.414) +#[tokio::test] +async fn exec_l2_distance_returns_l2sq() { + let ctx = make_exec_ctx("items::vector").await; + let sql = + format!("SELECT id, l2_distance(vector, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 4"); + let df = ctx.sql(&sql).await.expect("sql"); + let batches = df.collect().await.expect("collect"); + + let mut dists: Vec<(u64, f32)> = vec![]; + for batch in &batches { + let id_idx = batch.schema().index_of("id").unwrap(); + let dist_idx = batch.schema().index_of("dist").unwrap(); + let ids = batch + .column(id_idx) + .as_any() + .downcast_ref::() + .unwrap(); + let ds = batch + .column(dist_idx) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + dists.push((ids.value(i), ds.value(i))); + } + } + + // Row 1: exact match → 0.0 + let row1 = dists + .iter() + .find(|(id, _)| *id == 1) + .expect("row 1 missing"); + assert!( + (row1.1 - 0.0).abs() < 1e-6, + "row 1 distance must be 0.0 (L2sq); got {}", + row1.1 + ); + + // Row 2: [0,1,0,0] vs [1,0,0,0] → L2sq = 2.0, NOT sqrt(2) ≈ 1.414 + let row2 = dists + .iter() + .find(|(id, _)| *id == 2) + .expect("row 2 missing"); + assert!( + (row2.1 - 2.0).abs() < 1e-6, + "row 2 distance must be 2.0 (L2sq), not {:.4} (would be ~1.414 if L2)", + row2.1 + ); +}