Skip to content

Commit 3f303d3

Browse files
authored
fix(udf): remove sqrt from l2_distance to match USearch L2sq metric (#8)
l2_distance UDF was computing actual L2 (with sqrt) while USearch and the rewritten execution paths all use L2sq (no sqrt). This caused the same query to return different numeric distance values depending on whether the optimizer rewrote it. Remove sqrt to match USearch's MetricKind::L2sq and DuckDB VSS's array_distance behavior. All paths now return consistent L2sq values.
1 parent 80059cc commit 3f303d3

4 files changed

Lines changed: 63 additions & 5 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,11 @@ All three distance functions are **lower-is-closer**:
272272

273273
| SQL function | Index metric | Kernel |
274274
|---|---|---|
275-
| `l2_distance(a, b)` | `L2sq` | `sqrt(sum((a_i - b_i)^2))` (UDF) / `sum((a_i - b_i)^2)` (index) |
275+
| `l2_distance(a, b)` | `L2sq` | `sum((a_i - b_i)^2)` |
276276
| `cosine_distance(a, b)` | `Cos` | `1 - dot(a,b) / (norm(a) * norm(b))` |
277277
| `negative_dot_product(a, b)` | `IP` | `-(a . b)` |
278278

279-
Note: `l2_distance` UDF returns actual L2 (with sqrt) for human-readable distances; USearch uses L2sq internally (no sqrt). The sort order is identical.
279+
`l2_distance` returns squared L2 (no sqrt), matching USearch's `MetricKind::L2sq`. This ensures numeric consistency between the UDF, the rewritten index path, and the brute-force path.
280280

281281
### Running tests
282282

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ use datafusion::prelude::SessionContext;
9393
/// Register all extension components with a DataFusion [`SessionContext`].
9494
///
9595
/// Registers:
96-
/// - `l2_distance(col, query)` — Euclidean distance (L2)
96+
/// - `l2_distance(col, query)` — squared Euclidean distance (L2sq)
9797
/// - `cosine_distance(col, query)` — cosine distance
9898
/// - `negative_dot_product(col, query)` — negated inner product
9999
/// - `vector_usearch(table, query, k)` — explicit ANN table function

src/udf.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ use datafusion::scalar::ScalarValue;
2222

2323
type Kernel = fn(&[f32], &[f32]) -> f32;
2424

25+
// Returns L2sq (no sqrt) — matches USearch MetricKind::L2sq and keeps numeric
26+
// values consistent between the UDF path and the optimizer-rewritten index path.
2527
fn l2_kernel(a: &[f32], b: &[f32]) -> f32 {
2628
a.iter()
2729
.zip(b.iter())
2830
.map(|(x, y)| (x - y) * (x - y))
2931
.sum::<f32>()
30-
.sqrt()
3132
}
3233

3334
fn cosine_kernel(a: &[f32], b: &[f32]) -> f32 {

tests/execution.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
use std::sync::Arc;
2020

2121
use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
22-
use arrow_array::{FixedSizeListArray, RecordBatch, StringArray, UInt64Array};
22+
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array};
2323
use arrow_schema::{DataType, Field, Schema};
2424
use datafusion::execution::session_state::SessionStateBuilder;
2525
use datafusion::prelude::SessionContext;
@@ -453,3 +453,60 @@ async fn exec_parquet_native_where_no_matches() {
453453
let ids = collect_ids(&ctx, &sql).await;
454454
assert!(ids.is_empty(), "no rows should match; got {ids:?}");
455455
}
456+
457+
// ═══════════════════════════════════════════════════════════════════════════════
458+
// Numeric regression — l2_distance must return L2sq (no sqrt)
459+
// ═══════════════════════════════════════════════════════════════════════════════
460+
461+
/// l2_distance must return squared L2, not actual L2.
462+
/// Row 1 = [1,0,0,0], query = [1,0,0,0] → L2sq = 0.0
463+
/// Row 2 = [0,1,0,0], query = [1,0,0,0] → L2sq = 2.0 (L2 would be ~1.414)
464+
#[tokio::test]
465+
async fn exec_l2_distance_returns_l2sq() {
466+
let ctx = make_exec_ctx("items::vector").await;
467+
let sql =
468+
format!("SELECT id, l2_distance(vector, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 4");
469+
let df = ctx.sql(&sql).await.expect("sql");
470+
let batches = df.collect().await.expect("collect");
471+
472+
let mut dists: Vec<(u64, f32)> = vec![];
473+
for batch in &batches {
474+
let id_idx = batch.schema().index_of("id").unwrap();
475+
let dist_idx = batch.schema().index_of("dist").unwrap();
476+
let ids = batch
477+
.column(id_idx)
478+
.as_any()
479+
.downcast_ref::<UInt64Array>()
480+
.unwrap();
481+
let ds = batch
482+
.column(dist_idx)
483+
.as_any()
484+
.downcast_ref::<Float32Array>()
485+
.unwrap();
486+
for i in 0..batch.num_rows() {
487+
dists.push((ids.value(i), ds.value(i)));
488+
}
489+
}
490+
491+
// Row 1: exact match → 0.0
492+
let row1 = dists
493+
.iter()
494+
.find(|(id, _)| *id == 1)
495+
.expect("row 1 missing");
496+
assert!(
497+
(row1.1 - 0.0).abs() < 1e-6,
498+
"row 1 distance must be 0.0 (L2sq); got {}",
499+
row1.1
500+
);
501+
502+
// Row 2: [0,1,0,0] vs [1,0,0,0] → L2sq = 2.0, NOT sqrt(2) ≈ 1.414
503+
let row2 = dists
504+
.iter()
505+
.find(|(id, _)| *id == 2)
506+
.expect("row 2 missing");
507+
assert!(
508+
(row2.1 - 2.0).abs() < 1e-6,
509+
"row 2 distance must be 2.0 (L2sq), not {:.4} (would be ~1.414 if L2)",
510+
row2.1
511+
);
512+
}

0 commit comments

Comments
 (0)