Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,11 @@ All three distance functions are **lower-is-closer**:

| SQL function | Index metric | Kernel |
|---|---|---|
| `l2_distance(a, b)` | `L2sq` | `sqrt(sum((a_i - b_i)^2))` (UDF) / `sum((a_i - b_i)^2)` (index) |
| `l2_distance(a, b)` | `L2sq` | `sum((a_i - b_i)^2)` |
| `cosine_distance(a, b)` | `Cos` | `1 - dot(a,b) / (norm(a) * norm(b))` |
| `negative_dot_product(a, b)` | `IP` | `-(a . b)` |

Note: `l2_distance` UDF returns actual L2 (with sqrt) for human-readable distances; USearch uses L2sq internally (no sqrt). The sort order is identical.
`l2_distance` returns squared L2 (no sqrt), matching USearch's `MetricKind::L2sq`. This ensures numeric consistency between the UDF, the rewritten index path, and the brute-force path.

### Running tests

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ use datafusion::prelude::SessionContext;
/// Register all extension components with a DataFusion [`SessionContext`].
///
/// Registers:
/// - `l2_distance(col, query)` — Euclidean distance (L2)
/// - `l2_distance(col, query)` — squared Euclidean distance (L2sq)
/// - `cosine_distance(col, query)` — cosine distance
/// - `negative_dot_product(col, query)` — negated inner product
/// - `vector_usearch(table, query, k)` — explicit ANN table function
Expand Down
3 changes: 2 additions & 1 deletion src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ use datafusion::scalar::ScalarValue;

type Kernel = fn(&[f32], &[f32]) -> f32;

// Returns L2sq (no sqrt) — matches USearch MetricKind::L2sq and keeps numeric
// values consistent between the UDF path and the optimizer-rewritten index path.
fn l2_kernel(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 (non-blocking): No comment here explaining that the sqrt omission is intentional. A future maintainer may add .sqrt() back thinking it's a missing step, silently reintroducing the inconsistency. Consider:

Suggested change
}
// Returns L2sq (no sqrt) — matches USearch MetricKind::L2sq and keeps numeric
// values consistent between the UDF path and the optimizer-rewritten index path.
}


fn cosine_kernel(a: &[f32], b: &[f32]) -> f32 {
Expand Down
59 changes: 58 additions & 1 deletion tests/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use std::sync::Arc;

use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
use arrow_array::{FixedSizeListArray, RecordBatch, StringArray, UInt64Array};
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -453,3 +453,60 @@ async fn exec_parquet_native_where_no_matches() {
let ids = collect_ids(&ctx, &sql).await;
assert!(ids.is_empty(), "no rows should match; got {ids:?}");
}

// ═══════════════════════════════════════════════════════════════════════════════
// Numeric regression — l2_distance must return L2sq (no sqrt)
// ═══════════════════════════════════════════════════════════════════════════════

/// l2_distance must return squared L2, not actual L2.
/// Row 1 = [1,0,0,0], query = [1,0,0,0] → L2sq = 0.0
/// Row 2 = [0,1,0,0], query = [1,0,0,0] → L2sq = 2.0 (L2 would be ~1.414)
#[tokio::test]
async fn exec_l2_distance_returns_l2sq() {
let ctx = make_exec_ctx("items::vector").await;
let sql =
format!("SELECT id, l2_distance(vector, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 4");
let df = ctx.sql(&sql).await.expect("sql");
let batches = df.collect().await.expect("collect");

let mut dists: Vec<(u64, f32)> = vec![];
for batch in &batches {
let id_idx = batch.schema().index_of("id").unwrap();
let dist_idx = batch.schema().index_of("dist").unwrap();
let ids = batch
.column(id_idx)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let ds = batch
.column(dist_idx)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
for i in 0..batch.num_rows() {
dists.push((ids.value(i), ds.value(i)));
}
}

// Row 1: exact match → 0.0
let row1 = dists
.iter()
.find(|(id, _)| *id == 1)
.expect("row 1 missing");
assert!(
(row1.1 - 0.0).abs() < 1e-6,
"row 1 distance must be 0.0 (L2sq); got {}",
row1.1
);

// Row 2: [0,1,0,0] vs [1,0,0,0] → L2sq = 2.0, NOT sqrt(2) ≈ 1.414
let row2 = dists
.iter()
.find(|(id, _)| *id == 2)
.expect("row 2 missing");
assert!(
(row2.1 - 2.0).abs() < 1e-6,
"row 2 distance must be 2.0 (L2sq), not {:.4} (would be ~1.414 if L2)",
row2.1
);
}
Loading