Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt](lambda) let lambda expression support refer outer slot #45186

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion be/src/vec/exprs/lambda_function/lambda_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include <fmt/core.h>
#include <runtime/runtime_state.h>

#include "common/status.h"
#include "vec/core/block.h"
Expand All @@ -31,9 +31,16 @@ class LambdaFunction {

virtual std::string get_name() const = 0;

virtual doris::Status prepare(RuntimeState* state) {
batch_size = state->batch_size();
return Status::OK();
}

virtual doris::Status execute(VExprContext* context, doris::vectorized::Block* block,
int* result_column_id, const DataTypePtr& result_type,
const VExprSPtrs& children) = 0;

int batch_size;
};

using LambdaFunctionPtr = std::shared_ptr<LambdaFunction>;
Expand Down
212 changes: 184 additions & 28 deletions be/src/vec/exprs/lambda_function/varray_map_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
// specific language governing permissions and limitations
// under the License.

#include <vec/data_types/data_type_number.h>
#include <vec/exprs/vcolumn_ref.h>
#include <vec/exprs/vslot_ref.h>

#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -47,6 +51,26 @@ class VExprContext;

namespace doris::vectorized {

// extend a block with all required parameters
struct LambdaArgs {
// the lambda function need the column ids of all the slots
std::vector<int> output_slot_ref_indexs;
// which line is extended to the original block
int64_t current_row_idx = 0;
// when a block is filled, the array may be truncated, recording where it was truncated
int64_t current_offset_in_array = 0;
// the beginning position of the array
size_t array_start = 0;
// the size of the array
int64_t cur_size = 0;
// offset of column array
const ColumnArray::Offsets64* offsets_ptr = nullptr;
// expend data of repeat times
int current_repeat_times = 0;
// whether the current row of the original block has been extended
bool current_row_eos = false;
};

class ArrayMapFunction : public LambdaFunction {
ENABLE_FACTORY_CREATOR(ArrayMapFunction);

Expand All @@ -62,8 +86,33 @@ class ArrayMapFunction : public LambdaFunction {
doris::Status execute(VExprContext* context, doris::vectorized::Block* block,
int* result_column_id, const DataTypePtr& result_type,
const VExprSPtrs& children) override {
///* array_map(lambda,arg1,arg2,.....) *///
LambdaArgs args;
// collect used slot ref in lambda function body
_collect_slot_ref_column_id(children[0], args);

int gap = 0;
if (!args.output_slot_ref_indexs.empty()) {
auto max_id = std::max_element(args.output_slot_ref_indexs.begin(),
args.output_slot_ref_indexs.end());
gap = *max_id + 1;
_set_column_ref_column_id(children[0], gap);
}

std::vector<std::string> names(gap);
DataTypes data_types(gap);

for (int i = 0; i < gap; ++i) {
if (_contains_column_id(args, i)) {
names[i] = block->get_by_position(i).name;
data_types[i] = block->get_by_position(i).type;
} else {
// padding some mock data
names[i] = "temp";
data_types[i] = std::make_shared<DataTypeUInt8>();
}
}

///* array_map(lambda,arg1,arg2,.....) *///
//1. child[1:end]->execute(src_block)
doris::vectorized::ColumnNumbers arguments(children.size() - 1);
for (int i = 1; i < children.size(); ++i) {
Expand All @@ -82,14 +131,13 @@ class ArrayMapFunction : public LambdaFunction {
MutableColumnPtr array_column_offset;
int nested_array_column_rows = 0;
ColumnPtr first_array_offsets = nullptr;

//2. get the result column from executed expr, and the needed is nested column of array
Block lambda_block;
std::vector<ColumnPtr> lambda_datas(arguments.size());

for (int i = 0; i < arguments.size(); ++i) {
const auto& array_column_type_name = block->get_by_position(arguments[i]);
auto column_array = array_column_type_name.column->convert_to_full_column_if_const();
auto type_array = array_column_type_name.type;

if (type_array->is_nullable()) {
// get the nullmap of nullable column
const auto& column_array_nullmap =
Expand Down Expand Up @@ -118,6 +166,7 @@ class ArrayMapFunction : public LambdaFunction {
auto& off_data = assert_cast<const ColumnArray::ColumnOffsets&>(
col_array.get_offsets_column());
array_column_offset = off_data.clone_resized(col_array.get_offsets_column().size());
args.offsets_ptr = &col_array.get_offsets();
} else {
// select array_map((x,y)->x+y,c_array1,[0,1,2,3]) from array_test2;
// c_array1: [0,1,2,3,4,5,6,7,8,9]
Expand All @@ -136,57 +185,164 @@ class ArrayMapFunction : public LambdaFunction {
nested_array_column_rows, i + 1, col_array.get_data_ptr()->size());
}
}

// insert the data column to the new block
ColumnWithTypeAndName data_column {col_array.get_data_ptr(), col_type.get_nested_type(),
"R" + array_column_type_name.name};
lambda_block.insert(std::move(data_column));
lambda_datas[i] = col_array.get_data_ptr();
names.push_back("R" + array_column_type_name.name);
data_types.push_back(col_type.get_nested_type());
}

//3. child[0]->execute(new_block)
RETURN_IF_ERROR(children[0]->execute(context, &lambda_block, result_column_id));
ColumnPtr result_col = nullptr;
DataTypePtr res_type;
std::string res_name;

//process first row
args.array_start = (*args.offsets_ptr)[args.current_row_idx - 1];
args.cur_size = (*args.offsets_ptr)[args.current_row_idx] - args.array_start;

while (args.current_row_idx < block->rows()) {
Block lambda_block;
for (int i = 0; i < names.size(); i++) {
ColumnWithTypeAndName data_column;
if (_contains_column_id(args, i) || i >= gap) {
data_column = ColumnWithTypeAndName(data_types[i], names[i]);
} else {
data_column = ColumnWithTypeAndName(
data_types[i]->create_column_const_with_default_value(0), data_types[i],
names[i]);
}
lambda_block.insert(std::move(data_column));
}

MutableColumns columns = lambda_block.mutate_columns();
while (columns[gap]->size() < batch_size) {
long max_step = batch_size - columns[gap]->size();
long current_step =
std::min(max_step, (long)(args.cur_size - args.current_offset_in_array));
size_t pos = args.array_start + args.current_offset_in_array;
for (int i = 0; i < arguments.size(); ++i) {
columns[gap + i]->insert_range_from(*lambda_datas[i], pos, current_step);
}
args.current_offset_in_array += current_step;
args.current_repeat_times += current_step;
if (args.current_offset_in_array >= args.cur_size) {
args.current_row_eos = true;
}
_extend_data(columns, block, args, gap);
if (args.current_row_eos) {
args.current_row_idx++;
args.current_offset_in_array = 0;
if (args.current_row_idx >= block->rows()) {
break;
}
args.current_row_eos = false;
args.array_start = (*args.offsets_ptr)[args.current_row_idx - 1];
args.cur_size = (*args.offsets_ptr)[args.current_row_idx] - args.array_start;
}
}

lambda_block.set_columns(std::move(columns));

//3. child[0]->execute(new_block)
RETURN_IF_ERROR(children[0]->execute(context, &lambda_block, result_column_id));

auto res_col = lambda_block.get_by_position(*result_column_id)
.column->convert_to_full_column_if_const();
auto res_type = lambda_block.get_by_position(*result_column_id).type;
auto res_name = lambda_block.get_by_position(*result_column_id).name;
auto res_col = lambda_block.get_by_position(*result_column_id)
.column->convert_to_full_column_if_const();
res_type = lambda_block.get_by_position(*result_column_id).type;
res_name = lambda_block.get_by_position(*result_column_id).name;
if (!result_col) {
result_col = std::move(res_col);
} else {
MutableColumnPtr column = (*std::move(result_col)).mutate();
column->insert_range_from(*res_col, 0, res_col->size());
}
}

//4. get the result column after execution, reassemble it into a new array column, and return.
ColumnWithTypeAndName result_arr;
if (result_type->is_nullable()) {
if (res_type->is_nullable()) {
result_arr = {ColumnNullable::create(
ColumnArray::create(res_col, std::move(array_column_offset)),
std::move(outside_null_map)),
result_type, res_name};
result_arr = {
ColumnNullable::create(
ColumnArray::create(result_col, std::move(array_column_offset)),
std::move(outside_null_map)),
result_type, res_name};
} else {
// deal with eg: select array_map(x -> x is null, [null, 1, 2]);
// need to create the nested column null map for column array
auto nested_null_map = ColumnUInt8::create(res_col->size(), 0);
auto nested_null_map = ColumnUInt8::create(result_col->size(), 0);
result_arr = {
ColumnNullable::create(
ColumnArray::create(
ColumnNullable::create(res_col, std::move(nested_null_map)),
std::move(array_column_offset)),
ColumnArray::create(ColumnNullable::create(
result_col, std::move(nested_null_map)),
std::move(array_column_offset)),
std::move(outside_null_map)),
result_type, res_name};
}
} else {
if (res_type->is_nullable()) {
result_arr = {ColumnArray::create(res_col, std::move(array_column_offset)),
result_arr = {ColumnArray::create(result_col, std::move(array_column_offset)),
result_type, res_name};
} else {
auto nested_null_map = ColumnUInt8::create(res_col->size(), 0);
result_arr = {ColumnArray::create(
ColumnNullable::create(res_col, std::move(nested_null_map)),
std::move(array_column_offset)),
auto nested_null_map = ColumnUInt8::create(result_col->size(), 0);
result_arr = {ColumnArray::create(ColumnNullable::create(
result_col, std::move(nested_null_map)),
std::move(array_column_offset)),
result_type, res_name};
}
}
block->insert(std::move(result_arr));
*result_column_id = block->columns() - 1;

return Status::OK();
}

private:
bool _contains_column_id(LambdaArgs& args, int id) {
const auto it = std::find(args.output_slot_ref_indexs.begin(),
args.output_slot_ref_indexs.end(), id);
return it != args.output_slot_ref_indexs.end();
}

void _set_column_ref_column_id(VExprSPtr expr, int gap) {
for (const auto& child : expr->children()) {
if (child->is_column_ref()) {
auto* ref = static_cast<VColumnRef*>(child.get());
ref->set_gap(gap);
} else {
_set_column_ref_column_id(child, gap);
}
}
}

void _collect_slot_ref_column_id(VExprSPtr expr, LambdaArgs& args) {
for (const auto& child : expr->children()) {
if (child->is_slot_ref()) {
const auto* ref = static_cast<VSlotRef*>(child.get());
args.output_slot_ref_indexs.push_back(ref->column_id());
} else {
_collect_slot_ref_column_id(child, args);
}
}
}

void _extend_data(std::vector<MutableColumnPtr>& columns, Block* block, LambdaArgs& args,
int size) {
if (!args.current_repeat_times || !size) {
return;
}
for (int i = 0; i < size; i++) {
if (_contains_column_id(args, i)) {
auto src_column =
block->get_by_position(i).column->convert_to_full_column_if_const();
columns[i]->insert_many_from(*src_column, args.current_row_idx,
args.current_repeat_times);
} else {
// must be column const
DCHECK(is_column_const(*columns[i]));
columns[i]->resize(columns[i]->size() + args.current_repeat_times);
}
}
args.current_repeat_times = 0;
}
};

void register_function_array_map(doris::vectorized::LambdaFunctionFactory& factory) {
Expand Down
11 changes: 10 additions & 1 deletion be/src/vec/exprs/vcolumn_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// under the License.

#pragma once
#include <atomic>

#include "runtime/descriptors.h"
#include "runtime/runtime_state.h"
#include "vec/exprs/vexpr.h"
Expand Down Expand Up @@ -57,7 +59,7 @@ class VColumnRef final : public VExpr {

Status execute(VExprContext* context, Block* block, int* result_column_id) override {
DCHECK(_open_finished || _getting_const_col);
*result_column_id = _column_id;
*result_column_id = _column_id + _gap;
return Status::OK();
}

Expand All @@ -67,6 +69,12 @@ class VColumnRef final : public VExpr {

const std::string& expr_name() const override { return _column_name; }

void set_gap(int gap) {
if (_gap == 0) {
_gap = gap;
}
}

std::string debug_string() const override {
std::stringstream out;
out << "VColumnRef(slot_id: " << _column_id << ",column_name: " << _column_name
Expand All @@ -76,6 +84,7 @@ class VColumnRef final : public VExpr {

private:
int _column_id;
std::atomic<int> _gap = 0;
std::string _column_name;
};
} // namespace vectorized
Expand Down
3 changes: 3 additions & 0 deletions be/src/vec/exprs/vexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class VExpr {
TypeDescriptor type() { return _type; }

bool is_slot_ref() const { return _node_type == TExprNodeType::SLOT_REF; }

bool is_column_ref() const { return _node_type == TExprNodeType::COLUMN_REF; }

virtual bool is_literal() const { return false; }

TExprNodeType::type node_type() const { return _node_type; }
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/exprs/vlambda_function_call_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class VLambdaFunctionCallExpr : public VExpr {
return Status::InternalError("Lambda Function {} is not implemented.",
_fn.name.function_name);
}
RETURN_IF_ERROR(_lambda_function->prepare(state));
_prepare_finished = true;
return Status::OK();
}
Expand Down
Loading
Loading