From afad4bbed58f5f6882e89e696150be1b7768e9c4 Mon Sep 17 00:00:00 2001 From: Mryange Date: Thu, 26 Mar 2026 09:32:25 +0800 Subject: [PATCH 1/4] upd --- be/src/exec/common/agg_context_utils.h | 133 +++ be/src/exec/common/groupby_agg_context.cpp | 981 ++++++++++++++++++ be/src/exec/common/groupby_agg_context.h | 317 ++++++ .../exec/common/inline_count_agg_context.cpp | 372 +++++++ be/src/exec/common/inline_count_agg_context.h | 91 ++ be/src/exec/common/ungroupby_agg_context.cpp | 257 +++++ be/src/exec/common/ungroupby_agg_context.h | 118 +++ .../operator/aggregation_sink_operator.cpp | 843 ++------------- .../exec/operator/aggregation_sink_operator.h | 82 +- .../operator/aggregation_source_operator.cpp | 653 +----------- .../operator/aggregation_source_operator.h | 41 - .../partitioned_aggregation_sink_operator.cpp | 43 +- ...artitioned_aggregation_source_operator.cpp | 27 +- .../streaming_aggregation_operator.cpp | 768 ++------------ .../operator/streaming_aggregation_operator.h | 138 +-- be/src/exec/pipeline/dependency.cpp | 142 --- be/src/exec/pipeline/dependency.h | 156 +-- 17 files changed, 2603 insertions(+), 2559 deletions(-) create mode 100644 be/src/exec/common/agg_context_utils.h create mode 100644 be/src/exec/common/groupby_agg_context.cpp create mode 100644 be/src/exec/common/groupby_agg_context.h create mode 100644 be/src/exec/common/inline_count_agg_context.cpp create mode 100644 be/src/exec/common/inline_count_agg_context.h create mode 100644 be/src/exec/common/ungroupby_agg_context.cpp create mode 100644 be/src/exec/common/ungroupby_agg_context.h diff --git a/be/src/exec/common/agg_context_utils.h b/be/src/exec/common/agg_context_utils.h new file mode 100644 index 00000000000000..65d9b41d6d02bd --- /dev/null +++ b/be/src/exec/common/agg_context_utils.h @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "core/block/block.h" +#include "exprs/vectorized_agg_fn.h" +#include "exprs/vexpr_context.h" +#include "exprs/vslot_ref.h" + +/// Utility functions for aggregation context result output. +/// Eliminates duplicate column-preparation and block-assembly patterns +/// across GroupByAggContext, InlineCountAggContext, and UngroupByAggContext. +namespace doris::agg_context_utils { + +/// Take existing columns from block [start, start+count) if mem_reuse, +/// otherwise create new columns via create_fn(index). +/// +/// @param block the output block +/// @param mem_reuse whether the block supports memory reuse +/// @param start starting column position in the block +/// @param count number of columns to take or create +/// @param create_fn callable(size_t i) -> MutableColumnPtr for non-reuse path +/// @return MutableColumns with count elements +template +MutableColumns take_or_create_columns(Block* block, bool mem_reuse, size_t start, size_t count, + CreateFn&& create_fn) { + MutableColumns columns; + columns.reserve(count); + for (size_t i = 0; i < count; ++i) { + if (mem_reuse) { + columns.emplace_back(std::move(*block->get_by_position(start + i).column).mutate()); + } else { + columns.emplace_back(create_fn(i)); + } + } + return columns; +} + +/// Assemble a finalized output block from schema + key/value columns (non mem_reuse path). +/// +/// @param block the output block to overwrite +/// @param columns_with_schema the target schema +/// @param key_columns key columns to place at [0, key_size) +/// @param value_columns value columns to place at [key_size, ...) +/// @param key_size number of key columns +inline void assemble_finalized_output(Block* block, const ColumnsWithTypeAndName& columns_with_schema, + MutableColumns& key_columns, MutableColumns& value_columns, + size_t key_size) { + *block = columns_with_schema; + MutableColumns columns(block->columns()); + for (size_t i = 0; i < columns.size(); ++i) { + if (i < key_size) { + columns[i] = std::move(key_columns[i]); + } else { + columns[i] = std::move(value_columns[i - key_size]); + } + } + block->set_columns(std::move(columns)); +} + +/// Build a serialized output block from key expr types + value data types (non mem_reuse path). +/// +/// @param block the output block to overwrite +/// @param key_columns key columns (moved into the new block) +/// @param key_exprs groupby expression contexts (for type and name) +/// @param value_columns value columns (moved into the new block) +/// @param value_types data types for value columns +inline void build_serialized_output_block(Block* block, MutableColumns& key_columns, + const VExprContextSPtrs& key_exprs, + MutableColumns& value_columns, + const DataTypes& value_types) { + ColumnsWithTypeAndName schema; + schema.reserve(key_columns.size() + value_columns.size()); + for (size_t i = 0; i < key_columns.size(); ++i) { + schema.emplace_back(std::move(key_columns[i]), key_exprs[i]->root()->data_type(), + key_exprs[i]->root()->expr_name()); + } + for (size_t i = 0; i < value_columns.size(); ++i) { + schema.emplace_back(std::move(value_columns[i]), value_types[i], ""); + } + *block = Block(schema); +} + +/// Overload for streaming agg passthrough: keys come from ColumnRawPtrs (clone + resize). +/// +/// @param block the output block to overwrite (via swap) +/// @param key_columns raw key column pointers (will be clone_resized) +/// @param rows number of rows to clone +/// @param key_exprs groupby expression contexts (for type and name) +/// @param value_columns value columns (moved into the new block) +/// @param value_types data types for value columns +inline void build_serialized_output_block(Block* block, ColumnRawPtrs& key_columns, uint32_t rows, + const VExprContextSPtrs& key_exprs, + MutableColumns& value_columns, + const DataTypes& value_types) { + ColumnsWithTypeAndName schema; + schema.reserve(key_columns.size() + value_columns.size()); + for (size_t i = 0; i < key_columns.size(); ++i) { + schema.emplace_back(key_columns[i]->clone_resized(rows), + key_exprs[i]->root()->data_type(), key_exprs[i]->root()->expr_name()); + } + for (size_t i = 0; i < value_columns.size(); ++i) { + schema.emplace_back(std::move(value_columns[i]), value_types[i], ""); + } + block->swap(Block(schema)); +} + +/// Get the input column id from an evaluator's single SlotRef input expression. +/// Unified version used by both GroupByAggContext and UngroupByAggContext. +inline int get_slot_column_id(const AggFnEvaluator* evaluator) { + auto ctxs = evaluator->input_exprs_ctxs(); + DCHECK(ctxs.size() == 1 && ctxs[0]->root()->is_slot_ref()) + << "input_exprs_ctxs is invalid, input_exprs_ctx[0]=" + << ctxs[0]->root()->debug_string(); + return static_cast(ctxs[0]->root().get())->column_id(); +} + +} // namespace doris::agg_context_utils diff --git a/be/src/exec/common/groupby_agg_context.cpp b/be/src/exec/common/groupby_agg_context.cpp new file mode 100644 index 00000000000000..4975cc6cc9a69b --- /dev/null +++ b/be/src/exec/common/groupby_agg_context.cpp @@ -0,0 +1,981 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exec/common/groupby_agg_context.h" + +#include + +#include "common/config.h" +#include "common/exception.h" +#include "exec/common/agg_context_utils.h" +#include "exec/common/columns_hashing.h" +#include "exec/common/hash_table/hash_map_context.h" +#include "exec/common/hash_table/hash_map_util.h" +#include "exec/common/template_helpers.hpp" +#include "exprs/vectorized_agg_fn.h" +#include "exprs/vexpr_context.h" +#include "exprs/vslot_ref.h" +#include "runtime/runtime_state.h" + +namespace doris { + +GroupByAggContext::GroupByAggContext(std::vector agg_evaluators, + VExprContextSPtrs groupby_expr_ctxs, Sizes agg_state_offsets, + size_t total_agg_state_size, size_t agg_state_alignment, + bool is_first_phase) + : _hash_table_data(std::make_unique()), + _agg_evaluators(std::move(agg_evaluators)), + _groupby_expr_ctxs(std::move(groupby_expr_ctxs)), + _agg_state_offsets(std::move(agg_state_offsets)), + _total_agg_state_size(total_agg_state_size), + _agg_state_alignment(agg_state_alignment), + _is_first_phase(is_first_phase) {} + +GroupByAggContext::~GroupByAggContext() = default; + +// ==================== Profile initialization ==================== + +void GroupByAggContext::init_sink_profile(RuntimeProfile* profile) { + _hash_table_compute_timer = ADD_TIMER(profile, "HashTableComputeTime"); + _hash_table_emplace_timer = ADD_TIMER(profile, "HashTableEmplaceTime"); + _hash_table_input_counter = ADD_COUNTER(profile, "HashTableInputCount", TUnit::UNIT); + _hash_table_limit_compute_timer = ADD_TIMER(profile, "DoLimitComputeTime"); + _build_timer = ADD_TIMER(profile, "BuildTime"); + _merge_timer = ADD_TIMER(profile, "MergeTime"); + _expr_timer = ADD_TIMER(profile, "ExprTime"); + _deserialize_data_timer = ADD_TIMER(profile, "DeserializeAndMergeTime"); + _hash_table_size_counter = ADD_COUNTER(profile, "HashTableSize", TUnit::UNIT); + _hash_table_memory_usage = + ADD_COUNTER_WITH_LEVEL(profile, "MemoryUsageHashTable", TUnit::BYTES, 1); + _serialize_key_arena_memory_usage = + ADD_COUNTER_WITH_LEVEL(profile, "MemoryUsageSerializeKeyArena", TUnit::BYTES, 1); + _memory_usage_container = ADD_COUNTER(profile, "MemoryUsageContainer", TUnit::BYTES); + _memory_usage_arena = ADD_COUNTER(profile, "MemoryUsageArena", TUnit::BYTES); + _memory_used_counter = profile->get_counter("MemoryUsage"); +} + +void GroupByAggContext::init_source_profile(RuntimeProfile* profile) { + _get_results_timer = ADD_TIMER(profile, "GetResultsTime"); + _hash_table_iterate_timer = ADD_TIMER(profile, "HashTableIterateTime"); + _insert_keys_to_column_timer = ADD_TIMER(profile, "InsertKeysToColumnTime"); + _insert_values_to_column_timer = ADD_TIMER(profile, "InsertValuesToColumnTime"); + + // Register overlapping counters on source profile (same names as sink, for + // PartitionedAggLocalState::_update_profile to read from inner source profile). + _source_merge_timer = ADD_TIMER(profile, "MergeTime"); + _source_deserialize_data_timer = ADD_TIMER(profile, "DeserializeAndMergeTime"); + _source_hash_table_compute_timer = ADD_TIMER(profile, "HashTableComputeTime"); + _source_hash_table_emplace_timer = ADD_TIMER(profile, "HashTableEmplaceTime"); + _source_hash_table_input_counter = ADD_COUNTER(profile, "HashTableInputCount", TUnit::UNIT); + _source_hash_table_size_counter = ADD_COUNTER(profile, "HashTableSize", TUnit::UNIT); + _source_hash_table_memory_usage = + ADD_COUNTER_WITH_LEVEL(profile, "MemoryUsageHashTable", TUnit::BYTES, 1); + _source_memory_usage_container = ADD_COUNTER(profile, "MemoryUsageContainer", TUnit::BYTES); + _source_memory_usage_arena = ADD_COUNTER(profile, "MemoryUsageArena", TUnit::BYTES); +} + +// ==================== Hash table management ==================== + +void GroupByAggContext::init_hash_method() { + auto st = doris::init_hash_method( + _hash_table_data.get(), get_data_types(_groupby_expr_ctxs), _is_first_phase); + if (!st.ok()) { + throw Exception(st.code(), st.to_string()); + } +} + +void GroupByAggContext::init_agg_data_container() { + std::visit( + Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) { + using HashTableType = std::decay_t; + using KeyType = typename HashTableType::Key; + _agg_data_container = std::make_unique( + sizeof(KeyType), + ((_total_agg_state_size + _agg_state_alignment - 1) / + _agg_state_alignment) * + _agg_state_alignment); + }}, + _hash_table_data->method_variant); +} + +size_t GroupByAggContext::hash_table_size() const { + return std::visit(Overload {[&](std::monostate& arg) -> size_t { return 0; }, + [&](auto& agg_method) { return agg_method.hash_table->size(); }}, + _hash_table_data->method_variant); +} + +size_t GroupByAggContext::memory_usage() const { + if (hash_table_size() == 0) { + return 0; + } + size_t usage = 0; + usage += _agg_arena.size(); + + if (_agg_data_container) { + usage += _agg_data_container->memory_usage(); + } + + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + usage += agg_method.hash_table->get_buffer_size_in_bytes(); + }}, + _hash_table_data->method_variant); + + return usage; +} + +void GroupByAggContext::update_memusage() { + std::visit( + Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + auto& data = *agg_method.hash_table; + int64_t memory_usage_arena = _agg_arena.size(); + int64_t memory_usage_container = + _agg_data_container ? _agg_data_container->memory_usage() : 0; + int64_t hash_table_memory_usage = data.get_buffer_size_in_bytes(); + auto ht_size = static_cast(data.size()); + + // Update sink-side counters + if (_memory_usage_arena) { + COUNTER_SET(_memory_usage_arena, memory_usage_arena); + } + if (_memory_usage_container) { + COUNTER_SET(_memory_usage_container, memory_usage_container); + } + if (_hash_table_memory_usage) { + COUNTER_SET(_hash_table_memory_usage, hash_table_memory_usage); + } + if (_hash_table_size_counter) { + COUNTER_SET(_hash_table_size_counter, ht_size); + } + if (_serialize_key_arena_memory_usage) { + COUNTER_SET(_serialize_key_arena_memory_usage, + memory_usage_arena + memory_usage_container); + } + if (_memory_used_counter) { + COUNTER_SET(_memory_used_counter, memory_usage_arena + + memory_usage_container + + hash_table_memory_usage); + } + + // Update source-side counters (for PartitionedAgg source profile) + if (_source_memory_usage_arena) { + COUNTER_SET(_source_memory_usage_arena, memory_usage_arena); + } + if (_source_memory_usage_container) { + COUNTER_SET(_source_memory_usage_container, memory_usage_container); + } + if (_source_hash_table_memory_usage) { + COUNTER_SET(_source_hash_table_memory_usage, + hash_table_memory_usage); + } + if (_source_hash_table_size_counter) { + COUNTER_SET(_source_hash_table_size_counter, ht_size); + } + }}, + _hash_table_data->method_variant); +} + +size_t GroupByAggContext::get_reserve_mem_size(RuntimeState* state) const { + size_t size_to_reserve = std::visit( + [&](auto&& arg) -> size_t { + using HashTableCtxType = std::decay_t; + if constexpr (std::is_same_v) { + return 0; + } else { + return arg.hash_table->estimate_memory(state->batch_size()); + } + }, + _hash_table_data->method_variant); + + size_to_reserve += memory_usage_last_executing; + return size_to_reserve; +} + +Status GroupByAggContext::reset_hash_table() { + return std::visit( + Overload { + [&](std::monostate& arg) -> Status { + return Status::InternalError("Uninited hash table"); + }, + [&](auto& agg_method) { + auto& hash_table = *agg_method.hash_table; + using HashTableType = std::decay_t; + + agg_method.arena.clear(); + agg_method.inited_iterator = false; + + hash_table.for_each_mapped([&](auto& mapped) { + if (mapped) { + destroy_agg_state(mapped); + mapped = nullptr; + } + }); + + if (hash_table.has_null_key_data()) { + destroy_agg_state( + hash_table.template get_null_key_data()); + } + + _agg_data_container.reset(new AggregateDataContainer( + sizeof(typename HashTableType::key_type), + ((_total_agg_state_size + _agg_state_alignment - 1) / + _agg_state_alignment) * + _agg_state_alignment)); + agg_method.hash_table.reset(new HashTableType()); + return Status::OK(); + }}, + _hash_table_data->method_variant); +} + +// ==================== Agg state management ==================== + +Status GroupByAggContext::create_agg_state(AggregateDataPtr data) { + for (int i = 0; i < _agg_evaluators.size(); ++i) { + try { + _agg_evaluators[i]->create(data + _agg_state_offsets[i]); + } catch (...) { + for (int j = 0; j < i; ++j) { + _agg_evaluators[j]->destroy(data + _agg_state_offsets[j]); + } + throw; + } + } + return Status::OK(); +} + +void GroupByAggContext::destroy_agg_state(AggregateDataPtr data) { + for (int i = 0; i < _agg_evaluators.size(); ++i) { + _agg_evaluators[i]->function()->destroy(data + _agg_state_offsets[i]); + } +} + +void GroupByAggContext::close() { + std::visit(Overload {[&](std::monostate& arg) -> void { + // Do nothing + }, + [&](auto& agg_method) -> void { + auto& data = *agg_method.hash_table; + data.for_each_mapped([&](auto& mapped) { + if (mapped) { + destroy_agg_state(mapped); + mapped = nullptr; + } + }); + if (data.has_null_key_data()) { + destroy_agg_state( + data.template get_null_key_data()); + } + }}, + _hash_table_data->method_variant); +} + +// ==================== Hash table write operations ==================== + +void GroupByAggContext::emplace_into_hash_table(AggregateDataPtr* places, + ColumnRawPtrs& key_columns, uint32_t num_rows, + RuntimeProfile::Counter* hash_table_compute_timer, + RuntimeProfile::Counter* hash_table_emplace_timer, + RuntimeProfile::Counter* hash_table_input_counter) { + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + auto creator = [this](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin(key, origin, _agg_arena); + auto mapped = _agg_data_container->append_data(origin); + auto st = create_agg_state(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { + mapped = _agg_arena.aligned_alloc(_total_agg_state_size, + _agg_state_alignment); + auto st = create_agg_state(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + }; + + SCOPED_TIMER(hash_table_emplace_timer); + lazy_emplace_batch( + agg_method, state, num_rows, creator, creator_for_null_key, + [&](uint32_t row, auto& mapped) { places[row] = mapped; }); + + COUNTER_UPDATE(hash_table_input_counter, num_rows); + }}, + _hash_table_data->method_variant); +} + +void GroupByAggContext::find_in_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows) { + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + find_batch(agg_method, state, num_rows, + [&](uint32_t row, auto& find_result) { + if (find_result.is_found()) { + places[row] = find_result.get_mapped(); + } else { + places[row] = nullptr; + } + }); + }}, + _hash_table_data->method_variant); +} + +bool GroupByAggContext::emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, + const std::vector* key_locs, + ColumnRawPtrs& key_columns, + uint32_t num_rows) { + return std::visit( + Overload {[&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return true; + }, + [&](auto&& agg_method) -> bool { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + + bool need_filter = false; + { + SCOPED_TIMER(_hash_table_limit_compute_timer); + need_filter = do_limit_filter(num_rows, key_columns); + } + + auto& need_computes = _need_computes; + if (auto need_agg = + std::find(need_computes.begin(), need_computes.end(), 1); + need_agg != need_computes.end()) { + if (need_filter) { + Block::filter_block_internal(block, need_computes); + if (key_locs) { + for (int i = 0; i < key_locs->size(); ++i) { + key_columns[i] = + block->get_by_position((*key_locs)[i]) + .column.get(); + } + } + num_rows = (uint32_t)block->rows(); + } + + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + size_t i = 0; + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + try { + HashMethodType::try_presis_key_and_origin(key, origin, + _agg_arena); + auto mapped = _agg_data_container->append_data(origin); + auto st = create_agg_state(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + ctor(key, mapped); + refresh_top_limit(i, key_columns); + } catch (...) { + ctor(key, nullptr); + throw; + } + }; + + auto creator_for_null_key = [&](auto& mapped) { + mapped = _agg_arena.aligned_alloc(_total_agg_state_size, + _agg_state_alignment); + auto st = create_agg_state(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + refresh_top_limit(i, key_columns); + }; + + SCOPED_TIMER(_hash_table_emplace_timer); + lazy_emplace_batch( + agg_method, state, num_rows, creator, creator_for_null_key, + [&](uint32_t row) { i = row; }, + [&](uint32_t row, auto& mapped) { places[row] = mapped; }); + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + return true; + } + return false; + }}, + _hash_table_data->method_variant); +} + +// ==================== Aggregation execution ==================== + +Status GroupByAggContext::evaluate_groupby_keys(Block* block, ColumnRawPtrs& key_columns, + std::vector* key_locs) { + SCOPED_TIMER(_expr_timer); + const size_t key_size = _groupby_expr_ctxs.size(); + for (size_t i = 0; i < key_size; ++i) { + int result_column_id = -1; + RETURN_IF_ERROR(_groupby_expr_ctxs[i]->execute(block, &result_column_id)); + block->get_by_position(result_column_id).column = + block->get_by_position(result_column_id) + .column->convert_to_full_column_if_const(); + key_columns[i] = block->get_by_position(result_column_id).column.get(); + key_columns[i]->assume_mutable()->replace_float_special_values(); + if (key_locs) { + (*key_locs)[i] = result_column_id; + } + } + return Status::OK(); +} + +Status GroupByAggContext::execute_with_serialized_key(Block* block) { + memory_usage_last_executing = 0; + SCOPED_PEAK_MEM(&memory_usage_last_executing); + + SCOPED_TIMER(_build_timer); + DCHECK(!_groupby_expr_ctxs.empty()); + + size_t key_size = _groupby_expr_ctxs.size(); + ColumnRawPtrs key_columns(key_size); + std::vector key_locs(key_size); + RETURN_IF_ERROR(evaluate_groupby_keys(block, key_columns, &key_locs)); + + auto rows = (uint32_t)block->rows(); + if (_places.size() < rows) { + _places.resize(rows); + } + + if (reach_limit && !do_sort_limit) { + find_in_hash_table(_places.data(), key_columns, rows); + + for (int i = 0; i < _agg_evaluators.size(); ++i) { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_batch_add_selected( + block, _agg_state_offsets[i], _places.data(), _agg_arena)); + } + } else { + auto do_aggregate_evaluators = [&] { + for (int i = 0; i < _agg_evaluators.size(); ++i) { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_batch_add( + block, _agg_state_offsets[i], _places.data(), _agg_arena)); + } + return Status::OK(); + }; + + if (reach_limit) { + // do_sort_limit == true here + if (emplace_into_hash_table_limit(_places.data(), block, &key_locs, key_columns, + rows)) { + RETURN_IF_ERROR(do_aggregate_evaluators()); + } + } else { + emplace_into_hash_table(_places.data(), key_columns, rows, + _hash_table_compute_timer, _hash_table_emplace_timer, + _hash_table_input_counter); + RETURN_IF_ERROR(do_aggregate_evaluators()); + + _check_limit_after_emplace(); + } + } + return Status::OK(); +} + +Status GroupByAggContext::emplace_and_forward(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows, Block* block, + bool expand_hash_table) { + emplace_into_hash_table(places, key_columns, num_rows, _hash_table_compute_timer, + _hash_table_emplace_timer, _hash_table_input_counter); + + for (int i = 0; i < _agg_evaluators.size(); ++i) { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_batch_add( + block, _agg_state_offsets[i], places, _agg_arena, expand_hash_table)); + } + return Status::OK(); +} + +Status GroupByAggContext::merge_with_serialized_key(Block* block) { + memory_usage_last_executing = 0; + SCOPED_PEAK_MEM(&memory_usage_last_executing); + + if (reach_limit) { + return _merge_with_serialized_key_helper(block); + } else { + return _merge_with_serialized_key_helper(block); + } +} + +Status GroupByAggContext::merge_with_serialized_key_for_spill(Block* block) { + return _merge_with_serialized_key_helper(block); +} + +void GroupByAggContext::_check_limit_after_emplace() { + if (should_limit_output && !enable_spill) { + const size_t ht_size = hash_table_size(); + reach_limit = ht_size >= (do_sort_limit ? limit * config::topn_agg_limit_multiplier + : limit); + if (reach_limit && do_sort_limit) { + build_limit_heap(ht_size); + } + } +} + +void GroupByAggContext::_check_limit_after_emplace_for_merge() { + if (should_limit_output) { + const size_t ht_size = hash_table_size(); + reach_limit = ht_size >= limit; + if (do_sort_limit && reach_limit) { + build_limit_heap(ht_size); + } + } +} + +template +Status GroupByAggContext::_merge_with_serialized_key_helper(Block* block) { + auto* merge_timer = for_spill ? _source_merge_timer : _merge_timer; + auto* deser_timer = for_spill ? _source_deserialize_data_timer : _deserialize_data_timer; + SCOPED_TIMER(merge_timer); + + size_t key_size = _groupby_expr_ctxs.size(); + ColumnRawPtrs key_columns(key_size); + std::vector key_locs(key_size); + + if constexpr (for_spill) { + for (int i = 0; i < key_size; ++i) { + key_columns[i] = block->get_by_position(i).column.get(); + key_columns[i]->assume_mutable()->replace_float_special_values(); + key_locs[i] = i; + } + } else { + RETURN_IF_ERROR(evaluate_groupby_keys(block, key_columns, &key_locs)); + } + + size_t rows = block->rows(); + if (_places.size() < rows) { + _places.resize(rows); + } + + if (limit && !do_sort_limit) { + find_in_hash_table(_places.data(), key_columns, (uint32_t)rows); + + for (int i = 0; i < _agg_evaluators.size(); ++i) { + if (_agg_evaluators[i]->is_merge()) { + int col_id = get_slot_column_id(_agg_evaluators[i]); + auto column = block->get_by_position(col_id).column; + + size_t buffer_size = _agg_evaluators[i]->function()->size_of_data() * rows; + if (_deserialize_buffer.size() < buffer_size) { + _deserialize_buffer.resize(buffer_size); + } + + { + SCOPED_TIMER(deser_timer); + _agg_evaluators[i]->function()->deserialize_and_merge_vec_selected( + _places.data(), _agg_state_offsets[i], _deserialize_buffer.data(), + column.get(), _agg_arena, rows); + } + } else { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_batch_add_selected( + block, _agg_state_offsets[i], _places.data(), _agg_arena)); + } + } + } else { + bool need_do_agg = true; + if (limit) { + need_do_agg = emplace_into_hash_table_limit(_places.data(), block, &key_locs, + key_columns, (uint32_t)rows); + rows = block->rows(); + } else { + if constexpr (for_spill) { + emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows, + _source_hash_table_compute_timer, + _source_hash_table_emplace_timer, + _source_hash_table_input_counter); + } else { + emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows, + _hash_table_compute_timer, _hash_table_emplace_timer, + _hash_table_input_counter); + } + } + + if (need_do_agg) { + for (int i = 0; i < _agg_evaluators.size(); ++i) { + if (_agg_evaluators[i]->is_merge() || for_spill) { + size_t col_id = 0; + if constexpr (for_spill) { + col_id = _groupby_expr_ctxs.size() + i; + } else { + col_id = get_slot_column_id(_agg_evaluators[i]); + } + auto column = block->get_by_position(col_id).column; + + size_t buffer_size = _agg_evaluators[i]->function()->size_of_data() * rows; + if (_deserialize_buffer.size() < buffer_size) { + _deserialize_buffer.resize(buffer_size); + } + + { + SCOPED_TIMER(deser_timer); + _agg_evaluators[i]->function()->deserialize_and_merge_vec( + _places.data(), _agg_state_offsets[i], _deserialize_buffer.data(), + column.get(), _agg_arena, rows); + } + } else { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_batch_add( + block, _agg_state_offsets[i], _places.data(), _agg_arena)); + } + } + } + + if (!limit && should_limit_output) { + _check_limit_after_emplace_for_merge(); + } + } + + return Status::OK(); +} + +// Explicit template instantiation +template Status GroupByAggContext::_merge_with_serialized_key_helper(Block* block); +template Status GroupByAggContext::_merge_with_serialized_key_helper(Block* block); +template Status GroupByAggContext::_merge_with_serialized_key_helper(Block* block); + +// ==================== Result output ==================== + +Status GroupByAggContext::get_serialized_results(RuntimeState* state, Block* block, bool* eos) { + SCOPED_TIMER(_get_results_timer); + size_t key_size = _groupby_expr_ctxs.size(); + size_t agg_size = _agg_evaluators.size(); + MutableColumns value_columns(agg_size); + DataTypes value_data_types(agg_size); + + bool mem_reuse = make_nullable_keys.empty() && block->mem_reuse(); + + auto key_columns = agg_context_utils::take_or_create_columns( + block, mem_reuse, 0, key_size, + [&](size_t i) { return _groupby_expr_ctxs[i]->root()->data_type()->create_column(); }); + + std::visit( + Overload { + [&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + agg_method.init_iterator(); + auto& data = *agg_method.hash_table; + const auto size = std::min(data.size(), size_t(state->batch_size())); + using KeyType = std::decay_t::Key; + std::vector keys(size); + + if (_values.size() < size + 1) { + _values.resize(size + 1); + } + + uint32_t num_rows = 0; + _agg_data_container->init_once(); + auto& iter = _agg_data_container->iterator; + + { + SCOPED_TIMER(_hash_table_iterate_timer); + while (iter != _agg_data_container->end() && + num_rows < state->batch_size()) { + keys[num_rows] = iter.template get_key(); + _values[num_rows] = iter.get_aggregate_data(); + ++iter; + ++num_rows; + } + } + + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + if (iter == _agg_data_container->end()) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (agg_method.hash_table->has_null_key_data()) { + key_columns[0]->insert_data(nullptr, 0); + _values[num_rows] = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + ++num_rows; + *eos = true; + } + } else { + *eos = true; + } + } + + { + SCOPED_TIMER(_insert_values_to_column_timer); + for (size_t i = 0; i < _agg_evaluators.size(); ++i) { + value_data_types[i] = + _agg_evaluators[i]->function()->get_serialized_type(); + if (mem_reuse) { + value_columns[i] = + std::move(*block->get_by_position(i + key_size).column) + .mutate(); + } else { + value_columns[i] = _agg_evaluators[i] + ->function() + ->create_serialize_column(); + } + _agg_evaluators[i]->function()->serialize_to_column( + _values, _agg_state_offsets[i], value_columns[i], + num_rows); + } + } + }}, + _hash_table_data->method_variant); + + if (!mem_reuse) { + agg_context_utils::build_serialized_output_block(block, key_columns, _groupby_expr_ctxs, + value_columns, value_data_types); + } + + return Status::OK(); +} + +Status GroupByAggContext::get_finalized_results(RuntimeState* state, Block* block, bool* eos, + const ColumnsWithTypeAndName& columns_with_schema) { + bool mem_reuse = make_nullable_keys.empty() && block->mem_reuse(); + + size_t key_size = _groupby_expr_ctxs.size(); + + auto key_columns = agg_context_utils::take_or_create_columns( + block, mem_reuse, 0, key_size, + [&](size_t i) { return columns_with_schema[i].type->create_column(); }); + auto value_columns = agg_context_utils::take_or_create_columns( + block, mem_reuse, key_size, columns_with_schema.size() - key_size, + [&](size_t i) { return columns_with_schema[key_size + i].type->create_column(); }); + + SCOPED_TIMER(_get_results_timer); + std::visit( + Overload { + [&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + auto& data = *agg_method.hash_table; + agg_method.init_iterator(); + const auto size = std::min(data.size(), size_t(state->batch_size())); + using KeyType = std::decay_t::Key; + std::vector keys(size); + + if (_values.size() < size) { + _values.resize(size); + } + + uint32_t num_rows = 0; + _agg_data_container->init_once(); + auto& iter = _agg_data_container->iterator; + + { + SCOPED_TIMER(_hash_table_iterate_timer); + while (iter != _agg_data_container->end() && + num_rows < state->batch_size()) { + keys[num_rows] = iter.template get_key(); + _values[num_rows] = iter.get_aggregate_data(); + ++iter; + ++num_rows; + } + } + + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + for (size_t i = 0; i < _agg_evaluators.size(); ++i) { + _agg_evaluators[i]->insert_result_info_vec( + _values, _agg_state_offsets[i], value_columns[i].get(), + num_rows); + } + + if (iter == _agg_data_container->end()) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (key_columns[0]->size() < state->batch_size()) { + key_columns[0]->insert_data(nullptr, 0); + auto mapped = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + for (size_t i = 0; i < _agg_evaluators.size(); ++i) { + _agg_evaluators[i]->insert_result_info( + mapped + _agg_state_offsets[i], + value_columns[i].get()); + } + *eos = true; + } + } else { + *eos = true; + } + } + }}, + _hash_table_data->method_variant); + + if (!mem_reuse) { + agg_context_utils::assemble_finalized_output(block, columns_with_schema, key_columns, + value_columns, key_size); + } + + return Status::OK(); +} + +// ==================== Sort limit ==================== + +MutableColumns GroupByAggContext::_get_keys_hash_table() { + return std::visit( + Overload {[&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return MutableColumns(); + }, + [&](auto&& agg_method) -> MutableColumns { + MutableColumns key_columns; + for (int i = 0; i < _groupby_expr_ctxs.size(); ++i) { + key_columns.emplace_back( + _groupby_expr_ctxs[i]->root()->data_type()->create_column()); + } + auto& data = *agg_method.hash_table; + bool has_null_key = data.has_null_key_data(); + const auto size = data.size() - has_null_key; + using KeyType = std::decay_t::Key; + std::vector keys(size); + + uint32_t num_rows = 0; + auto iter = _agg_data_container->begin(); + { + while (iter != _agg_data_container->end()) { + keys[num_rows] = iter.get_key(); + ++iter; + ++num_rows; + } + } + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + if (has_null_key) { + key_columns[0]->insert_data(nullptr, 0); + } + return key_columns; + }}, + _hash_table_data->method_variant); +} + +void GroupByAggContext::build_limit_heap(size_t hash_table_size_val) { + _limit_columns = _get_keys_hash_table(); + for (size_t i = 0; i < hash_table_size_val; ++i) { + _limit_heap.emplace(i, _limit_columns, order_directions, null_directions); + } + while (hash_table_size_val > limit) { + _limit_heap.pop(); + hash_table_size_val--; + } + _limit_columns_min = _limit_heap.top()._row_id; +} + +bool GroupByAggContext::do_limit_filter(size_t num_rows, const ColumnRawPtrs& key_columns) { + if (num_rows) { + _cmp_res.resize(num_rows); + _need_computes.resize(num_rows); + memset(_need_computes.data(), 0, _need_computes.size()); + memset(_cmp_res.data(), 0, _cmp_res.size()); + + const auto key_size = null_directions.size(); + for (int i = 0; i < key_size; i++) { + key_columns[i]->compare_internal(_limit_columns_min, *_limit_columns[i], + null_directions[i], order_directions[i], _cmp_res, + _need_computes.data()); + } + + auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) { + for (size_t i = 0; i < rows; ++i) { + computes[i] = computes[i] == res[i]; + } + }; + set_computes_arr(_cmp_res.data(), _need_computes.data(), num_rows); + + return std::find(_need_computes.begin(), _need_computes.end(), 0) != _need_computes.end(); + } + + return false; +} + +void GroupByAggContext::refresh_top_limit(size_t row_id, const ColumnRawPtrs& key_columns) { + for (int j = 0; j < key_columns.size(); ++j) { + _limit_columns[j]->insert_from(*key_columns[j], row_id); + } + _limit_heap.emplace(_limit_columns[0]->size() - 1, _limit_columns, order_directions, + null_directions); + + _limit_heap.pop(); + _limit_columns_min = _limit_heap.top()._row_id; +} + +void GroupByAggContext::add_limit_heap_top(ColumnRawPtrs& key_columns, size_t rows) { + for (size_t i = 0; i < rows; ++i) { + if (_cmp_res[i] == 1 && _need_computes[i]) { + for (size_t j = 0; j < key_columns.size(); ++j) { + _limit_columns[j]->insert_from(*key_columns[j], i); + } + _limit_heap.emplace(_limit_columns[0]->size() - 1, _limit_columns, order_directions, + null_directions); + _limit_heap.pop(); + _limit_columns_min = _limit_heap.top()._row_id; + break; + } + } +} + +// ==================== Static utilities ==================== + +void GroupByAggContext::make_nullable_output_key(Block* block, + const std::vector& make_nullable_keys) { + if (block->rows() != 0) { + for (auto cid : make_nullable_keys) { + block->get_by_position(cid).column = make_nullable(block->get_by_position(cid).column); + block->get_by_position(cid).type = make_nullable(block->get_by_position(cid).type); + } + } +} + +int GroupByAggContext::get_slot_column_id(const AggFnEvaluator* evaluator) { + return agg_context_utils::get_slot_column_id(evaluator); +} + +} // namespace doris diff --git a/be/src/exec/common/groupby_agg_context.h b/be/src/exec/common/groupby_agg_context.h new file mode 100644 index 00000000000000..aeb36750320b59 --- /dev/null +++ b/be/src/exec/common/groupby_agg_context.h @@ -0,0 +1,317 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "common/status.h" +#include "core/arena.h" +#include "core/block/block.h" +#include "exec/common/agg_utils.h" +#include "runtime/runtime_profile.h" + +namespace doris { + +class AggFnEvaluator; +class RuntimeState; +class VExprContext; +using VExprContextSPtr = std::shared_ptr; +using VExprContextSPtrs = std::vector; + +/// GroupByAggContext encapsulates all hash-table-based aggregation logic for GROUP BY queries. +/// It is shared between AggSinkLocalState (write path) and AggLocalState (read path) in +/// 2-phase aggregation, or owned locally by StreamingAggLocalState in 1-phase streaming agg. +/// +/// InlineCountAggContext (subclass) overrides virtual methods to implement the +/// inline-count optimization (storing UInt64 count directly in the hash table mapped slot +/// instead of a full aggregate state). +class GroupByAggContext { +public: + GroupByAggContext(std::vector agg_evaluators, + VExprContextSPtrs groupby_expr_ctxs, Sizes agg_state_offsets, + size_t total_agg_state_size, size_t agg_state_alignment, + bool is_first_phase); + + virtual ~GroupByAggContext(); + + // ==================== Aggregation execution (Sink side) ==================== + + /// Update mode: evaluate groupby exprs → emplace → execute_batch_add + virtual Status execute_with_serialized_key(Block* block); + + /// Emplace + execute_batch_add with pre-evaluated key columns. + /// InlineCountAggContext overrides to only emplace (count++ is done internally). + /// Used by StreamingAgg which evaluates key expressions separately. + virtual Status emplace_and_forward(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows, Block* block, bool expand_hash_table); + + /// Merge mode: evaluate groupby exprs → emplace → deserialize_and_merge + virtual Status merge_with_serialized_key(Block* block); + + /// Merge for spill restore (keys already materialized as first N columns of block) + Status merge_with_serialized_key_for_spill(Block* block); + + // ==================== Result output (Source side) ==================== + + /// Serialize mode output (for non-finalize path and StreamingAgg) + virtual Status get_serialized_results(RuntimeState* state, Block* block, bool* eos); + + /// Finalize mode output (for AggSource finalize path) + virtual Status get_finalized_results(RuntimeState* state, Block* block, bool* eos, + const ColumnsWithTypeAndName& columns_with_schema); + + // ==================== Agg state management ==================== + + virtual Status create_agg_state(AggregateDataPtr data); + virtual void close(); + + // ==================== Utilities ==================== + + size_t hash_table_size() const; + size_t memory_usage() const; + void update_memusage(); + void init_hash_method(); + /// Initialize the AggregateDataContainer after hash method is set up. + /// Must be called after init_hash_method(). + virtual void init_agg_data_container(); + virtual Status reset_hash_table(); + + /// Sink operator calls this to register sink-side profile counters. + void init_sink_profile(RuntimeProfile* profile); + /// Source operator calls this to register source-side profile counters. + void init_source_profile(RuntimeProfile* profile); + + /// Evaluate groupby expressions on block, filling key_columns and optionally key_locs. + /// Handles convert_to_full_column_if_const and replace_float_special_values. + Status evaluate_groupby_keys(Block* block, ColumnRawPtrs& key_columns, + std::vector* key_locs = nullptr); + + // ==================== Sort limit ==================== + + void build_limit_heap(size_t hash_table_size); + bool do_limit_filter(size_t num_rows, const ColumnRawPtrs& key_columns); + void refresh_top_limit(size_t row_id, const ColumnRawPtrs& key_columns); + /// Update limit heap with new top-N candidates from passthrough path. + /// Finds the first row where cmp_res==1 && need_computes[i], inserts into heap, then breaks. + void add_limit_heap_top(ColumnRawPtrs& key_columns, size_t rows); + + /// Emplace with sort-limit filtering. Returns true if aggregation should proceed. + /// When key_locs is provided, re-fetches key_columns from block after filtering. + bool emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, + const std::vector* key_locs, + ColumnRawPtrs& key_columns, uint32_t num_rows); + + // ==================== Data accessors ==================== + + AggregatedDataVariants* hash_table_data() { return _hash_table_data.get(); } + Arena& agg_arena() { return _agg_arena; } + AggregateDataContainer* agg_data_container() { return _agg_data_container.get(); } + std::vector& agg_evaluators() { return _agg_evaluators; } + const VExprContextSPtrs& groupby_expr_ctxs() const { return _groupby_expr_ctxs; } + const Sizes& agg_state_offsets() const { return _agg_state_offsets; } + size_t total_agg_state_size() const { return _total_agg_state_size; } + size_t agg_state_alignment() const { return _agg_state_alignment; } + PaddedPODArray& need_computes() { return _need_computes; } + + // Sort limit public state + int64_t limit = -1; + bool do_sort_limit = false; + bool reach_limit = false; + std::vector order_directions; + std::vector null_directions; + + // Limit check configuration (set by operator during open) + bool should_limit_output = false; + bool enable_spill = false; + + // Key columns that need nullable wrapping in output (left/full join). + // When non-empty, mem_reuse must be disabled in get_*_results to avoid + // column type mismatch after make_nullable_output_key transforms the block. + std::vector make_nullable_keys; + + // Memory tracking for reserve estimation + int64_t memory_usage_last_executing = 0; + + // Sink-side profile counters (public for operator-level SCOPED_TIMER access) + RuntimeProfile::Counter* build_timer() const { return _build_timer; } + RuntimeProfile::Counter* merge_timer() const { return _merge_timer; } + RuntimeProfile::Counter* expr_timer() const { return _expr_timer; } + RuntimeProfile::Counter* deserialize_data_timer() const { return _deserialize_data_timer; } + RuntimeProfile::Counter* hash_table_compute_timer() const { return _hash_table_compute_timer; } + RuntimeProfile::Counter* hash_table_emplace_timer() const { return _hash_table_emplace_timer; } + RuntimeProfile::Counter* hash_table_input_counter() const { return _hash_table_input_counter; } + + // Source-side profile counters + RuntimeProfile::Counter* get_results_timer() const { return _get_results_timer; } + RuntimeProfile::Counter* hash_table_iterate_timer() const { return _hash_table_iterate_timer; } + RuntimeProfile::Counter* insert_keys_to_column_timer() const { + return _insert_keys_to_column_timer; + } + RuntimeProfile::Counter* insert_values_to_column_timer() const { + return _insert_values_to_column_timer; + } + RuntimeProfile::Counter* hash_table_limit_compute_timer() const { + return _hash_table_limit_compute_timer; + } + + // For spill: estimate memory needed + size_t get_reserve_mem_size(RuntimeState* state) const; + +protected: + // ==================== Internal hash table operations ==================== + + /// Insert keys into the hash table, fill places array. New keys get agg state created. + /// Counter parameters allow callers to direct timing to sink or source profile counters. + virtual void emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows, + RuntimeProfile::Counter* hash_table_compute_timer, + RuntimeProfile::Counter* hash_table_emplace_timer, + RuntimeProfile::Counter* hash_table_input_counter); + + /// Find existing keys in hash table (used when reach_limit && !do_sort_limit). + void find_in_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows); + + virtual void destroy_agg_state(AggregateDataPtr data); + + /// Convert columns at specified positions to nullable. + static void make_nullable_output_key(Block* block, + const std::vector& make_nullable_keys); + + /// Get the column id from an evaluator's input expression (used in merge path). + /// Only valid for 1st phase evaluators with a single SlotRef input. + static int get_slot_column_id(const AggFnEvaluator* evaluator); + + // Core hash table data + AggregatedDataVariantsUPtr _hash_table_data; + Arena _agg_arena; + std::unique_ptr _agg_data_container; + + // Aggregation metadata + std::vector _agg_evaluators; + VExprContextSPtrs _groupby_expr_ctxs; + Sizes _agg_state_offsets; + size_t _total_agg_state_size; + size_t _agg_state_alignment; + bool _is_first_phase; + + // Working buffers + PODArray _places; + std::vector _deserialize_buffer; + std::vector _values; + + // Sort limit state + MutableColumns _limit_columns; + int _limit_columns_min = -1; + PaddedPODArray _need_computes; + std::vector _cmp_res; + + struct HeapLimitCursor { + HeapLimitCursor(int row_id, MutableColumns& limit_columns, + std::vector& order_directions, std::vector& null_directions) + : _row_id(row_id), + _limit_columns(limit_columns), + _order_directions(order_directions), + _null_directions(null_directions) {} + + HeapLimitCursor(const HeapLimitCursor& other) = default; + + HeapLimitCursor(HeapLimitCursor&& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { + _row_id = other._row_id; + return *this; + } + + HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { + _row_id = other._row_id; + return *this; + } + + bool operator<(const HeapLimitCursor& rhs) const { + for (int i = 0; i < _limit_columns.size(); ++i) { + const auto& col = _limit_columns[i]; + auto res = col->compare_at(_row_id, rhs._row_id, *col, _null_directions[i]) * + _order_directions[i]; + if (res < 0) { + return true; + } else if (res > 0) { + return false; + } + } + return false; + } + + int _row_id; + MutableColumns& _limit_columns; + std::vector& _order_directions; + std::vector& _null_directions; + }; + + std::priority_queue _limit_heap; + MutableColumns _get_keys_hash_table(); + + template + Status _merge_with_serialized_key_helper(Block* block); + + /// Check and update reach_limit after emplace (execute path) + void _check_limit_after_emplace(); + /// Check and update reach_limit after emplace (merge path, simpler: no topn multiplier) + void _check_limit_after_emplace_for_merge(); + + // ---- Sink-side profile counters (created by init_sink_profile) ---- + RuntimeProfile::Counter* _hash_table_compute_timer = nullptr; + RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr; + RuntimeProfile::Counter* _hash_table_input_counter = nullptr; + RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr; + RuntimeProfile::Counter* _build_timer = nullptr; + RuntimeProfile::Counter* _merge_timer = nullptr; + RuntimeProfile::Counter* _expr_timer = nullptr; + RuntimeProfile::Counter* _deserialize_data_timer = nullptr; + RuntimeProfile::Counter* _hash_table_size_counter = nullptr; + RuntimeProfile::Counter* _hash_table_memory_usage = nullptr; + RuntimeProfile::Counter* _serialize_key_arena_memory_usage = nullptr; + RuntimeProfile::Counter* _memory_usage_container = nullptr; + RuntimeProfile::Counter* _memory_usage_arena = nullptr; + RuntimeProfile::Counter* _memory_used_counter = nullptr; + + // ---- Source-side profile counters (created by init_source_profile) ---- + RuntimeProfile::Counter* _get_results_timer = nullptr; + RuntimeProfile::Counter* _hash_table_iterate_timer = nullptr; + RuntimeProfile::Counter* _insert_keys_to_column_timer = nullptr; + RuntimeProfile::Counter* _insert_values_to_column_timer = nullptr; + + // Source-side counters for overlapping metrics (same names as sink, different profile). + // Used during spill recovery merge path (for_spill=true) so that + // PartitionedAggLocalState::_update_profile can read them from the inner source profile. + RuntimeProfile::Counter* _source_merge_timer = nullptr; + RuntimeProfile::Counter* _source_deserialize_data_timer = nullptr; + RuntimeProfile::Counter* _source_hash_table_compute_timer = nullptr; + RuntimeProfile::Counter* _source_hash_table_emplace_timer = nullptr; + RuntimeProfile::Counter* _source_hash_table_input_counter = nullptr; + RuntimeProfile::Counter* _source_hash_table_size_counter = nullptr; + RuntimeProfile::Counter* _source_hash_table_memory_usage = nullptr; + RuntimeProfile::Counter* _source_memory_usage_container = nullptr; + RuntimeProfile::Counter* _source_memory_usage_arena = nullptr; +}; + +} // namespace doris diff --git a/be/src/exec/common/inline_count_agg_context.cpp b/be/src/exec/common/inline_count_agg_context.cpp new file mode 100644 index 00000000000000..6a74f17010cd20 --- /dev/null +++ b/be/src/exec/common/inline_count_agg_context.cpp @@ -0,0 +1,372 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exec/common/inline_count_agg_context.h" + +#include "common/cast_set.h" +#include "common/exception.h" +#include "core/column/column_fixed_length_object.h" +#include "exec/common/agg_context_utils.h" +#include "exec/common/columns_hashing.h" +#include "exec/common/hash_table/hash_map_context.h" +#include "exec/common/template_helpers.hpp" +#include "exprs/aggregate/aggregate_function_count.h" +#include "exprs/vectorized_agg_fn.h" +#include "exprs/vexpr_context.h" +#include "runtime/runtime_state.h" + +namespace doris { + +// ==================== Hash table write ==================== + +void InlineCountAggContext::emplace_into_hash_table(AggregateDataPtr* /*places*/, + ColumnRawPtrs& key_columns, + uint32_t num_rows, + RuntimeProfile::Counter* hash_table_compute_timer, + RuntimeProfile::Counter* hash_table_emplace_timer, + RuntimeProfile::Counter* hash_table_input_counter) { + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin(key, origin, _agg_arena); + AggregateDataPtr mapped = nullptr; + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; + + SCOPED_TIMER(hash_table_emplace_timer); + lazy_emplace_batch(agg_method, state, num_rows, creator, + creator_for_null_key, [&](uint32_t, auto& mapped) { + ++reinterpret_cast(mapped); + }); + + COUNTER_UPDATE(hash_table_input_counter, num_rows); + }}, + _hash_table_data->method_variant); +} + +// ==================== Aggregation execution ==================== + +Status InlineCountAggContext::execute_with_serialized_key(Block* block) { + memory_usage_last_executing = 0; + SCOPED_PEAK_MEM(&memory_usage_last_executing); + SCOPED_TIMER(_build_timer); + DCHECK(!block->empty()); + + ColumnRawPtrs key_columns(_groupby_expr_ctxs.size()); + RETURN_IF_ERROR(evaluate_groupby_keys(block, key_columns)); + + // InlineCount: emplace all keys, count is incremented inside emplace. + // No evaluator execution needed. + emplace_into_hash_table(nullptr, key_columns, cast_set(block->rows()), + _hash_table_compute_timer, _hash_table_emplace_timer, + _hash_table_input_counter); + + return Status::OK(); +} + +Status InlineCountAggContext::emplace_and_forward(AggregateDataPtr* places, + ColumnRawPtrs& key_columns, uint32_t num_rows, + Block* block, bool expand_hash_table) { + // InlineCount: emplace increments UInt64 count directly, no execute_batch_add needed. + emplace_into_hash_table(places, key_columns, num_rows, _hash_table_compute_timer, + _hash_table_emplace_timer, _hash_table_input_counter); + return Status::OK(); +} + +Status InlineCountAggContext::merge_with_serialized_key(Block* block) { + SCOPED_TIMER(_merge_timer); + DCHECK(!block->empty()); + + size_t key_size = _groupby_expr_ctxs.size(); + ColumnRawPtrs key_columns(key_size); + RETURN_IF_ERROR(evaluate_groupby_keys(block, key_columns)); + + const auto rows = block->rows(); + + // Get the serialized count column (ColumnFixedLengthObject containing AggregateFunctionCountData) + DCHECK_EQ(_agg_evaluators.size(), 1); + auto col_id = get_slot_column_id(_agg_evaluators[0]); + auto column = block->get_by_position(col_id).column; + + _merge_inline_count(key_columns, column.get(), cast_set(rows)); + + return Status::OK(); +} + +void InlineCountAggContext::_merge_inline_count(ColumnRawPtrs& key_columns, + const IColumn* merge_column, + uint32_t num_rows) { + std::visit( + Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + const auto& col = + assert_cast(*merge_column); + const auto* col_data = + reinterpret_cast( + col.get_data().data()); + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin(key, origin, _agg_arena); + AggregateDataPtr mapped = nullptr; + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; + + SCOPED_TIMER(_hash_table_emplace_timer); + lazy_emplace_batch(agg_method, state, num_rows, creator, + creator_for_null_key, [&](uint32_t i, auto& mapped) { + reinterpret_cast(mapped) += + col_data[i].count; + }); + + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + }}, + _hash_table_data->method_variant); +} + +// ==================== Result output ==================== + +Status InlineCountAggContext::get_serialized_results(RuntimeState* state, Block* block, + bool* eos) { + SCOPED_TIMER(_get_results_timer); + size_t key_size = _groupby_expr_ctxs.size(); + DCHECK_EQ(_agg_evaluators.size(), 1); + + bool mem_reuse = make_nullable_keys.empty() && block->mem_reuse(); + + auto key_columns = agg_context_utils::take_or_create_columns( + block, mem_reuse, 0, key_size, + [&](size_t i) { return _groupby_expr_ctxs[i]->root()->data_type()->create_column(); }); + + MutableColumnPtr value_column; + DataTypePtr value_data_type = _agg_evaluators[0]->function()->get_serialized_type(); + if (mem_reuse) { + value_column = std::move(*block->get_by_position(key_size).column).mutate(); + } else { + value_column = _agg_evaluators[0]->function()->create_serialize_column(); + } + + std::visit( + Overload { + [&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + agg_method.init_iterator(); + auto& data = *agg_method.hash_table; + const auto size = std::min(data.size(), size_t(state->batch_size())); + using KeyType = std::decay_t::Key; + std::vector keys(size); + + auto& count_col = + assert_cast(*value_column); + uint32_t num_rows = 0; + { + SCOPED_TIMER(_hash_table_iterate_timer); + auto& it = agg_method.begin; + while (it != agg_method.end && num_rows < state->batch_size()) { + keys[num_rows] = it.get_first(); + auto inline_count = + std::bit_cast(it.get_second()); + count_col.insert_data( + reinterpret_cast(&inline_count), + sizeof(UInt64)); + ++it; + ++num_rows; + } + } + + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + // Handle null key if present + if (agg_method.begin == agg_method.end) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (num_rows < state->batch_size()) { + key_columns[0]->insert_data(nullptr, 0); + auto mapped = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + auto inline_count = + std::bit_cast(mapped); + count_col.insert_data( + reinterpret_cast(&inline_count), + sizeof(UInt64)); + *eos = true; + } + } else { + *eos = true; + } + } + }}, + _hash_table_data->method_variant); + + if (!mem_reuse) { + MutableColumns value_columns; + value_columns.emplace_back(std::move(value_column)); + DataTypes value_types {value_data_type}; + agg_context_utils::build_serialized_output_block(block, key_columns, _groupby_expr_ctxs, + value_columns, value_types); + } + + return Status::OK(); +} + +Status InlineCountAggContext::get_finalized_results( + RuntimeState* state, Block* block, bool* eos, + const ColumnsWithTypeAndName& columns_with_schema) { + bool mem_reuse = make_nullable_keys.empty() && block->mem_reuse(); + + size_t key_size = _groupby_expr_ctxs.size(); + + auto key_columns = agg_context_utils::take_or_create_columns( + block, mem_reuse, 0, key_size, + [&](size_t i) { return columns_with_schema[i].type->create_column(); }); + MutableColumnPtr value_column; + if (!mem_reuse) { + value_column = columns_with_schema[key_size].type->create_column(); + } else { + value_column = std::move(*block->get_by_position(key_size).column).mutate(); + } + + SCOPED_TIMER(_get_results_timer); + std::visit( + Overload { + [&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + auto& data = *agg_method.hash_table; + agg_method.init_iterator(); + const auto size = std::min(data.size(), size_t(state->batch_size())); + using KeyType = std::decay_t::Key; + std::vector keys(size); + + DCHECK_EQ(_agg_evaluators.size(), 1); + auto& count_column = assert_cast(*value_column); + uint32_t num_rows = 0; + { + SCOPED_TIMER(_hash_table_iterate_timer); + auto& it = agg_method.begin; + while (it != agg_method.end && num_rows < state->batch_size()) { + keys[num_rows] = it.get_first(); + auto& mapped = it.get_second(); + count_column.insert_value(static_cast( + std::bit_cast(mapped))); + ++it; + ++num_rows; + } + } + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + // Handle null key if present + if (agg_method.begin == agg_method.end) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (key_columns[0]->size() < state->batch_size()) { + key_columns[0]->insert_data(nullptr, 0); + auto mapped = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + count_column.insert_value( + static_cast(std::bit_cast(mapped))); + *eos = true; + } + } else { + *eos = true; + } + } + }}, + _hash_table_data->method_variant); + + if (!mem_reuse) { + MutableColumns value_columns; + value_columns.emplace_back(std::move(value_column)); + agg_context_utils::assemble_finalized_output(block, columns_with_schema, key_columns, + value_columns, key_size); + } + + return Status::OK(); +} + +// ==================== Agg state management ==================== + +Status InlineCountAggContext::create_agg_state(AggregateDataPtr /*data*/) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "InlineCountAggContext should never create agg state"); +} + +void InlineCountAggContext::destroy_agg_state(AggregateDataPtr /*data*/) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "InlineCountAggContext should never destroy agg state"); +} + +void InlineCountAggContext::close() { + // InlineCount stores UInt64 directly in mapped slots, not real agg state pointers. + // Skip agg state destruction — the hash table memory is managed by AggregatedDataVariants. +} + +Status InlineCountAggContext::reset_hash_table() { + return std::visit( + Overload { + [&](std::monostate& arg) -> Status { + return Status::InternalError("Uninited hash table"); + }, + [&](auto& agg_method) -> Status { + auto& hash_table = *agg_method.hash_table; + using HashTableType = std::decay_t; + + agg_method.arena.clear(); + agg_method.inited_iterator = false; + + // No agg state to destroy — mapped slots hold UInt64 counts. + // No AggregateDataContainer to reset either. + agg_method.hash_table.reset(new HashTableType()); + return Status::OK(); + }}, + _hash_table_data->method_variant); +} + +} // namespace doris diff --git a/be/src/exec/common/inline_count_agg_context.h b/be/src/exec/common/inline_count_agg_context.h new file mode 100644 index 00000000000000..aae136bb0d95eb --- /dev/null +++ b/be/src/exec/common/inline_count_agg_context.h @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "exec/common/groupby_agg_context.h" + +namespace doris { + +/// InlineCountAggContext: specialized subclass of GroupByAggContext for single count(*) +/// optimization. Instead of allocating full aggregate states, it stores a UInt64 counter +/// directly in the hash table's mapped slot (reinterpret_cast(mapped)). +/// +/// This avoids: +/// - AggregateDataContainer allocation (no separate agg state storage) +/// - AggFnEvaluator dispatch (direct UInt64 increment) +/// - Agg state create/destroy overhead +/// +/// Usage condition: exactly one evaluator, count(*), first-phase (not merge), is_simple_count. +class InlineCountAggContext final : public GroupByAggContext { +public: + using GroupByAggContext::GroupByAggContext; + + // ==================== Hash table write (override) ==================== + + /// Emplace keys, increment UInt64 count for each row (new key starts at 1). + void emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows, + RuntimeProfile::Counter* hash_table_compute_timer, + RuntimeProfile::Counter* hash_table_emplace_timer, + RuntimeProfile::Counter* hash_table_input_counter) override; + + // ==================== Aggregation execution (override) ==================== + + /// Execute path: emplace keys (count already incremented), skip evaluator execution. + Status execute_with_serialized_key(Block* block) override; + + /// Emplace only (count++ done internally), skip execute_batch_add. + Status emplace_and_forward(AggregateDataPtr* places, ColumnRawPtrs& key_columns, + uint32_t num_rows, Block* block, + bool expand_hash_table) override; + + /// Merge path: read count from ColumnFixedLengthObject (AggregateFunctionCountData) + /// and add to mapped UInt64. + Status merge_with_serialized_key(Block* block) override; + + // ==================== Result output (override) ==================== + + /// Serialize output: iterate hash table directly, output ColumnFixedLengthObject with UInt64. + Status get_serialized_results(RuntimeState* state, Block* block, bool* eos) override; + + /// Finalize output: iterate hash table directly, output ColumnInt64. + Status get_finalized_results(RuntimeState* state, Block* block, bool* eos, + const ColumnsWithTypeAndName& columns_with_schema) override; + + // ==================== Agg state management (override) ==================== + + /// InlineCount does NOT use aggregate states. Calling these is a logic bug. + Status create_agg_state(AggregateDataPtr data) override; + void destroy_agg_state(AggregateDataPtr data) override; + + /// Skip agg state destruction (mapped slots hold UInt64, not agg state pointers). + void close() override; + + /// Skip agg state destruction, just clear hash table. No AggregateDataContainer needed. + Status reset_hash_table() override; + + /// No-op: InlineCount does not use AggregateDataContainer. + void init_agg_data_container() override {} + +private: + /// Merge helper: emplace keys and add count values from a serialized count column. + void _merge_inline_count(ColumnRawPtrs& key_columns, const IColumn* merge_column, + uint32_t num_rows); +}; + +} // namespace doris diff --git a/be/src/exec/common/ungroupby_agg_context.cpp b/be/src/exec/common/ungroupby_agg_context.cpp new file mode 100644 index 00000000000000..b5de867b612a03 --- /dev/null +++ b/be/src/exec/common/ungroupby_agg_context.cpp @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exec/common/ungroupby_agg_context.h" + +#include "common/exception.h" +#include "exec/common/agg_context_utils.h" +#include "exec/common/util.hpp" +#include "exprs/vectorized_agg_fn.h" +#include "exprs/vslot_ref.h" +#include "runtime/descriptors.h" +#include "runtime/runtime_state.h" + +namespace doris { + +// ==================== Constructor / Destructor ==================== + +UngroupByAggContext::UngroupByAggContext(std::vector agg_evaluators, + Sizes agg_state_offsets, size_t total_agg_state_size, + size_t agg_state_alignment) + : _agg_evaluators(std::move(agg_evaluators)), + _agg_state_offsets(std::move(agg_state_offsets)), + _total_agg_state_size(total_agg_state_size), + _agg_state_alignment(agg_state_alignment) {} + +UngroupByAggContext::~UngroupByAggContext() = default; + +// ==================== Profile ==================== + +void UngroupByAggContext::init_profile(RuntimeProfile* profile) { + _build_timer = ADD_TIMER(profile, "BuildTime"); + _merge_timer = ADD_TIMER(profile, "MergeTime"); + _deserialize_data_timer = ADD_TIMER(profile, "DeserializeAndMergeTime"); + _get_results_timer = ADD_TIMER(profile, "GetResultsTime"); + + auto* memory_usage = + profile->create_child("MemoryUsage", true, true); + _memory_used_counter = profile->get_counter("MemoryUsage"); + _memory_usage_arena = ADD_COUNTER(memory_usage, "Arena", TUnit::BYTES); +} + +// ==================== Agg state management ==================== + +Status UngroupByAggContext::_create_agg_state() { + DCHECK(!_agg_state_created); + _agg_state_data = reinterpret_cast( + _alloc_arena.aligned_alloc(_total_agg_state_size, _agg_state_alignment)); + + for (int i = 0; i < _agg_evaluators.size(); ++i) { + try { + _agg_evaluators[i]->create(_agg_state_data + _agg_state_offsets[i]); + } catch (...) { + for (int j = 0; j < i; ++j) { + _agg_evaluators[j]->destroy(_agg_state_data + _agg_state_offsets[j]); + } + throw; + } + } + + _agg_state_created = true; + return Status::OK(); +} + +void UngroupByAggContext::_destroy_agg_state() { + if (!_agg_state_created) { + return; + } + for (int i = 0; i < _agg_evaluators.size(); ++i) { + _agg_evaluators[i]->function()->destroy(_agg_state_data + _agg_state_offsets[i]); + } + _agg_state_created = false; +} + +void UngroupByAggContext::close() { + _destroy_agg_state(); +} + +// ==================== Aggregation execution (Sink side) ==================== + +Status UngroupByAggContext::execute(Block* block) { + // Create agg state on first call (lazy init to match original behavior, which creates + // state in open() - here we ensure it's created before first use). + if (!_agg_state_created) { + RETURN_IF_ERROR(_create_agg_state()); + } + + DCHECK(_agg_state_data != nullptr); + SCOPED_TIMER(_build_timer); + memory_usage_last_executing = 0; + SCOPED_PEAK_MEM(&memory_usage_last_executing); + + for (int i = 0; i < _agg_evaluators.size(); ++i) { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_single_add( + block, _agg_state_data + _agg_state_offsets[i], _agg_arena)); + } + return Status::OK(); +} + +Status UngroupByAggContext::merge(Block* block) { + if (!_agg_state_created) { + RETURN_IF_ERROR(_create_agg_state()); + } + + SCOPED_TIMER(_merge_timer); + DCHECK(_agg_state_data != nullptr); + memory_usage_last_executing = 0; + SCOPED_PEAK_MEM(&memory_usage_last_executing); + + for (int i = 0; i < _agg_evaluators.size(); ++i) { + if (_agg_evaluators[i]->is_merge()) { + int col_id = _get_slot_column_id(_agg_evaluators[i]); + auto column = block->get_by_position(col_id).column; + + SCOPED_TIMER(_deserialize_data_timer); + _agg_evaluators[i]->function()->deserialize_and_merge_from_column( + _agg_state_data + _agg_state_offsets[i], *column, _agg_arena); + } else { + RETURN_IF_ERROR(_agg_evaluators[i]->execute_single_add( + block, _agg_state_data + _agg_state_offsets[i], _agg_arena)); + } + } + return Status::OK(); +} + +// ==================== Result output (Source side) ==================== + +Status UngroupByAggContext::get_serialized_result(RuntimeState* state, Block* block, bool* eos) { + SCOPED_TIMER(_get_results_timer); + + // Ensure agg state exists even if no data flowed through the sink. + if (!_agg_state_created) { + RETURN_IF_ERROR(_create_agg_state()); + } + + // If no data was ever fed, return empty result. + if (UNLIKELY(input_num_rows == 0)) { + *eos = true; + return Status::OK(); + } + block->clear(); + + DCHECK(_agg_state_data != nullptr); + size_t agg_size = _agg_evaluators.size(); + + MutableColumns value_columns(agg_size); + std::vector data_types(agg_size); + + for (int i = 0; i < agg_size; ++i) { + data_types[i] = _agg_evaluators[i]->function()->get_serialized_type(); + value_columns[i] = _agg_evaluators[i]->function()->create_serialize_column(); + } + + for (int i = 0; i < agg_size; ++i) { + _agg_evaluators[i]->function()->serialize_without_key_to_column( + _agg_state_data + _agg_state_offsets[i], *value_columns[i]); + } + + { + ColumnsWithTypeAndName data_with_schema; + for (int i = 0; i < agg_size; ++i) { + ColumnWithTypeAndName column_with_schema = {nullptr, data_types[i], ""}; + data_with_schema.push_back(std::move(column_with_schema)); + } + *block = Block(data_with_schema); + } + + block->set_columns(std::move(value_columns)); + *eos = true; + return Status::OK(); +} + +Status UngroupByAggContext::get_finalized_result(RuntimeState* state, Block* block, bool* eos, + const RowDescriptor& row_desc) { + // Ensure agg state exists even if no data flowed through the sink. + // Without GROUP BY, aggregation always produces one row (e.g., COUNT(*) → 0). + if (!_agg_state_created) { + RETURN_IF_ERROR(_create_agg_state()); + } + DCHECK(_agg_state_data != nullptr); + block->clear(); + + *block = VectorizedUtils::create_empty_columnswithtypename(row_desc); + size_t agg_size = _agg_evaluators.size(); + + MutableColumns columns(agg_size); + std::vector data_types(agg_size); + for (int i = 0; i < agg_size; ++i) { + data_types[i] = _agg_evaluators[i]->function()->get_return_type(); + columns[i] = data_types[i]->create_column(); + } + + for (int i = 0; i < agg_size; ++i) { + auto column = columns[i].get(); + _agg_evaluators[i]->insert_result_info(_agg_state_data + _agg_state_offsets[i], column); + } + + const auto& block_schema = block->get_columns_with_type_and_name(); + DCHECK_EQ(block_schema.size(), columns.size()); + for (int i = 0; i < block_schema.size(); ++i) { + const auto column_type = block_schema[i].type; + if (!column_type->equals(*data_types[i])) { + if (column_type->get_primitive_type() != TYPE_ARRAY) { + if (!column_type->is_nullable() || data_types[i]->is_nullable() || + !remove_nullable(column_type)->equals(*data_types[i])) { + return Status::InternalError( + "column_type not match data_types, column_type={}, data_types={}", + column_type->get_name(), data_types[i]->get_name()); + } + } + + // Result of operator is nullable, but aggregate function result is not nullable. + // This happens when: 1) no group by, 2) input empty, 3) all input columns not nullable. + if (column_type->is_nullable() && !data_types[i]->is_nullable()) { + ColumnPtr ptr = std::move(columns[i]); + // Unless count, other aggregate functions on empty set should produce null. + ptr = make_nullable(ptr, input_num_rows == 0); + columns[i] = ptr->assume_mutable(); + } + } + } + + block->set_columns(std::move(columns)); + *eos = true; + return Status::OK(); +} + +// ==================== Utilities ==================== + +void UngroupByAggContext::update_memusage() { + int64_t arena_memory_usage = _agg_arena.size(); + if (_memory_used_counter) { + COUNTER_SET(_memory_used_counter, arena_memory_usage); + } + if (_memory_usage_arena) { + COUNTER_SET(_memory_usage_arena, arena_memory_usage); + } +} + +int UngroupByAggContext::_get_slot_column_id(const AggFnEvaluator* evaluator) { + return agg_context_utils::get_slot_column_id(evaluator); +} + +} // namespace doris diff --git a/be/src/exec/common/ungroupby_agg_context.h b/be/src/exec/common/ungroupby_agg_context.h new file mode 100644 index 00000000000000..995bbb3c5152c1 --- /dev/null +++ b/be/src/exec/common/ungroupby_agg_context.h @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/status.h" +#include "core/arena.h" +#include "core/block/block.h" +#include "runtime/runtime_profile.h" + +namespace doris { + +class AggFnEvaluator; +class RuntimeState; +class RowDescriptor; +using AggregateDataPtr = char*; +using Sizes = std::vector; + +/// UngroupByAggContext encapsulates aggregation logic for queries WITHOUT GROUP BY. +/// There is no hash table — only a single AggregateDataPtr pointing to one agg state row. +/// +/// This context is used by AggSinkLocalState (sink side: execute/merge) and +/// AggLocalState (source side: get_serialized_result/get_finalized_result). +class UngroupByAggContext { +public: + UngroupByAggContext(std::vector agg_evaluators, Sizes agg_state_offsets, + size_t total_agg_state_size, size_t agg_state_alignment); + + ~UngroupByAggContext(); + + // ==================== Aggregation execution (Sink side) ==================== + + /// Update mode: execute_single_add for each evaluator. + Status execute(Block* block); + + /// Merge mode: deserialize_and_merge or execute_single_add depending on evaluator. + Status merge(Block* block); + + // ==================== Result output (Source side) ==================== + + /// Serialize mode: serialize agg state to output block (for non-finalize path). + Status get_serialized_result(RuntimeState* state, Block* block, bool* eos); + + /// Finalize mode: insert final result info to output block. + Status get_finalized_result(RuntimeState* state, Block* block, bool* eos, + const RowDescriptor& row_desc); + + // ==================== Utilities ==================== + + /// Update memory usage counters. + void update_memusage(); + + /// Create profile counters under the given profile. + void init_profile(RuntimeProfile* profile); + + /// Destroy agg state if created. Safe to call multiple times. + void close(); + + AggregateDataPtr agg_state_data() const { return _agg_state_data; } + Arena& agg_arena() { return _agg_arena; } + std::vector& agg_evaluators() { return _agg_evaluators; } + + /// Track total input rows (used by source to detect empty input). + size_t input_num_rows = 0; + + /// Memory tracking for reserve estimation (Sink side). + int64_t memory_usage_last_executing = 0; + + // Profile timer accessors (for SCOPED_TIMER in operator code) + RuntimeProfile::Counter* build_timer() const { return _build_timer; } + RuntimeProfile::Counter* merge_timer() const { return _merge_timer; } + RuntimeProfile::Counter* deserialize_data_timer() const { return _deserialize_data_timer; } + RuntimeProfile::Counter* get_results_timer() const { return _get_results_timer; } + +private: + /// Allocate and initialize aggregate states for all evaluators. + Status _create_agg_state(); + + /// Destroy aggregate states for all evaluators. + void _destroy_agg_state(); + + /// Get the input column id for a merge-mode evaluator (same logic as GroupByAggContext). + static int _get_slot_column_id(const AggFnEvaluator* evaluator); + + AggregateDataPtr _agg_state_data = nullptr; + Arena _agg_arena; + Arena _alloc_arena; // used to allocate the agg state memory block + + std::vector _agg_evaluators; + Sizes _agg_state_offsets; + size_t _total_agg_state_size; + size_t _agg_state_alignment; + bool _agg_state_created = false; + + // ---- Profile counters (created by init_profile) ---- + RuntimeProfile::Counter* _build_timer = nullptr; + RuntimeProfile::Counter* _merge_timer = nullptr; + RuntimeProfile::Counter* _deserialize_data_timer = nullptr; + RuntimeProfile::Counter* _get_results_timer = nullptr; + RuntimeProfile::Counter* _memory_used_counter = nullptr; + RuntimeProfile::Counter* _memory_usage_arena = nullptr; +}; + +} // namespace doris diff --git a/be/src/exec/operator/aggregation_sink_operator.cpp b/be/src/exec/operator/aggregation_sink_operator.cpp index 0f7505423a3552..009a7fee018ff3 100644 --- a/be/src/exec/operator/aggregation_sink_operator.cpp +++ b/be/src/exec/operator/aggregation_sink_operator.cpp @@ -22,14 +22,11 @@ #include "common/cast_set.h" #include "common/status.h" -#include "core/data_type/primitive_type.h" -#include "exec/common/hash_table/hash.h" -#include "exec/operator/operator.h" -#include "exprs/aggregate/aggregate_function_count.h" +#include "exec/common/inline_count_agg_context.h" #include "exprs/aggregate/aggregate_function_simple_factory.h" +#include "exec/operator/operator.h" #include "exprs/vectorized_agg_fn.h" #include "runtime/runtime_profile.h" -#include "runtime/thread_context.h" namespace doris { #include "common/compile_check_begin.h" @@ -59,26 +56,6 @@ Status AggSinkLocalState::init(RuntimeState* state, LocalSinkStateInfo& info) { RETURN_IF_ERROR(Base::init(state, info)); SCOPED_TIMER(Base::exec_time_counter()); SCOPED_TIMER(Base::_init_timer); - _agg_data = Base::_shared_state->agg_data.get(); - _hash_table_size_counter = ADD_COUNTER(custom_profile(), "HashTableSize", TUnit::UNIT); - _hash_table_memory_usage = - ADD_COUNTER_WITH_LEVEL(Base::custom_profile(), "MemoryUsageHashTable", TUnit::BYTES, 1); - _serialize_key_arena_memory_usage = ADD_COUNTER_WITH_LEVEL( - Base::custom_profile(), "MemoryUsageSerializeKeyArena", TUnit::BYTES, 1); - - _build_timer = ADD_TIMER(Base::custom_profile(), "BuildTime"); - _merge_timer = ADD_TIMER(Base::custom_profile(), "MergeTime"); - _expr_timer = ADD_TIMER(Base::custom_profile(), "ExprTime"); - _deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime"); - _hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime"); - _hash_table_limit_compute_timer = ADD_TIMER(Base::custom_profile(), "DoLimitComputeTime"); - _hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime"); - _hash_table_input_counter = - ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT); - - _memory_usage_container = ADD_COUNTER(custom_profile(), "MemoryUsageContainer", TUnit::BYTES); - _memory_usage_arena = ADD_COUNTER(custom_profile(), "MemoryUsageArena", TUnit::BYTES); - return Status::OK(); } @@ -87,740 +64,135 @@ Status AggSinkLocalState::open(RuntimeState* state) { SCOPED_TIMER(Base::_open_timer); RETURN_IF_ERROR(Base::open(state)); auto& p = Base::_parent->template cast(); - Base::_shared_state->align_aggregate_states = p._align_aggregate_states; - Base::_shared_state->total_size_of_aggregate_states = p._total_size_of_aggregate_states; - Base::_shared_state->offsets_of_aggregate_states = p._offsets_of_aggregate_states; Base::_shared_state->make_nullable_keys = p._make_nullable_keys; - Base::_shared_state->probe_expr_ctxs.resize(p._probe_expr_ctxs.size()); - - Base::_shared_state->limit = p._limit; - Base::_shared_state->do_sort_limit = p._do_sort_limit; - Base::_shared_state->null_directions = p._null_directions; - Base::_shared_state->order_directions = p._order_directions; - for (size_t i = 0; i < Base::_shared_state->probe_expr_ctxs.size(); i++) { - RETURN_IF_ERROR( - p._probe_expr_ctxs[i]->clone(state, Base::_shared_state->probe_expr_ctxs[i])); - } - - if (Base::_shared_state->probe_expr_ctxs.empty()) { - _agg_data->without_key = reinterpret_cast( - Base::_shared_state->agg_profile_arena.aligned_alloc( - p._total_size_of_aggregate_states, p._align_aggregate_states)); + if (p._probe_expr_ctxs.empty()) { + // ── Without GROUP BY → create UngroupByAggContext ── + std::vector evaluators; + for (auto& evaluator : p._aggregate_evaluators) { + evaluators.push_back(evaluator->clone(state, p._pool)); + } + auto ctx = std::make_unique( + std::move(evaluators), p._offsets_of_aggregate_states, + p._total_size_of_aggregate_states, p._align_aggregate_states); + ctx->init_profile(custom_profile()); + for (auto& evaluator : ctx->agg_evaluators()) { + evaluator->set_timer(ctx->merge_timer(), nullptr); + } + Base::_shared_state->ungroupby_agg_ctx = std::move(ctx); if (p._is_merge) { _executor = std::make_unique>(); } else { _executor = std::make_unique>(); } } else { - RETURN_IF_ERROR(_init_hash_method(Base::_shared_state->probe_expr_ctxs)); - - std::visit(Overload {[&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) { - using HashTableType = std::decay_t; - using KeyType = typename HashTableType::Key; - - /// some aggregate functions (like AVG for decimal) have align issues. - Base::_shared_state->aggregate_data_container = - std::make_unique( - sizeof(KeyType), - ((p._total_size_of_aggregate_states + - p._align_aggregate_states - 1) / - p._align_aggregate_states) * - p._align_aggregate_states); - }}, - _agg_data->method_variant); - if (p._is_merge) { - _executor = std::make_unique>(); - } else { - _executor = std::make_unique>(); + // ── With GROUP BY → create GroupByAggContext or InlineCountAggContext ── + VExprContextSPtrs probe_expr_ctxs(p._probe_expr_ctxs.size()); + for (size_t i = 0; i < probe_expr_ctxs.size(); i++) { + RETURN_IF_ERROR(p._probe_expr_ctxs[i]->clone(state, probe_expr_ctxs[i])); + } + std::vector evaluators; + for (auto& evaluator : p._aggregate_evaluators) { + evaluators.push_back(evaluator->clone(state, p._pool)); } - _should_limit_output = p._limit != -1 && // has limit - (!p._have_conjuncts) && // no having conjunct - !Base::_shared_state->enable_spill; - } - for (auto& evaluator : p._aggregate_evaluators) { - Base::_shared_state->aggregate_evaluators.push_back(evaluator->clone(state, p._pool)); - } - for (auto& evaluator : Base::_shared_state->aggregate_evaluators) { - evaluator->set_timer(_merge_timer, _expr_timer); - } - // move _create_agg_status to open not in during prepare, - // because during prepare and open thread is not the same one, - // this could cause unable to get JVM - if (Base::_shared_state->probe_expr_ctxs.empty()) { - // _create_agg_status may acquire a lot of memory, may allocate failed when memory is very few - RETURN_IF_ERROR(_create_agg_status(_agg_data->without_key)); - _shared_state->agg_data_created_without_key = true; - } - - // Determine whether to use simple count aggregation. - // For queries like: SELECT xxx, count(*) / count(not_null_column) FROM table GROUP BY xxx, - // count(*) / count(not_null_column) can store a uint64 counter directly in the hash table, - // instead of storing the full aggregate state, saving memory and computation overhead. - // Requirements: - // 0. The aggregation has a GROUP BY clause. - // 1. There is exactly one count aggregate function. - // 2. No limit optimization is applied. - // 3. Spill is not enabled (the spill path accesses aggregate_data_container, which is empty in inline count mode). - // Supports update / merge / finalize / serialize phases, since count's serialization format is UInt64 itself. - - if (!Base::_shared_state->probe_expr_ctxs.empty() /* has GROUP BY */ - && (p._aggregate_evaluators.size() == 1 && - p._aggregate_evaluators[0]->function()->is_simple_count()) /* only one count(*) */ - && !_should_limit_output /* no limit optimization */ && - !Base::_shared_state->enable_spill /* spill not enabled */) { - _shared_state->use_simple_count = true; + bool should_limit_output = p._limit != -1 && !p._have_conjuncts && + !Base::_shared_state->enable_spill; + bool use_simple_count = p._aggregate_evaluators.size() == 1 && + p._aggregate_evaluators[0]->function()->is_simple_count() && + !should_limit_output && !Base::_shared_state->enable_spill; #ifndef NDEBUG // Randomly enable/disable in debug mode to verify correctness of multi-phase agg promotion/demotion. - _shared_state->use_simple_count = rand() % 2 == 0; -#endif - } - - return Status::OK(); -} - -Status AggSinkLocalState::_create_agg_status(AggregateDataPtr data) { - auto& shared_state = *Base::_shared_state; - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - try { - shared_state.aggregate_evaluators[i]->create( - data + shared_state.offsets_of_aggregate_states[i]); - } catch (...) { - for (int j = 0; j < i; ++j) { - shared_state.aggregate_evaluators[j]->destroy( - data + shared_state.offsets_of_aggregate_states[j]); - } - throw; + if (use_simple_count) { + use_simple_count = rand() % 2 == 0; } - } - return Status::OK(); -} - -Status AggSinkLocalState::_execute_without_key(Block* block) { - DCHECK(_agg_data->without_key != nullptr); - SCOPED_TIMER(_build_timer); - _memory_usage_last_executing = 0; - SCOPED_PEAK_MEM(&_memory_usage_last_executing); - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_single_add( - block, - _agg_data->without_key + Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - Base::_shared_state->agg_arena_pool)); - } - return Status::OK(); -} - -Status AggSinkLocalState::_merge_with_serialized_key(Block* block) { - _memory_usage_last_executing = 0; - SCOPED_PEAK_MEM(&_memory_usage_last_executing); - if (_shared_state->reach_limit) { - return _merge_with_serialized_key_helper(block); - } else { - return _merge_with_serialized_key_helper(block); - } -} - -size_t AggSinkLocalState::_memory_usage() const { - if (0 == get_hash_table_size()) { - return 0; - } - size_t usage = 0; - usage += Base::_shared_state->agg_arena_pool.size(); - - if (Base::_shared_state->aggregate_data_container) { - usage += Base::_shared_state->aggregate_data_container->memory_usage(); - } - - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - auto data = agg_method.hash_table; - usage += data->get_buffer_size_in_bytes(); - }}, - _agg_data->method_variant); - - return usage; -} - -bool AggSinkLocalState::is_blockable() const { - return std::any_of(Base::_shared_state->aggregate_evaluators.begin(), - Base::_shared_state->aggregate_evaluators.end(), - [](const AggFnEvaluator* evaluator) { return evaluator->is_blockable(); }); -} - -void AggSinkLocalState::_update_memusage_with_serialized_key() { - std::visit( - Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) -> void { - auto& data = *agg_method.hash_table; - int64_t memory_usage_arena = Base::_shared_state->agg_arena_pool.size(); - int64_t memory_usage_container = - _shared_state->aggregate_data_container->memory_usage(); - int64_t hash_table_memory_usage = data.get_buffer_size_in_bytes(); - - COUNTER_SET(_memory_usage_arena, memory_usage_arena); - COUNTER_SET(_memory_usage_container, memory_usage_container); - COUNTER_SET(_hash_table_memory_usage, hash_table_memory_usage); - COUNTER_SET(_serialize_key_arena_memory_usage, - memory_usage_arena + memory_usage_container); - - COUNTER_SET(_memory_used_counter, memory_usage_arena + - memory_usage_container + - hash_table_memory_usage); - }}, - _agg_data->method_variant); -} - -Status AggSinkLocalState::_destroy_agg_status(AggregateDataPtr data) { - auto& shared_state = *Base::_shared_state; - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - shared_state.aggregate_evaluators[i]->function()->destroy( - data + shared_state.offsets_of_aggregate_states[i]); - } - return Status::OK(); -} - -template -Status AggSinkLocalState::_merge_with_serialized_key_helper(Block* block) { - SCOPED_TIMER(_merge_timer); - - size_t key_size = Base::_shared_state->probe_expr_ctxs.size(); - ColumnRawPtrs key_columns(key_size); - std::vector key_locs(key_size); - - for (int i = 0; i < key_size; ++i) { - if constexpr (for_spill) { - key_columns[i] = block->get_by_position(i).column.get(); - key_locs[i] = i; - } else { - int& result_column_id = key_locs[i]; - RETURN_IF_ERROR( - Base::_shared_state->probe_expr_ctxs[i]->execute(block, &result_column_id)); - block->replace_by_position_if_const(result_column_id); - key_columns[i] = block->get_by_position(result_column_id).column.get(); - } - key_columns[i]->assume_mutable()->replace_float_special_values(); - } - - size_t rows = block->rows(); - if (_places.size() < rows) { - _places.resize(rows); - } - - if (limit && !_shared_state->do_sort_limit) { - _find_in_hash_table(_places.data(), key_columns, (uint32_t)rows); - - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - if (Base::_shared_state->aggregate_evaluators[i]->is_merge()) { - int col_id = AggSharedState::get_slot_column_id( - Base::_shared_state->aggregate_evaluators[i]); - auto column = block->get_by_position(col_id).column; - - size_t buffer_size = - Base::_shared_state->aggregate_evaluators[i]->function()->size_of_data() * - rows; - if (_deserialize_buffer.size() < buffer_size) { - _deserialize_buffer.resize(buffer_size); - } +#endif - { - SCOPED_TIMER(_deserialize_data_timer); - Base::_shared_state->aggregate_evaluators[i] - ->function() - ->deserialize_and_merge_vec_selected( - _places.data(), - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - _deserialize_buffer.data(), column.get(), - Base::_shared_state->agg_arena_pool, rows); - } - } else { - RETURN_IF_ERROR( - Base::_shared_state->aggregate_evaluators[i]->execute_batch_add_selected( - block, - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - _places.data(), Base::_shared_state->agg_arena_pool)); - } - } - } else { - bool need_do_agg = true; - if (limit) { - need_do_agg = _emplace_into_hash_table_limit(_places.data(), block, key_locs, - key_columns, (uint32_t)rows); - rows = block->rows(); + std::unique_ptr ctx; + if (use_simple_count) { + ctx = std::make_unique( + std::move(evaluators), std::move(probe_expr_ctxs), + p._offsets_of_aggregate_states, p._total_size_of_aggregate_states, + p._align_aggregate_states, p._is_first_phase); } else { - if (_shared_state->use_simple_count) { - DCHECK(!for_spill); - - auto col_id = AggSharedState::get_slot_column_id( - Base::_shared_state->aggregate_evaluators[0]); - - auto column = block->get_by_position(col_id).column; - _merge_into_hash_table_inline_count(key_columns, column.get(), (uint32_t)rows); - need_do_agg = false; - } else { - _emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows); - } + ctx = std::make_unique( + std::move(evaluators), std::move(probe_expr_ctxs), + p._offsets_of_aggregate_states, p._total_size_of_aggregate_states, + p._align_aggregate_states, p._is_first_phase); } - if (need_do_agg) { - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - if (Base::_shared_state->aggregate_evaluators[i]->is_merge() || for_spill) { - size_t col_id = 0; - if constexpr (for_spill) { - col_id = Base::_shared_state->probe_expr_ctxs.size() + i; - } else { - col_id = AggSharedState::get_slot_column_id( - Base::_shared_state->aggregate_evaluators[i]); - } - auto column = block->get_by_position(col_id).column; - - size_t buffer_size = Base::_shared_state->aggregate_evaluators[i] - ->function() - ->size_of_data() * - rows; - if (_deserialize_buffer.size() < buffer_size) { - _deserialize_buffer.resize(buffer_size); - } - - { - SCOPED_TIMER(_deserialize_data_timer); - Base::_shared_state->aggregate_evaluators[i] - ->function() - ->deserialize_and_merge_vec( - _places.data(), - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - _deserialize_buffer.data(), column.get(), - Base::_shared_state->agg_arena_pool, rows); - } - } else { - RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add( - block, - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - _places.data(), Base::_shared_state->agg_arena_pool)); - } - } - } - - if (!limit && _should_limit_output) { - const size_t hash_table_size = get_hash_table_size(); - _shared_state->reach_limit = - hash_table_size >= Base::_parent->template cast()._limit; - if (_shared_state->do_sort_limit && _shared_state->reach_limit) { - _shared_state->build_limit_heap(hash_table_size); - } + ctx->limit = p._limit; + ctx->do_sort_limit = p._do_sort_limit; + ctx->order_directions = p._order_directions; + ctx->null_directions = p._null_directions; + ctx->should_limit_output = should_limit_output; + ctx->enable_spill = Base::_shared_state->enable_spill; + ctx->make_nullable_keys = p._make_nullable_keys; + + ctx->init_hash_method(); + ctx->init_agg_data_container(); + ctx->init_sink_profile(custom_profile()); + for (auto& evaluator : ctx->agg_evaluators()) { + evaluator->set_timer(ctx->merge_timer(), ctx->expr_timer()); } - } - - return Status::OK(); -} -Status AggSinkLocalState::_merge_without_key(Block* block) { - SCOPED_TIMER(_merge_timer); - DCHECK(_agg_data->without_key != nullptr); - - _memory_usage_last_executing = 0; - SCOPED_PEAK_MEM(&_memory_usage_last_executing); - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - if (Base::_shared_state->aggregate_evaluators[i]->is_merge()) { - int col_id = AggSharedState::get_slot_column_id( - Base::_shared_state->aggregate_evaluators[i]); - auto column = block->get_by_position(col_id).column; - - SCOPED_TIMER(_deserialize_data_timer); - Base::_shared_state->aggregate_evaluators[i] - ->function() - ->deserialize_and_merge_from_column( - _agg_data->without_key + - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - *column, Base::_shared_state->agg_arena_pool); + Base::_shared_state->groupby_agg_ctx = std::move(ctx); + if (p._is_merge) { + _executor = std::make_unique>(); } else { - RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_single_add( - block, - _agg_data->without_key + Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - Base::_shared_state->agg_arena_pool)); + _executor = std::make_unique>(); } } return Status::OK(); } -void AggSinkLocalState::_update_memusage_without_key() { - int64_t arena_memory_usage = Base::_shared_state->agg_arena_pool.size(); - COUNTER_SET(_memory_used_counter, arena_memory_usage); - COUNTER_SET(_serialize_key_arena_memory_usage, arena_memory_usage); -} - -Status AggSinkLocalState::_execute_with_serialized_key(Block* block) { - _memory_usage_last_executing = 0; - SCOPED_PEAK_MEM(&_memory_usage_last_executing); - if (_shared_state->reach_limit) { - return _execute_with_serialized_key_helper(block); - } else { - return _execute_with_serialized_key_helper(block); - } -} - -template -Status AggSinkLocalState::_execute_with_serialized_key_helper(Block* block) { - SCOPED_TIMER(_build_timer); - DCHECK(!Base::_shared_state->probe_expr_ctxs.empty()); - - size_t key_size = Base::_shared_state->probe_expr_ctxs.size(); - ColumnRawPtrs key_columns(key_size); - std::vector key_locs(key_size); - { - SCOPED_TIMER(_expr_timer); - for (size_t i = 0; i < key_size; ++i) { - int& result_column_id = key_locs[i]; - RETURN_IF_ERROR( - Base::_shared_state->probe_expr_ctxs[i]->execute(block, &result_column_id)); - block->get_by_position(result_column_id).column = - block->get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - key_columns[i] = block->get_by_position(result_column_id).column.get(); - key_columns[i]->assume_mutable()->replace_float_special_values(); - } - } - - auto rows = (uint32_t)block->rows(); - if (_places.size() < rows) { - _places.resize(rows); - } - - if (limit && !_shared_state->do_sort_limit) { - _find_in_hash_table(_places.data(), key_columns, rows); - - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR( - Base::_shared_state->aggregate_evaluators[i]->execute_batch_add_selected( - block, - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - _places.data(), Base::_shared_state->agg_arena_pool)); +template +Status AggSinkLocalState::Executor::execute( + AggSinkLocalState* local_state, Block* block) { + if constexpr (WithoutKey) { + auto* ctx = local_state->Base::_shared_state->ungroupby_agg_ctx.get(); + ctx->input_num_rows += block->rows(); + if constexpr (NeedToMerge) { + return ctx->merge(block); + } else { + return ctx->execute(block); } } else { - auto do_aggregate_evaluators = [&] { - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add( - block, - Base::_parent->template cast() - ._offsets_of_aggregate_states[i], - _places.data(), Base::_shared_state->agg_arena_pool)); - } - return Status::OK(); - }; - - if constexpr (limit) { - if (_emplace_into_hash_table_limit(_places.data(), block, key_locs, key_columns, - rows)) { - RETURN_IF_ERROR(do_aggregate_evaluators()); - } + auto* ctx = local_state->Base::_shared_state->groupby_agg_ctx.get(); + if constexpr (NeedToMerge) { + return ctx->merge_with_serialized_key(block); } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); - if (!_shared_state->use_simple_count) { - RETURN_IF_ERROR(do_aggregate_evaluators()); - } - - if (_should_limit_output && !Base::_shared_state->enable_spill) { - const size_t hash_table_size = get_hash_table_size(); - - _shared_state->reach_limit = - hash_table_size >= - (_shared_state->do_sort_limit - ? Base::_parent->template cast()._limit * - config::topn_agg_limit_multiplier - : Base::_parent->template cast()._limit); - if (_shared_state->reach_limit && _shared_state->do_sort_limit) { - _shared_state->build_limit_heap(hash_table_size); - } - } + return ctx->execute_with_serialized_key(block); } } - return Status::OK(); -} - -size_t AggSinkLocalState::get_hash_table_size() const { - return std::visit(Overload {[&](std::monostate& arg) -> size_t { return 0; }, - [&](auto& agg_method) { return agg_method.hash_table->size(); }}, - _agg_data->method_variant); } -void AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places, - ColumnRawPtrs& key_columns, uint32_t num_rows) { - if (_shared_state->use_simple_count) { - _emplace_into_hash_table_inline_count(key_columns, num_rows); - return; +template +void AggSinkLocalState::Executor::update_memusage( + AggSinkLocalState* local_state) { + if constexpr (WithoutKey) { + local_state->Base::_shared_state->ungroupby_agg_ctx->update_memusage(); + } else { + local_state->Base::_shared_state->groupby_agg_ctx->update_memusage(); } - - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - auto creator = [this](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin( - key, origin, Base::_shared_state->agg_arena_pool); - auto mapped = - Base::_shared_state->aggregate_data_container->append_data( - origin); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { - mapped = Base::_shared_state->agg_arena_pool.aligned_alloc( - Base::_parent->template cast() - ._total_size_of_aggregate_states, - Base::_parent->template cast() - ._align_aggregate_states); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch( - agg_method, state, num_rows, creator, creator_for_null_key, - [&](uint32_t row, auto& mapped) { places[row] = mapped; }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - }}, - _agg_data->method_variant); -} - -// For the agg hashmap, the value is a char* type which is exactly 64 bits. -// Here we treat it as a uint64 counter: each time the same key is encountered, the counter -// is incremented by 1. This avoids storing the full aggregate state, saving memory and computation overhead. -void AggSinkLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, - uint32_t num_rows) { - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - auto creator = [&](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin( - key, origin, Base::_shared_state->agg_arena_pool); - AggregateDataPtr mapped = nullptr; - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch(agg_method, state, num_rows, creator, - creator_for_null_key, [&](uint32_t, auto& mapped) { - ++reinterpret_cast(mapped); - }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - }}, - _agg_data->method_variant); -} - -void AggSinkLocalState::_merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns, - const IColumn* merge_column, - uint32_t num_rows) { - std::visit( - Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - const auto& col = - assert_cast(*merge_column); - const auto* col_data = - reinterpret_cast( - col.get_data().data()); - - auto creator = [&](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin( - key, origin, Base::_shared_state->agg_arena_pool); - AggregateDataPtr mapped = nullptr; - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch(agg_method, state, num_rows, creator, - creator_for_null_key, [&](uint32_t i, auto& mapped) { - reinterpret_cast(mapped) += - col_data[i].count; - }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - }}, - _agg_data->method_variant); } -bool AggSinkLocalState::_emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, - const std::vector& key_locs, - ColumnRawPtrs& key_columns, - uint32_t num_rows) { - return std::visit( - Overload {[&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - return true; - }, - [&](auto&& agg_method) -> bool { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; +// Explicit template instantiations +template struct AggSinkLocalState::Executor; +template struct AggSinkLocalState::Executor; +template struct AggSinkLocalState::Executor; +template struct AggSinkLocalState::Executor; - bool need_filter = false; - { - SCOPED_TIMER(_hash_table_limit_compute_timer); - need_filter = - _shared_state->do_limit_filter(block, num_rows, &key_locs); - } - - auto& need_computes = _shared_state->need_computes; - if (auto need_agg = - std::find(need_computes.begin(), need_computes.end(), 1); - need_agg != need_computes.end()) { - if (need_filter) { - Block::filter_block_internal(block, need_computes); - for (int i = 0; i < key_locs.size(); ++i) { - key_columns[i] = - block->get_by_position(key_locs[i]).column.get(); - } - num_rows = (uint32_t)block->rows(); - } - - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - size_t i = 0; - - auto creator = [&](const auto& ctor, auto& key, auto& origin) { - try { - HashMethodType::try_presis_key_and_origin( - key, origin, Base::_shared_state->agg_arena_pool); - _shared_state->refresh_top_limit(i, key_columns); - auto mapped = - _shared_state->aggregate_data_container->append_data( - origin); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - ctor(key, mapped); - } catch (...) { - // Exception-safety - if it can not allocate memory or create status, - // the destructors will not be called. - ctor(key, nullptr); - throw; - } - }; - - auto creator_for_null_key = [&](auto& mapped) { - mapped = Base::_shared_state->agg_arena_pool.aligned_alloc( - Base::_parent->template cast() - ._total_size_of_aggregate_states, - Base::_parent->template cast() - ._align_aggregate_states); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - _shared_state->refresh_top_limit(i, key_columns); - }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch( - agg_method, state, num_rows, creator, creator_for_null_key, - [&](uint32_t row) { i = row; }, - [&](uint32_t row, auto& mapped) { places[row] = mapped; }); - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - return true; - } - return false; - }}, - _agg_data->method_variant); -} - -void AggSinkLocalState::_find_in_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, - uint32_t num_rows) { - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - /// For all rows. - find_batch(agg_method, state, num_rows, - [&](uint32_t row, auto& find_result) { - if (find_result.is_found()) { - places[row] = find_result.get_mapped(); - } else { - places[row] = nullptr; - } - }); - }}, - _agg_data->method_variant); -} - -Status AggSinkLocalState::_init_hash_method(const VExprContextSPtrs& probe_exprs) { - RETURN_IF_ERROR(init_hash_method( - _agg_data, get_data_types(probe_exprs), - Base::_parent->template cast()._is_first_phase)); - return Status::OK(); -} - -size_t AggSinkLocalState::get_reserve_mem_size(RuntimeState* state, bool eos) const { - size_t size_to_reserve = std::visit( - [&](auto&& arg) -> size_t { - using HashTableCtxType = std::decay_t; - if constexpr (std::is_same_v) { - return 0; - } else { - return arg.hash_table->estimate_memory(state->batch_size()); - } - }, - _agg_data->method_variant); - - size_to_reserve += _memory_usage_last_executing; - return size_to_reserve; +bool AggSinkLocalState::is_blockable() const { + if (auto* ctx = Base::_shared_state->groupby_agg_ctx.get()) { + return std::any_of(ctx->agg_evaluators().begin(), ctx->agg_evaluators().end(), + [](const AggFnEvaluator* evaluator) { return evaluator->is_blockable(); }); + } + if (auto* ctx = Base::_shared_state->ungroupby_agg_ctx.get()) { + return std::any_of(ctx->agg_evaluators().begin(), ctx->agg_evaluators().end(), + [](const AggFnEvaluator* evaluator) { return evaluator->is_blockable(); }); + } + return false; } // TODO: Tricky processing if `multi_distinct_` exists which will be re-planed by optimizer. @@ -987,12 +359,9 @@ Status AggSinkOperatorX::sink(doris::RuntimeState* state, Block* in_block, bool auto& local_state = get_local_state(state); SCOPED_TIMER(local_state.exec_time_counter()); COUNTER_UPDATE(local_state.rows_input_counter(), (int64_t)in_block->rows()); - local_state._shared_state->input_num_rows += in_block->rows(); if (in_block->rows() > 0) { RETURN_IF_ERROR(local_state._executor->execute(&local_state, in_block)); local_state._executor->update_memusage(&local_state); - COUNTER_SET(local_state._hash_table_size_counter, - (int64_t)local_state.get_hash_table_size()); } if (eos) { local_state._dependency->set_ready_to_read(); @@ -1002,26 +371,27 @@ Status AggSinkOperatorX::sink(doris::RuntimeState* state, Block* in_block, bool size_t AggSinkOperatorX::get_revocable_mem_size(RuntimeState* state) const { auto& local_state = get_local_state(state); - return local_state._memory_usage(); + auto* ctx = local_state.Base::_shared_state->groupby_agg_ctx.get(); + return ctx ? ctx->memory_usage() : 0; } Status AggSinkOperatorX::reset_hash_table(RuntimeState* state) { auto& local_state = get_local_state(state); - auto& ss = *local_state.Base::_shared_state; - RETURN_IF_ERROR(ss.reset_hash_table()); - local_state._serialize_key_arena_memory_usage->set((int64_t)0); - local_state.Base::_shared_state->agg_arena_pool.clear(true); - return Status::OK(); + auto* ctx = local_state.Base::_shared_state->groupby_agg_ctx.get(); + DCHECK(ctx); + return ctx->reset_hash_table(); } size_t AggSinkOperatorX::get_reserve_mem_size(RuntimeState* state, bool eos) { auto& local_state = get_local_state(state); - return local_state.get_reserve_mem_size(state, eos); + auto* ctx = local_state.Base::_shared_state->groupby_agg_ctx.get(); + return ctx ? ctx->get_reserve_mem_size(state) : 0; } size_t AggSinkOperatorX::get_hash_table_size(RuntimeState* state) const { auto& local_state = get_local_state(state); - return local_state.get_hash_table_size(); + auto* ctx = local_state.Base::_shared_state->groupby_agg_ctx.get(); + return ctx ? ctx->hash_table_size() : 0; } Status AggSinkLocalState::close(RuntimeState* state, Status exec_status) { @@ -1031,11 +401,6 @@ Status AggSinkLocalState::close(RuntimeState* state, Status exec_status) { return Status::OK(); } _preagg_block.clear(); - PODArray tmp_places; - _places.swap(tmp_places); - - std::vector tmp_deserialize_buffer; - _deserialize_buffer.swap(tmp_deserialize_buffer); return Base::close(state, exec_status); } diff --git a/be/src/exec/operator/aggregation_sink_operator.h b/be/src/exec/operator/aggregation_sink_operator.h index 0a7067ecb4130a..e4f4d8e5827b51 100644 --- a/be/src/exec/operator/aggregation_sink_operator.h +++ b/be/src/exec/operator/aggregation_sink_operator.h @@ -38,7 +38,6 @@ class AggSinkLocalState : public PipelineXSinkLocalState { Status open(RuntimeState* state) override; Status close(RuntimeState* state, Status exec_status) override; bool is_blockable() const override; - size_t get_hash_table_size() const; protected: friend class AggSinkOperatorX; @@ -50,87 +49,12 @@ class AggSinkLocalState : public PipelineXSinkLocalState { }; template struct Executor final : public ExecutorBase { - Status execute(AggSinkLocalState* local_state, Block* block) override { - if constexpr (WithoutKey) { - if constexpr (NeedToMerge) { - return local_state->_merge_without_key(block); - } else { - return local_state->_execute_without_key(block); - } - } else { - if constexpr (NeedToMerge) { - return local_state->_merge_with_serialized_key(block); - } else { - return local_state->_execute_with_serialized_key(block); - } - } - } - void update_memusage(AggSinkLocalState* local_state) override { - if constexpr (WithoutKey) { - local_state->_update_memusage_without_key(); - } else { - local_state->_update_memusage_with_serialized_key(); - } - } + Status execute(AggSinkLocalState* local_state, Block* block) override; + void update_memusage(AggSinkLocalState* local_state) override; }; - Status _execute_without_key(Block* block); - Status _merge_without_key(Block* block); - void _update_memusage_without_key(); - Status _init_hash_method(const VExprContextSPtrs& probe_exprs); - Status _execute_with_serialized_key(Block* block); - Status _merge_with_serialized_key(Block* block); - void _update_memusage_with_serialized_key(); - template - - Status _execute_with_serialized_key_helper(Block* block); - void _find_in_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, - uint32_t num_rows); - void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, - uint32_t num_rows); - - void _emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, uint32_t num_rows); - void _merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns, - const IColumn* merge_column, uint32_t num_rows); - bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, - const std::vector& key_locs, - ColumnRawPtrs& key_columns, uint32_t num_rows); - - template - Status _merge_with_serialized_key_helper(Block* block); - - Status _destroy_agg_status(AggregateDataPtr data); - Status _create_agg_status(AggregateDataPtr data); - size_t _memory_usage() const; - - size_t get_reserve_mem_size(RuntimeState* state, bool eos) const; - - RuntimeProfile::Counter* _hash_table_compute_timer = nullptr; - RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr; - RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr; - RuntimeProfile::Counter* _hash_table_input_counter = nullptr; - RuntimeProfile::Counter* _build_timer = nullptr; - RuntimeProfile::Counter* _expr_timer = nullptr; - RuntimeProfile::Counter* _merge_timer = nullptr; - RuntimeProfile::Counter* _deserialize_data_timer = nullptr; - RuntimeProfile::Counter* _hash_table_memory_usage = nullptr; - RuntimeProfile::Counter* _hash_table_size_counter = nullptr; - RuntimeProfile::Counter* _serialize_key_arena_memory_usage = nullptr; - RuntimeProfile::Counter* _memory_usage_container = nullptr; - RuntimeProfile::Counter* _memory_usage_arena = nullptr; - - bool _should_limit_output = false; - - PODArray _places; - std::vector _deserialize_buffer; - Block _preagg_block; - - AggregatedDataVariants* _agg_data = nullptr; - std::unique_ptr _executor = nullptr; - - int64_t _memory_usage_last_executing = 0; }; class AggSinkOperatorX MOCK_REMOVE(final) : public DataSinkOperatorX { @@ -176,7 +100,7 @@ class AggSinkOperatorX MOCK_REMOVE(final) : public DataSinkOperatorXgroupby_agg_ctx->hash_table_data(); } Status reset_hash_table(RuntimeState* state); diff --git a/be/src/exec/operator/aggregation_source_operator.cpp b/be/src/exec/operator/aggregation_source_operator.cpp index faeeb1a93d2c27..fb739b3635f398 100644 --- a/be/src/exec/operator/aggregation_source_operator.cpp +++ b/be/src/exec/operator/aggregation_source_operator.cpp @@ -21,12 +21,12 @@ #include #include "common/exception.h" -#include "core/column/column_fixed_length_object.h" +#include "exec/common/groupby_agg_context.h" +#include "exec/common/ungroupby_agg_context.h" #include "exec/operator/operator.h" #include "exprs/vectorized_agg_fn.h" -#include "exprs/vexpr_fwd.h" +#include "runtime/descriptors.h" #include "runtime/runtime_profile.h" -#include "runtime/thread_context.h" namespace doris { #include "common/compile_check_begin.h" @@ -37,44 +37,38 @@ Status AggLocalState::init(RuntimeState* state, LocalStateInfo& info) { RETURN_IF_ERROR(Base::init(state, info)); SCOPED_TIMER(exec_time_counter()); SCOPED_TIMER(_init_timer); - _get_results_timer = ADD_TIMER(custom_profile(), "GetResultsTime"); - _hash_table_iterate_timer = ADD_TIMER(custom_profile(), "HashTableIterateTime"); - _insert_keys_to_column_timer = ADD_TIMER(custom_profile(), "InsertKeysToColumnTime"); - _insert_values_to_column_timer = ADD_TIMER(custom_profile(), "InsertValuesToColumnTime"); - _merge_timer = ADD_TIMER(Base::custom_profile(), "MergeTime"); - _deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime"); - _hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime"); - _hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime"); - _hash_table_input_counter = - ADD_COUNTER_WITH_LEVEL(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT, 1); - _hash_table_memory_usage = - ADD_COUNTER_WITH_LEVEL(Base::custom_profile(), "MemoryUsageHashTable", TUnit::BYTES, 1); - _hash_table_size_counter = - ADD_COUNTER_WITH_LEVEL(Base::custom_profile(), "HashTableSize", TUnit::UNIT, 1); - - _memory_usage_container = ADD_COUNTER(custom_profile(), "MemoryUsageContainer", TUnit::BYTES); - _memory_usage_arena = ADD_COUNTER(custom_profile(), "MemoryUsageArena", TUnit::BYTES); + // Init source-side profile counters on the groupby context if present. + if (_shared_state->groupby_agg_ctx) { + _shared_state->groupby_agg_ctx->init_source_profile(custom_profile()); + } auto& p = _parent->template cast(); if (p._without_key) { if (p._needs_finalize) { _executor.get_result = [this](RuntimeState* state, Block* block, bool* eos) { - return _get_without_key_result(state, block, eos); + return _shared_state->ungroupby_agg_ctx->get_finalized_result( + state, block, eos, + _parent->cast().row_descriptor()); }; } else { _executor.get_result = [this](RuntimeState* state, Block* block, bool* eos) { - return _get_results_without_key(state, block, eos); + return _shared_state->ungroupby_agg_ctx->get_serialized_result(state, block, eos); }; } } else { if (p._needs_finalize) { - _executor.get_result = [this](RuntimeState* state, Block* block, bool* eos) { - return _get_with_serialized_key_result(state, block, eos); + auto columns_with_schema = VectorizedUtils::create_columns_with_type_and_name( + p.row_descriptor()); + _executor.get_result = [this, cols = std::move(columns_with_schema)]( + RuntimeState* state, Block* block, bool* eos) { + return _shared_state->groupby_agg_ctx->get_finalized_results( + state, block, eos, cols); }; } else { _executor.get_result = [this](RuntimeState* state, Block* block, bool* eos) { - return _get_results_with_serialized_key(state, block, eos); + return _shared_state->groupby_agg_ctx->get_serialized_results( + state, block, eos); }; } } @@ -82,463 +76,6 @@ Status AggLocalState::init(RuntimeState* state, LocalStateInfo& info) { return Status::OK(); } -Status AggLocalState::_create_agg_status(AggregateDataPtr data) { - auto& shared_state = *Base::_shared_state; - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - try { - shared_state.aggregate_evaluators[i]->create( - data + shared_state.offsets_of_aggregate_states[i]); - } catch (...) { - for (int j = 0; j < i; ++j) { - shared_state.aggregate_evaluators[j]->destroy( - data + shared_state.offsets_of_aggregate_states[j]); - } - throw; - } - } - return Status::OK(); -} - -Status AggLocalState::_get_results_with_serialized_key(RuntimeState* state, Block* block, - bool* eos) { - SCOPED_TIMER(_get_results_timer); - auto& shared_state = *_shared_state; - size_t key_size = _shared_state->probe_expr_ctxs.size(); - size_t agg_size = _shared_state->aggregate_evaluators.size(); - MutableColumns value_columns(agg_size); - DataTypes value_data_types(agg_size); - - // non-nullable column(id in `_make_nullable_keys`) will be converted to nullable. - bool mem_reuse = shared_state.make_nullable_keys.empty() && block->mem_reuse(); - - MutableColumns key_columns; - for (int i = 0; i < key_size; ++i) { - if (mem_reuse) { - key_columns.emplace_back(std::move(*block->get_by_position(i).column).mutate()); - } else { - key_columns.emplace_back( - shared_state.probe_expr_ctxs[i]->root()->data_type()->create_column()); - } - } - - std::visit( - Overload { - [&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) -> void { - agg_method.init_iterator(); - auto& data = *agg_method.hash_table; - const auto size = std::min(data.size(), size_t(state->batch_size())); - using KeyType = std::decay_t::Key; - std::vector keys(size); - - if (shared_state.use_simple_count) { - DCHECK_EQ(shared_state.aggregate_evaluators.size(), 1); - - value_data_types[0] = shared_state.aggregate_evaluators[0] - ->function() - ->get_serialized_type(); - if (mem_reuse) { - value_columns[0] = - std::move(*block->get_by_position(key_size).column) - .mutate(); - } else { - value_columns[0] = shared_state.aggregate_evaluators[0] - ->function() - ->create_serialize_column(); - } - - auto& count_col = - assert_cast(*value_columns[0]); - uint32_t num_rows = 0; - { - SCOPED_TIMER(_hash_table_iterate_timer); - auto& it = agg_method.begin; - while (it != agg_method.end && num_rows < state->batch_size()) { - keys[num_rows] = it.get_first(); - auto inline_count = - reinterpret_cast(it.get_second()); - count_col.insert_data( - reinterpret_cast(&inline_count), - sizeof(UInt64)); - ++it; - ++num_rows; - } - } - - { - SCOPED_TIMER(_insert_keys_to_column_timer); - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - } - - // Handle null key if present - if (agg_method.begin == agg_method.end) { - if (agg_method.hash_table->has_null_key_data()) { - DCHECK(key_columns.size() == 1); - DCHECK(key_columns[0]->is_nullable()); - if (num_rows < state->batch_size()) { - key_columns[0]->insert_data(nullptr, 0); - auto mapped = - agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - count_col.resize(num_rows + 1); - *reinterpret_cast(count_col.get_data().data() + - num_rows * sizeof(UInt64)) = - std::bit_cast(mapped); - *eos = true; - } - } else { - *eos = true; - } - } - return; - } - - if (shared_state.values.size() < size + 1) { - shared_state.values.resize(size + 1); - } - - uint32_t num_rows = 0; - shared_state.aggregate_data_container->init_once(); - auto& iter = shared_state.aggregate_data_container->iterator; - - { - SCOPED_TIMER(_hash_table_iterate_timer); - while (iter != shared_state.aggregate_data_container->end() && - num_rows < state->batch_size()) { - keys[num_rows] = iter.template get_key(); - shared_state.values[num_rows] = iter.get_aggregate_data(); - ++iter; - ++num_rows; - } - } - - { - SCOPED_TIMER(_insert_keys_to_column_timer); - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - } - - if (iter == shared_state.aggregate_data_container->end()) { - if (agg_method.hash_table->has_null_key_data()) { - // only one key of group by support wrap null key - // here need additional processing logic on the null key / value - DCHECK(key_columns.size() == 1); - DCHECK(key_columns[0]->is_nullable()); - if (agg_method.hash_table->has_null_key_data()) { - key_columns[0]->insert_data(nullptr, 0); - shared_state.values[num_rows] = - agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - ++num_rows; - *eos = true; - } - } else { - *eos = true; - } - } - - { - SCOPED_TIMER(_insert_values_to_column_timer); - for (size_t i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - value_data_types[i] = shared_state.aggregate_evaluators[i] - ->function() - ->get_serialized_type(); - if (mem_reuse) { - value_columns[i] = - std::move(*block->get_by_position(i + key_size).column) - .mutate(); - } else { - value_columns[i] = shared_state.aggregate_evaluators[i] - ->function() - ->create_serialize_column(); - } - shared_state.aggregate_evaluators[i] - ->function() - ->serialize_to_column( - shared_state.values, - shared_state.offsets_of_aggregate_states[i], - value_columns[i], num_rows); - } - } - }}, - shared_state.agg_data->method_variant); - - if (!mem_reuse) { - ColumnsWithTypeAndName columns_with_schema; - for (int i = 0; i < key_size; ++i) { - columns_with_schema.emplace_back(std::move(key_columns[i]), - shared_state.probe_expr_ctxs[i]->root()->data_type(), - shared_state.probe_expr_ctxs[i]->root()->expr_name()); - } - for (int i = 0; i < agg_size; ++i) { - columns_with_schema.emplace_back(std::move(value_columns[i]), value_data_types[i], ""); - } - *block = Block(columns_with_schema); - } - - return Status::OK(); -} - -Status AggLocalState::_get_with_serialized_key_result(RuntimeState* state, Block* block, - bool* eos) { - auto& shared_state = *_shared_state; - // non-nullable column(id in `_make_nullable_keys`) will be converted to nullable. - bool mem_reuse = shared_state.make_nullable_keys.empty() && block->mem_reuse(); - - auto columns_with_schema = VectorizedUtils::create_columns_with_type_and_name( - _parent->cast().row_descriptor()); - size_t key_size = shared_state.probe_expr_ctxs.size(); - - MutableColumns key_columns; - for (int i = 0; i < key_size; ++i) { - if (!mem_reuse) { - key_columns.emplace_back(columns_with_schema[i].type->create_column()); - } else { - key_columns.emplace_back(std::move(*block->get_by_position(i).column).mutate()); - } - } - MutableColumns value_columns; - for (size_t i = key_size; i < columns_with_schema.size(); ++i) { - if (!mem_reuse) { - value_columns.emplace_back(columns_with_schema[i].type->create_column()); - } else { - value_columns.emplace_back(std::move(*block->get_by_position(i).column).mutate()); - } - } - - SCOPED_TIMER(_get_results_timer); - std::visit( - Overload { - [&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) -> void { - auto& data = *agg_method.hash_table; - agg_method.init_iterator(); - const auto size = std::min(data.size(), size_t(state->batch_size())); - using KeyType = std::decay_t::Key; - std::vector keys(size); - - if (shared_state.use_simple_count) { - // Inline count: mapped slot stores UInt64 count directly - // (not a real AggregateDataPtr). Iterate hash table directly. - DCHECK_EQ(value_columns.size(), 1); - auto& count_column = assert_cast(*value_columns[0]); - uint32_t num_rows = 0; - { - SCOPED_TIMER(_hash_table_iterate_timer); - auto& it = agg_method.begin; - while (it != agg_method.end && num_rows < state->batch_size()) { - keys[num_rows] = it.get_first(); - auto& mapped = it.get_second(); - count_column.insert_value(static_cast( - reinterpret_cast(mapped))); - ++it; - ++num_rows; - } - } - { - SCOPED_TIMER(_insert_keys_to_column_timer); - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - } - - // Handle null key if present - if (agg_method.begin == agg_method.end) { - if (agg_method.hash_table->has_null_key_data()) { - DCHECK(key_columns.size() == 1); - DCHECK(key_columns[0]->is_nullable()); - if (key_columns[0]->size() < state->batch_size()) { - key_columns[0]->insert_data(nullptr, 0); - auto mapped = - agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - count_column.insert_value( - static_cast(std::bit_cast(mapped))); - *eos = true; - } - } else { - *eos = true; - } - } - return; - } - - // Normal (non-simple-count) path - if (shared_state.values.size() < size) { - shared_state.values.resize(size); - } - - uint32_t num_rows = 0; - shared_state.aggregate_data_container->init_once(); - auto& iter = shared_state.aggregate_data_container->iterator; - - { - SCOPED_TIMER(_hash_table_iterate_timer); - while (iter != shared_state.aggregate_data_container->end() && - num_rows < state->batch_size()) { - keys[num_rows] = iter.template get_key(); - shared_state.values[num_rows] = iter.get_aggregate_data(); - ++iter; - ++num_rows; - } - } - - { - SCOPED_TIMER(_insert_keys_to_column_timer); - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - } - - for (size_t i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - shared_state.aggregate_evaluators[i]->insert_result_info_vec( - shared_state.values, - shared_state.offsets_of_aggregate_states[i], - value_columns[i].get(), num_rows); - } - - if (iter == shared_state.aggregate_data_container->end()) { - if (agg_method.hash_table->has_null_key_data()) { - // only one key of group by support wrap null key - // here need additional processing logic on the null key / value - DCHECK(key_columns.size() == 1); - DCHECK(key_columns[0]->is_nullable()); - if (key_columns[0]->size() < state->batch_size()) { - key_columns[0]->insert_data(nullptr, 0); - auto mapped = agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - for (size_t i = 0; i < shared_state.aggregate_evaluators.size(); - ++i) - shared_state.aggregate_evaluators[i]->insert_result_info( - mapped + - shared_state.offsets_of_aggregate_states[i], - value_columns[i].get()); - *eos = true; - } - } else { - *eos = true; - } - } - }}, - shared_state.agg_data->method_variant); - - if (!mem_reuse) { - *block = columns_with_schema; - MutableColumns columns(block->columns()); - for (int i = 0; i < block->columns(); ++i) { - if (i < key_size) { - columns[i] = std::move(key_columns[i]); - } else { - columns[i] = std::move(value_columns[i - key_size]); - } - } - block->set_columns(std::move(columns)); - } - - return Status::OK(); -} - -Status AggLocalState::_get_results_without_key(RuntimeState* state, Block* block, bool* eos) { - SCOPED_TIMER(_get_results_timer); - auto& shared_state = *_shared_state; - // 1. `child(0)->rows_returned() == 0` mean not data from child - // in level two aggregation node should return NULL result - // level one aggregation node set `eos = true` return directly - if (UNLIKELY(_shared_state->input_num_rows == 0)) { - *eos = true; - return Status::OK(); - } - block->clear(); - - DCHECK(shared_state.agg_data->without_key != nullptr); - size_t agg_size = shared_state.aggregate_evaluators.size(); - - MutableColumns value_columns(agg_size); - std::vector data_types(agg_size); - // will serialize data to string column - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - data_types[i] = shared_state.aggregate_evaluators[i]->function()->get_serialized_type(); - value_columns[i] = - shared_state.aggregate_evaluators[i]->function()->create_serialize_column(); - } - - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - shared_state.aggregate_evaluators[i]->function()->serialize_without_key_to_column( - shared_state.agg_data->without_key + shared_state.offsets_of_aggregate_states[i], - *value_columns[i]); - } - - { - ColumnsWithTypeAndName data_with_schema; - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - ColumnWithTypeAndName column_with_schema = {nullptr, data_types[i], ""}; - data_with_schema.push_back(std::move(column_with_schema)); - } - *block = Block(data_with_schema); - } - - block->set_columns(std::move(value_columns)); - *eos = true; - return Status::OK(); -} - -Status AggLocalState::_get_without_key_result(RuntimeState* state, Block* block, bool* eos) { - auto& shared_state = *_shared_state; - DCHECK(_shared_state->agg_data->without_key != nullptr); - block->clear(); - - auto& p = _parent->cast(); - *block = VectorizedUtils::create_empty_columnswithtypename(p.row_descriptor()); - size_t agg_size = shared_state.aggregate_evaluators.size(); - - MutableColumns columns(agg_size); - std::vector data_types(agg_size); - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - data_types[i] = shared_state.aggregate_evaluators[i]->function()->get_return_type(); - columns[i] = data_types[i]->create_column(); - } - - for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { - auto column = columns[i].get(); - shared_state.aggregate_evaluators[i]->insert_result_info( - shared_state.agg_data->without_key + shared_state.offsets_of_aggregate_states[i], - column); - } - - const auto& block_schema = block->get_columns_with_type_and_name(); - DCHECK_EQ(block_schema.size(), columns.size()); - for (int i = 0; i < block_schema.size(); ++i) { - const auto column_type = block_schema[i].type; - if (!column_type->equals(*data_types[i])) { - if (column_type->get_primitive_type() != TYPE_ARRAY) { - if (!column_type->is_nullable() || data_types[i]->is_nullable() || - !remove_nullable(column_type)->equals(*data_types[i])) { - return Status::InternalError( - "node id = {}, column_type not match data_types, column_type={}, " - "data_types={}", - _parent->node_id(), column_type->get_name(), data_types[i]->get_name()); - } - } - - // Result of operator is nullable, but aggregate function result is not nullable - // this happens when: - // 1. no group by - // 2. input of aggregate function is empty - // 3. all of input columns are not nullable - if (column_type->is_nullable() && !data_types[i]->is_nullable()) { - ColumnPtr ptr = std::move(columns[i]); - // unless `count`, other aggregate function dispose empty set should be null - // so here check the children row return - ptr = make_nullable(ptr, shared_state.input_num_rows == 0); - columns[i] = ptr->assume_mutable(); - } - } - } - - block->set_columns(std::move(columns)); - *eos = true; - return Status::OK(); -} - AggSourceOperatorX::AggSourceOperatorX(ObjectPool* pool, const TPlanNode& tnode, int operator_id, const DescriptorTbl& descs) : Base(pool, tnode, operator_id, descs), @@ -558,11 +95,26 @@ Status AggSourceOperatorX::get_block(RuntimeState* state, Block* block, bool* eo } void AggLocalState::do_agg_limit(Block* block, bool* eos) { - if (_shared_state->reach_limit) { - if (_shared_state->do_sort_limit && _shared_state->do_limit_filter(block, block->rows())) { - Block::filter_block_internal(block, _shared_state->need_computes); - if (auto rows = block->rows()) { - _num_rows_returned += rows; + auto* ctx = _shared_state->groupby_agg_ctx.get(); + if (!ctx) { + // without key, no limit/sort-limit support + if (auto rows = block->rows()) { + _num_rows_returned += rows; + } + return; + } + if (ctx->reach_limit) { + if (ctx->do_sort_limit) { + const size_t key_size = ctx->groupby_expr_ctxs().size(); + ColumnRawPtrs key_columns(key_size); + for (size_t i = 0; i < key_size; ++i) { + key_columns[i] = block->get_by_position(i).column.get(); + } + if (ctx->do_limit_filter(block->rows(), key_columns)) { + Block::filter_block_internal(block, ctx->need_computes()); + if (auto rows = block->rows()) { + _num_rows_returned += rows; + } } } else { reached_limit(block, eos); @@ -584,43 +136,9 @@ void AggLocalState::make_nullable_output_key(Block* block) { } Status AggLocalState::merge_with_serialized_key_helper(Block* block) { - SCOPED_TIMER(_merge_timer); - SCOPED_PEAK_MEM(&_estimate_memory_usage); - - size_t key_size = Base::_shared_state->probe_expr_ctxs.size(); - ColumnRawPtrs key_columns(key_size); - - for (size_t i = 0; i < key_size; ++i) { - key_columns[i] = block->get_by_position(i).column.get(); - } - - uint32_t rows = (uint32_t)block->rows(); - if (_places.size() < rows) { - _places.resize(rows); - } - - _emplace_into_hash_table(_places.data(), key_columns, rows); - - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - auto col_id = Base::_shared_state->probe_expr_ctxs.size() + i; - auto column = block->get_by_position(col_id).column; - - size_t buffer_size = - Base::_shared_state->aggregate_evaluators[i]->function()->size_of_data() * rows; - if (_deserialize_buffer.size() < buffer_size) { - _deserialize_buffer.resize(buffer_size); - } - - { - SCOPED_TIMER(_deserialize_data_timer); - Base::_shared_state->aggregate_evaluators[i]->function()->deserialize_and_merge_vec( - _places.data(), _shared_state->offsets_of_aggregate_states[i], - _deserialize_buffer.data(), column.get(), Base::_shared_state->agg_arena_pool, - rows); - } - } - - return Status::OK(); + auto* ctx = _shared_state->groupby_agg_ctx.get(); + DCHECK(ctx); + return ctx->merge_with_serialized_key_for_spill(block); } Status AggSourceOperatorX::merge_with_serialized_key_helper(RuntimeState* state, Block* block) { @@ -631,88 +149,29 @@ Status AggSourceOperatorX::merge_with_serialized_key_helper(RuntimeState* state, size_t AggSourceOperatorX::get_estimated_memory_size_for_merging(RuntimeState* state, size_t rows) const { auto& local_state = get_local_state(state); + auto* ctx = local_state._shared_state->groupby_agg_ctx.get(); + DCHECK(ctx); size_t size = std::visit( Overload { - [&](std::monostate& arg) -> size_t { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - return 0; - }, + [&](std::monostate& arg) -> size_t { return 0; }, [&](auto& agg_method) { return agg_method.hash_table->estimate_memory(rows); }}, - local_state._shared_state->agg_data->method_variant); - size += local_state._shared_state->aggregate_data_container->estimate_memory(rows); + ctx->hash_table_data()->method_variant); + size += ctx->agg_data_container()->estimate_memory(rows); return size; } Status AggSourceOperatorX::reset_hash_table(RuntimeState* state) { auto& local_state = get_local_state(state); - auto& ss = *local_state.Base::_shared_state; - RETURN_IF_ERROR(ss.reset_hash_table()); - ss.agg_arena_pool.clear(true); - return Status::OK(); + auto* ctx = local_state._shared_state->groupby_agg_ctx.get(); + DCHECK(ctx); + return ctx->reset_hash_table(); } Status AggSourceOperatorX::get_serialized_block(RuntimeState* state, Block* block, bool* eos) { auto& local_state = get_local_state(state); - // Always use the serialized intermediate output path, regardless of _needs_finalize. - return local_state._get_results_with_serialized_key(state, block, eos); -} - -void AggLocalState::_emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, - uint32_t num_rows) { - std::visit( - Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - auto creator = [this](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin( - key, origin, Base::_shared_state->agg_arena_pool); - auto mapped = - Base::_shared_state->aggregate_data_container->append_data( - origin); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { - mapped = Base::_shared_state->agg_arena_pool.aligned_alloc( - _shared_state->total_size_of_aggregate_states, - _shared_state->align_aggregate_states); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch( - agg_method, state, num_rows, creator, creator_for_null_key, - [&](uint32_t row, auto& mapped) { places[row] = mapped; }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - COUNTER_SET(_hash_table_memory_usage, - static_cast( - agg_method.hash_table->get_buffer_size_in_bytes())); - COUNTER_SET(_hash_table_size_counter, - static_cast(agg_method.hash_table->size())); - COUNTER_SET( - _memory_usage_container, - static_cast( - _shared_state->aggregate_data_container->memory_usage())); - COUNTER_SET( - _memory_usage_arena, - static_cast(Base::_shared_state->agg_arena_pool.size())); - }}, - _shared_state->agg_data->method_variant); + auto* ctx = local_state._shared_state->groupby_agg_ctx.get(); + DCHECK(ctx); + return ctx->get_serialized_results(state, block, eos); } Status AggLocalState::close(RuntimeState* state) { @@ -721,12 +180,6 @@ Status AggLocalState::close(RuntimeState* state) { if (_closed) { return Status::OK(); } - - PODArray tmp_places; - _places.swap(tmp_places); - - std::vector tmp_deserialize_buffer; - _deserialize_buffer.swap(tmp_deserialize_buffer); return Base::close(state); } diff --git a/be/src/exec/operator/aggregation_source_operator.h b/be/src/exec/operator/aggregation_source_operator.h index e6443b2a77e3fd..95e4d7b7799b79 100644 --- a/be/src/exec/operator/aggregation_source_operator.h +++ b/be/src/exec/operator/aggregation_source_operator.h @@ -45,43 +45,6 @@ class AggLocalState MOCK_REMOVE(final) : public PipelineXLocalStaterows() != 0) { - auto& shared_state = *Base ::_shared_state; - for (auto cid : shared_state.make_nullable_keys) { - block->get_by_position(cid).column = - make_nullable(block->get_by_position(cid).column); - block->get_by_position(cid).type = make_nullable(block->get_by_position(cid).type); - } - } - } - - void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, - uint32_t num_rows); - - PODArray _places; - std::vector _deserialize_buffer; - - RuntimeProfile::Counter* _get_results_timer = nullptr; - RuntimeProfile::Counter* _hash_table_iterate_timer = nullptr; - RuntimeProfile::Counter* _insert_keys_to_column_timer = nullptr; - RuntimeProfile::Counter* _insert_values_to_column_timer = nullptr; - - RuntimeProfile::Counter* _hash_table_compute_timer = nullptr; - RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr; - RuntimeProfile::Counter* _hash_table_input_counter = nullptr; - RuntimeProfile::Counter* _hash_table_size_counter = nullptr; - RuntimeProfile::Counter* _hash_table_memory_usage = nullptr; - RuntimeProfile::Counter* _merge_timer = nullptr; - RuntimeProfile::Counter* _deserialize_data_timer = nullptr; - RuntimeProfile::Counter* _memory_usage_container = nullptr; - RuntimeProfile::Counter* _memory_usage_arena = nullptr; - using vectorized_get_result = std::function; @@ -125,10 +88,6 @@ class AggSourceOperatorX : public OperatorX { bool _needs_finalize; bool _without_key; - - // left / full join will change the key nullable make output/input solt - // nullable diff. so we need make nullable of it. - std::vector _make_nullable_keys; }; } // namespace doris diff --git a/be/src/exec/operator/partitioned_aggregation_sink_operator.cpp b/be/src/exec/operator/partitioned_aggregation_sink_operator.cpp index 82d29ca4bb84a2..f857e0fc24dbfc 100644 --- a/be/src/exec/operator/partitioned_aggregation_sink_operator.cpp +++ b/be/src/exec/operator/partitioned_aggregation_sink_operator.cpp @@ -52,11 +52,12 @@ Status PartitionedAggSinkLocalState::init(doris::RuntimeState* state, _spill_writers.resize(parent._partition_count); RETURN_IF_ERROR(_setup_in_memory_agg_op(state)); - for (const auto& probe_expr_ctx : Base::_shared_state->_in_mem_shared_state->probe_expr_ctxs) { + for (const auto& probe_expr_ctx : + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->groupby_expr_ctxs()) { _key_columns.emplace_back(probe_expr_ctx->root()->data_type()->create_column()); } for (const auto& aggregate_evaluator : - Base::_shared_state->_in_mem_shared_state->aggregate_evaluators) { + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_evaluators()) { _value_data_types.emplace_back(aggregate_evaluator->function()->get_serialized_type()); _value_columns.emplace_back(aggregate_evaluator->function()->create_serialize_column()); } @@ -165,8 +166,8 @@ Status PartitionedAggSinkOperatorX::sink(doris::RuntimeState* state, Block* in_b if (local_state._shared_state->_is_spilled) { if (revocable_mem_size(state) >= state->spill_aggregation_sink_mem_limit_bytes()) { RETURN_IF_ERROR(revoke_memory(state)); - DCHECK(local_state._shared_state->_in_mem_shared_state->aggregate_data_container - ->total_count() == 0); + DCHECK(local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx + ->agg_data_container()->total_count() == 0); } } else { auto* sink_local_state = local_state._runtime_state->get_sink_local_state(); @@ -179,8 +180,8 @@ Status PartitionedAggSinkOperatorX::sink(doris::RuntimeState* state, Block* in_b // If there are still memory aggregation data, revoke memory, it is a flush operation. if (_agg_sink_operator->get_hash_table_size(runtime_state) > 0) { RETURN_IF_ERROR(revoke_memory(state)); - DCHECK(local_state._shared_state->_in_mem_shared_state->aggregate_data_container - ->total_count() == 0); + DCHECK(local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx + ->agg_data_container()->total_count() == 0); } // Close all writers (finalizes SpillFile metadata) for (auto& writer : local_state._spill_writers) { @@ -271,13 +272,15 @@ Status PartitionedAggSinkLocalState::_to_block(HashTableCtxType& context, values.emplace_back(null_key_data); } - for (size_t i = 0; i < Base::_shared_state->_in_mem_shared_state->aggregate_evaluators.size(); + for (size_t i = 0; + i < Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_evaluators().size(); ++i) { - Base::_shared_state->_in_mem_shared_state->aggregate_evaluators[i] + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_evaluators()[i] ->function() ->serialize_to_column( values, - Base::_shared_state->_in_mem_shared_state->offsets_of_aggregate_states[i], + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx + ->agg_state_offsets()[i], _value_columns[i], values.size()); } @@ -285,8 +288,14 @@ Status PartitionedAggSinkLocalState::_to_block(HashTableCtxType& context, for (int i = 0; i < _key_columns.size(); ++i) { key_columns_with_schema.emplace_back( std::move(_key_columns[i]), - Base::_shared_state->_in_mem_shared_state->probe_expr_ctxs[i]->root()->data_type(), - Base::_shared_state->_in_mem_shared_state->probe_expr_ctxs[i]->root()->expr_name()); + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx + ->groupby_expr_ctxs()[i] + ->root() + ->data_type(), + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx + ->groupby_expr_ctxs()[i] + ->root() + ->expr_name()); } _key_block = key_columns_with_schema; @@ -294,7 +303,7 @@ Status PartitionedAggSinkLocalState::_to_block(HashTableCtxType& context, for (int i = 0; i < _value_columns.size(); ++i) { value_columns_with_schema.emplace_back( std::move(_value_columns[i]), _value_data_types[i], - Base::_shared_state->_in_mem_shared_state->aggregate_evaluators[i] + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_evaluators()[i] ->function() ->get_name()); } @@ -361,7 +370,7 @@ Status PartitionedAggSinkLocalState::_spill_hash_table(RuntimeState* state, context.init_iterator(); auto& parent = _parent->template cast(); - Base::_shared_state->_in_mem_shared_state->aggregate_data_container->init_once(); + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_data_container()->init_once(); const auto total_rows = parent._agg_sink_operator->get_hash_table_size(_runtime_state.get()); @@ -386,8 +395,12 @@ Status PartitionedAggSinkLocalState::_spill_hash_table(RuntimeState* state, std::vector> spill_infos( parent._partition_count); - auto& iter = Base::_shared_state->_in_mem_shared_state->aggregate_data_container->iterator; - while (iter != Base::_shared_state->_in_mem_shared_state->aggregate_data_container->end() && + auto& iter = + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_data_container() + ->iterator; + while (iter != + Base::_shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_data_container() + ->end() && !state->is_cancelled()) { const auto& key = iter.template get_key(); auto partition_index = hash_table.hash(key) % parent._partition_count; diff --git a/be/src/exec/operator/partitioned_aggregation_source_operator.cpp b/be/src/exec/operator/partitioned_aggregation_source_operator.cpp index bfefcfb9af2051..8d68cdbeed1a4a 100644 --- a/be/src/exec/operator/partitioned_aggregation_source_operator.cpp +++ b/be/src/exec/operator/partitioned_aggregation_source_operator.cpp @@ -197,19 +197,19 @@ size_t PartitionedAggSourceOperatorX::revocable_mem_size(RuntimeState* state) co bytes += block.allocated_bytes(); } if (local_state._shared_state->_in_mem_shared_state != nullptr && - local_state._shared_state->_in_mem_shared_state->agg_data != nullptr) { - auto* agg_data = local_state._shared_state->_in_mem_shared_state->agg_data.get(); + local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx != nullptr) { + auto* agg_data = + local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx + ->hash_table_data(); bytes += std::visit(Overload {[&](std::monostate& arg) -> size_t { return 0; }, [&](auto& agg_method) -> size_t { return agg_method.hash_table->get_buffer_size_in_bytes(); }}, agg_data->method_variant); - if (auto& aggregate_data_container = - local_state._shared_state->_in_mem_shared_state->aggregate_data_container; - aggregate_data_container) { - bytes += aggregate_data_container->memory_usage(); - } + bytes += local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx + ->agg_data_container() + ->memory_usage(); } return bytes > state->spill_min_revocable_mem() ? bytes : 0; } @@ -239,7 +239,8 @@ Status PartitionedAggSourceOperatorX::get_block(RuntimeState* state, Block* bloc // ── Fast path: not spilled ───────────────────────────────────────── if (!local_state._shared_state->_is_spilled) { auto* runtime_state = local_state._runtime_state.get(); - local_state._shared_state->_in_mem_shared_state->aggregate_data_container->init_once(); + local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_data_container() + ->init_once(); status = _agg_source_operator->get_block(runtime_state, block, eos); RETURN_IF_ERROR(status); if (*eos) { @@ -317,7 +318,8 @@ Status PartitionedAggSourceOperatorX::get_block(RuntimeState* state, Block* bloc // Phase 4: All spill files consumed and merged — output aggregated results from hash table. auto* runtime_state = local_state._runtime_state.get(); - local_state._shared_state->_in_mem_shared_state->aggregate_data_container->init_once(); + local_state._shared_state->_in_mem_shared_state->groupby_agg_ctx->agg_data_container() + ->init_once(); bool inner_eos = false; RETURN_IF_ERROR(_agg_source_operator->get_block(runtime_state, block, &inner_eos)); @@ -434,7 +436,7 @@ Status PartitionedAggLocalState::_flush_hash_table_to_sub_spill_files(RuntimeSta // setup_output must have been called by the caller (_flush_and_repartition) // before calling this function. The repartitioner writes to the persistent output writers. - in_mem_state->aggregate_data_container->init_once(); + in_mem_state->groupby_agg_ctx->agg_data_container()->init_once(); bool inner_eos = false; while (!inner_eos && !state->is_cancelled()) { Block block; @@ -479,12 +481,13 @@ Status PartitionedAggLocalState::_flush_and_repartition(RuntimeState* state) { static_cast(p._partition_count), output_spill_files)); auto* in_mem_state = _shared_state->_in_mem_shared_state; - size_t num_keys = in_mem_state->probe_expr_ctxs.size(); + size_t num_keys = in_mem_state->groupby_agg_ctx->groupby_expr_ctxs().size(); std::vector key_column_indices(num_keys); std::vector key_data_types(num_keys); for (size_t i = 0; i < num_keys; ++i) { key_column_indices[i] = i; - key_data_types[i] = in_mem_state->probe_expr_ctxs[i]->root()->data_type(); + key_data_types[i] = + in_mem_state->groupby_agg_ctx->groupby_expr_ctxs()[i]->root()->data_type(); } _repartitioner.init_with_key_columns(std::move(key_column_indices), std::move(key_data_types), diff --git a/be/src/exec/operator/streaming_aggregation_operator.cpp b/be/src/exec/operator/streaming_aggregation_operator.cpp index 7df37a59911103..6cb4b4cf49b10b 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.cpp +++ b/be/src/exec/operator/streaming_aggregation_operator.cpp @@ -24,7 +24,9 @@ #include "common/cast_set.h" #include "common/compiler_util.h" // IWYU pragma: keep -#include "core/column/column_fixed_length_object.h" +#include "exec/common/agg_context_utils.h" +#include "exec/common/groupby_agg_context.h" +#include "exec/common/inline_count_agg_context.h" #include "exec/operator/operator.h" #include "exec/operator/streaming_agg_min_reduction.h" #include "exprs/aggregate/aggregate_function_count.h" @@ -41,39 +43,17 @@ namespace doris { StreamingAggLocalState::StreamingAggLocalState(RuntimeState* state, OperatorXBase* parent) : Base(state, parent), - _agg_data(std::make_unique()), _child_block(Block::create_unique()), _pre_aggregated_block(Block::create_unique()), _is_single_backend(state->get_query_ctx()->is_single_backend_query()) {} +StreamingAggLocalState::~StreamingAggLocalState() = default; + Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) { RETURN_IF_ERROR(Base::init(state, info)); SCOPED_TIMER(Base::exec_time_counter()); SCOPED_TIMER(Base::_init_timer); - _hash_table_memory_usage = - ADD_COUNTER_WITH_LEVEL(Base::custom_profile(), "MemoryUsageHashTable", TUnit::BYTES, 1); - _serialize_key_arena_memory_usage = Base::custom_profile()->AddHighWaterMarkCounter( - "MemoryUsageSerializeKeyArena", TUnit::BYTES, "", 1); - - _build_timer = ADD_TIMER(Base::custom_profile(), "BuildTime"); - _merge_timer = ADD_TIMER(Base::custom_profile(), "MergeTime"); - _expr_timer = ADD_TIMER(Base::custom_profile(), "ExprTime"); - _insert_values_to_column_timer = ADD_TIMER(Base::custom_profile(), "InsertValuesToColumnTime"); - _deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime"); - _hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime"); - _hash_table_limit_compute_timer = - ADD_TIMER(Base::custom_profile(), "HashTableLimitComputeTime"); - _hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime"); - _hash_table_input_counter = - ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT); - _hash_table_size_counter = ADD_COUNTER(custom_profile(), "HashTableSize", TUnit::UNIT); _streaming_agg_timer = ADD_TIMER(custom_profile(), "StreamingAggTime"); - _build_timer = ADD_TIMER(custom_profile(), "BuildTime"); - _expr_timer = ADD_TIMER(Base::custom_profile(), "ExprTime"); - _get_results_timer = ADD_TIMER(custom_profile(), "GetResultsTime"); - _hash_table_iterate_timer = ADD_TIMER(custom_profile(), "HashTableIterateTime"); - _insert_keys_to_column_timer = ADD_TIMER(custom_profile(), "InsertKeysToColumnTime"); - return Status::OK(); } @@ -83,99 +63,62 @@ Status StreamingAggLocalState::open(RuntimeState* state) { RETURN_IF_ERROR(Base::open(state)); auto& p = Base::_parent->template cast(); + + // Clone evaluators and probe expression contexts + std::vector evaluators; for (auto& evaluator : p._aggregate_evaluators) { - _aggregate_evaluators.push_back(evaluator->clone(state, p._pool)); + evaluators.push_back(evaluator->clone(state, p._pool)); } - _probe_expr_ctxs.resize(p._probe_expr_ctxs.size()); - for (size_t i = 0; i < _probe_expr_ctxs.size(); i++) { - RETURN_IF_ERROR(p._probe_expr_ctxs[i]->clone(state, _probe_expr_ctxs[i])); + VExprContextSPtrs probe_expr_ctxs(p._probe_expr_ctxs.size()); + for (size_t i = 0; i < probe_expr_ctxs.size(); i++) { + RETURN_IF_ERROR(p._probe_expr_ctxs[i]->clone(state, probe_expr_ctxs[i])); } - for (auto& evaluator : _aggregate_evaluators) { - evaluator->set_timer(_merge_timer, _expr_timer); - } - - DCHECK(!_probe_expr_ctxs.empty()); - - RETURN_IF_ERROR(_init_hash_method(_probe_expr_ctxs)); + DCHECK(!probe_expr_ctxs.empty()); // Determine whether to use simple count aggregation. - // StreamingAgg only operates in update + serialize mode: input is raw data, output is serialized intermediate state. - // The serialization format of count is UInt64 itself, so it can be inlined into the hash table mapped slot. - if (_aggregate_evaluators.size() == 1 && - _aggregate_evaluators[0]->function()->is_simple_count() && p._sort_limit == -1) { - _use_simple_count = true; + bool use_simple_count = evaluators.size() == 1 && + evaluators[0]->function()->is_simple_count() && p._sort_limit == -1; #ifndef NDEBUG - // Randomly enable/disable in debug mode to verify correctness of multi-phase agg promotion/demotion. - _use_simple_count = rand() % 2 == 0; + // Randomly disable simple count in debug mode to test demotion correctness. + // Only demote (true→false), never promote (false→true). + if (use_simple_count) { + use_simple_count = rand() % 2 == 0; + } #endif + + if (use_simple_count) { + _groupby_agg_ctx = std::make_unique( + std::move(evaluators), std::move(probe_expr_ctxs), + p._offsets_of_aggregate_states, p._total_size_of_aggregate_states, + p._align_aggregate_states, p._is_first_phase); + } else { + _groupby_agg_ctx = std::make_unique( + std::move(evaluators), std::move(probe_expr_ctxs), + p._offsets_of_aggregate_states, p._total_size_of_aggregate_states, + p._align_aggregate_states, p._is_first_phase); } - std::visit( - Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) { - using HashTableType = std::decay_t; - using KeyType = typename HashTableType::Key; - - if (!_use_simple_count) { - /// some aggregate functions (like AVG for decimal) have align issues. - _aggregate_data_container = std::make_unique( - sizeof(KeyType), ((p._total_size_of_aggregate_states + - p._align_aggregate_states - 1) / - p._align_aggregate_states) * - p._align_aggregate_states); - } - }}, - _agg_data->method_variant); - - limit = p._sort_limit; - do_sort_limit = p._do_sort_limit; - null_directions = p._null_directions; - order_directions = p._order_directions; + // Configure sort-limit on context + _groupby_agg_ctx->limit = p._sort_limit; + _groupby_agg_ctx->do_sort_limit = p._do_sort_limit; + _groupby_agg_ctx->order_directions = p._order_directions; + _groupby_agg_ctx->null_directions = p._null_directions; - return Status::OK(); -} + _groupby_agg_ctx->init_hash_method(); + _groupby_agg_ctx->init_agg_data_container(); + _groupby_agg_ctx->init_sink_profile(custom_profile()); + _groupby_agg_ctx->init_source_profile(custom_profile()); -size_t StreamingAggLocalState::_get_hash_table_size() { - return std::visit(Overload {[&](std::monostate& arg) -> size_t { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - return 0; - }, - [&](auto& agg_method) { return agg_method.hash_table->size(); }}, - _agg_data->method_variant); -} + for (auto& evaluator : _groupby_agg_ctx->agg_evaluators()) { + evaluator->set_timer(_groupby_agg_ctx->merge_timer(), _groupby_agg_ctx->expr_timer()); + } -void StreamingAggLocalState::_update_memusage_with_serialized_key() { - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - auto& data = *agg_method.hash_table; - int64_t arena_memory_usage = - _agg_arena_pool.size() + - (_aggregate_data_container - ? _aggregate_data_container->memory_usage() - : 0); - int64_t hash_table_memory_usage = data.get_buffer_size_in_bytes(); - - COUNTER_SET(_memory_used_counter, - arena_memory_usage + hash_table_memory_usage); - - COUNTER_SET(_serialize_key_arena_memory_usage, arena_memory_usage); - COUNTER_SET(_hash_table_memory_usage, hash_table_memory_usage); - }}, - _agg_data->method_variant); + return Status::OK(); } -Status StreamingAggLocalState::_init_hash_method(const VExprContextSPtrs& probe_exprs) { - RETURN_IF_ERROR(init_hash_method( - _agg_data.get(), get_data_types(probe_exprs), - Base::_parent->template cast()._is_first_phase)); - return Status::OK(); +size_t StreamingAggLocalState::_memory_usage() const { + return _groupby_agg_ctx ? _groupby_agg_ctx->memory_usage() : 0; } Status StreamingAggLocalState::do_pre_agg(RuntimeState* state, Block* input_block, @@ -189,7 +132,7 @@ Status StreamingAggLocalState::do_pre_agg(RuntimeState* state, Block* input_bloc // pre stream agg need use _num_row_return to decide whether to do pre stream agg _cur_num_rows_returned += output_block->rows(); make_nullable_output_key(output_block); - _update_memusage_with_serialized_key(); + _groupby_agg_ctx->update_memusage(); return Status::OK(); } @@ -258,27 +201,7 @@ bool StreamingAggLocalState::_should_expand_preagg_hash_tables() { _should_expand_hash_table = current_reduction > min_reduction; return _should_expand_hash_table; }}, - _agg_data->method_variant); -} - -size_t StreamingAggLocalState::_memory_usage() const { - size_t usage = 0; - usage += _agg_arena_pool.size(); - - if (_aggregate_data_container) { - usage += _aggregate_data_container->memory_usage(); - } - - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) { - usage += agg_method.hash_table->get_buffer_size_in_bytes(); - }}, - _agg_data->method_variant); - - return usage; + _groupby_agg_ctx->hash_table_data()->method_variant); } bool StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) { @@ -309,53 +232,43 @@ bool StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) { ret_flag = true; } }}, - _agg_data->method_variant); + _groupby_agg_ctx->hash_table_data()->method_variant); return ret_flag; } Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::Block* in_block, doris::Block* out_block) { - SCOPED_TIMER(_build_timer); - DCHECK(!_probe_expr_ctxs.empty()); + SCOPED_TIMER(_groupby_agg_ctx->build_timer()); + DCHECK(!_groupby_agg_ctx->groupby_expr_ctxs().empty()); auto& p = Base::_parent->template cast(); - size_t key_size = _probe_expr_ctxs.size(); + size_t key_size = _groupby_agg_ctx->groupby_expr_ctxs().size(); ColumnRawPtrs key_columns(key_size); - { - SCOPED_TIMER(_expr_timer); - for (size_t i = 0; i < key_size; ++i) { - int result_column_id = -1; - RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(in_block, &result_column_id)); - in_block->get_by_position(result_column_id).column = - in_block->get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - key_columns[i] = in_block->get_by_position(result_column_id).column.get(); - key_columns[i]->assume_mutable()->replace_float_special_values(); - } - } + RETURN_IF_ERROR(_groupby_agg_ctx->evaluate_groupby_keys(in_block, key_columns)); uint32_t rows = (uint32_t)in_block->rows(); _places.resize(rows); if (_should_not_do_pre_agg(rows)) { - if (limit > 0) { - DCHECK(do_sort_limit); - if (need_do_sort_limit == -1) { - const size_t hash_table_size = _get_hash_table_size(); - need_do_sort_limit = hash_table_size >= limit ? 1 : 0; - if (need_do_sort_limit == 1) { - build_limit_heap(hash_table_size); + if (_groupby_agg_ctx->limit > 0) { + DCHECK(_groupby_agg_ctx->do_sort_limit); + if (_need_do_sort_limit == -1) { + const size_t hash_table_size = _groupby_agg_ctx->hash_table_size(); + _need_do_sort_limit = hash_table_size >= _groupby_agg_ctx->limit ? 1 : 0; + if (_need_do_sort_limit == 1) { + _groupby_agg_ctx->build_limit_heap(hash_table_size); } } - if (need_do_sort_limit == 1) { - if (_do_limit_filter(rows, key_columns)) { + if (_need_do_sort_limit == 1) { + if (_groupby_agg_ctx->do_limit_filter(rows, key_columns)) { + auto& need_computes = _groupby_agg_ctx->need_computes(); bool need_filter = std::find(need_computes.begin(), need_computes.end(), 1) != need_computes.end(); if (need_filter) { - _add_limit_heap_top(key_columns, rows); + _groupby_agg_ctx->add_limit_heap_top(key_columns, rows); Block::filter_block_internal(in_block, need_computes); rows = (uint32_t)in_block->rows(); } else { @@ -366,37 +279,28 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::Block* in_blo } bool mem_reuse = p._make_nullable_keys.empty() && out_block->mem_reuse(); - std::vector data_types; - MutableColumns value_columns; - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - auto data_type = _aggregate_evaluators[i]->function()->get_serialized_type(); - if (mem_reuse) { - value_columns.emplace_back( - std::move(*out_block->get_by_position(i + key_size).column).mutate()); - } else { - value_columns.emplace_back( - _aggregate_evaluators[i]->function()->create_serialize_column()); - } - data_types.emplace_back(data_type); + size_t agg_size = _groupby_agg_ctx->agg_evaluators().size(); + DataTypes data_types(agg_size); + for (size_t i = 0; i < agg_size; ++i) { + data_types[i] = _groupby_agg_ctx->agg_evaluators()[i]->function()->get_serialized_type(); } - - for (int i = 0; i != _aggregate_evaluators.size(); ++i) { - SCOPED_TIMER(_insert_values_to_column_timer); - RETURN_IF_ERROR(_aggregate_evaluators[i]->streaming_agg_serialize_to_column( - in_block, value_columns[i], rows, _agg_arena_pool)); + auto value_columns = agg_context_utils::take_or_create_columns( + out_block, mem_reuse, key_size, agg_size, [&](size_t i) { + return _groupby_agg_ctx->agg_evaluators()[i] + ->function() + ->create_serialize_column(); + }); + + for (int i = 0; i != _groupby_agg_ctx->agg_evaluators().size(); ++i) { + SCOPED_TIMER(_groupby_agg_ctx->insert_values_to_column_timer()); + RETURN_IF_ERROR(_groupby_agg_ctx->agg_evaluators()[i]->streaming_agg_serialize_to_column( + in_block, value_columns[i], rows, _groupby_agg_ctx->agg_arena())); } if (!mem_reuse) { - ColumnsWithTypeAndName columns_with_schema; - for (int i = 0; i < key_size; ++i) { - columns_with_schema.emplace_back(key_columns[i]->clone_resized(rows), - _probe_expr_ctxs[i]->root()->data_type(), - _probe_expr_ctxs[i]->root()->expr_name()); - } - for (int i = 0; i < value_columns.size(); ++i) { - columns_with_schema.emplace_back(std::move(value_columns[i]), data_types[i], ""); - } - out_block->swap(Block(columns_with_schema)); + agg_context_utils::build_serialized_output_block( + out_block, key_columns, rows, + _groupby_agg_ctx->groupby_expr_ctxs(), value_columns, data_types); } else { for (int i = 0; i < key_size; ++i) { std::move(*out_block->get_by_position(i).column) @@ -405,219 +309,25 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::Block* in_blo } } } else { - bool need_agg = true; - if (need_do_sort_limit != 1) { - if (_use_simple_count) { - _emplace_into_hash_table_inline_count(key_columns, rows); - need_agg = false; - } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); - } + if (_need_do_sort_limit != 1) { + RETURN_IF_ERROR(_groupby_agg_ctx->emplace_and_forward( + _places.data(), key_columns, rows, in_block, _should_expand_hash_table)); } else { - need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows); - } - - if (need_agg) { - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( - in_block, p._offsets_of_aggregate_states[i], _places.data(), - _agg_arena_pool, _should_expand_hash_table)); - } - if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size() >= limit) { - need_do_sort_limit = 1; - build_limit_heap(_get_hash_table_size()); - } - } - } - - return Status::OK(); -} - -Status StreamingAggLocalState::_create_agg_status(AggregateDataPtr data) { - auto& p = Base::_parent->template cast(); - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - try { - _aggregate_evaluators[i]->create(data + p._offsets_of_aggregate_states[i]); - } catch (...) { - for (int j = 0; j < i; ++j) { - _aggregate_evaluators[j]->destroy(data + p._offsets_of_aggregate_states[j]); + bool need_agg = _groupby_agg_ctx->emplace_into_hash_table_limit( + _places.data(), in_block, nullptr, key_columns, rows); + if (need_agg) { + for (int i = 0; i < _groupby_agg_ctx->agg_evaluators().size(); ++i) { + RETURN_IF_ERROR(_groupby_agg_ctx->agg_evaluators()[i]->execute_batch_add( + in_block, _groupby_agg_ctx->agg_state_offsets()[i], _places.data(), + _groupby_agg_ctx->agg_arena(), _should_expand_hash_table)); + } } - throw; - } - } - return Status::OK(); -} - -Status StreamingAggLocalState::_get_results_with_serialized_key(RuntimeState* state, Block* block, - bool* eos) { - SCOPED_TIMER(_get_results_timer); - auto& p = _parent->cast(); - const auto key_size = _probe_expr_ctxs.size(); - const auto agg_size = _aggregate_evaluators.size(); - MutableColumns value_columns(agg_size); - DataTypes value_data_types(agg_size); - - // non-nullable column(id in `_make_nullable_keys`) will be converted to nullable. - bool mem_reuse = p._make_nullable_keys.empty() && block->mem_reuse(); - - MutableColumns key_columns; - for (int i = 0; i < key_size; ++i) { - if (mem_reuse) { - key_columns.emplace_back(std::move(*block->get_by_position(i).column).mutate()); - } else { - key_columns.emplace_back(_probe_expr_ctxs[i]->root()->data_type()->create_column()); - } - } - - std::visit( - Overload { - [&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) -> void { - agg_method.init_iterator(); - auto& data = *agg_method.hash_table; - const auto size = std::min(data.size(), size_t(state->batch_size())); - using KeyType = std::decay_t::Key; - std::vector keys(size); - - if (_use_simple_count) { - DCHECK_EQ(_aggregate_evaluators.size(), 1); - - value_data_types[0] = - _aggregate_evaluators[0]->function()->get_serialized_type(); - if (mem_reuse) { - value_columns[0] = - std::move(*block->get_by_position(key_size).column) - .mutate(); - } else { - value_columns[0] = _aggregate_evaluators[0] - ->function() - ->create_serialize_column(); - } - - auto& count_col = - assert_cast(*value_columns[0]); - uint32_t num_rows = 0; - { - SCOPED_TIMER(_hash_table_iterate_timer); - auto& it = agg_method.begin; - while (it != agg_method.end && num_rows < state->batch_size()) { - keys[num_rows] = it.get_first(); - auto inline_count = - reinterpret_cast(it.get_second()); - count_col.insert_data( - reinterpret_cast(&inline_count), - sizeof(UInt64)); - ++it; - ++num_rows; - } - } - - { - SCOPED_TIMER(_insert_keys_to_column_timer); - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - } - - // Handle null key if present - if (agg_method.begin == agg_method.end) { - if (agg_method.hash_table->has_null_key_data()) { - DCHECK(key_columns.size() == 1); - DCHECK(key_columns[0]->is_nullable()); - if (num_rows < state->batch_size()) { - key_columns[0]->insert_data(nullptr, 0); - auto mapped = - agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - count_col.resize(num_rows + 1); - *reinterpret_cast(count_col.get_data().data() + - num_rows * sizeof(UInt64)) = - std::bit_cast(mapped); - *eos = true; - } - } else { - *eos = true; - } - } - return; - } - - if (_values.size() < size + 1) { - _values.resize(size + 1); - } - - uint32_t num_rows = 0; - _aggregate_data_container->init_once(); - auto& iter = _aggregate_data_container->iterator; - - { - SCOPED_TIMER(_hash_table_iterate_timer); - while (iter != _aggregate_data_container->end() && - num_rows < state->batch_size()) { - keys[num_rows] = iter.template get_key(); - _values[num_rows] = iter.get_aggregate_data(); - ++iter; - ++num_rows; - } - } - - { - SCOPED_TIMER(_insert_keys_to_column_timer); - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - } - - if (iter == _aggregate_data_container->end()) { - if (agg_method.hash_table->has_null_key_data()) { - // only one key of group by support wrap null key - // here need additional processing logic on the null key / value - DCHECK(key_columns.size() == 1); - DCHECK(key_columns[0]->is_nullable()); - if (agg_method.hash_table->has_null_key_data()) { - key_columns[0]->insert_data(nullptr, 0); - _values[num_rows] = - agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - ++num_rows; - *eos = true; - } - } else { - *eos = true; - } - } - - { - SCOPED_TIMER(_insert_values_to_column_timer); - for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) { - value_data_types[i] = - _aggregate_evaluators[i]->function()->get_serialized_type(); - if (mem_reuse) { - value_columns[i] = - std::move(*block->get_by_position(i + key_size).column) - .mutate(); - } else { - value_columns[i] = _aggregate_evaluators[i] - ->function() - ->create_serialize_column(); - } - _aggregate_evaluators[i]->function()->serialize_to_column( - _values, p._offsets_of_aggregate_states[i], - value_columns[i], num_rows); - } - } - }}, - _agg_data->method_variant); - - if (!mem_reuse) { - ColumnsWithTypeAndName columns_with_schema; - for (int i = 0; i < key_size; ++i) { - columns_with_schema.emplace_back(std::move(key_columns[i]), - _probe_expr_ctxs[i]->root()->data_type(), - _probe_expr_ctxs[i]->root()->expr_name()); } - for (int i = 0; i < agg_size; ++i) { - columns_with_schema.emplace_back(std::move(value_columns[i]), value_data_types[i], ""); + if (_groupby_agg_ctx->limit > 0 && _need_do_sort_limit == -1 && + _groupby_agg_ctx->hash_table_size() >= _groupby_agg_ctx->limit) { + _need_do_sort_limit = 1; + _groupby_agg_ctx->build_limit_heap(_groupby_agg_ctx->hash_table_size()); } - *block = Block(columns_with_schema); } return Status::OK(); @@ -632,270 +342,6 @@ void StreamingAggLocalState::make_nullable_output_key(Block* block) { } } -void StreamingAggLocalState::_destroy_agg_status(AggregateDataPtr data) { - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - _aggregate_evaluators[i]->function()->destroy( - data + _parent->cast()._offsets_of_aggregate_states[i]); - } -} - -MutableColumns StreamingAggLocalState::_get_keys_hash_table() { - return std::visit( - Overload {[&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - return MutableColumns(); - }, - [&](auto&& agg_method) -> MutableColumns { - MutableColumns key_columns; - for (int i = 0; i < _probe_expr_ctxs.size(); ++i) { - key_columns.emplace_back( - _probe_expr_ctxs[i]->root()->data_type()->create_column()); - } - auto& data = *agg_method.hash_table; - bool has_null_key = data.has_null_key_data(); - const auto size = data.size() - has_null_key; - using KeyType = std::decay_t::Key; - std::vector keys(size); - - uint32_t num_rows = 0; - auto iter = _aggregate_data_container->begin(); - { - while (iter != _aggregate_data_container->end()) { - keys[num_rows] = iter.get_key(); - ++iter; - ++num_rows; - } - } - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - if (has_null_key) { - key_columns[0]->insert_data(nullptr, 0); - } - return key_columns; - }}, - _agg_data->method_variant); -} - -void StreamingAggLocalState::build_limit_heap(size_t hash_table_size) { - limit_columns = _get_keys_hash_table(); - for (size_t i = 0; i < hash_table_size; ++i) { - limit_heap.emplace(i, limit_columns, order_directions, null_directions); - } - while (hash_table_size > limit) { - limit_heap.pop(); - hash_table_size--; - } - limit_columns_min = limit_heap.top()._row_id; -} - -void StreamingAggLocalState::_add_limit_heap_top(ColumnRawPtrs& key_columns, size_t rows) { - for (int i = 0; i < rows; ++i) { - if (cmp_res[i] == 1 && need_computes[i]) { - for (int j = 0; j < key_columns.size(); ++j) { - limit_columns[j]->insert_from(*key_columns[j], i); - } - limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, - null_directions); - limit_heap.pop(); - limit_columns_min = limit_heap.top()._row_id; - break; - } - } -} - -void StreamingAggLocalState::_refresh_limit_heap(size_t i, ColumnRawPtrs& key_columns) { - for (int j = 0; j < key_columns.size(); ++j) { - limit_columns[j]->insert_from(*key_columns[j], i); - } - limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, - null_directions); - limit_heap.pop(); - limit_columns_min = limit_heap.top()._row_id; -} - -bool StreamingAggLocalState::_emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, - ColumnRawPtrs& key_columns, - uint32_t num_rows) { - return std::visit( - Overload {[&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - return true; - }, - [&](auto&& agg_method) -> bool { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - - bool need_filter = _do_limit_filter(num_rows, key_columns); - if (auto need_agg = - std::find(need_computes.begin(), need_computes.end(), 1); - need_agg != need_computes.end()) { - if (need_filter) { - Block::filter_block_internal(block, need_computes); - num_rows = (uint32_t)block->rows(); - } - - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - size_t i = 0; - - auto creator = [&](const auto& ctor, auto& key, auto& origin) { - try { - HashMethodType::try_presis_key_and_origin(key, origin, - _agg_arena_pool); - auto mapped = _aggregate_data_container->append_data(origin); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - ctor(key, mapped); - _refresh_limit_heap(i, key_columns); - } catch (...) { - // Exception-safety - if it can not allocate memory or create status, - // the destructors will not be called. - ctor(key, nullptr); - throw; - } - }; - - auto creator_for_null_key = [&](auto& mapped) { - mapped = _agg_arena_pool.aligned_alloc( - Base::_parent->template cast() - ._total_size_of_aggregate_states, - Base::_parent->template cast() - ._align_aggregate_states); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - _refresh_limit_heap(i, key_columns); - }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch( - agg_method, state, num_rows, creator, creator_for_null_key, - [&](uint32_t row) { i = row; }, - [&](uint32_t row, auto& mapped) { places[row] = mapped; }); - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - return true; - } - return false; - }}, - _agg_data->method_variant); -} - -bool StreamingAggLocalState::_do_limit_filter(size_t num_rows, ColumnRawPtrs& key_columns) { - SCOPED_TIMER(_hash_table_limit_compute_timer); - if (num_rows) { - cmp_res.resize(num_rows); - need_computes.resize(num_rows); - memset(need_computes.data(), 0, need_computes.size()); - memset(cmp_res.data(), 0, cmp_res.size()); - - const auto key_size = null_directions.size(); - for (int i = 0; i < key_size; i++) { - key_columns[i]->compare_internal(limit_columns_min, *limit_columns[i], - null_directions[i], order_directions[i], cmp_res, - need_computes.data()); - } - - auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) { - for (size_t i = 0; i < rows; ++i) { - computes[i] = computes[i] == res[i]; - } - }; - set_computes_arr(cmp_res.data(), need_computes.data(), num_rows); - - return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end(); - } - - return false; -} - -void StreamingAggLocalState::_emplace_into_hash_table(AggregateDataPtr* places, - ColumnRawPtrs& key_columns, - const uint32_t num_rows) { - if (_use_simple_count) { - _emplace_into_hash_table_inline_count(key_columns, num_rows); - return; - } - - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - auto creator = [this](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin(key, origin, - _agg_arena_pool); - auto mapped = _aggregate_data_container->append_data(origin); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { - mapped = _agg_arena_pool.aligned_alloc( - Base::_parent->template cast() - ._total_size_of_aggregate_states, - Base::_parent->template cast() - ._align_aggregate_states); - auto st = _create_agg_status(mapped); - if (!st) { - throw Exception(st.code(), st.to_string()); - } - }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch( - agg_method, state, num_rows, creator, creator_for_null_key, - [&](uint32_t row, auto& mapped) { places[row] = mapped; }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - }}, - _agg_data->method_variant); -} - -void StreamingAggLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, - uint32_t num_rows) { - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - auto creator = [&](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin(key, origin, - _agg_arena_pool); - AggregateDataPtr mapped = nullptr; - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch(agg_method, state, num_rows, creator, - creator_for_null_key, [&](uint32_t, auto& mapped) { - ++reinterpret_cast(mapped); - }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - }}, - _agg_data->method_variant); -} - StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id, const TPlanNode& tnode, const DescriptorTbl& descs) : StatefulOperatorX(pool, tnode, operator_id, descs), @@ -1074,22 +520,10 @@ Status StreamingAggLocalState::close(RuntimeState* state) { PODArray tmp_places; _places.swap(tmp_places); - std::vector tmp_deserialize_buffer; - _deserialize_buffer.swap(tmp_deserialize_buffer); - - /// _hash_table_size_counter may be null if prepare failed. - if (_hash_table_size_counter) { - std::visit(Overload {[&](std::monostate& arg) -> void { - // Do nothing - }, - [&](auto& agg_method) { - COUNTER_SET(_hash_table_size_counter, - int64_t(agg_method.hash_table->size())); - }}, - _agg_data->method_variant); + if (_groupby_agg_ctx) { + _groupby_agg_ctx->close(); + _groupby_agg_ctx.reset(); } - _close_with_serialized_key(); - _agg_arena_pool.clear(true); return Base::close(state); } @@ -1099,7 +533,7 @@ Status StreamingAggOperatorX::pull(RuntimeState* state, Block* block, bool* eos) if (!local_state._pre_aggregated_block->empty()) { local_state._pre_aggregated_block->swap(*block); } else { - RETURN_IF_ERROR(local_state._get_results_with_serialized_key(state, block, eos)); + RETURN_IF_ERROR(local_state._groupby_agg_ctx->get_serialized_results(state, block, eos)); local_state.make_nullable_output_key(block); // dispose the having clause, should not be execute in prestreaming agg RETURN_IF_ERROR(local_state.filter_block(local_state._conjuncts, block)); diff --git a/be/src/exec/operator/streaming_aggregation_operator.h b/be/src/exec/operator/streaming_aggregation_operator.h index cf1100f8dc126c..3cbd9692549a28 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.h +++ b/be/src/exec/operator/streaming_aggregation_operator.h @@ -31,6 +31,7 @@ namespace doris { class RuntimeState; class StreamingAggOperatorX; +class GroupByAggContext; class StreamingAggLocalState MOCK_REMOVE(final) : public PipelineXLocalState { public: @@ -38,14 +39,13 @@ class StreamingAggLocalState MOCK_REMOVE(final) : public PipelineXLocalState; ENABLE_FACTORY_CREATOR(StreamingAggLocalState); StreamingAggLocalState(RuntimeState* state, OperatorXBase* parent); - ~StreamingAggLocalState() override = default; + ~StreamingAggLocalState() override; Status init(RuntimeState* state, LocalStateInfo& info) override; Status open(RuntimeState* state) override; Status close(RuntimeState* state) override; Status do_pre_agg(RuntimeState* state, Block* input_block, Block* output_block); void make_nullable_output_key(Block* block); - void build_limit_heap(size_t hash_table_size); private: friend class StreamingAggOperatorX; @@ -53,152 +53,28 @@ class StreamingAggLocalState MOCK_REMOVE(final) : public PipelineXLocalState _aggregate_evaluators; - // group by k1,k2 - VExprContextSPtrs _probe_expr_ctxs; - std::unique_ptr _aggregate_data_container = nullptr; - bool _use_simple_count = false; - bool _reach_limit = false; size_t _input_num_rows = 0; - int64_t limit = -1; - int need_do_sort_limit = -1; - bool do_sort_limit = false; - MutableColumns limit_columns; - int limit_columns_min = -1; - PaddedPODArray need_computes; - std::vector cmp_res; - std::vector order_directions; - std::vector null_directions; - - struct HeapLimitCursor { - HeapLimitCursor(int row_id, MutableColumns& limit_columns, - std::vector& order_directions, std::vector& null_directions) - : _row_id(row_id), - _limit_columns(limit_columns), - _order_directions(order_directions), - _null_directions(null_directions) {} - - HeapLimitCursor(const HeapLimitCursor& other) = default; - - HeapLimitCursor(HeapLimitCursor&& other) noexcept - : _row_id(other._row_id), - _limit_columns(other._limit_columns), - _order_directions(other._order_directions), - _null_directions(other._null_directions) {} - - HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { - _row_id = other._row_id; - return *this; - } - - HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { - _row_id = other._row_id; - return *this; - } - - bool operator<(const HeapLimitCursor& rhs) const { - for (int i = 0; i < _limit_columns.size(); ++i) { - const auto& _limit_column = _limit_columns[i]; - auto res = _limit_column->compare_at(_row_id, rhs._row_id, *_limit_column, - _null_directions[i]) * - _order_directions[i]; - if (res < 0) { - return true; - } else if (res > 0) { - return false; - } - } - return false; - } - - int _row_id; - MutableColumns& _limit_columns; - std::vector& _order_directions; - std::vector& _null_directions; - }; - - std::priority_queue limit_heap; - - MutableColumns _get_keys_hash_table(); - + std::unique_ptr _groupby_agg_ctx; PODArray _places; - std::vector _deserialize_buffer; + + // Sort limit: tracks whether sort limit filtering has been activated. + // -1 = not yet determined, 0 = no, 1 = yes + int _need_do_sort_limit = -1; std::unique_ptr _child_block = nullptr; bool _child_eos = false; std::unique_ptr _pre_aggregated_block = nullptr; - std::vector _values; - bool _opened = false; - - void _destroy_agg_status(AggregateDataPtr data); - - void _close_with_serialized_key() { - std::visit(Overload {[&](std::monostate& arg) -> void { - // Do nothing - }, - [&](auto& agg_method) -> void { - if (_use_simple_count) { - // Inline count: mapped slots hold UInt64, - // not real agg state pointers. Skip destroy. - return; - } - auto& data = *agg_method.hash_table; - data.for_each_mapped([&](auto& mapped) { - if (mapped) { - _destroy_agg_status(mapped); - mapped = nullptr; - } - }); - if (data.has_null_key_data()) { - _destroy_agg_status( - data.template get_null_key_data()); - } - }}, - _agg_data->method_variant); - } bool _is_single_backend = false; }; diff --git a/be/src/exec/pipeline/dependency.cpp b/be/src/exec/pipeline/dependency.cpp index d4372de341ff77..ce335762cb5238 100644 --- a/be/src/exec/pipeline/dependency.cpp +++ b/be/src/exec/pipeline/dependency.cpp @@ -25,8 +25,6 @@ #include "exec/pipeline/pipeline_fragment_context.h" #include "exec/pipeline/pipeline_task.h" #include "exec/spill/spill_file_manager.h" -#include "exprs/vectorized_agg_fn.h" -#include "exprs/vslot_ref.h" #include "runtime/exec_env.h" namespace doris { @@ -198,121 +196,6 @@ LocalExchangeSharedState::LocalExchangeSharedState(int num_instances) { mem_counters.resize(num_instances, nullptr); } -MutableColumns AggSharedState::_get_keys_hash_table() { - return std::visit( - Overload {[&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - return MutableColumns(); - }, - [&](auto&& agg_method) -> MutableColumns { - MutableColumns key_columns; - for (int i = 0; i < probe_expr_ctxs.size(); ++i) { - key_columns.emplace_back( - probe_expr_ctxs[i]->root()->data_type()->create_column()); - } - auto& data = *agg_method.hash_table; - bool has_null_key = data.has_null_key_data(); - const auto size = data.size() - has_null_key; - using KeyType = std::decay_t::Key; - std::vector keys(size); - - uint32_t num_rows = 0; - auto iter = aggregate_data_container->begin(); - { - while (iter != aggregate_data_container->end()) { - keys[num_rows] = iter.get_key(); - ++iter; - ++num_rows; - } - } - agg_method.insert_keys_into_columns(keys, key_columns, num_rows); - if (has_null_key) { - key_columns[0]->insert_data(nullptr, 0); - } - return key_columns; - }}, - agg_data->method_variant); -} - -void AggSharedState::build_limit_heap(size_t hash_table_size) { - limit_columns = _get_keys_hash_table(); - for (size_t i = 0; i < hash_table_size; ++i) { - limit_heap.emplace(i, limit_columns, order_directions, null_directions); - } - while (hash_table_size > limit) { - limit_heap.pop(); - hash_table_size--; - } - limit_columns_min = limit_heap.top()._row_id; -} - -bool AggSharedState::do_limit_filter(Block* block, size_t num_rows, - const std::vector* key_locs) { - if (num_rows) { - cmp_res.resize(num_rows); - need_computes.resize(num_rows); - memset(need_computes.data(), 0, need_computes.size()); - memset(cmp_res.data(), 0, cmp_res.size()); - - const auto key_size = null_directions.size(); - for (int i = 0; i < key_size; i++) { - block->get_by_position(key_locs ? key_locs->operator[](i) : i) - .column->compare_internal(limit_columns_min, *limit_columns[i], - null_directions[i], order_directions[i], cmp_res, - need_computes.data()); - } - - auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) { - for (size_t i = 0; i < rows; ++i) { - computes[i] = computes[i] == res[i]; - } - }; - set_computes_arr(cmp_res.data(), need_computes.data(), num_rows); - - return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end(); - } - - return false; -} - -Status AggSharedState::reset_hash_table() { - return std::visit( - Overload { - [&](std::monostate& arg) -> Status { - return Status::InternalError("Uninited hash table"); - }, - [&](auto& agg_method) { - auto& hash_table = *agg_method.hash_table; - using HashTableType = std::decay_t; - - agg_method.arena.clear(); - agg_method.inited_iterator = false; - - if (!use_simple_count) { - hash_table.for_each_mapped([&](auto& mapped) { - if (mapped) { - _destroy_agg_status(mapped); - mapped = nullptr; - } - }); - - if (hash_table.has_null_key_data()) { - _destroy_agg_status( - hash_table.template get_null_key_data()); - } - - aggregate_data_container.reset(new AggregateDataContainer( - sizeof(typename HashTableType::key_type), - ((total_size_of_aggregate_states + align_aggregate_states - 1) / - align_aggregate_states) * - align_aggregate_states)); - } - agg_method.hash_table.reset(new HashTableType()); - return Status::OK(); - }}, - agg_data->method_variant); -} - void PartitionedAggSharedState::close() { for (auto& partition : _spill_partitions) { if (partition) { @@ -337,20 +220,6 @@ MultiCastSharedState::MultiCastSharedState(ObjectPool* pool, int cast_sender_cou : multi_cast_data_streamer( std::make_unique(pool, cast_sender_count, node_id)) {} -int AggSharedState::get_slot_column_id(const AggFnEvaluator* evaluator) { - auto ctxs = evaluator->input_exprs_ctxs(); - CHECK(ctxs.size() == 1 && ctxs[0]->root()->is_slot_ref()) - << "input_exprs_ctxs is invalid, input_exprs_ctx[0]=" - << ctxs[0]->root()->debug_string(); - return ((VSlotRef*)ctxs[0]->root().get())->column_id(); -} - -void AggSharedState::_destroy_agg_status(AggregateDataPtr data) { - for (int i = 0; i < aggregate_evaluators.size(); ++i) { - aggregate_evaluators[i]->function()->destroy(data + offsets_of_aggregate_states[i]); - } -} - LocalExchangeSharedState::~LocalExchangeSharedState() = default; Status SetSharedState::update_build_not_ignore_null(const VExprContextSPtrs& ctxs) { @@ -391,15 +260,4 @@ Status SetSharedState::hash_table_init() { return init_hash_method(hash_table_variants.get(), data_types, true); } -void AggSharedState::refresh_top_limit(size_t row_id, const ColumnRawPtrs& key_columns) { - for (int j = 0; j < key_columns.size(); ++j) { - limit_columns[j]->insert_from(*key_columns[j], row_id); - } - limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, - null_directions); - - limit_heap.pop(); - limit_columns_min = limit_heap.top()._row_id; -} - } // namespace doris diff --git a/be/src/exec/pipeline/dependency.h b/be/src/exec/pipeline/dependency.h index 2f82d4f18e2fe0..210e98625dd761 100644 --- a/be/src/exec/pipeline/dependency.h +++ b/be/src/exec/pipeline/dependency.h @@ -38,7 +38,9 @@ #include "core/block/block.h" #include "core/types.h" #include "exec/common/agg_utils.h" +#include "exec/common/groupby_agg_context.h" #include "exec/common/join_utils.h" +#include "exec/common/ungroupby_agg_context.h" #include "exec/common/set_utils.h" #include "exec/operator/data_queue.h" #include "exec/operator/join/process_hash_table_probe.h" @@ -284,146 +286,34 @@ struct RuntimeFilterTimerQueue { struct AggSharedState : public BasicSharedState { ENABLE_FACTORY_CREATOR(AggSharedState) public: - AggSharedState() { agg_data = std::make_unique(); } + AggSharedState() = default; ~AggSharedState() override { - if (!probe_expr_ctxs.empty()) { - _close_with_serialized_key(); - } else { - _close_without_key(); + // Explicitly close contexts before destruction. close() is virtual and must be + // called while the derived object (e.g. InlineCountAggContext) is still alive, + // not from the base class destructor where vtable has already reverted. + // close() is idempotent: GroupByAggContext::close sets mapped=nullptr after destroy; + // UngroupByAggContext::close has _agg_state_created guard. + if (groupby_agg_ctx) { + groupby_agg_ctx->close(); + groupby_agg_ctx.reset(); + } + if (ungroupby_agg_ctx) { + ungroupby_agg_ctx->close(); + ungroupby_agg_ctx.reset(); } } - Status reset_hash_table(); - - bool do_limit_filter(Block* block, size_t num_rows, const std::vector* key_locs = nullptr); - void build_limit_heap(size_t hash_table_size); - - // We should call this function only at 1st phase. - // 1st phase: is_merge=true, only have one SlotRef. - // 2nd phase: is_merge=false, maybe have multiple exprs. - static int get_slot_column_id(const AggFnEvaluator* evaluator); - - AggregatedDataVariantsUPtr agg_data = nullptr; - std::unique_ptr aggregate_data_container; - std::vector aggregate_evaluators; - // group by k1,k2 - VExprContextSPtrs probe_expr_ctxs; - size_t input_num_rows = 0; - std::vector values; - /// The total size of the row from the aggregate functions. - size_t total_size_of_aggregate_states = 0; - size_t align_aggregate_states = 1; - /// The offset to the n-th aggregate function in a row of aggregate functions. - Sizes offsets_of_aggregate_states; + // Exactly one of these is non-null at runtime: + // groupby_agg_ctx — created when the query has GROUP BY + // ungroupby_agg_ctx — created when the query has no GROUP BY + std::unique_ptr groupby_agg_ctx; + std::unique_ptr ungroupby_agg_ctx; + + // Kept in AggSharedState (used by Source operators for output key conversion). std::vector make_nullable_keys; - bool agg_data_created_without_key = false; + // Spill support (set by Sink operator during open). bool enable_spill = false; - bool reach_limit = false; - - bool use_simple_count = false; - int64_t limit = -1; - bool do_sort_limit = false; - MutableColumns limit_columns; - int limit_columns_min = -1; - PaddedPODArray need_computes; - std::vector cmp_res; - std::vector order_directions; - std::vector null_directions; - - struct HeapLimitCursor { - HeapLimitCursor(int row_id, MutableColumns& limit_columns, - std::vector& order_directions, std::vector& null_directions) - : _row_id(row_id), - _limit_columns(limit_columns), - _order_directions(order_directions), - _null_directions(null_directions) {} - - HeapLimitCursor(const HeapLimitCursor& other) = default; - - HeapLimitCursor(HeapLimitCursor&& other) noexcept - : _row_id(other._row_id), - _limit_columns(other._limit_columns), - _order_directions(other._order_directions), - _null_directions(other._null_directions) {} - - HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { - _row_id = other._row_id; - return *this; - } - - HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { - _row_id = other._row_id; - return *this; - } - - bool operator<(const HeapLimitCursor& rhs) const { - for (int i = 0; i < _limit_columns.size(); ++i) { - const auto& _limit_column = _limit_columns[i]; - auto res = _limit_column->compare_at(_row_id, rhs._row_id, *_limit_column, - _null_directions[i]) * - _order_directions[i]; - if (res < 0) { - return true; - } else if (res > 0) { - return false; - } - } - return false; - } - - int _row_id; - MutableColumns& _limit_columns; - std::vector& _order_directions; - std::vector& _null_directions; - }; - - std::priority_queue limit_heap; - - // Refresh the top limit heap with a new row - void refresh_top_limit(size_t row_id, const ColumnRawPtrs& key_columns); - - Arena agg_arena_pool; - Arena agg_profile_arena; - -private: - MutableColumns _get_keys_hash_table(); - - void _close_with_serialized_key() { - std::visit(Overload {[&](std::monostate& arg) -> void { - // Do nothing - }, - [&](auto& agg_method) -> void { - if (use_simple_count) { - // Inline count: mapped slots hold UInt64, - // not real agg state pointers. Skip destroy. - return; - } - auto& data = *agg_method.hash_table; - data.for_each_mapped([&](auto& mapped) { - if (mapped) { - _destroy_agg_status(mapped); - mapped = nullptr; - } - }); - if (data.has_null_key_data()) { - _destroy_agg_status( - data.template get_null_key_data()); - } - }}, - agg_data->method_variant); - } - - void _close_without_key() { - //because prepare maybe failed, and couldn't create agg data. - //but finally call close to destory agg data, if agg data has bitmapValue - //will be core dump, it's not initialized - if (agg_data_created_without_key) { - _destroy_agg_status(agg_data->without_key); - agg_data_created_without_key = false; - } - } - void _destroy_agg_status(AggregateDataPtr data); }; struct PartitionedAggSharedState : public BasicSharedState, From 92d9f6605086b386bf238e552e62fa7d52e2fcbe Mon Sep 17 00:00:00 2001 From: Mryange Date: Thu, 26 Mar 2026 18:39:34 +0800 Subject: [PATCH 2/4] upd --- be/src/exec/common/groupby_agg_context.cpp | 14 ++++++-------- .../operator/streaming_aggregation_operator.cpp | 3 --- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/be/src/exec/common/groupby_agg_context.cpp b/be/src/exec/common/groupby_agg_context.cpp index 4975cc6cc9a69b..c1eab047138e08 100644 --- a/be/src/exec/common/groupby_agg_context.cpp +++ b/be/src/exec/common/groupby_agg_context.cpp @@ -729,14 +729,12 @@ Status GroupByAggContext::get_serialized_results(RuntimeState* state, Block* blo if (agg_method.hash_table->has_null_key_data()) { DCHECK(key_columns.size() == 1); DCHECK(key_columns[0]->is_nullable()); - if (agg_method.hash_table->has_null_key_data()) { - key_columns[0]->insert_data(nullptr, 0); - _values[num_rows] = - agg_method.hash_table->template get_null_key_data< - AggregateDataPtr>(); - ++num_rows; - *eos = true; - } + key_columns[0]->insert_data(nullptr, 0); + _values[num_rows] = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + ++num_rows; + *eos = true; } else { *eos = true; } diff --git a/be/src/exec/operator/streaming_aggregation_operator.cpp b/be/src/exec/operator/streaming_aggregation_operator.cpp index 6cb4b4cf49b10b..9675d7d6534f86 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.cpp +++ b/be/src/exec/operator/streaming_aggregation_operator.cpp @@ -513,9 +513,6 @@ Status StreamingAggLocalState::close(RuntimeState* state) { } SCOPED_TIMER(Base::exec_time_counter()); SCOPED_TIMER(Base::_close_timer); - if (Base::_closed) { - return Status::OK(); - } _pre_aggregated_block->clear(); PODArray tmp_places; _places.swap(tmp_places); From bcdabdbdb143fe9e8f4ec2fd1ae9c103ad4c1a7a Mon Sep 17 00:00:00 2001 From: Mryange Date: Fri, 27 Mar 2026 11:28:26 +0800 Subject: [PATCH 3/4] beut --- be/test/exec/operator/agg_operator_test.cpp | 13 ++-- .../exec/operator/agg_shared_state_test.cpp | 73 +++++++++++-------- ...itioned_aggregation_sink_operator_test.cpp | 36 ++++----- ...ioned_aggregation_source_operator_test.cpp | 22 +++--- .../operator/streaming_agg_operator_test.cpp | 16 ++-- .../partitioned_agg_shared_state_test.cpp | 37 ++++++---- 6 files changed, 109 insertions(+), 88 deletions(-) diff --git a/be/test/exec/operator/agg_operator_test.cpp b/be/test/exec/operator/agg_operator_test.cpp index 945fd0f9f1fc81..eed6e829e65dce 100644 --- a/be/test/exec/operator/agg_operator_test.cpp +++ b/be/test/exec/operator/agg_operator_test.cpp @@ -582,17 +582,18 @@ TEST_F(AggOperatorTestWithGroupBy, other_case_2) { { Block block = ColumnHelper::create_nullable_block( {1, 2, 3, 1, 2, 3}, {false, false, false, true, true, true}); - auto* local_state = - static_cast(ctx.state.get_sink_local_state()); ColumnRawPtrs key_columns; key_columns.push_back(block.get_by_position(0).column.get()); - local_state->_places.resize(block.rows()); - local_state->_emplace_into_hash_table(local_state->_places.data(), key_columns, - block.rows()); + auto* ctx = + static_cast(shared_state.get())->groupby_agg_ctx.get(); + std::vector places(block.rows()); + EXPECT_TRUE( + ctx->emplace_and_forward(places.data(), key_columns, block.rows(), &block, true) + .ok()); - EXPECT_EQ(local_state->get_hash_table_size(), 4); // [1,2,3,null] + EXPECT_EQ(ctx->hash_table_size(), 4); // [1,2,3,null] } } diff --git a/be/test/exec/operator/agg_shared_state_test.cpp b/be/test/exec/operator/agg_shared_state_test.cpp index aa9d1597539bc5..fda5276931a103 100644 --- a/be/test/exec/operator/agg_shared_state_test.cpp +++ b/be/test/exec/operator/agg_shared_state_test.cpp @@ -19,28 +19,41 @@ #include "core/column/column_vector.h" #include "core/data_type/data_type_number.h" -#include "exec/pipeline/dependency.h" +#include "exec/common/groupby_agg_context.h" namespace doris { -class AggSharedStateTest : public testing::Test { +// Expose protected sort-limit state for unit testing. +class TestableGroupByAggContext : public GroupByAggContext { +public: + TestableGroupByAggContext() + : GroupByAggContext(/*agg_evaluators=*/{}, /*groupby_expr_ctxs=*/{}, + /*agg_state_offsets=*/{}, /*total_agg_state_size=*/0, + /*agg_state_alignment=*/1, /*is_first_phase=*/true) {} + + MutableColumns& limit_columns() { return _limit_columns; } + int& limit_columns_min_ref() { return _limit_columns_min; } + std::priority_queue& limit_heap_ref() { return _limit_heap; } +}; + +class GroupByAggContextLimitTest : public testing::Test { protected: void SetUp() override { - _shared_state = std::make_shared(); + _ctx = std::make_shared(); // Setup test data auto int_type = std::make_shared(); - _shared_state->limit_columns.push_back(int_type->create_column()); + _ctx->limit_columns().push_back(int_type->create_column()); // Setup order directions (ascending) - _shared_state->order_directions = {1}; - _shared_state->null_directions = {1}; + _ctx->order_directions = {1}; + _ctx->null_directions = {1}; // Create test column _test_column = int_type->create_column(); auto* col_data = reinterpret_cast(_test_column.get()); - // Insert test values: 5, 3, 1, -2, -1, 0 + // Insert test values: 5, 3, 1, -1, 0, 2 col_data->insert(Field::create_field(5)); col_data->insert(Field::create_field(3)); col_data->insert(Field::create_field(1)); @@ -49,47 +62,47 @@ class AggSharedStateTest : public testing::Test { col_data->insert(Field::create_field(2)); _key_columns.push_back(_test_column.get()); - // prepare the heap data first [5, 3, 1, -2] + // prepare the heap data first [5, 3, 1, -1] for (int i = 0; i < 4; ++i) { - for (int j = 0; j < _key_columns.size(); ++j) { - _shared_state->limit_columns[j]->insert_from(*_key_columns[j], i); + for (size_t j = 0; j < _key_columns.size(); ++j) { + _ctx->limit_columns()[j]->insert_from(*_key_columns[j], i); } // build agg limit heap - _shared_state->limit_heap.emplace( - _shared_state->limit_columns[0]->size() - 1, _shared_state->limit_columns, - _shared_state->order_directions, _shared_state->null_directions); + _ctx->limit_heap_ref().emplace(_ctx->limit_columns()[0]->size() - 1, + _ctx->limit_columns(), _ctx->order_directions, + _ctx->null_directions); } - // keep the top limit values, only 3 value in heap [-1, 3, 1] - _shared_state->limit_heap.pop(); - _shared_state->limit_columns_min = _shared_state->limit_heap.top()._row_id; + // keep the top limit values, only 3 values in heap [-1, 3, 1] + _ctx->limit_heap_ref().pop(); + _ctx->limit_columns_min_ref() = _ctx->limit_heap_ref().top()._row_id; } - std::shared_ptr _shared_state; + std::shared_ptr _ctx; MutableColumnPtr _test_column; ColumnRawPtrs _key_columns; }; -TEST_F(AggSharedStateTest, TestRefreshTopLimit) { +TEST_F(GroupByAggContextLimitTest, TestRefreshTopLimit) { // Test with limit = 3 (keep top 3 values) - _shared_state->limit = 3; + _ctx->limit = 3; // Add values one by one and verify the minimum value is tracked correctly - EXPECT_EQ(_shared_state->limit_columns_min, 1); + EXPECT_EQ(_ctx->limit_columns_min_ref(), 1); - _shared_state->refresh_top_limit(4, _key_columns); - EXPECT_EQ(_shared_state->limit_columns_min, 2); + _ctx->refresh_top_limit(4, _key_columns); + EXPECT_EQ(_ctx->limit_columns_min_ref(), 2); - _shared_state->refresh_top_limit(5, _key_columns); - EXPECT_EQ(_shared_state->limit_columns_min, 2); // 1 should still be max + _ctx->refresh_top_limit(5, _key_columns); + EXPECT_EQ(_ctx->limit_columns_min_ref(), 2); // 1 should still be max - auto heap_size = _shared_state->limit_heap.size(); + auto heap_size = _ctx->limit_heap_ref().size(); EXPECT_EQ(heap_size, 3); - EXPECT_EQ(_shared_state->limit_heap.top()._row_id, 2); // 1 should be the top value - _shared_state->limit_heap.pop(); - EXPECT_EQ(_shared_state->limit_heap.top()._row_id, 4); // 0 should be the top value - _shared_state->limit_heap.pop(); - EXPECT_EQ(_shared_state->limit_heap.top()._row_id, 3); // -1 should be the top value + EXPECT_EQ(_ctx->limit_heap_ref().top()._row_id, 2); // 1 should be the top value + _ctx->limit_heap_ref().pop(); + EXPECT_EQ(_ctx->limit_heap_ref().top()._row_id, 4); // 0 should be the top value + _ctx->limit_heap_ref().pop(); + EXPECT_EQ(_ctx->limit_heap_ref().top()._row_id, 3); // -1 should be the top value } } // namespace doris diff --git a/be/test/exec/operator/partitioned_aggregation_sink_operator_test.cpp b/be/test/exec/operator/partitioned_aggregation_sink_operator_test.cpp index 42cc320f1832eb..75d66ed47361b1 100644 --- a/be/test/exec/operator/partitioned_aggregation_sink_operator_test.cpp +++ b/be/test/exec/operator/partitioned_aggregation_sink_operator_test.cpp @@ -224,12 +224,12 @@ TEST_F(PartitionedAggregationSinkOperatorTest, SinkWithSpill) { auto* inner_sink_local_state = reinterpret_cast( local_state->_runtime_state->get_sink_local_state()); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke_memory failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->sink(_helper.runtime_state.get(), &block, true); ASSERT_TRUE(st.ok()) << "sink failed: " << st.to_string(); @@ -283,12 +283,12 @@ TEST_F(PartitionedAggregationSinkOperatorTest, SinkWithSpillAndEmptyEOS) { auto* inner_sink_local_state = reinterpret_cast( local_state->_runtime_state->get_sink_local_state()); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke_memory failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); block.clear_column_data(); st = sink_operator->sink(_helper.runtime_state.get(), &block, true); @@ -342,7 +342,7 @@ TEST_F(PartitionedAggregationSinkOperatorTest, SinkWithSpillLargeData) { auto* inner_sink_local_state = reinterpret_cast( local_state->_runtime_state->get_sink_local_state()); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke_memory failed: " << st.to_string(); @@ -351,7 +351,7 @@ TEST_F(PartitionedAggregationSinkOperatorTest, SinkWithSpillLargeData) { ASSERT_TRUE(spill_write_rows_counter != nullptr); ASSERT_EQ(spill_write_rows_counter->value(), 4); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); const size_t count = 1048576; std::vector data(count); @@ -412,7 +412,7 @@ TEST_F(PartitionedAggregationSinkOperatorTest, SinkWithSpilError) { auto* inner_sink_local_state = reinterpret_cast( local_state->_runtime_state->get_sink_local_state()); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); SpillableDebugPointHelper dp_helper("fault_inject::spill_file::spill_block"); st = sink_operator->revoke_memory(_helper.runtime_state.get()); @@ -465,20 +465,20 @@ TEST_F(PartitionedAggregationSinkOperatorTest, SinkWithMultipleRevokes) { block.insert(ColumnHelper::create_column_with_name({1, 2, 3, 4, 5})); st = sink_operator->sink(_helper.runtime_state.get(), &block, false); ASSERT_TRUE(st.ok()) << "sink round 1 failed: " << st.to_string(); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke round 1 failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); // Round 2: sink more → revoke again auto block2 = ColumnHelper::create_block({6, 7, 8, 9, 10}); block2.insert(ColumnHelper::create_column_with_name({6, 7, 8, 9, 10})); st = sink_operator->sink(_helper.runtime_state.get(), &block2, false); ASSERT_TRUE(st.ok()) << "sink round 2 failed: " << st.to_string(); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke round 2 failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); ASSERT_TRUE(shared_state->_is_spilled); @@ -570,7 +570,7 @@ TEST_F(PartitionedAggregationSinkOperatorTest, GetHashTableSizeViaAggSinkOperato local_state->_runtime_state->get_sink_local_state()); // Hash table should be empty before any data is sinked - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); auto block = ColumnHelper::create_block({1, 2, 3, 4, 5}); block.insert(ColumnHelper::create_column_with_name({1, 2, 3, 4, 5})); @@ -578,13 +578,13 @@ TEST_F(PartitionedAggregationSinkOperatorTest, GetHashTableSizeViaAggSinkOperato ASSERT_TRUE(st.ok()) << "sink failed: " << st.to_string(); // Hash table should have entries after sinked data - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke_memory failed: " << st.to_string(); // Hash table should be cleared after revoke - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); st = sink_operator->close(_helper.runtime_state.get(), st); ASSERT_TRUE(st.ok()) << "close failed: " << st.to_string(); @@ -668,12 +668,12 @@ TEST_F(PartitionedAggregationNullableKeySinkTest, SinkEOSFlushNullKeyOnly) { block1.insert(ColumnHelper::create_column_with_name({42})); st = sink_operator->sink(_helper.runtime_state.get(), &block1, false); ASSERT_TRUE(st.ok()) << "first sink failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 1); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 1); // Spill to disk and mark as spilled. st = sink_operator->revoke_memory(_helper.runtime_state.get()); ASSERT_TRUE(st.ok()) << "revoke_memory failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); ASSERT_TRUE(shared_state->_is_spilled); auto* spill_write_rows_counter = local_state->custom_profile()->get_counter("SpillWriteRows"); @@ -689,7 +689,7 @@ TEST_F(PartitionedAggregationNullableKeySinkTest, SinkEOSFlushNullKeyOnly) { block2.insert(ColumnHelper::create_column_with_name({10})); st = sink_operator->sink(_helper.runtime_state.get(), &block2, false); ASSERT_TRUE(st.ok()) << "second sink failed: " << st.to_string(); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 1); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 1); // EOS: send an empty block with eos=true. // Old code: aggregate_data_container->total_count() = 0 → SKIP flush → NULL key row LOST! @@ -699,7 +699,7 @@ TEST_F(PartitionedAggregationNullableKeySinkTest, SinkEOSFlushNullKeyOnly) { ASSERT_TRUE(st.ok()) << "EOS sink failed: " << st.to_string(); // Hash table must be empty after EOS flush. - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); ASSERT_FALSE(dep->is_blocked_by()); // Two NULL key aggregated rows were spilled (one per revoke/flush cycle). diff --git a/be/test/exec/operator/partitioned_aggregation_source_operator_test.cpp b/be/test/exec/operator/partitioned_aggregation_source_operator_test.cpp index f331b529ce5d5d..769aae198a782a 100644 --- a/be/test/exec/operator/partitioned_aggregation_source_operator_test.cpp +++ b/be/test/exec/operator/partitioned_aggregation_source_operator_test.cpp @@ -208,7 +208,7 @@ TEST_F(PartitionedAggregationSourceOperatorTest, GetBlock) { auto* inner_sink_local_state = reinterpret_cast( sink_local_state->_runtime_state->get_sink_local_state()); - ASSERT_GT(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_GT(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); LocalStateInfo info { .parent_profile = _helper.operator_profile.get(), @@ -302,7 +302,7 @@ TEST_F(PartitionedAggregationSourceOperatorTest, GetBlockWithSpill) { auto* inner_sink_local_state = reinterpret_cast( sink_local_state->_runtime_state->get_sink_local_state()); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); LocalStateInfo info { .parent_profile = _helper.operator_profile.get(), @@ -401,7 +401,7 @@ TEST_F(PartitionedAggregationSourceOperatorTest, GetBlockWithSpillError) { auto* inner_sink_local_state = reinterpret_cast( sink_local_state->_runtime_state->get_sink_local_state()); - ASSERT_EQ(inner_sink_local_state->get_hash_table_size(), 0); + ASSERT_EQ(inner_sink_local_state->_shared_state->groupby_agg_ctx->hash_table_size(), 0); LocalStateInfo info { .parent_profile = _helper.operator_profile.get(), @@ -718,13 +718,15 @@ TEST_F(PartitionedAggregationSourceOperatorTest, RevocableMemSizeWithAggContaine auto agg_sptr = std::make_shared(); shared_state->_in_mem_shared_state_sptr = agg_sptr; shared_state->_in_mem_shared_state = agg_sptr.get(); - agg_sptr->aggregate_data_container = + agg_sptr->groupby_agg_ctx = std::make_unique( + std::vector{}, VExprContextSPtrs{}, Sizes{}, 0, 1, true); + agg_sptr->groupby_agg_ctx->_agg_data_container = std::make_unique(sizeof(uint32_t), 8); // ~13 sub-containers of 8192 entries each ≈ 1.28 MB → exceeds 1MB threshold for (uint32_t i = 0; i < 100000; ++i) { - agg_sptr->aggregate_data_container->append_data(i); + agg_sptr->groupby_agg_ctx->agg_data_container()->append_data(i); } - const size_t container_bytes = agg_sptr->aggregate_data_container->memory_usage(); + const size_t container_bytes = agg_sptr->groupby_agg_ctx->agg_data_container()->memory_usage(); ASSERT_GT(container_bytes, 1UL << 20); SpillFileSPtr spill_file; @@ -918,17 +920,17 @@ TEST_F(PartitionedAggregationSourceOperatorTest, FlushHashTableToSubSpillFilesSu auto* in_mem_state = shared_state->_in_mem_shared_state; ASSERT_NE(in_mem_state, nullptr); - ASSERT_NE(in_mem_state->aggregate_data_container, nullptr); + ASSERT_NE(in_mem_state->groupby_agg_ctx->agg_data_container(), nullptr); // Set up the repartitioner the same way _flush_and_repartition does. const int new_level = local_state->_current_partition.level + 1; const int fanout = static_cast(source_operator->_partition_count); - size_t num_keys = in_mem_state->probe_expr_ctxs.size(); + size_t num_keys = in_mem_state->groupby_agg_ctx->groupby_expr_ctxs().size(); std::vector key_column_indices(num_keys); std::vector key_data_types(num_keys); for (size_t i = 0; i < num_keys; ++i) { key_column_indices[i] = i; - key_data_types[i] = in_mem_state->probe_expr_ctxs[i]->root()->data_type(); + key_data_types[i] = in_mem_state->groupby_agg_ctx->groupby_expr_ctxs()[i]->root()->data_type(); } std::vector output_spill_files; ASSERT_TRUE(SpillRepartitioner::create_output_spill_files( @@ -1008,7 +1010,7 @@ TEST_F(PartitionedAggregationSourceOperatorTest, // never reached — the repartitioner does not need setup_output. auto* in_mem_state = shared_state->_in_mem_shared_state; ASSERT_NE(in_mem_state, nullptr); - ASSERT_NE(in_mem_state->aggregate_data_container, nullptr); + ASSERT_NE(in_mem_state->groupby_agg_ctx->agg_data_container(), nullptr); auto st = local_state->_flush_hash_table_to_sub_spill_files(_helper.runtime_state.get()); EXPECT_TRUE(st.ok()) << st.to_string(); diff --git a/be/test/exec/operator/streaming_agg_operator_test.cpp b/be/test/exec/operator/streaming_agg_operator_test.cpp index 0421d58bfd256b..9a7eb217df2623 100644 --- a/be/test/exec/operator/streaming_agg_operator_test.cpp +++ b/be/test/exec/operator/streaming_agg_operator_test.cpp @@ -134,7 +134,7 @@ TEST_F(StreamingAggOperatorTest, test1) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 3); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 3); EXPECT_TRUE(op->need_more_input_data(state.get())); } @@ -145,7 +145,7 @@ TEST_F(StreamingAggOperatorTest, test1) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 4); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 4); EXPECT_TRUE(op->need_more_input_data(state.get())); } @@ -190,7 +190,7 @@ TEST_F(StreamingAggOperatorTest, test2) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 3); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 3); EXPECT_TRUE(op->need_more_input_data(state.get())); } @@ -202,7 +202,7 @@ TEST_F(StreamingAggOperatorTest, test2) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 3); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 3); EXPECT_FALSE(op->need_more_input_data(state.get())); } @@ -268,7 +268,7 @@ TEST_F(StreamingAggOperatorTest, test3) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 3); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 3); EXPECT_TRUE(op->need_more_input_data(state.get())); } @@ -281,7 +281,7 @@ TEST_F(StreamingAggOperatorTest, test3) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 3); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 3); EXPECT_FALSE(op->need_more_input_data(state.get())); } @@ -352,7 +352,7 @@ TEST_F(StreamingAggOperatorTest, test4) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 3); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 3); EXPECT_TRUE(op->need_more_input_data(state.get())); } @@ -368,7 +368,7 @@ TEST_F(StreamingAggOperatorTest, test4) { auto st = op->push(state.get(), &block, true); EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_EQ(local_state->_get_hash_table_size(), 4); + EXPECT_EQ(local_state->_groupby_agg_ctx->hash_table_size(), 4); EXPECT_TRUE(op->need_more_input_data(state.get())); } diff --git a/be/test/exec/pipeline/partitioned_agg_shared_state_test.cpp b/be/test/exec/pipeline/partitioned_agg_shared_state_test.cpp index 4118d923e57be6..7558fbebeada5f 100644 --- a/be/test/exec/pipeline/partitioned_agg_shared_state_test.cpp +++ b/be/test/exec/pipeline/partitioned_agg_shared_state_test.cpp @@ -21,6 +21,7 @@ #include #include "exec/common/agg_utils.h" +#include "exec/common/groupby_agg_context.h" #include "exec/pipeline/dependency.h" #include "exec/spill/spill_file_manager.h" #include "io/fs/local_file_system.h" @@ -156,22 +157,22 @@ TEST_F(PartitionedAggSharedStateTest, InMemSharedStateDefaultsNull) { EXPECT_EQ(state._in_mem_shared_state, nullptr); } -// Hash table contribution: AggSharedState constructor always creates agg_data. +// Hash table contribution: GroupByAggContext always creates hash_table_data. TEST_F(PartitionedAggSharedStateTest, AggSharedStateCreatesNonNullAggData) { - AggSharedState agg_state; - EXPECT_NE(agg_state.agg_data, nullptr); + GroupByAggContext ctx({}, {}, {}, 0, 1, true); + EXPECT_NE(ctx.hash_table_data(), nullptr); } // Hash table contribution: default method_variant is monostate (index 0) → 0 bytes. TEST_F(PartitionedAggSharedStateTest, AggSharedStateDefaultVariantIsMonostate) { - AggSharedState agg_state; - EXPECT_EQ(agg_state.agg_data->method_variant.index(), 0); + GroupByAggContext ctx({}, {}, {}, 0, 1, true); + EXPECT_EQ(ctx.hash_table_data()->method_variant.index(), 0); } -// Container contribution: aggregate_data_container defaults to null → 0 bytes. +// Container contribution: agg_data_container defaults to null → 0 bytes. TEST_F(PartitionedAggSharedStateTest, AggSharedStateAggContainerDefaultsNull) { - AggSharedState agg_state; - EXPECT_EQ(agg_state.aggregate_data_container, nullptr); + GroupByAggContext ctx({}, {}, {}, 0, 1, true); + EXPECT_EQ(ctx.agg_data_container(), nullptr); } // Container contribution: freshly constructed container has 0 memory_usage. @@ -193,29 +194,33 @@ TEST_F(PartitionedAggSharedStateTest, AggregateDataContainerMemoryGrowsAfterAppe // with monostate variant and null container → 0 bytes from both sources. TEST_F(PartitionedAggSharedStateTest, PartitionedAggStateLinkedToAggStateWithDefaultData) { AggSharedState agg_state; + agg_state.groupby_agg_ctx = std::make_unique( + std::vector{}, VExprContextSPtrs{}, Sizes{}, 0, 1, true); PartitionedAggSharedState state; state._in_mem_shared_state = &agg_state; state._is_spilled = true; EXPECT_NE(state._in_mem_shared_state, nullptr); - EXPECT_NE(state._in_mem_shared_state->agg_data, nullptr); + EXPECT_NE(state._in_mem_shared_state->groupby_agg_ctx->hash_table_data(), nullptr); // monostate → hash table contributes 0 bytes - EXPECT_EQ(state._in_mem_shared_state->agg_data->method_variant.index(), 0); + EXPECT_EQ(state._in_mem_shared_state->groupby_agg_ctx->hash_table_data()->method_variant.index(), 0); // null container → container contributes 0 bytes - EXPECT_EQ(state._in_mem_shared_state->aggregate_data_container, nullptr); + EXPECT_EQ(state._in_mem_shared_state->groupby_agg_ctx->agg_data_container(), nullptr); } // Container contribution through AggSharedState: memory_usage reflects arena allocation. TEST_F(PartitionedAggSharedStateTest, AggSharedStateContainerMemoryUsage) { AggSharedState agg_state; - agg_state.aggregate_data_container = + agg_state.groupby_agg_ctx = std::make_unique( + std::vector{}, VExprContextSPtrs{}, Sizes{}, 0, 1, true); + agg_state.groupby_agg_ctx->_agg_data_container = std::make_unique(sizeof(uint32_t), 8); - ASSERT_NE(agg_state.aggregate_data_container, nullptr); - EXPECT_EQ(agg_state.aggregate_data_container->memory_usage(), 0); + ASSERT_NE(agg_state.groupby_agg_ctx->agg_data_container(), nullptr); + EXPECT_EQ(agg_state.groupby_agg_ctx->agg_data_container()->memory_usage(), 0); uint32_t key = 99; - agg_state.aggregate_data_container->append_data(key); - EXPECT_GT(agg_state.aggregate_data_container->memory_usage(), 0); + agg_state.groupby_agg_ctx->agg_data_container()->append_data(key); + EXPECT_GT(agg_state.groupby_agg_ctx->agg_data_container()->memory_usage(), 0); } } // namespace doris \ No newline at end of file From 4e6ffa3cd15ccc893a43c789d33be4221cb7f713 Mon Sep 17 00:00:00 2001 From: Mryange Date: Fri, 27 Mar 2026 14:29:48 +0800 Subject: [PATCH 4/4] fix beut --- be/test/exec/operator/agg_operator_test.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/be/test/exec/operator/agg_operator_test.cpp b/be/test/exec/operator/agg_operator_test.cpp index eed6e829e65dce..ad395bf2b523c8 100644 --- a/be/test/exec/operator/agg_operator_test.cpp +++ b/be/test/exec/operator/agg_operator_test.cpp @@ -589,9 +589,10 @@ TEST_F(AggOperatorTestWithGroupBy, other_case_2) { auto* ctx = static_cast(shared_state.get())->groupby_agg_ctx.get(); std::vector places(block.rows()); - EXPECT_TRUE( - ctx->emplace_and_forward(places.data(), key_columns, block.rows(), &block, true) - .ok()); + ctx->emplace_into_hash_table(places.data(), key_columns, block.rows(), + ctx->hash_table_compute_timer(), + ctx->hash_table_emplace_timer(), + ctx->hash_table_input_counter()); EXPECT_EQ(ctx->hash_table_size(), 4); // [1,2,3,null] }