Skip to content

Commit

Permalink
Optimize the performance of arrays_overlap
Browse files Browse the repository at this point in the history
Signed-off-by: trueeyu <[email protected]>
  • Loading branch information
trueeyu committed Sep 19, 2024
1 parent 809192e commit 8b79bbe
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 78 deletions.
32 changes: 32 additions & 0 deletions be/src/column/column_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,38 @@ class ColumnHelper {
}
}

template <LogicalType LT>
static const RunTimeColumnType<LT>* get_data_column_by_type(const Column* column) {
using ColumnType = RunTimeColumnType<LT>;
if (column->is_nullable()) {
const auto* nullable_column = down_cast<const NullableColumn*>(column);
return down_cast<const ColumnType*>(&nullable_column->data_column_ref());
} else if (column->is_constant()) {
const auto* const_column = down_cast<const ConstColumn*>(column);
return down_cast<const ColumnType*>(const_column->data_column().get());
} else {
return reinterpret_cast<const ColumnType*>(column);
}
}

static NullColumn* get_null_column(const Column* column) {
if (column->is_nullable()) {
auto* nullable_column = down_cast<const NullableColumn*>(column);
return nullable_column->null_column().get();
} else {
return nullptr;
}
}

static NullColumnPtr clone_null_column(const ColumnPtr& column) {
if (column->is_nullable()) {
ColumnPtr result = ColumnHelper::as_raw_column<NullableColumn>(column)->null_column()->clone();
return ColumnHelper::cast_to<TYPE_NULL>(result);
} else {
return nullptr;
}
}

static const Column* get_data_column(const Column* column) {
if (column->is_nullable()) {
auto* nullable_column = down_cast<const NullableColumn*>(column);
Expand Down
10 changes: 10 additions & 0 deletions be/src/exprs/array_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ class ArrayFunctions {
return ArrayOverlap<type>::process(context, columns);
}

template <LogicalType type>
static Status array_overlap_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
return ArrayOverlap<type>::prepare(context, scope);
}

template <LogicalType type>
static Status array_overlap_close(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
return ArrayOverlap<type>::close(context, scope);
}

template <LogicalType type>
static StatusOr<ColumnPtr> array_intersect(FunctionContext* context, const Columns& columns) {
return ArrayIntersect<type>::process(context, columns);
Expand Down
295 changes: 234 additions & 61 deletions be/src/exprs/array_functions.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,98 +252,271 @@ private:
}
};

template <typename HashSet>
struct ArrayOverlapState {
bool left_is_notnull_const = false;
bool right_is_notnull_const = false;
bool has_overlapping = false;
bool has_null = false;
std::unique_ptr<HashSet> hash_set;
};

template <LogicalType LT>
class ArrayOverlap {
public:
using CppType = RunTimeCppType<LT>;
using ColumnType = RunTimeColumnType<LT>;
using DataArray = RunTimeProxyContainerType<LT>;
using HashFunc = PhmapDefaultHashFunc<LT, PhmapSeed1>;
using HashSet = phmap::flat_hash_set<CppType, HashFunc>;

static ColumnPtr process(FunctionContext* ctx, const Columns& columns) {
static_assert(lt_is_largeint<LT> || lt_is_decimal128<LT> || lt_is_fixedlength<LT> || lt_is_string<LT>);
static Status prepare(FunctionContext* ctx, FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::FRAGMENT_LOCAL) {
return Status::OK();
}

auto* state = new ArrayOverlapState<HashSet>();
ctx->set_function_state(scope, state);

if (!ctx->is_notnull_constant_column(0) && !ctx->is_notnull_constant_column(1)) {
return Status::OK();
}

if (ctx->is_notnull_constant_column(1)) {
const auto* array_column = ColumnHelper::get_data_column_by_type<TYPE_ARRAY>(
ctx->get_constant_column(1).get());
state->right_is_notnull_const = true;
state->hash_set = std::make_unique<HashSet>();
state->has_null = _put_array_to_hash_set(*array_column, 0, state->hash_set.get());
}

if (ctx->is_notnull_constant_column(0)) {
const auto* array_column = ColumnHelper::get_data_column_by_type<TYPE_ARRAY>(
ctx->get_constant_column(0).get());
state->left_is_notnull_const = true;

if (state->right_is_notnull_const) {
const auto* elements_column = &array_column->elements();

if (elements_column->is_nullable()) {
state->has_overlapping =
_check_column_overlap_nullable(*state->hash_set, *array_column, 0, state->has_null);
} else {
state->has_overlapping = _check_column_overlap(*state->hash_set, *array_column, 0);
}
} else {
state->hash_set = std::make_unique<HashSet>();
state->has_null = _put_array_to_hash_set(*array_column, 0, state->hash_set.get());
}
}

return Status::OK();
}

static Status close(FunctionContext* ctx, FunctionContext::FunctionStateScope scope) {
if (scope == FunctionContext::FRAGMENT_LOCAL) {
auto* state = reinterpret_cast<ArrayOverlapState<HashSet>*>(
ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
delete state;
}

return Status::OK();
}

static StatusOr<ColumnPtr> process(FunctionContext* ctx, const Columns& columns) {
RETURN_IF_COLUMNS_ONLY_NULL(columns);
static_assert(PhmapDefaultHashFunc<LT, PhmapSeed1>::is_supported());

return _array_overlap<phmap::flat_hash_set<CppType, PhmapDefaultHashFunc<LT, PhmapSeed1>>>(columns);
auto* state =
reinterpret_cast<ArrayOverlapState<HashSet>*>(ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
if (UNLIKELY(state == nullptr)) {
return Status::InternalError("array_overloap get state failed");
}

bool is_nullable = columns[0]->is_nullable() || columns[1]->is_nullable();
auto chunk_size = columns[0]->size();

if (state->left_is_notnull_const && state->right_is_notnull_const) {
ColumnPtr result_column;
if (state->has_overlapping) {
result_column = ColumnHelper::create_const_column<TYPE_BOOLEAN>(1, chunk_size);
} else {
result_column = ColumnHelper::create_const_column<TYPE_BOOLEAN>(0, chunk_size);
}
if (is_nullable) {
result_column = ColumnHelper::cast_to_nullable_column(result_column);
}
return result_column;
} else if (state->left_is_notnull_const) {
return _array_overlap_const(*state, *columns[1]);
} else if (state->right_is_notnull_const) {
return _array_overlap_const(*state, *columns[0]);
} else {
return _array_overlap(columns);
}
}

private:
template <typename HashSet>
static ColumnPtr _array_overlap(const Columns& original_columns) {
size_t chunk_size = original_columns[0]->size();
auto result_column = BooleanColumn::create(chunk_size, 0);
Columns columns;
for (const auto& col : original_columns) {
columns.push_back(ColumnHelper::unpack_and_duplicate_const_column(chunk_size, col));
static ColumnPtr _array_overlap_const(const ArrayOverlapState<HashSet>& state, const Column& column) {
size_t chunk_size = column.size();
auto result_data_column = BooleanColumn::create(chunk_size, 0);
auto& result_data = result_data_column->get_data();
NullColumnPtr result_null_column;
const ArrayColumn* src_data_column = ColumnHelper::get_data_column_by_type<TYPE_ARRAY>(&column);
const NullColumn* src_null_column = ColumnHelper::get_null_column(&column);

if (src_null_column != nullptr) {
result_null_column = ColumnHelper::as_column<UInt8Column>(src_null_column->clone_shared());
}

bool is_nullable = false;
bool has_null = false;
std::vector<ArrayColumn*> src_columns;
src_columns.reserve(columns.size());
NullColumnPtr null_result = NullColumn::create();
null_result->resize(chunk_size);
bool elements_column_is_nullable = src_data_column->elements_column()->is_nullable();

for (const auto& column : columns) {
if (column->is_nullable()) {
is_nullable = true;
has_null = (column->has_null() || has_null);
const auto* src_nullable_column = down_cast<const NullableColumn*>(column.get());
src_columns.emplace_back(down_cast<ArrayColumn*>(src_nullable_column->data_column().get()));
null_result = FunctionHelper::union_null_column(null_result, src_nullable_column->null_column());
} else {
src_columns.emplace_back(down_cast<ArrayColumn*>(column.get()));
if (elements_column_is_nullable) {
for (size_t i = 0; i < chunk_size; i++) {
result_data[i] = _check_column_overlap_nullable(*state.hash_set, *src_data_column, i, state.has_null);
}
} else {
for (size_t i = 0; i < chunk_size; i++) {
result_data[i] = _check_column_overlap(*state.hash_set, *src_data_column, i);
}
}

HashSet hash_set;
for (size_t i = 0; i < chunk_size; i++) {
_array_overlap_item<HashSet>(src_columns, i, &hash_set,
static_cast<BooleanColumn*>(result_column.get())->get_data().data());
hash_set.clear();
if (result_null_column != nullptr) {
return NullableColumn::create(result_data_column, result_null_column);
} else {
return result_data_column;
}
}

if (is_nullable) {
return NullableColumn::create(result_column, null_result);
static ColumnPtr _array_overlap(const Columns& columns) {
size_t chunk_size = columns[0]->size();
auto result_data_column = BooleanColumn::create(chunk_size, 0);
auto& result_data = result_data_column->get_data();

const auto* src_data_column_0 = ColumnHelper::get_data_column_by_type<TYPE_ARRAY>(columns[0].get());
const auto* src_data_column_1 = ColumnHelper::get_data_column_by_type<TYPE_ARRAY>(columns[1].get());

NullColumnPtr result_null_column = FunctionHelper::union_nullable_column(columns[0], columns[1]);

//TODO: use small array to build hash set
bool element_column_is_nullable = src_data_column_0->elements_column()->is_nullable();

if (element_column_is_nullable) {
for (size_t i = 0; i < chunk_size; i++) {
HashSet hash_set;
bool has_null = _put_array_to_hash_set(*src_data_column_1, i, &hash_set);
result_data[i] = _check_column_overlap_nullable(hash_set, *src_data_column_0, i, has_null);
}
} else {
for (size_t i = 0; i < chunk_size; i++) {
HashSet hash_set;
(void)_put_array_to_hash_set(*src_data_column_1, i, &hash_set);
result_data[i] = _check_column_overlap(hash_set, *src_data_column_0, i);
hash_set.clear();
}
}

return result_column;
if (result_null_column != nullptr) {
return NullableColumn::create(result_data_column, result_null_column);
} else {
return result_data_column;
}
}

template <typename HashSet>
static void _array_overlap_item(const std::vector<ArrayColumn*>& columns, size_t index, HashSet* hash_set,
uint8_t* data) {
static bool _put_array_to_hash_set(const ArrayColumn& column, size_t index, HashSet* hash_set) {
const auto* elements_column = column.elements_column().get();
const auto& offsets = column.offsets().get_data();
bool has_null = false;
uint32_t start = offsets[index];
uint32_t end = offsets[index + 1];

{
Datum v = columns[0]->get(index);
const auto& items = v.get<DatumArray>();
for (const auto& item : items) {
if (item.is_null()) {
has_null = true;
} else {
hash_set->emplace(item.get<CppType>());
if (elements_column->is_nullable()) {
const NullableColumn* nullable_column = down_cast<const NullableColumn*>(elements_column);
const auto& datas = GetContainer<LT>::get_data(nullable_column->data_column());
const auto& nulls = nullable_column->null_column()->get_data();

if (nullable_column->has_null()) {
for (size_t i = start; i < end; i++) {
if (nulls[i]) {
has_null = true;
} else {
hash_set->emplace(datas[i]);
}
}
} else {
for (size_t i = start; i < end; i++) {
hash_set->emplace(datas[i]);
}
}
} else {
const auto& datas = GetContainer<LT>::get_data(elements_column);
for (size_t i = start; i < end; i++) {
hash_set->emplace(datas[i]);
}
}

{
Datum v = columns[1]->get(index);
const auto& items = v.get<DatumArray>();
for (const auto& item : items) {
if (item.is_null()) {
if (has_null) {
data[index] = 1;
return;
}
} else {
auto iter = hash_set->find(item.get<CppType>());
if (iter != hash_set->end()) {
data[index] = 1;
return;
}
}
return has_null;
}

static bool _check_column_overlap_nullable(const HashSet& hash_set, const ArrayColumn& column, size_t index,
bool has_null) {
const auto* elements_column = column.elements_column().get();
const auto& offsets = column.offsets().get_data();
uint32_t start = offsets[index];
uint32_t end = offsets[index + 1];
bool overlap = false;

if (elements_column->is_nullable()) {
const NullableColumn* nullable_elements_column = down_cast<const NullableColumn*>(elements_column);
const auto& datas = GetContainer<LT>::get_data(nullable_elements_column->data_column());

if (nullable_elements_column->has_null()) {
const auto& nulls = nullable_elements_column->null_column()->get_data();

overlap = _check_overlap_nullable(hash_set, datas, nulls, start, end, has_null, index);
} else {
overlap = _check_overlap(hash_set, datas, start, end, index);
}
} else {
const auto& datas = GetContainer<LT>::get_data(elements_column);
overlap = _check_overlap(hash_set, datas, start, end, index);
}
return overlap;
}

data[index] = 0;
static bool _check_column_overlap(const HashSet& hash_set, const ArrayColumn& column, size_t index) {
const auto& datas = GetContainer<LT>::get_data(column.elements_column().get());
const auto& offsets = column.offsets().get_data();
uint32_t start = offsets[index];
uint32_t end = offsets[index + 1];

return _check_overlap(hash_set, datas, start, end, index);
}

static bool _check_overlap(const HashSet& hash_set, const DataArray& data, uint32_t start, uint32_t end,
size_t index) {
for (auto i = start; i < end; i++) {
if (hash_set.contains(data[i])) {
return true;
}
}
return false;
}

static bool _check_overlap_nullable(const HashSet& hash_set, const DataArray& data, const NullData& null_data,
uint32_t start, uint32_t end, bool has_null, size_t index) {
for (auto i = start; i < end; i++) {
if (null_data[i] == 1) {
if (has_null) {
return true;
}
} else {
if (hash_set.contains(data[i])) {
return true;
}
}
}
return false;
}
};

Expand Down
2 changes: 1 addition & 1 deletion be/src/exprs/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ bool FunctionContext::is_notnull_constant_column(int i) const {
return col && col->is_constant() && !col->is_null(0);
}

starrocks::ColumnPtr FunctionContext::get_constant_column(int i) const {
ColumnPtr FunctionContext::get_constant_column(int i) const {
if (i < 0 || i >= _constant_columns.size()) {
return nullptr;
}
Expand Down
Loading

0 comments on commit 8b79bbe

Please sign in to comment.