Skip to content

Commit

Permalink
[Enhancement] optimize histogram implementation (StarRocks#52400)
Browse files Browse the repository at this point in the history
Signed-off-by: Murphy <[email protected]>
  • Loading branch information
murphyatwork authored Nov 12, 2024
1 parent b523e82 commit ac8dcb3
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 86 deletions.
5 changes: 5 additions & 0 deletions be/src/column/column_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ class ColumnHelper {
return down_cast<RunTimeColumnType<Type>*>(value);
}

template <LogicalType Type>
static inline const RunTimeColumnType<Type>* cast_to_raw(const Column* value) {
return down_cast<const RunTimeColumnType<Type>*>(value);
}

/**
* Cast columnPtr to special type ColumnPtr
* Plz sure actual column type by yourself
Expand Down
10 changes: 10 additions & 0 deletions be/src/exprs/agg/aggregate_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ struct AggDataTypeTraits<lt, FixedLengthLTGuard<lt>> {
static void append_value(ColumnType* column, const ValueType& value) { column->append(value); }

static RefType get_row_ref(const ColumnType& column, size_t row) { return column.get_data()[row]; }
static RefType get_ref(const ValueType& value) { return value; }

static void update_max(ValueType& current, const RefType& input) { current = std::max<ValueType>(current, input); }
static void update_min(ValueType& current, const RefType& input) { current = std::min<ValueType>(current, input); }

static bool is_equal(const RefType& lhs, const RefType& rhs) { return lhs == rhs; }
};

// For pointer ref types
Expand All @@ -55,11 +58,14 @@ struct AggDataTypeTraits<lt, ObjectFamilyLTGuard<lt>> {
static void assign_value(ColumnType* column, size_t row, const ValueType& ref) { *column->get_object(row) = ref; }

static void append_value(ColumnType* column, const ValueType& value) { column->append(&value); }
static RefType get_ref(const ValueType& value) { return &value; }

static const RefType get_row_ref(const ColumnType& column, size_t row) { return column.get_object(row); }

static void update_max(ValueType& current, const RefType& input) { current = std::max<ValueType>(current, *input); }
static void update_min(ValueType& current, const RefType& input) { current = std::min<ValueType>(current, *input); }

static bool is_equal(const RefType& lhs, const RefType& rhs) { return *lhs == *rhs; }
};

template <LogicalType lt>
Expand All @@ -79,6 +85,8 @@ struct AggDataTypeTraits<lt, StringLTGuard<lt>> {

static RefType get_row_ref(const ColumnType& column, size_t row) { return column.get_slice(row); }

static RefType get_ref(const ValueType& value) { return Slice(value.data(), value.size()); }

static void update_max(ValueType& current, const RefType& input) {
if (Slice(current.data(), current.size()).compare(input) < 0) {
current.resize(input.size);
Expand All @@ -91,6 +99,8 @@ struct AggDataTypeTraits<lt, StringLTGuard<lt>> {
memcpy(current.data(), input.data, input.size);
}
}

static bool is_equal(const RefType& lhs, const RefType& rhs) { return lhs == rhs; }
};

template <LogicalType lt>
Expand Down
160 changes: 78 additions & 82 deletions be/src/exprs/agg/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,42 @@
#pragma once

#include "column/column_helper.h"
#include "column/object_column.h"
#include "column/column_viewer.h"
#include "column/datum_convert.h"
#include "column/nullable_column.h"
#include "column/type_traits.h"
#include "column/vectorized_fwd.h"
#include "exprs/agg/aggregate.h"
#include "exprs/agg/aggregate_traits.h"
#include "gutil/casts.h"
#include "runtime/large_int_value.h"
#include "storage/types.h"

namespace starrocks {

template <typename T>
template <LogicalType LT>
struct Bucket {
public:
using RefType = AggDataRefType<LT>;
using ValueType = AggDataValueType<LT>;

Bucket() = default;
Bucket(T lower, T upper, size_t count, size_t upper_repeats)
: lower(lower), upper(upper), count(count), upper_repeats(upper_repeats), count_in_bucket(1) {}
T lower;
T upper;

Bucket(RefType input_lower, RefType input_upper, size_t count, size_t upper_repeats)
: count(count), upper_repeats(upper_repeats), count_in_bucket(1) {
AggDataTypeTraits<LT>::assign_value(lower, input_lower);
AggDataTypeTraits<LT>::assign_value(upper, input_upper);
}

bool is_equals_to_upper(RefType value) {
return AggDataTypeTraits<LT>::is_equal(value, AggDataTypeTraits<LT>::get_ref(upper));
}

void update_upper(RefType value) { AggDataTypeTraits<LT>::assign_value(upper, value); }

Datum get_lower_datum() { return Datum(AggDataTypeTraits<LT>::get_ref(lower)); }
Datum get_upper_datum() { return Datum(AggDataTypeTraits<LT>::get_ref(upper)); }

ValueType lower;
ValueType upper;
// Up to this bucket, the total value
int64_t count;
// the number of values that on the upper boundary
Expand All @@ -40,140 +59,117 @@ struct Bucket {
int64_t count_in_bucket;
};

template <typename T>
template <LogicalType LT>
struct HistogramState {
HistogramState() = default;
std::vector<T> data;
HistogramState() {
auto data = RunTimeColumnType<LT>::create();
column = NullableColumn::create(data, NullColumn::create());
}

ColumnPtr column;
};

template <LogicalType LT, typename T = RunTimeCppType<LT>>
class HistogramAggregationFunction final
: public AggregateFunctionBatchHelper<HistogramState<T>, HistogramAggregationFunction<LT, T>> {
: public AggregateFunctionBatchHelper<HistogramState<LT>, HistogramAggregationFunction<LT, T>> {
public:
using ColumnType = RunTimeColumnType<LT>;

void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
size_t row_num) const override {
T v;
if (columns[0]->is_nullable()) {
if (columns[0]->is_null(row_num)) {
return;
}

const auto* data_column = down_cast<const NullableColumn*>(columns[0]);
v = down_cast<const ColumnType*>(data_column->data_column().get())->get_data()[row_num];
} else {
v = down_cast<const ColumnType*>(columns[0])->get_data()[row_num];
}
CHECK(false);
}

this->data(state).data.emplace_back(v);
void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,
AggDataPtr __restrict state) const override {
this->data(state).column->append(*columns[0], 0, chunk_size);
}

void update_batch_single_state_with_frame(FunctionContext* ctx, AggDataPtr __restrict state, const Column** columns,
int64_t peer_group_start, int64_t peer_group_end, int64_t frame_start,
int64_t frame_end) const override {
//Histogram aggregation function only support one stage Agg
DCHECK(false);
CHECK(false);
}

void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
//Histogram aggregation function only support one stage Agg
DCHECK(false);
CHECK(false);
}

void serialize_to_column(FunctionContext* ctx __attribute__((unused)), ConstAggDataPtr __restrict state,
Column* to) const override {
//Histogram aggregation function only support one stage Agg
DCHECK(false);
CHECK(false);
}

void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size,
ColumnPtr* dst) const override {
//Histogram aggregation function only support one stage Agg
DCHECK(false);
CHECK(false);
}

std::string toBucketJson(const std::string& lower, const std::string& upper, size_t count, size_t upper_repeats,
double sample_ratio) const {
return fmt::format(R"(["{}","{}","{}","{}"])", lower, upper, std::to_string((int64_t)(count * sample_ratio)),
std::to_string((int64_t)(upper_repeats * sample_ratio)));
return fmt::format(R"(["{}","{}","{}","{}"])", lower, upper, (int64_t)(count * sample_ratio),
(int64_t)(upper_repeats * sample_ratio));
}

void finalize_to_column(FunctionContext* ctx __attribute__((unused)), ConstAggDataPtr __restrict state,
Column* to) const override {
auto bucket_num = ColumnHelper::get_const_value<TYPE_INT>(ctx->get_constant_column(1));
[[maybe_unused]] double sample_ratio =
1 / ColumnHelper::get_const_value<TYPE_DOUBLE>(ctx->get_constant_column(2));
int bucket_size = this->data(state).data.size() / bucket_num;

//Build bucket
std::vector<Bucket<T>> buckets;
for (int i = 0; i < this->data(state).data.size(); ++i) {
T v = this->data(state).data[i];
int bucket_size = this->data(state).column->size() / bucket_num;

// Build bucket
std::vector<Bucket<LT>> buckets;
ColumnViewer<LT> viewer(this->data(state).column);
for (size_t i = 0; i < viewer.size(); ++i) {
auto v = viewer.value(i);
if (viewer.is_null(i)) {
continue;
}
if (buckets.empty()) {
Bucket<T> bucket(v, v, 1, 1);
Bucket<LT> bucket(v, v, 1, 1);
buckets.emplace_back(bucket);
} else {
Bucket<T>* lastBucket = &buckets.back();
Bucket<LT>* last_bucket = &buckets.back();

if (lastBucket->upper == v) {
lastBucket->count++;
lastBucket->count_in_bucket++;
lastBucket->upper_repeats++;
if (last_bucket->is_equals_to_upper(v)) {
last_bucket->count++;
last_bucket->count_in_bucket++;
last_bucket->upper_repeats++;
} else {
if (lastBucket->count_in_bucket >= bucket_size) {
Bucket<T> bucket(v, v, lastBucket->count + 1, 1);
if (last_bucket->count_in_bucket >= bucket_size) {
Bucket<LT> bucket(v, v, last_bucket->count + 1, 1);
buckets.emplace_back(bucket);
} else {
lastBucket->upper = v;
lastBucket->count++;
lastBucket->count_in_bucket++;
lastBucket->upper_repeats = 1;
last_bucket->update_upper(v);
last_bucket->count++;
last_bucket->count_in_bucket++;
last_bucket->upper_repeats = 1;
}
}
}
}

const auto& type_desc = ctx->get_arg_type(0);
TypeInfoPtr type_info = get_type_info(LT, type_desc->precision, type_desc->scale);
std::string bucket_json;
if (buckets.empty()) {
bucket_json = "[]";
} else {
bucket_json = "[";
if constexpr (lt_is_largeint<LT>) {
for (int i = 0; i < buckets.size(); ++i) {
bucket_json += toBucketJson(LargeIntValue::to_string(buckets[i].lower),
LargeIntValue::to_string(buckets[i].upper), buckets[i].count,
buckets[i].upper_repeats, sample_ratio) +
",";
}
} else if constexpr (lt_is_arithmetic<LT>) {
for (int i = 0; i < buckets.size(); ++i) {
bucket_json += toBucketJson(std::to_string(buckets[i].lower), std::to_string(buckets[i].upper),
buckets[i].count, buckets[i].upper_repeats, sample_ratio) +
",";
}
} else if constexpr (lt_is_date_or_datetime<LT>) {
for (int i = 0; i < buckets.size(); ++i) {
bucket_json += toBucketJson(buckets[i].lower.to_string(), buckets[i].upper.to_string(),
buckets[i].count, buckets[i].upper_repeats, sample_ratio) +
",";
}
} else if constexpr (lt_is_decimal<LT>) {
int scale = ctx->get_arg_type(0)->scale;
int precision = ctx->get_arg_type(0)->precision;
for (int i = 0; i < buckets.size(); ++i) {
bucket_json += toBucketJson(DecimalV3Cast::to_string<T>(buckets[i].lower, precision, scale),
DecimalV3Cast::to_string<T>(buckets[i].upper, precision, scale),
buckets[i].count, buckets[i].upper_repeats, sample_ratio) +
",";
}
} else if constexpr (lt_is_string<LT>) {
for (int i = 0; i < buckets.size(); ++i) {
bucket_json += toBucketJson(buckets[i].lower.to_string(), buckets[i].upper.to_string(),
buckets[i].count, buckets[i].upper_repeats, sample_ratio) +
",";
}

for (int i = 0; i < buckets.size(); ++i) {
std::string lower_str = datum_to_string(type_info.get(), buckets[i].get_lower_datum());
std::string upper_str = datum_to_string(type_info.get(), buckets[i].get_upper_datum());
bucket_json +=
toBucketJson(lower_str, upper_str, buckets[i].count, buckets[i].upper_repeats, sample_ratio) +
",";
}

bucket_json[bucket_json.size() - 1] = ']';
}

Expand Down
2 changes: 1 addition & 1 deletion be/src/exprs/lambda_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Status LambdaFunction::extract_outer_common_exprs(RuntimeState* state, ExprConte
if (is_independent) {
SlotId slot_id = ctx->next_slot_id++;
#ifdef DEBUG
expr_ctx->root()->for_each_slot_id([expr_ctx, new_slot_id = slot_id](SlotId slot_id) {
expr_ctx->root()->for_each_slot_id([new_slot_id = slot_id](SlotId slot_id) {
DCHECK_NE(new_slot_id, slot_id) << "slot_id " << new_slot_id << " already exists in expr_ctx";
});
#endif
Expand Down
8 changes: 5 additions & 3 deletions be/test/exprs/agg/aggregate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1724,9 +1724,11 @@ TEST_F(AggregateTest, test_histogram) {
raw_columns[1] = const1.get();
raw_columns[2] = const2.get();
raw_columns[3] = const3.get();
for (int i = 0; i < data_column->size(); ++i) {
histogram_function->update(local_ctx.get(), raw_columns.data(), state->state(), i);
}
histogram_function->update_batch_single_state(local_ctx.get(), data_column->size(), raw_columns.data(),
state->state());
// for (int i = 0; i < data_column->size(); ++i) {
// histogram_function->update(local_ctx.get(), raw_columns.data(), state->state(), i);
// }

auto result_column = NullableColumn::create(BinaryColumn::create(), NullColumn::create());
histogram_function->finalize_to_column(local_ctx.get(), state->state(), result_column.get());
Expand Down
Loading

0 comments on commit ac8dcb3

Please sign in to comment.