diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..ed2c943b26 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -160,11 +160,76 @@ impl SparkBloomFilter { } pub fn merge_filter(&mut self, other: &[u8]) { - assert_eq!( - other.len(), - self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." - ); - self.bits.merge_bits(other); + // Extract bits data if other is in Spark's full serialization format + // We need to compute the expected size and extract data before borrowing self.bits mutably + let expected_bits_size = self.bits.byte_size(); + const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) + + let bits_data = if other.len() >= SPARK_HEADER_SIZE { + // Check if this is Spark's serialization format by reading the version + let version = i32::from_be_bytes([ + other[0], other[1], other[2], other[3], + ]); + if version == SPARK_BLOOM_FILTER_VERSION_1 { + // This is Spark's full format, parse it to extract bits data + let num_words = i32::from_be_bytes([ + other[8], other[9], other[10], other[11], + ]) as usize; + let bits_start = SPARK_HEADER_SIZE; + let bits_end = bits_start + (num_words * 8); + + // Verify the buffer is large enough + if bits_end > other.len() { + panic!( + "Cannot merge SparkBloomFilters: buffer too short. Expected at least {} bytes ({} words), got {} bytes", + bits_end, + num_words, + other.len() + ); + } + + // Check if the incoming bloom filter has compatible size + let incoming_bits_size = bits_end - bits_start; + if incoming_bits_size != expected_bits_size { + panic!( + "Cannot merge SparkBloomFilters with incompatible sizes. Expected {} bytes ({} words), got {} bytes ({} words) from Spark partial aggregate. Full buffer size: {} bytes", + expected_bits_size, + self.bits.word_size(), + incoming_bits_size, + num_words, + other.len() + ); + } + + // Extract just the bits portion + &other[bits_start..bits_end] + } else if other.len() == expected_bits_size { + // Not Spark format but size matches, treat as raw bits data (Comet format) + other + } else { + // Size doesn't match and not Spark format - provide helpful error + panic!( + "Cannot merge SparkBloomFilters: unexpected format. Expected {} bytes (Comet format) or Spark format (version 1, {} bytes header + bits), but got {} bytes with version {}", + expected_bits_size, + SPARK_HEADER_SIZE, + other.len(), + version + ); + } + } else { + // Too short to be Spark format + if other.len() != expected_bits_size { + panic!( + "Cannot merge SparkBloomFilters: buffer too short. Expected {} bytes (Comet format) or at least {} bytes (Spark format), got {} bytes", + expected_bits_size, + SPARK_HEADER_SIZE, + other.len() + ); + } + // Size matches, treat as raw bits data + other + }; + + self.bits.merge_bits(bits_data); } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..7860b8380f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1069,7 +1069,13 @@ trait CometBaseAggregate { val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // Exception: BloomFilterAggregate supports Spark partial / Comet final because + // merge_filter() handles Spark's serialization format (12-byte header + bits). + val hasBloomFilterAgg = aggregate.aggregateExpressions.exists(expr => + expr.aggregateFunction.getClass.getSimpleName == "BloomFilterAggregate") + val sparkFinalMode = modes.contains(Final) && + findCometPartialAgg(aggregate.child).isEmpty && + !hasBloomFilterAgg if (multiMode || sparkFinalMode) { return None diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1b2373ad71..afa8aab949 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,11 +32,12 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Hex} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate, Final, Partial} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -1149,6 +1150,68 @@ class CometExecSuite extends CometTestBase { spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) } + test("bloom_filter_agg - Spark partial / Comet final merge") { + // This test exercises the merge_filter() fix that handles Spark's full serialization + // format (12-byte header + bits) when merging from Spark partial to Comet final aggregates. + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Helper to count operators in plan + def countOperators(plan: SparkPlan, opClass: Class[_]): Int = { + stripAQEPlan(plan).collect { + case stage: QueryStageExec => + countOperators(stage.plan, opClass) + case op if op.getClass.isAssignableFrom(opClass) => 1 + }.sum + } + + withParquetTable( + (0 until 1000) + .map(_ => (Random.nextInt(1000), Random.nextInt(100))), + "tbl") { + + withSQLConf( + // Disable Comet partial aggregates to force Spark partial / Comet final scenario + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + + val df = sql( + "SELECT bloom_filter_agg(cast(_2 as long), cast(1000 as long)) FROM tbl GROUP BY _1") + + // Verify the query executes successfully (tests merge_filter compatibility) + checkSparkAnswer(df) + + // Verify we have Spark partial aggregates and Comet final aggregates + val plan = stripAQEPlan(df.queryExecution.executedPlan) + val sparkPartialAggs = plan.collect { + case agg: HashAggregateExec if agg.aggregateExpressions.exists(_.mode == Partial) => agg + } + val cometFinalAggs = plan.collect { + case agg: CometHashAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) => + agg + } + + assert( + sparkPartialAggs.nonEmpty, + s"Expected Spark partial aggregates but found none. Plan: $plan") + assert( + cometFinalAggs.nonEmpty, + s"Expected Comet final aggregates but found none. Plan: $plan") + } + } + + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + test("sort (non-global)") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)