From 1abdb192e94d69c996eb8153f0679d65b13d590f Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sun, 28 Dec 2025 15:12:58 +0530 Subject: [PATCH 1/7] Fix BloomFilter buffer incompatibility between Spark and Comet Handle Spark's full serialization format (12-byte header + bits) in merge_filter() to support Spark partial / Comet final execution. The fix automatically detects the format and extracts bits data accordingly. Fixes #2889 --- .../src/bloom_filter/spark_bloom_filter.rs | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..656171c947 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -159,12 +159,35 @@ impl SparkBloomFilter { self.bits.to_bytes() } + /// Extracts bits data from Spark's full serialization format. + /// Spark's format includes a 12-byte header (version + num_hash_functions + num_words) + /// followed by the bits data. This function extracts just the bits data. + fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { + const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) + + // Check if this is Spark's full serialization format + let expected_bits_size = self.bits.byte_size(); + if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { + // This is Spark's full format, extract bits data (skip header) + &buf[SPARK_HEADER_SIZE..] + } else { + // This is already just bits data (Comet format) + buf + } + } + pub fn merge_filter(&mut self, other: &[u8]) { + // Extract bits data if other is in Spark's full serialization format + let bits_data = self.extract_bits_from_spark_format(other); + assert_eq!( - other.len(), + bits_data.len(), + self.bits.byte_size(), + "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." + bits_data.len(), + other.len() ); - self.bits.merge_bits(other); + self.bits.merge_bits(bits_data); } } From 5994c3fd3400c23e3cf7f8d08ae3ee0dfe76d09b Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 3 Jan 2026 19:54:50 +0530 Subject: [PATCH 2/7] minor change --- native/spark-expr/src/bloom_filter/spark_bloom_filter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index 656171c947..1e315e9f88 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -164,7 +164,7 @@ impl SparkBloomFilter { /// followed by the bits data. This function extracts just the bits data. fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) - + // Check if this is Spark's full serialization format let expected_bits_size = self.bits.byte_size(); if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { @@ -179,7 +179,7 @@ impl SparkBloomFilter { pub fn merge_filter(&mut self, other: &[u8]) { // Extract bits data if other is in Spark's full serialization format let bits_data = self.extract_bits_from_spark_format(other); - + assert_eq!( bits_data.len(), self.bits.byte_size(), From 030c67b6174f5e5ca05f376a15216041f2cc3538 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Tue, 6 Jan 2026 10:05:29 +0530 Subject: [PATCH 3/7] Fix Rust lifetime and borrow checker errors in merge_filter --- .../src/bloom_filter/spark_bloom_filter.rs | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index 1e315e9f88..f5ed086d27 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -159,32 +159,25 @@ impl SparkBloomFilter { self.bits.to_bytes() } - /// Extracts bits data from Spark's full serialization format. - /// Spark's format includes a 12-byte header (version + num_hash_functions + num_words) - /// followed by the bits data. This function extracts just the bits data. - fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { + pub fn merge_filter(&mut self, other: &[u8]) { + // Extract bits data if other is in Spark's full serialization format + // We need to compute the expected size and extract data before borrowing self.bits mutably + let expected_bits_size = self.bits.byte_size(); const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) - // Check if this is Spark's full serialization format - let expected_bits_size = self.bits.byte_size(); - if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { + let bits_data = if other.len() == SPARK_HEADER_SIZE + expected_bits_size { // This is Spark's full format, extract bits data (skip header) - &buf[SPARK_HEADER_SIZE..] + &other[SPARK_HEADER_SIZE..] } else { // This is already just bits data (Comet format) - buf - } - } - - pub fn merge_filter(&mut self, other: &[u8]) { - // Extract bits data if other is in Spark's full serialization format - let bits_data = self.extract_bits_from_spark_format(other); + other + }; assert_eq!( bits_data.len(), - self.bits.byte_size(), + expected_bits_size, "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", - self.bits.byte_size(), + expected_bits_size, bits_data.len(), other.len() ); From 49169a6aee875146961de287227f3a3e05de9ec6 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Thu, 8 Jan 2026 20:41:58 +0530 Subject: [PATCH 4/7] Remove fallback and add test for Spark partial / Comet final BloomFilterAggregate merge --- .../apache/spark/sql/comet/operators.scala | 8 ++- .../apache/comet/exec/CometExecSuite.scala | 65 ++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..7860b8380f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1069,7 +1069,13 @@ trait CometBaseAggregate { val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // Exception: BloomFilterAggregate supports Spark partial / Comet final because + // merge_filter() handles Spark's serialization format (12-byte header + bits). + val hasBloomFilterAgg = aggregate.aggregateExpressions.exists(expr => + expr.aggregateFunction.getClass.getSimpleName == "BloomFilterAggregate") + val sparkFinalMode = modes.contains(Final) && + findCometPartialAgg(aggregate.child).isEmpty && + !hasBloomFilterAgg if (multiMode || sparkFinalMode) { return None diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1b2373ad71..b40ce6c2bb 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -1149,6 +1150,68 @@ class CometExecSuite extends CometTestBase { spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) } + test("bloom_filter_agg - Spark partial / Comet final merge") { + // This test exercises the merge_filter() fix that handles Spark's full serialization + // format (12-byte header + bits) when merging from Spark partial to Comet final aggregates. + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Helper to count operators in plan + def countOperators(plan: SparkPlan, opClass: Class[_]): Int = { + stripAQEPlan(plan).collect { + case stage: QueryStageExec => + countOperators(stage.plan, opClass) + case op if op.getClass.isAssignableFrom(opClass) => 1 + }.sum + } + + withParquetTable( + (0 until 1000) + .map(_ => (Random.nextInt(1000), Random.nextInt(100))), + "tbl") { + + withSQLConf( + // Disable Comet partial aggregates to force Spark partial / Comet final scenario + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + + val df = sql( + "SELECT bloom_filter_agg(cast(_2 as long), cast(1000 as long)) FROM tbl GROUP BY _1") + + // Verify the query executes successfully (tests merge_filter compatibility) + checkSparkAnswer(df) + + // Verify we have Spark partial aggregates and Comet final aggregates + val plan = stripAQEPlan(df.queryExecution.executedPlan) + val sparkPartialAggs = plan.collect { + case agg: HashAggregateExec if agg.aggregateExpressions.exists(_.mode == Partial) => agg + } + val cometFinalAggs = plan.collect { + case agg: CometHashAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) => + agg + } + + assert( + sparkPartialAggs.nonEmpty, + s"Expected Spark partial aggregates but found none. Plan: $plan") + assert( + cometFinalAggs.nonEmpty, + s"Expected Comet final aggregates but found none. Plan: $plan") + } + } + + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + test("sort (non-global)") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) From 0b34826a677fe580c4030f00516ea08a102af776 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Wed, 14 Jan 2026 16:03:52 +0530 Subject: [PATCH 5/7] Fix missing imports in CometExecSuite: add SparkPlan, Partial, and Final --- .../src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index b40ce6c2bb..afa8aab949 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Hex} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate, Final, Partial} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat From 559caecda1865509b2ccb6d80885c00262688517 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Wed, 14 Jan 2026 16:11:19 +0530 Subject: [PATCH 6/7] Fix Rust compilation errors: make allocators and HDFS features mutually exclusive --- native/core/src/lib.rs | 10 ++++++++-- native/core/src/parquet/parquet_support.rs | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 10ecefad5b..1282aa5b93 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -42,7 +42,10 @@ use once_cell::sync::OnceCell; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] use tikv_jemallocator::Jemalloc; -#[cfg(feature = "mimalloc")] +#[cfg(all( + feature = "mimalloc", + not(all(not(target_env = "msvc"), feature = "jemalloc")) +))] use mimalloc::MiMalloc; use errors::{try_unwrap_or_throw, CometError, CometResult}; @@ -59,7 +62,10 @@ pub mod parquet; #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -#[cfg(feature = "mimalloc")] +#[cfg(all( + feature = "mimalloc", + not(all(not(target_env = "msvc"), feature = "jemalloc")) +))] #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; diff --git a/native/core/src/parquet/parquet_support.rs b/native/core/src/parquet/parquet_support.rs index 0b5c45d24d..f87dbce7b9 100644 --- a/native/core/src/parquet/parquet_support.rs +++ b/native/core/src/parquet/parquet_support.rs @@ -385,7 +385,7 @@ fn parse_hdfs_url(url: &Url) -> Result<(Box, Path), object_stor } } -#[cfg(feature = "hdfs-opendal")] +#[cfg(all(feature = "hdfs-opendal", not(feature = "hdfs")))] fn parse_hdfs_url(url: &Url) -> Result<(Box, Path), object_store::Error> { let name_node = get_name_node_uri(url)?; let builder = opendal::services::Hdfs::default().name_node(&name_node); @@ -401,7 +401,7 @@ fn parse_hdfs_url(url: &Url) -> Result<(Box, Path), object_stor Ok((Box::new(store), path)) } -#[cfg(feature = "hdfs-opendal")] +#[cfg(all(feature = "hdfs-opendal", not(feature = "hdfs")))] fn get_name_node_uri(url: &Url) -> Result { use std::fmt::Write; if let Some(host) = url.host() { From 6018e4ae47cad1fcd58f9ad3325ab9b10912bf56 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Wed, 14 Jan 2026 22:31:43 +0530 Subject: [PATCH 7/7] Improve bloom filter merge to handle Spark partial aggregate format with better error messages --- native/core/src/parquet/parquet_support.rs | 2 +- .../src/bloom_filter/spark_bloom_filter.rs | 73 ++++++++++++++++--- 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/native/core/src/parquet/parquet_support.rs b/native/core/src/parquet/parquet_support.rs index f63a3ef687..c9a27d7dcb 100644 --- a/native/core/src/parquet/parquet_support.rs +++ b/native/core/src/parquet/parquet_support.rs @@ -406,7 +406,7 @@ fn create_hdfs_object_store( Ok((Box::new(store), path)) } -#[cfg(all(feature = "hdfs-opendal", not(feature = "hdfs")))] +#[cfg(feature = "hdfs-opendal")] fn get_name_node_uri(url: &Url) -> Result { use std::fmt::Write; if let Some(host) = url.host() { diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index f5ed086d27..ed2c943b26 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -165,22 +165,71 @@ impl SparkBloomFilter { let expected_bits_size = self.bits.byte_size(); const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) - let bits_data = if other.len() == SPARK_HEADER_SIZE + expected_bits_size { - // This is Spark's full format, extract bits data (skip header) - &other[SPARK_HEADER_SIZE..] + let bits_data = if other.len() >= SPARK_HEADER_SIZE { + // Check if this is Spark's serialization format by reading the version + let version = i32::from_be_bytes([ + other[0], other[1], other[2], other[3], + ]); + if version == SPARK_BLOOM_FILTER_VERSION_1 { + // This is Spark's full format, parse it to extract bits data + let num_words = i32::from_be_bytes([ + other[8], other[9], other[10], other[11], + ]) as usize; + let bits_start = SPARK_HEADER_SIZE; + let bits_end = bits_start + (num_words * 8); + + // Verify the buffer is large enough + if bits_end > other.len() { + panic!( + "Cannot merge SparkBloomFilters: buffer too short. Expected at least {} bytes ({} words), got {} bytes", + bits_end, + num_words, + other.len() + ); + } + + // Check if the incoming bloom filter has compatible size + let incoming_bits_size = bits_end - bits_start; + if incoming_bits_size != expected_bits_size { + panic!( + "Cannot merge SparkBloomFilters with incompatible sizes. Expected {} bytes ({} words), got {} bytes ({} words) from Spark partial aggregate. Full buffer size: {} bytes", + expected_bits_size, + self.bits.word_size(), + incoming_bits_size, + num_words, + other.len() + ); + } + + // Extract just the bits portion + &other[bits_start..bits_end] + } else if other.len() == expected_bits_size { + // Not Spark format but size matches, treat as raw bits data (Comet format) + other + } else { + // Size doesn't match and not Spark format - provide helpful error + panic!( + "Cannot merge SparkBloomFilters: unexpected format. Expected {} bytes (Comet format) or Spark format (version 1, {} bytes header + bits), but got {} bytes with version {}", + expected_bits_size, + SPARK_HEADER_SIZE, + other.len(), + version + ); + } } else { - // This is already just bits data (Comet format) + // Too short to be Spark format + if other.len() != expected_bits_size { + panic!( + "Cannot merge SparkBloomFilters: buffer too short. Expected {} bytes (Comet format) or at least {} bytes (Spark format), got {} bytes", + expected_bits_size, + SPARK_HEADER_SIZE, + other.len() + ); + } + // Size matches, treat as raw bits data other }; - assert_eq!( - bits_data.len(), - expected_bits_size, - "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", - expected_bits_size, - bits_data.len(), - other.len() - ); self.bits.merge_bits(bits_data); } }