Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] count roaring bitmap memory usage in aggregator #49206

Merged
merged 17 commits into from
Aug 8, 2024
Merged
6 changes: 0 additions & 6 deletions be/src/column/hash_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ namespace starrocks {
template <typename T>
using HashSet = phmap::flat_hash_set<T, StdHash<T>>;

template <typename T>
using HashSetWithMemoryCounting =
phmap::flat_hash_set<T, StdHash<T>, phmap::priv::hash_default_eq<T>, CountingAllocator<T>>;

// By storing hash value in slice, we can save the cost of
// 1. re-calculate hash value of the slice
// 2. touch slice memory area which may cause high latency of memory access.
Expand Down Expand Up @@ -84,8 +80,6 @@ class TEqualOnSliceWithHash {
};

using SliceHashSet = phmap::flat_hash_set<SliceWithHash, HashOnSliceWithHash, EqualOnSliceWithHash>;
using SliceHashSetWithMemoryCounting = phmap::flat_hash_set<SliceWithHash, HashOnSliceWithHash, EqualOnSliceWithHash,
CountingAllocator<SliceWithHash>>;

using SliceNormalHashSet = phmap::flat_hash_set<Slice, SliceHash, SliceNormalEqual>;

Expand Down
28 changes: 26 additions & 2 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
#include "exec/limited_pipeline_chunk_buffer.h"
#include "exec/pipeline/operator.h"
#include "exec/spill/spiller.hpp"
#include "exprs/agg/aggregate_state_allocator.h"
#include "exprs/anyval_util.h"
#include "gen_cpp/PlanNodes_types.h"
#include "runtime/current_thread.h"
#include "runtime/descriptors.h"
#include "runtime/memory/roaring_hook.h"
#include "types/logical_type.h"
#include "udf/java/utils.h"
#include "util/runtime_profile.h"
Expand Down Expand Up @@ -201,7 +203,23 @@ void AggregatorParams::init() {
#define ALIGN_TO(size, align) ((size + align - 1) / align * align)
#define PAD(size, align) (align - (size % align)) % align;

Aggregator::Aggregator(AggregatorParamsPtr params) : _params(std::move(params)) {}
class ThreadLocalStateAllocatorSetter {
public:
ThreadLocalStateAllocatorSetter(Allocator* allocator)
: _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {}
~ThreadLocalStateAllocatorSetter() = default;

private:
ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter;
ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter;
};

#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \
auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator)

Aggregator::Aggregator(AggregatorParamsPtr params) : _params(std::move(params)) {
_allocator = std::make_unique<CountingAllocatorWithHook>();
}

Status Aggregator::open(RuntimeState* state) {
if (_is_opened) {
Expand Down Expand Up @@ -529,6 +547,7 @@ Status Aggregator::reset_state(starrocks::RuntimeState* state, const std::vector
}

Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
_is_ht_eos = false;
_num_input_rows = 0;
_is_prepared = false;
Expand Down Expand Up @@ -623,6 +642,7 @@ void Aggregator::close(RuntimeState* state) {
if (_mem_pool != nullptr) {
// Note: we must free agg_states object before _mem_pool free_all;
if (_single_agg_state != nullptr) {
SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
for (int i = 0; i < _agg_functions.size(); i++) {
_agg_functions[i]->destroy(_agg_fn_ctxs[i], _single_agg_state + _agg_states_offsets[i]);
}
Expand Down Expand Up @@ -747,6 +767,7 @@ Status Aggregator::compute_single_agg_state(Chunk* chunk, size_t chunk_size) {
for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
// evaluate arguments at i-th agg function
RETURN_IF_ERROR(evaluate_agg_input_column(chunk, agg_expr_ctxs[i], i));
SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
// batch call update or merge for singe stage
if (!_is_merge_funcs[i] && !use_intermediate) {
_agg_functions[i]->update_batch_single_state(_agg_fn_ctxs[i], chunk_size, _agg_input_raw_columns[i].data(),
Expand All @@ -769,6 +790,7 @@ Status Aggregator::compute_batch_agg_states(Chunk* chunk, size_t chunk_size) {
for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
// evaluate arguments at i-th agg function
RETURN_IF_ERROR(evaluate_agg_input_column(chunk, agg_expr_ctxs[i], i));
SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
// batch call update or merge
if (!_is_merge_funcs[i] && !use_intermediate) {
_agg_functions[i]->update_batch(_agg_fn_ctxs[i], chunk_size, _agg_states_offsets[i],
Expand All @@ -790,7 +812,7 @@ Status Aggregator::compute_batch_agg_states_with_selection(Chunk* chunk, size_t

for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
RETURN_IF_ERROR(evaluate_agg_input_column(chunk, agg_expr_ctxs[i], i));

SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
if (!_is_merge_funcs[i] && !use_intermediate) {
_agg_functions[i]->update_batch_selectively(_agg_fn_ctxs[i], chunk_size, _agg_states_offsets[i],
_agg_input_raw_columns[i].data(), _tmp_agg_states.data(),
Expand Down Expand Up @@ -1074,6 +1096,7 @@ void Aggregator::_finalize_to_chunk(ConstAggDataPtr __restrict state, const Colu
}

void Aggregator::_destroy_state(AggDataPtr __restrict state) {
SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
_agg_functions[i]->destroy(_agg_fn_ctxs[i], state + _agg_states_offsets[i]);
}
Expand Down Expand Up @@ -1539,6 +1562,7 @@ void Aggregator::_release_agg_memory() {
// If all function states are of POD type,
// then we don't have to traverse the hash table to call destroy method.
//
SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get());
_hash_map_variant.visit([&](auto& hash_map_with_key) {
bool skip_destroy = std::all_of(_agg_functions.begin(), _agg_functions.end(),
[](auto* func) { return func->is_pod_state(); });
silverbullet233 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
8 changes: 6 additions & 2 deletions be/src/exec/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "runtime/current_thread.h"
#include "runtime/descriptors.h"
#include "runtime/mem_pool.h"
#include "runtime/memory/counting_allocator.h"
#include "runtime/runtime_state.h"
#include "runtime/types.h"
#include "util/defer_op.h"
Expand Down Expand Up @@ -294,12 +295,13 @@ class Aggregator : public pipeline::ContextWithDependency {
const int64_t hash_map_memory_usage() const { return _hash_map_variant.reserved_memory_usage(mem_pool()); }
const int64_t hash_set_memory_usage() const { return _hash_set_variant.reserved_memory_usage(mem_pool()); }
const int64_t agg_state_memory_usage() const { return _agg_state_mem_usage; }
const int64_t allocator_memory_usage() const { return _allocator->memory_usage(); }

const int64_t memory_usage() const {
if (is_hash_set()) {
return hash_set_memory_usage() + agg_state_memory_usage();
return hash_set_memory_usage() + agg_state_memory_usage() + allocator_memory_usage();
} else if (!_group_by_expr_ctxs.empty()) {
return hash_map_memory_usage() + agg_state_memory_usage();
return hash_map_memory_usage() + agg_state_memory_usage() + allocator_memory_usage();
} else {
return 0;
}
Expand Down Expand Up @@ -408,6 +410,8 @@ class Aggregator : public pipeline::ContextWithDependency {

ObjectPool* _pool;
std::unique_ptr<MemPool> _mem_pool;
// used to count heap memory usage of agg states
std::unique_ptr<CountingAllocatorWithHook> _allocator;
// The open phase still relies on the TFunction object for some initialization operations
std::vector<TFunction> _fns;

Expand Down
90 changes: 90 additions & 0 deletions be/src/exprs/agg/aggregate_state_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "column/hash_set.h"
#include "common/config.h"
#include "runtime/memory/allocator.h"

namespace starrocks {

inline thread_local Allocator* tls_agg_state_allocator = nullptr;

template <class T>
class AggregateStateAllocator {
public:
typedef T value_type;
typedef size_t size_type;
using propagate_on_container_copy_assignment = std::true_type; // for consistency
using propagate_on_container_move_assignment = std::true_type; // to avoid the pessimization
using propagate_on_container_swap = std::true_type; // to avoid the undefined behavior

template <typename U>
struct rebind {
using other = AggregateStateAllocator<U>;
};
AggregateStateAllocator() = default;
template <class U>
AggregateStateAllocator(const AggregateStateAllocator<U>& other) {}

~AggregateStateAllocator() = default;

T* allocate(size_t n) {
DCHECK(tls_agg_state_allocator != nullptr);
return static_cast<T*>(tls_agg_state_allocator->alloc(n * sizeof(T)));
}

void deallocate(T* ptr, size_t n) {
DCHECK(tls_agg_state_allocator != nullptr);
tls_agg_state_allocator->free(ptr);
}

AggregateStateAllocator& operator=(const AggregateStateAllocator& rhs) = default;

template <class U>
AggregateStateAllocator& operator=(const AggregateStateAllocator<U>& rhs) {
return *this;
}

bool operator==(const AggregateStateAllocator& rhs) const { return true; }

bool operator!=(const AggregateStateAllocator& rhs) const { return false; }

void swap(AggregateStateAllocator& rhs) {}
};
template <class T>
void swap(AggregateStateAllocator<T>& lhs, AggregateStateAllocator<T>& rhs) {
lhs.swap(rhs);
}

class ThreadLocalAggregateStateAllocatorSetter {
public:
ThreadLocalAggregateStateAllocatorSetter(Allocator* allocator) {
_prev = tls_agg_state_allocator;
tls_agg_state_allocator = allocator;
}
~ThreadLocalAggregateStateAllocatorSetter() { tls_agg_state_allocator = _prev; }

private:
Allocator* _prev = nullptr;
};

template <typename T>
using HashSetWithAggStateAllocator =
phmap::flat_hash_set<T, StdHash<T>, phmap::priv::hash_default_eq<T>, AggregateStateAllocator<T>>;

using SliceHashSetWithAggStateAllocator = phmap::flat_hash_set<SliceWithHash, HashOnSliceWithHash, EqualOnSliceWithHash,
AggregateStateAllocator<SliceWithHash>>;
} // namespace starrocks
Loading
Loading