From bcf3fc08a8dbcc2b9ce81537ef4f47c400943f56 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 19 Jan 2026 17:51:20 +0000 Subject: [PATCH 1/3] feat: add kll sketch implementation and tests --- datasketches/src/kll/helper.rs | 126 +++ datasketches/src/kll/mod.rs | 53 ++ datasketches/src/kll/serialization.rs | 48 ++ datasketches/src/kll/sketch.rs | 857 +++++++++++++++++++ datasketches/src/kll/sorted_view.rs | 191 +++++ datasketches/src/lib.rs | 1 + datasketches/tests/kll_serialization_test.rs | 306 +++++++ datasketches/tests/kll_test.rs | 317 +++++++ 8 files changed, 1899 insertions(+) create mode 100644 datasketches/src/kll/helper.rs create mode 100644 datasketches/src/kll/mod.rs create mode 100644 datasketches/src/kll/serialization.rs create mode 100644 datasketches/src/kll/sketch.rs create mode 100644 datasketches/src/kll/sorted_view.rs create mode 100644 datasketches/tests/kll_serialization_test.rs create mode 100644 datasketches/tests/kll_test.rs diff --git a/datasketches/src/kll/helper.rs b/datasketches/src/kll/helper.rs new file mode 100644 index 0000000..4004150 --- /dev/null +++ b/datasketches/src/kll/helper.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::Cell; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +const POWERS_OF_THREE: [u64; 31] = [ + 1, + 3, + 9, + 27, + 81, + 243, + 729, + 2187, + 6561, + 19683, + 59049, + 177147, + 531441, + 1594323, + 4782969, + 14348907, + 43046721, + 129140163, + 387420489, + 1162261467, + 3486784401, + 10460353203, + 31381059609, + 94143178827, + 282429536481, + 847288609443, + 2541865828329, + 7625597484987, + 22876792454961, + 68630377364883, + 205891132094649, +]; + +pub fn compute_total_capacity(k: u16, m: u8, num_levels: usize) -> u32 { + let mut total: u32 = 0; + for level in 0..num_levels { + total += level_capacity(k, num_levels, level, m); + } + total +} + +pub fn level_capacity(k: u16, num_levels: usize, height: usize, min_wid: u8) -> u32 { + assert!(height < num_levels, "height must be < num_levels"); + let depth = num_levels - height - 1; + let cap = int_cap_aux(k, depth as u8); + std::cmp::max(min_wid as u32, cap as u32) +} + +pub fn int_cap_aux(k: u16, depth: u8) -> u16 { + if depth > 60 { + panic!("depth must be <= 60"); + } + if depth <= 30 { + return int_cap_aux_aux(k, depth); + } + let half = depth / 2; + let rest = depth - half; + let tmp = int_cap_aux_aux(k, half); + int_cap_aux_aux(tmp, rest) +} + +pub fn int_cap_aux_aux(k: u16, depth: u8) -> u16 { + if depth > 30 { + panic!("depth must be <= 30"); + } + let twok = (k as u64) << 1; + let tmp = (twok << depth) / POWERS_OF_THREE[depth as usize]; + let result = (tmp + 1) >> 1; + assert!(result <= k as u64, "capacity result exceeds k"); + result as u16 +} + +pub fn sum_the_sample_weights(level_sizes: &[usize]) -> u64 { + let mut total = 0u64; + let mut weight = 1u64; + for &size in level_sizes { + total += weight * size as u64; + weight <<= 1; + } + total +} + +fn seed() -> u64 { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + nanos as u64 +} + +pub fn random_bit() -> u32 { + thread_local! { + static RNG_STATE: Cell = Cell::new(seed()); + } + + RNG_STATE.with(|state| { + let mut x = state.get(); + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + state.set(x); + (x & 1) as u32 + }) +} diff --git a/datasketches/src/kll/mod.rs b/datasketches/src/kll/mod.rs new file mode 100644 index 0000000..bfde855 --- /dev/null +++ b/datasketches/src/kll/mod.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! KLL sketch implementation for estimating quantiles and ranks. +//! +//! KLL is a compact, streaming quantiles sketch with lazy compaction and +//! near-optimal accuracy per retained item. It supports one-pass updates, +//! approximate quantiles, ranks, PMF, and CDF queries. +//! +//! This implementation follows Apache DataSketches semantics (Java KllSketch +//! / KllPreambleUtil, C++ kll_sketch) and uses the same binary serialization +//! format as those implementations. +//! +//! # Usage +//! +//! ```rust +//! # use datasketches::kll::KllSketch; +//! let mut sketch = KllSketch::::new(200); +//! sketch.update(1.0); +//! sketch.update(2.0); +//! let q = sketch.quantile(0.5, true).unwrap(); +//! assert!(q >= 1.0 && q <= 2.0); +//! ``` + +mod helper; +mod serialization; +mod sketch; +mod sorted_view; + +pub use self::sketch::KllSketch; + +/// Default value of parameter k. +pub const DEFAULT_K: u16 = 200; +/// Default value of parameter m. +pub const DEFAULT_M: u8 = 8; +/// Minimum value of parameter k. +pub const MIN_K: u16 = DEFAULT_M as u16; +/// Maximum value of parameter k. +pub const MAX_K: u16 = u16::MAX; diff --git a/datasketches/src/kll/serialization.rs b/datasketches/src/kll/serialization.rs new file mode 100644 index 0000000..998add5 --- /dev/null +++ b/datasketches/src/kll/serialization.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Binary serialization format constants for KLL sketches. +//! +//! Naming and layout follow the Apache DataSketches Java implementation +//! (`KllPreambleUtil`) and the C++ `kll_sketch` serialization format. + +/// Family ID for KLL sketches in DataSketches format (KllPreambleUtil.KLL_FAMILY). +pub const KLL_FAMILY_ID: u8 = 15; + +/// Serialization version for empty or full sketches (KllPreambleUtil.SERIAL_VERSION_EMPTY_FULL). +pub const SERIAL_VERSION_1: u8 = 1; +/// Serialization version for single-item sketches (KllPreambleUtil.SERIAL_VERSION_SINGLE). +pub const SERIAL_VERSION_2: u8 = 2; + +/// Preamble ints for empty and single-item sketches (KllPreambleUtil.PREAMBLE_INTS_EMPTY_SINGLE). +pub const PREAMBLE_INTS_SHORT: u8 = 2; +/// Preamble ints for sketches with more than one item (KllPreambleUtil.PREAMBLE_INTS_FULL). +pub const PREAMBLE_INTS_FULL: u8 = 5; + +/// Flag indicating the sketch is empty (KllPreambleUtil.EMPTY_BIT_MASK). +pub const FLAG_EMPTY: u8 = 1 << 0; +/// Flag indicating level zero is sorted (KllPreambleUtil.LEVEL_ZERO_SORTED_BIT_MASK). +pub const FLAG_LEVEL_ZERO_SORTED: u8 = 1 << 1; +/// Flag indicating the sketch has a single item (KllPreambleUtil.SINGLE_ITEM_BIT_MASK). +pub const FLAG_SINGLE_ITEM: u8 = 1 << 2; + +/// Serialized size for an empty sketch in bytes (KllPreambleUtil.DATA_START_ADR_SINGLE_ITEM). +pub const EMPTY_SIZE_BYTES: usize = 8; +/// Data offset for single-item sketches (KllPreambleUtil.DATA_START_ADR_SINGLE_ITEM). +pub const DATA_START_SINGLE_ITEM: usize = 8; +/// Data offset for sketches with more than one item (KllPreambleUtil.DATA_START_ADR). +pub const DATA_START: usize = 20; diff --git a/datasketches/src/kll/sketch.rs b/datasketches/src/kll/sketch.rs new file mode 100644 index 0000000..6a68bbb --- /dev/null +++ b/datasketches/src/kll/sketch.rs @@ -0,0 +1,857 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; + +use super::DEFAULT_K; +use super::DEFAULT_M; +use super::MAX_K; +use super::MIN_K; +use super::helper::compute_total_capacity; +use super::helper::level_capacity; +use super::helper::random_bit; +use super::helper::sum_the_sample_weights; +use super::serialization::DATA_START; +use super::serialization::DATA_START_SINGLE_ITEM; +use super::serialization::EMPTY_SIZE_BYTES; +use super::serialization::FLAG_EMPTY; +use super::serialization::FLAG_LEVEL_ZERO_SORTED; +use super::serialization::FLAG_SINGLE_ITEM; +use super::serialization::KLL_FAMILY_ID; +use super::serialization::PREAMBLE_INTS_FULL; +use super::serialization::PREAMBLE_INTS_SHORT; +use super::serialization::SERIAL_VERSION_1; +use super::serialization::SERIAL_VERSION_2; +use super::sorted_view::build_sorted_view; +use crate::codec::SketchBytes; +use crate::codec::SketchSlice; +use crate::error::Error; + +/// Trait implemented by item types supported by [`KllSketch`]. +pub(crate) trait KllItem: Clone { + /// Compare two items. + fn cmp(a: &Self, b: &Self) -> Ordering; + + /// Returns true if the item is NaN. + fn is_nan(_value: &Self) -> bool { + false + } + + /// Serialized size in bytes. + fn serialized_size(value: &Self) -> usize; + + /// Serialize a single item into the buffer. + fn serialize(value: &Self, bytes: &mut SketchBytes); + + /// Deserialize a single item from the input. + fn deserialize(input: &mut SketchSlice<'_>) -> Result; +} + +/// KLL sketch for estimating quantiles and ranks. +/// +/// See the [kll module level documentation](crate::kll) for more. +#[allow(private_bounds)] +#[derive(Debug, Clone, PartialEq)] +pub struct KllSketch { + k: u16, + m: u8, + min_k: u16, + n: u64, + is_level_zero_sorted: bool, + levels: Vec>, + min_item: Option, + max_item: Option, +} + +impl Default for KllSketch { + fn default() -> Self { + Self::new(DEFAULT_K) + } +} + +#[allow(private_bounds)] +impl KllSketch { + /// Creates a new sketch with the given value of k. + /// + /// # Panics + /// + /// Panics if k is not in [MIN_K, MAX_K]. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::kll::KllSketch; + /// let sketch = KllSketch::::new(200); + /// assert_eq!(sketch.k(), 200); + /// ``` + pub fn new(k: u16) -> Self { + assert!( + (MIN_K..=MAX_K).contains(&k), + "k must be in [{MIN_K}, {MAX_K}], got {k}" + ); + Self::make(k, k, 0, vec![Vec::new()], None, None, false) + } + + /// Returns parameter k used to configure this sketch. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns the minimum k used when merging sketches. + pub fn min_k(&self) -> u16 { + self.min_k + } + + /// Returns total weight of the stream. + pub fn n(&self) -> u64 { + self.n + } + + /// Returns true if the sketch has not seen any data. + pub fn is_empty(&self) -> bool { + self.n == 0 + } + + /// Returns the number of retained items. + pub fn num_retained(&self) -> usize { + self.levels.iter().map(|level| level.len()).sum() + } + + /// Returns true if the sketch is in estimation mode. + pub fn is_estimation_mode(&self) -> bool { + self.levels.len() > 1 + } + + /// Returns the minimum item seen by the sketch. + pub fn min_item(&self) -> Option<&T> { + self.min_item.as_ref() + } + + /// Returns the maximum item seen by the sketch. + pub fn max_item(&self) -> Option<&T> { + self.max_item.as_ref() + } + + /// Updates the sketch with a new item. + /// + /// NaN values are ignored for floating-point types. + pub fn update(&mut self, item: T) { + if T::is_nan(&item) { + return; + } + self.update_min_max(&item); + self.internal_update(item); + } + + /// Merges another sketch into this one. + /// + /// # Panics + /// + /// Panics if the sketches have incompatible parameters. + pub fn merge(&mut self, other: &KllSketch) { + if other.is_empty() { + return; + } + + assert_eq!( + self.m, other.m, + "incompatible m values: {} and {}", + self.m, other.m + ); + + self.update_min_max_from_other(other); + + let final_n = self.n + other.n; + for item in &other.levels[0] { + self.internal_update(item.clone()); + } + + if other.levels.len() >= 2 { + self.merge_higher_levels(other); + } + + self.n = final_n; + if other.is_estimation_mode() { + self.min_k = self.min_k.min(other.min_k); + } + + debug_assert_eq!(self.total_weight(), self.n, "total weight does not match n"); + } + + /// Returns the normalized rank of the given item. + pub fn rank(&self, item: &T, inclusive: bool) -> Option { + if self.is_empty() { + return None; + } + let view = build_sorted_view(&self.levels); + Some(view.rank(item, inclusive)) + } + + /// Returns the quantile for the given normalized rank. + /// + /// # Panics + /// + /// Panics if rank is not in [0.0, 1.0]. + pub fn quantile(&self, rank: f64, inclusive: bool) -> Option { + if self.is_empty() { + return None; + } + assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + let view = build_sorted_view(&self.levels); + Some(view.quantile(rank, inclusive)) + } + + /// Returns the approximate CDF for the given split points. + pub fn cdf(&self, split_points: &[T], inclusive: bool) -> Option> { + if self.is_empty() { + return None; + } + let view = build_sorted_view(&self.levels); + Some(view.cdf(split_points, inclusive)) + } + + /// Returns the approximate PMF for the given split points. + pub fn pmf(&self, split_points: &[T], inclusive: bool) -> Option> { + if self.is_empty() { + return None; + } + let view = build_sorted_view(&self.levels); + Some(view.pmf(split_points, inclusive)) + } + + /// Returns normalized rank error for the configured k. + pub fn normalized_rank_error(&self, pmf: bool) -> f64 { + normalized_rank_error(self.min_k, pmf) + } + + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + let size = self.serialized_size(); + let mut bytes = SketchBytes::with_capacity(size); + + let is_empty = self.is_empty(); + let is_single_item = self.n == 1; + + let preamble_ints = if is_empty || is_single_item { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_FULL + }; + let serial_version = if is_single_item { + SERIAL_VERSION_2 + } else { + SERIAL_VERSION_1 + }; + + let flags = (if is_empty { FLAG_EMPTY } else { 0 }) + | (if self.is_level_zero_sorted { + FLAG_LEVEL_ZERO_SORTED + } else { + 0 + }) + | (if is_single_item { FLAG_SINGLE_ITEM } else { 0 }); + + bytes.write_u8(preamble_ints); + bytes.write_u8(serial_version); + bytes.write_u8(KLL_FAMILY_ID); + bytes.write_u8(flags); + bytes.write_u16_le(self.k); + bytes.write_u8(self.m); + bytes.write_u8(0); + + if is_empty { + return bytes.into_bytes(); + } + + if !is_single_item { + bytes.write_u64_le(self.n); + bytes.write_u16_le(self.min_k); + bytes.write_u8(self.levels.len() as u8); + bytes.write_u8(0); + + let level_offsets = self.level_offsets(); + for offset in level_offsets.iter().take(self.levels.len()) { + bytes.write_u32_le(*offset); + } + + if let Some(min_item) = &self.min_item { + T::serialize(min_item, &mut bytes); + } + if let Some(max_item) = &self.max_item { + T::serialize(max_item, &mut bytes); + } + } + + for level in &self.levels { + for item in level { + T::serialize(item, &mut bytes); + } + } + + bytes.into_bytes() + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result, Error> { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { + move |_| Error::insufficient_data(tag) + } + + let mut cursor = SketchSlice::new(bytes); + + let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let k = cursor.read_u16_le().map_err(make_error("k"))?; + let m = cursor.read_u8().map_err(make_error("m"))?; + let _unused = cursor.read_u8().map_err(make_error("unused"))?; + + if m != DEFAULT_M { + return Err(Error::deserial(format!( + "invalid m: expected {DEFAULT_M}, got {m}" + ))); + } + if family_id != KLL_FAMILY_ID { + return Err(Error::invalid_family(KLL_FAMILY_ID, family_id, "KLL")); + } + if serial_version != SERIAL_VERSION_1 && serial_version != SERIAL_VERSION_2 { + return Err(Error::deserial(format!( + "invalid serial version: {serial_version}" + ))); + } + + let is_empty = (flags & FLAG_EMPTY) != 0; + let is_single_item = (flags & FLAG_SINGLE_ITEM) != 0; + let is_level_zero_sorted = (flags & FLAG_LEVEL_ZERO_SORTED) != 0; + if is_empty || is_single_item { + if preamble_ints != PREAMBLE_INTS_SHORT { + return Err(Error::deserial(format!( + "invalid preamble ints: expected {PREAMBLE_INTS_SHORT}, got {preamble_ints}" + ))); + } + } else if preamble_ints != PREAMBLE_INTS_FULL { + return Err(Error::deserial(format!( + "invalid preamble ints: expected {PREAMBLE_INTS_FULL}, got {preamble_ints}" + ))); + } + + if !(MIN_K..=MAX_K).contains(&k) { + return Err(Error::deserial(format!("k out of range: {k}"))); + } + + if is_empty { + return Ok(Self::make( + k, + k, + 0, + vec![Vec::new()], + None, + None, + is_level_zero_sorted, + )); + } + + let (n, min_k, num_levels) = if is_single_item { + (1u64, k, 1usize) + } else { + let n = cursor.read_u64_le().map_err(make_error("n"))?; + let min_k = cursor.read_u16_le().map_err(make_error("min_k"))?; + let num_levels = cursor.read_u8().map_err(make_error("num_levels"))?; + let _unused = cursor.read_u8().map_err(make_error("unused2"))?; + (n, min_k, num_levels as usize) + }; + + if num_levels == 0 { + return Err(Error::deserial("num_levels must be > 0")); + } + if min_k < MIN_K || min_k > k { + return Err(Error::deserial(format!( + "min_k must be in [{MIN_K}, {k}], got {min_k}" + ))); + } + + let capacity = compute_total_capacity(k, m, num_levels) as u32; + let mut level_offsets = Vec::with_capacity(num_levels + 1); + if !is_single_item { + for _ in 0..num_levels { + let offset = cursor.read_u32_le().map_err(make_error("levels"))?; + level_offsets.push(offset); + } + } else { + level_offsets.push(capacity - 1); + } + level_offsets.push(capacity); + + if level_offsets.is_empty() { + return Err(Error::deserial("levels array is empty")); + } + if level_offsets[0] > capacity { + return Err(Error::deserial("levels[0] exceeds capacity")); + } + for window in level_offsets.windows(2) { + if window[1] < window[0] { + return Err(Error::deserial("levels array must be non-decreasing")); + } + } + let last = *level_offsets.last().unwrap(); + if last != capacity { + return Err(Error::deserial("levels last offset must equal capacity")); + } + + let min_item = if is_single_item { + None + } else { + Some(T::deserialize(&mut cursor)?) + }; + let max_item = if is_single_item { + None + } else { + Some(T::deserialize(&mut cursor)?) + }; + + let mut levels = Vec::with_capacity(num_levels); + for level in 0..num_levels { + let size = (level_offsets[level + 1] - level_offsets[level]) as usize; + let mut items = Vec::with_capacity(size); + for _ in 0..size { + items.push(T::deserialize(&mut cursor)?); + } + levels.push(items); + } + + let mut sketch = Self::make( + k, + min_k, + n, + levels, + min_item, + max_item, + is_level_zero_sorted, + ); + + if is_single_item { + if let Some(item) = sketch.levels[0].first().cloned() { + sketch.min_item = Some(item.clone()); + sketch.max_item = Some(item); + } + } + + Ok(sketch) + } + + fn make( + k: u16, + min_k: u16, + n: u64, + levels: Vec>, + min_item: Option, + max_item: Option, + is_level_zero_sorted: bool, + ) -> Self { + Self { + k, + m: DEFAULT_M, + min_k, + n, + is_level_zero_sorted, + levels, + min_item, + max_item, + } + } + + fn capacity(&self) -> usize { + compute_total_capacity(self.k, self.m, self.levels.len()) as usize + } + + fn level_offsets(&self) -> Vec { + let capacity = self.capacity() as u32; + let retained = self.num_retained() as u32; + assert!(capacity >= retained, "capacity must be >= retained"); + + let mut offsets = Vec::with_capacity(self.levels.len() + 1); + let mut offset = capacity - retained; + offsets.push(offset); + for level in &self.levels { + offset += level.len() as u32; + offsets.push(offset); + } + offsets + } + + fn serialized_size(&self) -> usize { + if self.is_empty() { + return EMPTY_SIZE_BYTES; + } + if self.n == 1 { + let item = &self.levels[0][0]; + return DATA_START_SINGLE_ITEM + T::serialized_size(item); + } + + let mut size = DATA_START + self.levels.len() * 4; + if let Some(min_item) = &self.min_item { + size += T::serialized_size(min_item); + } + if let Some(max_item) = &self.max_item { + size += T::serialized_size(max_item); + } + for level in &self.levels { + for item in level { + size += T::serialized_size(item); + } + } + size + } + + fn update_min_max(&mut self, item: &T) { + match self.min_item.as_ref() { + None => { + self.min_item = Some(item.clone()); + self.max_item = Some(item.clone()); + } + Some(min) => { + if T::cmp(item, min) == Ordering::Less { + self.min_item = Some(item.clone()); + } + if let Some(max) = &self.max_item { + if T::cmp(max, item) == Ordering::Less { + self.max_item = Some(item.clone()); + } + } + } + } + } + + fn update_min_max_from_other(&mut self, other: &KllSketch) { + match (&self.min_item, &self.max_item) { + (None, None) => { + self.min_item = other.min_item.clone(); + self.max_item = other.max_item.clone(); + } + (Some(min), Some(max)) => { + if let Some(other_min) = &other.min_item { + if T::cmp(other_min, min) == Ordering::Less { + self.min_item = Some(other_min.clone()); + } + } + if let Some(other_max) = &other.max_item { + if T::cmp(max, other_max) == Ordering::Less { + self.max_item = Some(other_max.clone()); + } + } + } + _ => { + self.min_item = other.min_item.clone(); + self.max_item = other.max_item.clone(); + } + } + } + + fn internal_update(&mut self, item: T) { + if self.num_retained() >= self.capacity() { + self.compress_while_updating(); + } + self.n += 1; + self.is_level_zero_sorted = false; + self.levels[0].insert(0, item); + } + + fn compress_while_updating(&mut self) { + let level = self.find_level_to_compact(); + if level + 1 == self.levels.len() { + self.levels.push(Vec::new()); + } + + let mut current = std::mem::take(&mut self.levels[level]); + let mut above = std::mem::take(&mut self.levels[level + 1]); + + let odd = current.len() % 2 == 1; + let mut leftover = None; + if odd { + leftover = Some(current.remove(0)); + } + + if level == 0 && !self.is_level_zero_sorted { + current.sort_by(T::cmp); + } + + let use_up = above.is_empty(); + let promoted = downsample(current, random_bit(), use_up); + if above.is_empty() { + above = promoted; + } else { + above = merge_sorted_vec(promoted, above); + } + self.levels[level + 1] = above; + + let mut new_level = Vec::new(); + if let Some(item) = leftover { + new_level.push(item); + } + self.levels[level] = new_level; + } + + fn find_level_to_compact(&self) -> usize { + let num_levels = self.levels.len(); + for level in 0..num_levels { + let pop = self.levels[level].len() as u32; + let cap = level_capacity(self.k, num_levels, level, self.m); + if pop >= cap { + return level; + } + } + panic!("no level to compact"); + } + + fn merge_higher_levels(&mut self, other: &KllSketch) { + let provisional_levels = self.levels.len().max(other.levels.len()); + let mut self_levels = std::mem::take(&mut self.levels); + let mut work_levels = vec![Vec::new(); provisional_levels]; + work_levels[0] = std::mem::take(&mut self_levels[0]); + + for level in 1..provisional_levels { + let left = if level < self_levels.len() { + std::mem::take(&mut self_levels[level]) + } else { + Vec::new() + }; + let right = other.levels.get(level).cloned().unwrap_or_default(); + + work_levels[level] = if left.is_empty() { + right + } else if right.is_empty() { + left + } else { + merge_sorted_vec(left, right) + }; + } + + self.levels = general_compress(work_levels, self.k, self.m, self.is_level_zero_sorted); + } + + fn total_weight(&self) -> u64 { + let sizes: Vec = self.levels.iter().map(|level| level.len()).collect(); + sum_the_sample_weights(&sizes) + } +} + +fn normalized_rank_error(k: u16, pmf: bool) -> f64 { + let k = k as f64; + if pmf { + 2.446 / k.powf(0.9433) + } else { + 2.296 / k.powf(0.9723) + } +} + +fn downsample(items: Vec, offset: u32, use_up: bool) -> Vec { + let len = items.len(); + debug_assert!(len % 2 == 0, "length must be even"); + let offset = (offset & 1) as usize; + let parity = if use_up { + (len - 1 - offset) % 2 + } else { + offset + }; + + items + .into_iter() + .enumerate() + .filter_map(|(idx, item)| if idx % 2 == parity { Some(item) } else { None }) + .collect() +} + +fn merge_sorted_vec(left: Vec, right: Vec) -> Vec { + let mut merged = Vec::with_capacity(left.len() + right.len()); + let mut left_iter = left.into_iter().peekable(); + let mut right_iter = right.into_iter().peekable(); + + while let (Some(l), Some(r)) = (left_iter.peek(), right_iter.peek()) { + if T::cmp(l, r) == Ordering::Less { + merged.push(left_iter.next().unwrap()); + } else { + merged.push(right_iter.next().unwrap()); + } + } + merged.extend(left_iter); + merged.extend(right_iter); + merged +} + +fn general_compress( + mut levels_in: Vec>, + k: u16, + m: u8, + is_level_zero_sorted: bool, +) -> Vec> { + let mut current_num_levels = levels_in.len(); + let mut current_item_count: usize = levels_in.iter().map(|level| level.len()).sum(); + let mut target_item_count = compute_total_capacity(k, m, current_num_levels) as usize; + let mut levels_out = Vec::with_capacity(current_num_levels + 1); + + let mut current_level = 0usize; + while current_level < current_num_levels { + if current_level + 1 >= levels_in.len() { + levels_in.push(Vec::new()); + } + + let raw_pop = levels_in[current_level].len(); + let cap = level_capacity(k, current_num_levels, current_level, m) as usize; + + if current_item_count < target_item_count || raw_pop < cap { + levels_out.push(std::mem::take(&mut levels_in[current_level])); + } else { + let mut current = std::mem::take(&mut levels_in[current_level]); + let mut above = std::mem::take(&mut levels_in[current_level + 1]); + + let odd = current.len() % 2 == 1; + let mut leftover = None; + if odd { + leftover = Some(current.remove(0)); + } + + if current_level == 0 && !is_level_zero_sorted { + current.sort_by(T::cmp); + } + + let use_up = above.is_empty(); + let promoted = downsample(current, random_bit(), use_up); + let promoted_len = promoted.len(); + if above.is_empty() { + above = promoted; + } else { + above = merge_sorted_vec(promoted, above); + } + levels_in[current_level + 1] = above; + + let mut out_level = Vec::new(); + if let Some(item) = leftover { + out_level.push(item); + } + levels_out.push(out_level); + + current_item_count = current_item_count.saturating_sub(promoted_len); + + if current_level == current_num_levels - 1 { + current_num_levels += 1; + target_item_count += level_capacity(k, current_num_levels, 0, m) as usize; + if levels_in.len() < current_num_levels + 1 { + levels_in.resize_with(current_num_levels + 1, Vec::new); + } + } + } + current_level += 1; + } + + levels_out.truncate(current_num_levels); + levels_out +} + +impl KllItem for f32 { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.partial_cmp(b).unwrap_or(Ordering::Greater) + } + + fn is_nan(value: &Self) -> bool { + value.is_nan() + } + + fn serialized_size(_value: &Self) -> usize { + 4 + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_f32_le(*value); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + input + .read_f32_le() + .map_err(|_| Error::insufficient_data("f32")) + } +} + +impl KllItem for f64 { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.partial_cmp(b).unwrap_or(Ordering::Greater) + } + + fn is_nan(value: &Self) -> bool { + value.is_nan() + } + + fn serialized_size(_value: &Self) -> usize { + 8 + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_f64_le(*value); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + input + .read_f64_le() + .map_err(|_| Error::insufficient_data("f64")) + } +} + +impl KllItem for i64 { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.cmp(b) + } + + fn serialized_size(_value: &Self) -> usize { + 8 + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_i64_le(*value); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + input + .read_i64_le() + .map_err(|_| Error::insufficient_data("i64")) + } +} + +impl KllItem for String { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.cmp(b) + } + + fn serialized_size(value: &Self) -> usize { + 4 + value.len() + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_u32_le(value.len() as u32); + bytes.write(value.as_bytes()); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + let len = input + .read_u32_le() + .map_err(|_| Error::insufficient_data("string_len"))? as usize; + let mut buf = vec![0u8; len]; + input + .read_exact(&mut buf) + .map_err(|_| Error::insufficient_data("string_bytes"))?; + String::from_utf8(buf).map_err(|_| Error::deserial("invalid utf-8 string")) + } +} diff --git a/datasketches/src/kll/sorted_view.rs b/datasketches/src/kll/sorted_view.rs new file mode 100644 index 0000000..0918c52 --- /dev/null +++ b/datasketches/src/kll/sorted_view.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; + +use super::sketch::KllItem; + +#[derive(Debug, Clone)] +pub(crate) struct SortedView { + entries: Vec>, + total_weight: u64, +} + +#[derive(Debug, Clone)] +struct Entry { + item: T, + weight: u64, +} + +impl SortedView { + fn new(mut entries: Vec>) -> Self { + entries.sort_by(|a, b| T::cmp(&a.item, &b.item)); + let mut total_weight = 0u64; + for entry in &mut entries { + total_weight += entry.weight; + entry.weight = total_weight; + } + Self { + entries, + total_weight, + } + } + + pub fn rank(&self, item: &T, inclusive: bool) -> f64 { + if self.entries.is_empty() { + return 0.0; + } + + let idx = if inclusive { + upper_bound(&self.entries, item) + } else { + lower_bound(&self.entries, item) + }; + + if idx == 0 { + return 0.0; + } + let weight = self.entries[idx - 1].weight; + weight as f64 / self.total_weight as f64 + } + + pub fn quantile(&self, rank: f64, inclusive: bool) -> T { + let weight = if inclusive { + (rank * self.total_weight as f64).ceil() as u64 + } else { + (rank * self.total_weight as f64) as u64 + }; + + let idx = if inclusive { + lower_bound_by_weight(&self.entries, weight) + } else { + upper_bound_by_weight(&self.entries, weight) + }; + + if idx >= self.entries.len() { + return self.entries[self.entries.len() - 1].item.clone(); + } + self.entries[idx].item.clone() + } + + pub fn cdf(&self, split_points: &[T], inclusive: bool) -> Vec { + check_split_points(split_points); + let mut ranks = Vec::with_capacity(split_points.len() + 1); + for item in split_points { + ranks.push(self.rank(item, inclusive)); + } + ranks.push(1.0); + ranks + } + + pub fn pmf(&self, split_points: &[T], inclusive: bool) -> Vec { + let mut buckets = self.cdf(split_points, inclusive); + for i in (1..buckets.len()).rev() { + buckets[i] -= buckets[i - 1]; + } + buckets + } +} + +pub(crate) fn build_sorted_view(levels: &[Vec]) -> SortedView { + let num_retained: usize = levels.iter().map(|level| level.len()).sum(); + let mut entries = Vec::with_capacity(num_retained); + + for (level_idx, level) in levels.iter().enumerate() { + let weight = 1u64 << level_idx; + for item in level { + entries.push(Entry { + item: item.clone(), + weight, + }); + } + } + + SortedView::new(entries) +} + +fn check_split_points(split_points: &[T]) { + let len = split_points.len(); + if len == 1 && T::is_nan(&split_points[0]) { + panic!("split_points must not contain NaN values"); + } + for i in 0..len.saturating_sub(1) { + if T::is_nan(&split_points[i]) { + panic!("split_points must not contain NaN values"); + } + if T::cmp(&split_points[i], &split_points[i + 1]) == Ordering::Less { + continue; + } + panic!("split_points must be unique and monotonically increasing"); + } +} + +fn lower_bound(entries: &[Entry], item: &T) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if T::cmp(&entries[mid].item, item) == Ordering::Less { + left = mid + 1; + } else { + right = mid; + } + } + left +} + +fn upper_bound(entries: &[Entry], item: &T) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if T::cmp(&entries[mid].item, item) == Ordering::Greater { + right = mid; + } else { + left = mid + 1; + } + } + left +} + +fn lower_bound_by_weight(entries: &[Entry], weight: u64) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if entries[mid].weight < weight { + left = mid + 1; + } else { + right = mid; + } + } + left +} + +fn upper_bound_by_weight(entries: &[Entry], weight: u64) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if entries[mid].weight > weight { + right = mid; + } else { + left = mid + 1; + } + } + left +} diff --git a/datasketches/src/lib.rs b/datasketches/src/lib.rs index 009fd9e..9034d51 100644 --- a/datasketches/src/lib.rs +++ b/datasketches/src/lib.rs @@ -36,6 +36,7 @@ pub mod countmin; pub mod error; pub mod frequencies; pub mod hll; +pub mod kll; pub mod tdigest; pub mod theta; diff --git a/datasketches/tests/kll_serialization_test.rs b/datasketches/tests/kll_serialization_test.rs new file mode 100644 index 0000000..ee281ea --- /dev/null +++ b/datasketches/tests/kll_serialization_test.rs @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! KLL Sketch Serialization Compatibility Tests +//! +//! These tests verify binary compatibility with Apache DataSketches implementations: +//! - Java (datasketches-java) +//! - C++ (datasketches-cpp) +//! +//! Test data is generated by the reference implementations and stored in: +//! `tests/serialization_test_data/` + +mod common; + +use std::fs; +use std::path::PathBuf; + +use common::serialization_test_data; +use datasketches::kll::KllSketch; + +const DEFAULT_K: usize = 200; + +fn test_f32_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!( + sketch.k() as usize, + DEFAULT_K, + "wrong k in {}", + path.display() + ); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + assert_eq!( + sketch.min_item().cloned(), + Some(1.0), + "min item mismatch in {}", + path.display() + ); + assert_eq!( + sketch.max_item().cloned(), + Some(expected_n as f32), + "max item mismatch in {}", + path.display() + ); + } +} + +fn test_f64_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!( + sketch.k() as usize, + DEFAULT_K, + "wrong k in {}", + path.display() + ); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + assert_eq!( + sketch.min_item().cloned(), + Some(1.0), + "min item mismatch in {}", + path.display() + ); + assert_eq!( + sketch.max_item().cloned(), + Some(expected_n as f64), + "max item mismatch in {}", + path.display() + ); + } +} + +fn test_i64_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!( + sketch.k() as usize, + DEFAULT_K, + "wrong k in {}", + path.display() + ); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + assert_eq!( + sketch.min_item().cloned(), + Some(1), + "min item mismatch in {}", + path.display() + ); + assert_eq!( + sketch.max_item().cloned(), + Some(expected_n as i64), + "max item mismatch in {}", + path.display() + ); + } +} + +fn parse_string_value(value: &str) -> u64 { + value + .trim_start() + .parse::() + .expect("string value should be numeric") +} + +fn test_string_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!( + sketch.k() as usize, + DEFAULT_K, + "wrong k in {}", + path.display() + ); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + let min_item = sketch.min_item().expect("missing min item"); + let max_item = sketch.max_item().expect("missing max item"); + assert_eq!( + parse_string_value(min_item), + 1, + "min item mismatch in {}", + path.display() + ); + assert_eq!( + parse_string_value(max_item), + expected_n as u64, + "max item mismatch in {}", + path.display() + ); + } +} + +#[test] +fn test_java_kll_float_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_float_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_f32_file(path, n); + } +} + +#[test] +fn test_java_kll_double_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_double_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_f64_file(path, n); + } +} + +#[test] +fn test_java_kll_long_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_long_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_i64_file(path, n); + } +} + +#[test] +fn test_java_kll_string_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_string_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_string_file(path, n); + } +} + +#[test] +fn test_cpp_kll_float_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_float_n{n}_cpp.sk"); + let path = serialization_test_data("cpp_generated_files", &filename); + test_f32_file(path, n); + } +} + +#[test] +fn test_cpp_kll_double_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_double_n{n}_cpp.sk"); + let path = serialization_test_data("cpp_generated_files", &filename); + test_f64_file(path, n); + } +} + +#[test] +fn test_cpp_kll_string_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_string_n{n}_cpp.sk"); + let path = serialization_test_data("cpp_generated_files", &filename); + test_string_file(path, n); + } +} diff --git a/datasketches/tests/kll_test.rs b/datasketches/tests/kll_test.rs new file mode 100644 index 0000000..9473c28 --- /dev/null +++ b/datasketches/tests/kll_test.rs @@ -0,0 +1,317 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datasketches::kll::KllSketch; + +const DEFAULT_K: u16 = 200; +const RANK_EPS_FOR_K_200: f64 = 0.0133; +const NUMERIC_NOISE_TOLERANCE: f64 = 1e-6; + +fn assert_approx_eq(actual: f64, expected: f64, tolerance: f64) { + let delta = (actual - expected).abs(); + assert!( + delta <= tolerance, + "expected {expected} +/- {tolerance}, got {actual}" + ); +} + +#[test] +fn test_k_limits() { + let _min = KllSketch::::new(8); + let _max = KllSketch::::new(u16::MAX); +} + +#[test] +#[should_panic(expected = "k must be in")] +fn test_k_too_small_panics() { + KllSketch::::new(7); +} + +#[test] +fn test_empty() { + let sketch = KllSketch::::new(DEFAULT_K); + assert!(sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.n(), 0); + assert_eq!(sketch.num_retained(), 0); + assert!(sketch.min_item().is_none()); + assert!(sketch.max_item().is_none()); + assert!(sketch.rank(&0.0, true).is_none()); + assert!(sketch.quantile(0.5, true).is_none()); + assert!(sketch.pmf(&[0.0f32], true).is_none()); + assert!(sketch.cdf(&[0.0f32], true).is_none()); +} + +#[test] +#[should_panic(expected = "rank must be in [0.0, 1.0]")] +fn test_quantile_out_of_range_panics() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(0.0); + sketch.quantile(-1.0, true); +} + +#[test] +fn test_one_item() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(1.0); + assert!(!sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.n(), 1); + assert_eq!(sketch.num_retained(), 1); + assert_eq!(sketch.rank(&1.0, false), Some(0.0)); + assert_eq!(sketch.rank(&1.0, true), Some(1.0)); + assert_eq!(sketch.rank(&2.0, false), Some(1.0)); + assert_eq!(sketch.min_item().cloned(), Some(1.0)); + assert_eq!(sketch.max_item().cloned(), Some(1.0)); + assert_eq!(sketch.quantile(0.5, true), Some(1.0)); +} + +#[test] +fn test_nan_is_ignored() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(f32::NAN); + assert!(sketch.is_empty()); + sketch.update(0.0); + sketch.update(f32::NAN); + assert_eq!(sketch.n(), 1); +} + +#[test] +fn test_many_items_exact_mode() { + let mut sketch = KllSketch::::new(DEFAULT_K); + let n = DEFAULT_K as usize; + for i in 1..=n { + sketch.update(i as f32); + assert_eq!(sketch.n(), i as u64); + } + assert!(!sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.num_retained(), n); + assert_eq!(sketch.min_item().cloned(), Some(1.0)); + assert_eq!(sketch.quantile(0.0, true), Some(1.0)); + assert_eq!(sketch.max_item().cloned(), Some(n as f32)); + assert_eq!(sketch.quantile(1.0, true), Some(n as f32)); + + for i in 1..=n { + let inclusive_rank = i as f64 / n as f64; + assert_eq!(sketch.rank(&(i as f32), true), Some(inclusive_rank)); + let exclusive_rank = (i - 1) as f64 / n as f64; + assert_eq!(sketch.rank(&(i as f32), false), Some(exclusive_rank)); + } +} + +#[test] +fn test_ten_items_quantiles() { + let mut sketch = KllSketch::::new(DEFAULT_K); + for i in 1..=10 { + sketch.update(i as f32); + } + assert_eq!(sketch.quantile(0.0, true), Some(1.0)); + assert_eq!(sketch.quantile(0.5, true), Some(5.0)); + assert_eq!(sketch.quantile(0.99, true), Some(10.0)); + assert_eq!(sketch.quantile(1.0, true), Some(10.0)); +} + +#[test] +fn test_hundred_items_quantiles() { + let mut sketch = KllSketch::::new(DEFAULT_K); + for i in 0..100 { + sketch.update(i as f32); + } + assert_eq!(sketch.quantile(0.0, true), Some(0.0)); + assert_eq!(sketch.quantile(0.01, true), Some(0.0)); + assert_eq!(sketch.quantile(0.5, true), Some(49.0)); + assert_eq!(sketch.quantile(0.99, true), Some(98.0)); + assert_eq!(sketch.quantile(1.0, true), Some(99.0)); +} + +#[test] +fn test_many_items_estimation_mode_rank_error() { + let mut sketch = KllSketch::::new(DEFAULT_K); + let n = 10_000; + for i in 0..n { + sketch.update(i as f32); + } + assert!(!sketch.is_empty()); + assert!(sketch.is_estimation_mode()); + assert_eq!(sketch.min_item().cloned(), Some(0.0)); + assert_eq!(sketch.max_item().cloned(), Some((n - 1) as f32)); + + for i in (0..n).step_by(10) { + let true_rank = i as f64 / n as f64; + let rank = sketch.rank(&(i as f32), false).unwrap(); + assert_approx_eq(rank, true_rank, RANK_EPS_FOR_K_200); + } + + assert!(sketch.num_retained() > 0); +} + +#[test] +fn test_rank_cdf_pmf_consistency() { + let mut sketch = KllSketch::::new(DEFAULT_K); + let n = 200; + let mut values = Vec::with_capacity(n); + for i in 0..n { + sketch.update(i as f32); + values.push(i as f32); + } + + let ranks = sketch.cdf(&values, false).unwrap(); + let pmf = sketch.pmf(&values, false).unwrap(); + + let mut subtotal = 0.0; + for i in 0..n { + let rank = sketch.rank(&values[i], false).unwrap(); + assert_eq!(rank, ranks[i]); + subtotal += pmf[i]; + assert!( + (ranks[i] - subtotal).abs() <= NUMERIC_NOISE_TOLERANCE, + "cdf vs pmf mismatch at index {i}" + ); + } + + let ranks = sketch.cdf(&values, true).unwrap(); + let pmf = sketch.pmf(&values, true).unwrap(); + + let mut subtotal = 0.0; + for i in 0..n { + let rank = sketch.rank(&values[i], true).unwrap(); + assert_eq!(rank, ranks[i]); + subtotal += pmf[i]; + assert!( + (ranks[i] - subtotal).abs() <= NUMERIC_NOISE_TOLERANCE, + "cdf vs pmf mismatch at index {i}" + ); + } +} + +#[test] +#[should_panic(expected = "split_points must be unique and monotonically increasing")] +fn test_out_of_order_split_points_panics() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(0.0); + let split_points = [1.0, 0.0]; + let _ = sketch.cdf(&split_points, true); +} + +#[test] +#[should_panic(expected = "split_points must not contain NaN values")] +fn test_nan_split_point_panics() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(0.0); + let split_points = [f32::NAN]; + let _ = sketch.cdf(&split_points, true); +} + +#[test] +fn test_merge() { + let mut sketch1 = KllSketch::::new(DEFAULT_K); + let mut sketch2 = KllSketch::::new(DEFAULT_K); + let n = 10_000; + for i in 0..n { + sketch1.update(i as f32); + sketch2.update((2 * n - i - 1) as f32); + } + + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((n - 1) as f32)); + assert_eq!(sketch2.min_item().cloned(), Some(n as f32)); + assert_eq!(sketch2.max_item().cloned(), Some((2 * n - 1) as f32)); + + sketch1.merge(&sketch2); + + assert!(!sketch1.is_empty()); + assert_eq!(sketch1.n(), (2 * n) as u64); + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((2 * n - 1) as f32)); + let median = sketch1.quantile(0.5, true).unwrap(); + assert_approx_eq(median as f64, n as f64, n as f64 * RANK_EPS_FOR_K_200); +} + +#[test] +fn test_merge_lower_k() { + let mut sketch1 = KllSketch::::new(256); + let mut sketch2 = KllSketch::::new(128); + let n = 10_000; + for i in 0..n { + sketch1.update(i as f32); + sketch2.update((2 * n - i - 1) as f32); + } + + sketch1.merge(&sketch2); + + assert_eq!(sketch1.n(), (2 * n) as u64); + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((2 * n - 1) as f32)); + assert_eq!( + sketch1.normalized_rank_error(false), + sketch2.normalized_rank_error(false) + ); + assert_eq!( + sketch1.normalized_rank_error(true), + sketch2.normalized_rank_error(true) + ); + let median = sketch1.quantile(0.5, true).unwrap(); + assert_approx_eq(median as f64, n as f64, n as f64 * RANK_EPS_FOR_K_200); +} + +#[test] +fn test_merge_exact_mode_lower_k() { + let mut sketch1 = KllSketch::::new(256); + let sketch2 = KllSketch::::new(128); + let n = 10_000; + for i in 0..n { + sketch1.update(i as f32); + } + + let err_before = sketch1.normalized_rank_error(true); + sketch1.merge(&sketch2); + assert_eq!(sketch1.normalized_rank_error(true), err_before); + + assert_eq!(sketch1.n(), n as u64); + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((n - 1) as f32)); + let median = sketch1.quantile(0.5, true).unwrap(); + assert_approx_eq( + median as f64, + (n / 2) as f64, + (n as f64 / 2.0) * RANK_EPS_FOR_K_200, + ); +} + +#[test] +fn test_merge_min_max_from_other() { + let mut sketch1 = KllSketch::::new(DEFAULT_K); + let mut sketch2 = KllSketch::::new(DEFAULT_K); + sketch1.update(1.0); + sketch2.update(2.0); + sketch2.merge(&sketch1); + assert_eq!(sketch2.min_item().cloned(), Some(1.0)); + assert_eq!(sketch2.max_item().cloned(), Some(2.0)); +} + +#[test] +fn test_merge_min_max_large_other() { + let mut sketch1 = KllSketch::::new(DEFAULT_K); + for i in 0..1_000_000 { + sketch1.update(i as f32); + } + let mut sketch2 = KllSketch::::new(DEFAULT_K); + sketch2.merge(&sketch1); + assert_eq!(sketch2.min_item().cloned(), Some(0.0)); + assert_eq!(sketch2.max_item().cloned(), Some(999_999.0)); +} From 6ee9097143101d74d52a3407cdfeb3762eb5b9d3 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 19 Jan 2026 18:14:30 +0000 Subject: [PATCH 2/3] feat: align kll serialization with other sketches --- datasketches/src/error.rs | 6 + datasketches/src/kll/sketch.rs | 502 +++++++++++-------- datasketches/src/kll/sorted_view.rs | 1 + datasketches/tests/kll_serialization_test.rs | 39 +- datasketches/tests/kll_test.rs | 31 +- 5 files changed, 321 insertions(+), 258 deletions(-) diff --git a/datasketches/src/error.rs b/datasketches/src/error.rs index 31559d3..59ff32b 100644 --- a/datasketches/src/error.rs +++ b/datasketches/src/error.rs @@ -124,6 +124,12 @@ impl Error { "invalid preamble longs: expected {expected}, got {actual}" )) } + + pub(crate) fn invalid_preamble_ints(expected: u8, actual: u8) -> Self { + Self::deserial(format!( + "invalid preamble ints: expected {expected}, got {actual}" + )) + } } impl fmt::Debug for Error { diff --git a/datasketches/src/kll/sketch.rs b/datasketches/src/kll/sketch.rs index 6a68bbb..455e1ad 100644 --- a/datasketches/src/kll/sketch.rs +++ b/datasketches/src/kll/sketch.rs @@ -42,7 +42,7 @@ use crate::codec::SketchSlice; use crate::error::Error; /// Trait implemented by item types supported by [`KllSketch`]. -pub(crate) trait KllItem: Clone { +pub trait KllItem: Clone { /// Compare two items. fn cmp(a: &Self, b: &Self) -> Ordering; @@ -50,7 +50,9 @@ pub(crate) trait KllItem: Clone { fn is_nan(_value: &Self) -> bool { false } +} +pub(crate) trait KllSerde: KllItem { /// Serialized size in bytes. fn serialized_size(value: &Self) -> usize; @@ -64,7 +66,6 @@ pub(crate) trait KllItem: Clone { /// KLL sketch for estimating quantiles and ranks. /// /// See the [kll module level documentation](crate::kll) for more. -#[allow(private_bounds)] #[derive(Debug, Clone, PartialEq)] pub struct KllSketch { k: u16, @@ -83,7 +84,6 @@ impl Default for KllSketch { } } -#[allow(private_bounds)] impl KllSketch { /// Creates a new sketch with the given value of k. /// @@ -103,7 +103,16 @@ impl KllSketch { (MIN_K..=MAX_K).contains(&k), "k must be in [{MIN_K}, {MAX_K}], got {k}" ); - Self::make(k, k, 0, vec![Vec::new()], None, None, false) + Self { + k, + m: DEFAULT_M, + min_k: k, + n: 0, + is_level_zero_sorted: false, + levels: vec![Vec::new()], + min_item: None, + max_item: None, + } } /// Returns parameter k used to configure this sketch. @@ -237,223 +246,302 @@ impl KllSketch { pub fn normalized_rank_error(&self, pmf: bool) -> f64 { normalized_rank_error(self.min_k, pmf) } +} - /// Serializes the sketch to bytes. - pub fn serialize(&self) -> Vec { - let size = self.serialized_size(); - let mut bytes = SketchBytes::with_capacity(size); - - let is_empty = self.is_empty(); - let is_single_item = self.n == 1; +fn serialized_size(sketch: &KllSketch) -> usize { + if sketch.is_empty() { + return EMPTY_SIZE_BYTES; + } + if sketch.n == 1 { + let item = &sketch.levels[0][0]; + return DATA_START_SINGLE_ITEM + T::serialized_size(item); + } - let preamble_ints = if is_empty || is_single_item { - PREAMBLE_INTS_SHORT - } else { - PREAMBLE_INTS_FULL - }; - let serial_version = if is_single_item { - SERIAL_VERSION_2 - } else { - SERIAL_VERSION_1 - }; + let mut size = DATA_START + sketch.levels.len() * 4; + if let Some(min_item) = &sketch.min_item { + size += T::serialized_size(min_item); + } + if let Some(max_item) = &sketch.max_item { + size += T::serialized_size(max_item); + } + for level in &sketch.levels { + for item in level { + size += T::serialized_size(item); + } + } + size +} - let flags = (if is_empty { FLAG_EMPTY } else { 0 }) - | (if self.is_level_zero_sorted { - FLAG_LEVEL_ZERO_SORTED - } else { - 0 - }) - | (if is_single_item { FLAG_SINGLE_ITEM } else { 0 }); - - bytes.write_u8(preamble_ints); - bytes.write_u8(serial_version); - bytes.write_u8(KLL_FAMILY_ID); - bytes.write_u8(flags); - bytes.write_u16_le(self.k); - bytes.write_u8(self.m); - bytes.write_u8(0); +fn serialize_with_serde(sketch: &KllSketch) -> Vec { + let size = serialized_size(sketch); + let mut bytes = SketchBytes::with_capacity(size); - if is_empty { - return bytes.into_bytes(); - } + let is_empty = sketch.is_empty(); + let is_single_item = sketch.n == 1; - if !is_single_item { - bytes.write_u64_le(self.n); - bytes.write_u16_le(self.min_k); - bytes.write_u8(self.levels.len() as u8); - bytes.write_u8(0); + let preamble_ints = if is_empty || is_single_item { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_FULL + }; + let serial_version = if is_single_item { + SERIAL_VERSION_2 + } else { + SERIAL_VERSION_1 + }; - let level_offsets = self.level_offsets(); - for offset in level_offsets.iter().take(self.levels.len()) { - bytes.write_u32_le(*offset); - } + let flags = (if is_empty { FLAG_EMPTY } else { 0 }) + | (if sketch.is_level_zero_sorted { + FLAG_LEVEL_ZERO_SORTED + } else { + 0 + }) + | (if is_single_item { FLAG_SINGLE_ITEM } else { 0 }); + + bytes.write_u8(preamble_ints); + bytes.write_u8(serial_version); + bytes.write_u8(KLL_FAMILY_ID); + bytes.write_u8(flags); + bytes.write_u16_le(sketch.k); + bytes.write_u8(sketch.m); + bytes.write_u8(0); + + if is_empty { + return bytes.into_bytes(); + } + + if !is_single_item { + bytes.write_u64_le(sketch.n); + bytes.write_u16_le(sketch.min_k); + bytes.write_u8(sketch.levels.len() as u8); + bytes.write_u8(0); - if let Some(min_item) = &self.min_item { - T::serialize(min_item, &mut bytes); - } - if let Some(max_item) = &self.max_item { - T::serialize(max_item, &mut bytes); - } + let level_offsets = sketch.level_offsets(); + for offset in level_offsets.iter().take(sketch.levels.len()) { + bytes.write_u32_le(*offset); } - for level in &self.levels { - for item in level { - T::serialize(item, &mut bytes); - } + if let Some(min_item) = &sketch.min_item { + T::serialize(min_item, &mut bytes); + } + if let Some(max_item) = &sketch.max_item { + T::serialize(max_item, &mut bytes); } - - bytes.into_bytes() } - /// Deserializes a sketch from bytes. - pub fn deserialize(bytes: &[u8]) -> Result, Error> { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - - let mut cursor = SketchSlice::new(bytes); - - let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - let k = cursor.read_u16_le().map_err(make_error("k"))?; - let m = cursor.read_u8().map_err(make_error("m"))?; - let _unused = cursor.read_u8().map_err(make_error("unused"))?; - - if m != DEFAULT_M { - return Err(Error::deserial(format!( - "invalid m: expected {DEFAULT_M}, got {m}" - ))); - } - if family_id != KLL_FAMILY_ID { - return Err(Error::invalid_family(KLL_FAMILY_ID, family_id, "KLL")); - } - if serial_version != SERIAL_VERSION_1 && serial_version != SERIAL_VERSION_2 { - return Err(Error::deserial(format!( - "invalid serial version: {serial_version}" - ))); - } - - let is_empty = (flags & FLAG_EMPTY) != 0; - let is_single_item = (flags & FLAG_SINGLE_ITEM) != 0; - let is_level_zero_sorted = (flags & FLAG_LEVEL_ZERO_SORTED) != 0; - if is_empty || is_single_item { - if preamble_ints != PREAMBLE_INTS_SHORT { - return Err(Error::deserial(format!( - "invalid preamble ints: expected {PREAMBLE_INTS_SHORT}, got {preamble_ints}" - ))); - } - } else if preamble_ints != PREAMBLE_INTS_FULL { - return Err(Error::deserial(format!( - "invalid preamble ints: expected {PREAMBLE_INTS_FULL}, got {preamble_ints}" - ))); + for level in &sketch.levels { + for item in level { + T::serialize(item, &mut bytes); } + } - if !(MIN_K..=MAX_K).contains(&k) { - return Err(Error::deserial(format!("k out of range: {k}"))); - } + bytes.into_bytes() +} - if is_empty { - return Ok(Self::make( - k, - k, - 0, - vec![Vec::new()], - None, - None, - is_level_zero_sorted, +fn deserialize_with_serde(bytes: &[u8]) -> Result, Error> { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { + move |_| Error::insufficient_data(tag) + } + + let mut cursor = SketchSlice::new(bytes); + + let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let k = cursor.read_u16_le().map_err(make_error("k"))?; + let m = cursor.read_u8().map_err(make_error("m"))?; + let _unused = cursor.read_u8().map_err(make_error("unused"))?; + + if m != DEFAULT_M { + return Err(Error::deserial(format!( + "invalid m: expected {DEFAULT_M}, got {m}" + ))); + } + if family_id != KLL_FAMILY_ID { + return Err(Error::invalid_family(KLL_FAMILY_ID, family_id, "KLL")); + } + let is_empty = (flags & FLAG_EMPTY) != 0; + let is_single_item = (flags & FLAG_SINGLE_ITEM) != 0; + let is_level_zero_sorted = (flags & FLAG_LEVEL_ZERO_SORTED) != 0; + if is_empty || is_single_item { + if preamble_ints != PREAMBLE_INTS_SHORT { + return Err(Error::invalid_preamble_ints( + PREAMBLE_INTS_SHORT, + preamble_ints, )); } + } else if preamble_ints != PREAMBLE_INTS_FULL { + return Err(Error::invalid_preamble_ints( + PREAMBLE_INTS_FULL, + preamble_ints, + )); + } + let expected_version = if is_single_item { + SERIAL_VERSION_2 + } else { + SERIAL_VERSION_1 + }; + if serial_version != expected_version { + return Err(Error::unsupported_serial_version( + expected_version, + serial_version, + )); + } - let (n, min_k, num_levels) = if is_single_item { - (1u64, k, 1usize) - } else { - let n = cursor.read_u64_le().map_err(make_error("n"))?; - let min_k = cursor.read_u16_le().map_err(make_error("min_k"))?; - let num_levels = cursor.read_u8().map_err(make_error("num_levels"))?; - let _unused = cursor.read_u8().map_err(make_error("unused2"))?; - (n, min_k, num_levels as usize) - }; - - if num_levels == 0 { - return Err(Error::deserial("num_levels must be > 0")); - } - if min_k < MIN_K || min_k > k { - return Err(Error::deserial(format!( - "min_k must be in [{MIN_K}, {k}], got {min_k}" - ))); - } - - let capacity = compute_total_capacity(k, m, num_levels) as u32; - let mut level_offsets = Vec::with_capacity(num_levels + 1); - if !is_single_item { - for _ in 0..num_levels { - let offset = cursor.read_u32_le().map_err(make_error("levels"))?; - level_offsets.push(offset); - } - } else { - level_offsets.push(capacity - 1); - } - level_offsets.push(capacity); + if !(MIN_K..=MAX_K).contains(&k) { + return Err(Error::deserial(format!("k out of range: {k}"))); + } - if level_offsets.is_empty() { - return Err(Error::deserial("levels array is empty")); - } - if level_offsets[0] > capacity { - return Err(Error::deserial("levels[0] exceeds capacity")); - } - for window in level_offsets.windows(2) { - if window[1] < window[0] { - return Err(Error::deserial("levels array must be non-decreasing")); - } + if is_empty { + return Ok(KllSketch::make( + k, + k, + 0, + vec![Vec::new()], + None, + None, + is_level_zero_sorted, + )); + } + + let (n, min_k, num_levels) = if is_single_item { + (1u64, k, 1usize) + } else { + let n = cursor.read_u64_le().map_err(make_error("n"))?; + let min_k = cursor.read_u16_le().map_err(make_error("min_k"))?; + let num_levels = cursor.read_u8().map_err(make_error("num_levels"))?; + let _unused = cursor.read_u8().map_err(make_error("unused2"))?; + (n, min_k, num_levels as usize) + }; + + if num_levels == 0 { + return Err(Error::deserial("num_levels must be > 0")); + } + if min_k < MIN_K || min_k > k { + return Err(Error::deserial(format!( + "min_k must be in [{MIN_K}, {k}], got {min_k}" + ))); + } + + let capacity = compute_total_capacity(k, m, num_levels) as u32; + let mut level_offsets = Vec::with_capacity(num_levels + 1); + if !is_single_item { + for _ in 0..num_levels { + let offset = cursor.read_u32_le().map_err(make_error("levels"))?; + level_offsets.push(offset); } - let last = *level_offsets.last().unwrap(); - if last != capacity { - return Err(Error::deserial("levels last offset must equal capacity")); + } else { + level_offsets.push(capacity - 1); + } + level_offsets.push(capacity); + + if level_offsets.is_empty() { + return Err(Error::deserial("levels array is empty")); + } + if level_offsets[0] > capacity { + return Err(Error::deserial("levels[0] exceeds capacity")); + } + for window in level_offsets.windows(2) { + if window[1] < window[0] { + return Err(Error::deserial("levels array must be non-decreasing")); } + } + let last = *level_offsets.last().unwrap(); + if last != capacity { + return Err(Error::deserial("levels last offset must equal capacity")); + } - let min_item = if is_single_item { - None - } else { - Some(T::deserialize(&mut cursor)?) - }; - let max_item = if is_single_item { - None - } else { - Some(T::deserialize(&mut cursor)?) - }; + let min_item = if is_single_item { + None + } else { + Some(T::deserialize(&mut cursor)?) + }; + let max_item = if is_single_item { + None + } else { + Some(T::deserialize(&mut cursor)?) + }; - let mut levels = Vec::with_capacity(num_levels); - for level in 0..num_levels { - let size = (level_offsets[level + 1] - level_offsets[level]) as usize; - let mut items = Vec::with_capacity(size); - for _ in 0..size { - items.push(T::deserialize(&mut cursor)?); - } - levels.push(items); + let mut levels = Vec::with_capacity(num_levels); + for level in 0..num_levels { + let size = (level_offsets[level + 1] - level_offsets[level]) as usize; + let mut items = Vec::with_capacity(size); + for _ in 0..size { + items.push(T::deserialize(&mut cursor)?); } + levels.push(items); + } - let mut sketch = Self::make( - k, - min_k, - n, - levels, - min_item, - max_item, - is_level_zero_sorted, - ); + let mut sketch = KllSketch::make( + k, + min_k, + n, + levels, + min_item, + max_item, + is_level_zero_sorted, + ); - if is_single_item { - if let Some(item) = sketch.levels[0].first().cloned() { - sketch.min_item = Some(item.clone()); - sketch.max_item = Some(item); - } + if is_single_item { + if let Some(item) = sketch.levels[0].first().cloned() { + sketch.min_item = Some(item.clone()); + sketch.max_item = Some(item); } + } + + Ok(sketch) +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } - Ok(sketch) + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) } +} +impl KllSketch { fn make( k: u16, min_k: u16, @@ -494,30 +582,6 @@ impl KllSketch { offsets } - fn serialized_size(&self) -> usize { - if self.is_empty() { - return EMPTY_SIZE_BYTES; - } - if self.n == 1 { - let item = &self.levels[0][0]; - return DATA_START_SINGLE_ITEM + T::serialized_size(item); - } - - let mut size = DATA_START + self.levels.len() * 4; - if let Some(min_item) = &self.min_item { - size += T::serialized_size(min_item); - } - if let Some(max_item) = &self.max_item { - size += T::serialized_size(max_item); - } - for level in &self.levels { - for item in level { - size += T::serialized_size(item); - } - } - size - } - fn update_min_max(&mut self, item: &T) { match self.min_item.as_ref() { None => { @@ -770,7 +834,9 @@ impl KllItem for f32 { fn is_nan(value: &Self) -> bool { value.is_nan() } +} +impl KllSerde for f32 { fn serialized_size(_value: &Self) -> usize { 4 } @@ -794,7 +860,9 @@ impl KllItem for f64 { fn is_nan(value: &Self) -> bool { value.is_nan() } +} +impl KllSerde for f64 { fn serialized_size(_value: &Self) -> usize { 8 } @@ -814,7 +882,9 @@ impl KllItem for i64 { fn cmp(a: &Self, b: &Self) -> Ordering { a.cmp(b) } +} +impl KllSerde for i64 { fn serialized_size(_value: &Self) -> usize { 8 } @@ -834,7 +904,9 @@ impl KllItem for String { fn cmp(a: &Self, b: &Self) -> Ordering { a.cmp(b) } +} +impl KllSerde for String { fn serialized_size(value: &Self) -> usize { 4 + value.len() } diff --git a/datasketches/src/kll/sorted_view.rs b/datasketches/src/kll/sorted_view.rs index 0918c52..655fd05 100644 --- a/datasketches/src/kll/sorted_view.rs +++ b/datasketches/src/kll/sorted_view.rs @@ -118,6 +118,7 @@ pub(crate) fn build_sorted_view(levels: &[Vec]) -> SortedView SortedView::new(entries) } +#[track_caller] fn check_split_points(split_points: &[T]) { let len = split_points.len(); if len == 1 && T::is_nan(&split_points[0]) { diff --git a/datasketches/tests/kll_serialization_test.rs b/datasketches/tests/kll_serialization_test.rs index ee281ea..3c1dd57 100644 --- a/datasketches/tests/kll_serialization_test.rs +++ b/datasketches/tests/kll_serialization_test.rs @@ -30,20 +30,14 @@ use std::fs; use std::path::PathBuf; use common::serialization_test_data; +use datasketches::kll::DEFAULT_K; use datasketches::kll::KllSketch; -const DEFAULT_K: usize = 200; - fn test_f32_file(path: PathBuf, expected_n: usize) { let bytes = fs::read(&path).unwrap(); let sketch = KllSketch::::deserialize(&bytes).unwrap(); - assert_eq!( - sketch.k() as usize, - DEFAULT_K, - "wrong k in {}", - path.display() - ); + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); assert_eq!( sketch.n() as usize, expected_n, @@ -52,7 +46,7 @@ fn test_f32_file(path: PathBuf, expected_n: usize) { ); assert_eq!( sketch.is_estimation_mode(), - expected_n > DEFAULT_K, + expected_n > DEFAULT_K as usize, "wrong estimation mode in {}", path.display() ); @@ -86,12 +80,7 @@ fn test_f64_file(path: PathBuf, expected_n: usize) { let bytes = fs::read(&path).unwrap(); let sketch = KllSketch::::deserialize(&bytes).unwrap(); - assert_eq!( - sketch.k() as usize, - DEFAULT_K, - "wrong k in {}", - path.display() - ); + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); assert_eq!( sketch.n() as usize, expected_n, @@ -100,7 +89,7 @@ fn test_f64_file(path: PathBuf, expected_n: usize) { ); assert_eq!( sketch.is_estimation_mode(), - expected_n > DEFAULT_K, + expected_n > DEFAULT_K as usize, "wrong estimation mode in {}", path.display() ); @@ -134,12 +123,7 @@ fn test_i64_file(path: PathBuf, expected_n: usize) { let bytes = fs::read(&path).unwrap(); let sketch = KllSketch::::deserialize(&bytes).unwrap(); - assert_eq!( - sketch.k() as usize, - DEFAULT_K, - "wrong k in {}", - path.display() - ); + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); assert_eq!( sketch.n() as usize, expected_n, @@ -148,7 +132,7 @@ fn test_i64_file(path: PathBuf, expected_n: usize) { ); assert_eq!( sketch.is_estimation_mode(), - expected_n > DEFAULT_K, + expected_n > DEFAULT_K as usize, "wrong estimation mode in {}", path.display() ); @@ -189,12 +173,7 @@ fn test_string_file(path: PathBuf, expected_n: usize) { let bytes = fs::read(&path).unwrap(); let sketch = KllSketch::::deserialize(&bytes).unwrap(); - assert_eq!( - sketch.k() as usize, - DEFAULT_K, - "wrong k in {}", - path.display() - ); + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); assert_eq!( sketch.n() as usize, expected_n, @@ -203,7 +182,7 @@ fn test_string_file(path: PathBuf, expected_n: usize) { ); assert_eq!( sketch.is_estimation_mode(), - expected_n > DEFAULT_K, + expected_n > DEFAULT_K as usize, "wrong estimation mode in {}", path.display() ); diff --git a/datasketches/tests/kll_test.rs b/datasketches/tests/kll_test.rs index 9473c28..08ceaa7 100644 --- a/datasketches/tests/kll_test.rs +++ b/datasketches/tests/kll_test.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. +use datasketches::kll::DEFAULT_K; use datasketches::kll::KllSketch; +use datasketches::kll::MAX_K; +use datasketches::kll::MIN_K; -const DEFAULT_K: u16 = 200; -const RANK_EPS_FOR_K_200: f64 = 0.0133; const NUMERIC_NOISE_TOLERANCE: f64 = 1e-6; fn assert_approx_eq(actual: f64, expected: f64, tolerance: f64) { @@ -29,16 +30,20 @@ fn assert_approx_eq(actual: f64, expected: f64, tolerance: f64) { ); } +fn rank_eps(sketch: &KllSketch) -> f64 { + sketch.normalized_rank_error(false) +} + #[test] fn test_k_limits() { - let _min = KllSketch::::new(8); - let _max = KllSketch::::new(u16::MAX); + let _min = KllSketch::::new(MIN_K); + let _max = KllSketch::::new(MAX_K); } #[test] #[should_panic(expected = "k must be in")] fn test_k_too_small_panics() { - KllSketch::::new(7); + KllSketch::::new(MIN_K - 1); } #[test] @@ -151,10 +156,11 @@ fn test_many_items_estimation_mode_rank_error() { assert_eq!(sketch.min_item().cloned(), Some(0.0)); assert_eq!(sketch.max_item().cloned(), Some((n - 1) as f32)); + let rank_eps = rank_eps(&sketch); for i in (0..n).step_by(10) { let true_rank = i as f64 / n as f64; let rank = sketch.rank(&(i as f32), false).unwrap(); - assert_approx_eq(rank, true_rank, RANK_EPS_FOR_K_200); + assert_approx_eq(rank, true_rank, rank_eps); } assert!(sketch.num_retained() > 0); @@ -239,7 +245,8 @@ fn test_merge() { assert_eq!(sketch1.min_item().cloned(), Some(0.0)); assert_eq!(sketch1.max_item().cloned(), Some((2 * n - 1) as f32)); let median = sketch1.quantile(0.5, true).unwrap(); - assert_approx_eq(median as f64, n as f64, n as f64 * RANK_EPS_FOR_K_200); + let rank_eps = rank_eps(&sketch1); + assert_approx_eq(median as f64, n as f64, n as f64 * rank_eps); } #[test] @@ -266,7 +273,8 @@ fn test_merge_lower_k() { sketch2.normalized_rank_error(true) ); let median = sketch1.quantile(0.5, true).unwrap(); - assert_approx_eq(median as f64, n as f64, n as f64 * RANK_EPS_FOR_K_200); + let rank_eps = rank_eps(&sketch1); + assert_approx_eq(median as f64, n as f64, n as f64 * rank_eps); } #[test] @@ -286,11 +294,8 @@ fn test_merge_exact_mode_lower_k() { assert_eq!(sketch1.min_item().cloned(), Some(0.0)); assert_eq!(sketch1.max_item().cloned(), Some((n - 1) as f32)); let median = sketch1.quantile(0.5, true).unwrap(); - assert_approx_eq( - median as f64, - (n / 2) as f64, - (n as f64 / 2.0) * RANK_EPS_FOR_K_200, - ); + let rank_eps = rank_eps(&sketch1); + assert_approx_eq(median as f64, (n / 2) as f64, (n as f64 / 2.0) * rank_eps); } #[test] From f3fb6a23569be7e657c407cede5ceeabf0816e07 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 19 Jan 2026 18:23:37 +0000 Subject: [PATCH 3/3] feat: refine kll public api --- datasketches/src/kll/mod.rs | 9 +++++++++ datasketches/src/kll/sketch.rs | 6 +++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/datasketches/src/kll/mod.rs b/datasketches/src/kll/mod.rs index bfde855..5c27c0f 100644 --- a/datasketches/src/kll/mod.rs +++ b/datasketches/src/kll/mod.rs @@ -43,6 +43,15 @@ mod sorted_view; pub use self::sketch::KllSketch; +/// KLL sketch specialized for `f64`. +pub type KllSketchF64 = KllSketch; +/// KLL sketch specialized for `f32`. +pub type KllSketchF32 = KllSketch; +/// KLL sketch specialized for `i64`. +pub type KllSketchI64 = KllSketch; +/// KLL sketch specialized for `String`. +pub type KllSketchString = KllSketch; + /// Default value of parameter k. pub const DEFAULT_K: u16 = 200; /// Default value of parameter m. diff --git a/datasketches/src/kll/sketch.rs b/datasketches/src/kll/sketch.rs index 455e1ad..ace1476 100644 --- a/datasketches/src/kll/sketch.rs +++ b/datasketches/src/kll/sketch.rs @@ -42,6 +42,10 @@ use crate::codec::SketchSlice; use crate::error::Error; /// Trait implemented by item types supported by [`KllSketch`]. +/// +/// Implementations must provide a total ordering via `cmp`. +/// For floating-point types, ensure `cmp` handles NaN consistently and `is_nan` +/// returns true for values that should be ignored by updates. pub trait KllItem: Clone { /// Compare two items. fn cmp(a: &Self, b: &Self) -> Ordering; @@ -67,7 +71,7 @@ pub(crate) trait KllSerde: KllItem { /// /// See the [kll module level documentation](crate::kll) for more. #[derive(Debug, Clone, PartialEq)] -pub struct KllSketch { +pub struct KllSketch { k: u16, m: u8, min_k: u16,