Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix Java UDAF OOM in spill or sorted streaming agg #48618

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,13 @@ Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
} else if (!_is_only_group_by_columns) {
_release_agg_memory();
}

for (int i = 0; i < _agg_functions.size(); i++) {
if (_agg_fn_ctxs[i]) {
_agg_fn_ctxs[i]->release_mems();
}
}

_mem_pool->free_all();
_agg_state_mem_usage = 0;

Expand All @@ -566,6 +573,7 @@ Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
} else {
TRY_CATCH_BAD_ALLOC(_init_agg_hash_variant(_hash_map_variant));
}

// _state_allocator holds the entries of the hash_map/hash_set, when iterating a hash_map/set, the _state_allocator
// is used to access these entries, so we must reset the _state_allocator along with the hash_map/hash_set.
_state_allocator.reset();
Expand Down Expand Up @@ -623,6 +631,12 @@ void Aggregator::close(RuntimeState* state) {
_mem_pool->free_all();
}

for (int i = 0; i < _agg_functions.size(); i++) {
if (_agg_fn_ctxs[i]) {
_agg_fn_ctxs[i]->release_mems();
}
}

if (_is_only_group_by_columns) {
_hash_set_variant.reset();
} else {
Expand Down
5 changes: 4 additions & 1 deletion be/src/exprs/agg/java_udaf_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ Status init_udaf_context(int64_t id, const std::string& url, const std::string&
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
44 changes: 30 additions & 14 deletions be/src/exprs/agg/java_udaf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <cstring>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
Expand Down Expand Up @@ -255,8 +256,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {

// This is only used to get portion of the entire binary column
template <class StatesProvider, class MergeCaller>
void _merge_batch_process(StatesProvider&& states_provider, MergeCaller&& caller, const Column* column,
size_t start, size_t size) const {
void _merge_batch_process(FunctionContext* ctx, StatesProvider&& states_provider, MergeCaller&& caller,
const Column* column, size_t start, size_t size, bool need_multi_buffer) const {
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();
// get state lists
Expand All @@ -269,17 +270,32 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
ColumnHelper::get_binary_column(const_cast<Column*>(ColumnHelper::get_data_column(column)));

auto& serialized_bytes = serialized_column->get_bytes();
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
const auto& offsets = serialized_column->get_offset();

auto buffer =
std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset, end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
if (serialized_bytes.size() > std::numeric_limits<int>::max()) {
ctx->set_error("serialized column size is too large");
return;
}

// batch call merge
caller(state_array, buffer_array);
if (!need_multi_buffer) {
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
// create one buffer will be ok
auto buffer = std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset,
end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
} else {
auto buffer_array =
helper.batch_create_bytebuf(serialized_bytes.data(), offsets.data(), start, start + size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
}
}

void merge_batch(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -300,7 +316,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, false);
}

void merge_batch_selectively(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -318,7 +334,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_if_not_null(ctx, ctx->udaf_ctxs()->handle.handle(),
ctx->udaf_ctxs()->merge->method.handle(), state_array, state_and_buffer, 1);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, true);
}

void merge_batch_single_state(FunctionContext* ctx, AggDataPtr __restrict state, const Column* column, size_t start,
Expand All @@ -336,7 +352,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, start, size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, start, size, false);
}

void batch_serialize(FunctionContext* ctx, size_t batch_size, const Buffer<AggDataPtr>& agg_states,
Expand Down
6 changes: 4 additions & 2 deletions be/src/exprs/agg/java_window_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ Status window_init_jvm_context(int64_t fid, const std::string& url, const std::s
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));

udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
8 changes: 8 additions & 0 deletions be/src/exprs/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "runtime/runtime_state.h"
#include "storage/rowset/bloom_filter.h"
#include "types/logical_type_infra.h"
#include "udf/java/java_udf.h"

namespace starrocks {

Expand Down Expand Up @@ -131,6 +132,13 @@ void* FunctionContext::get_function_state(FunctionStateScope scope) const {
}
}

void FunctionContext::release_mems() {
if (_jvm_udaf_ctxs != nullptr && _jvm_udaf_ctxs->states) {
auto env = JVMFunctionHelper::getInstance().getEnv();
_jvm_udaf_ctxs->states->clear(this, env);
}
}

void FunctionContext::set_error(const char* error_msg, const bool is_udf) {
std::lock_guard<std::mutex> lock(_error_msg_mutex);
if (_error_msg.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class FunctionContext {

JavaUDAFContext* udaf_ctxs() { return _jvm_udaf_ctxs.get(); }

void release_mems();

ssize_t get_group_concat_max_len() { return group_concat_max_len; }
// min value is 4, default is 1024
void set_group_concat_max_len(ssize_t len) { group_concat_max_len = len < 4 ? 4 : len; }
Expand Down
33 changes: 31 additions & 2 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "column/binary_column.h"
#include "column/column.h"
#include "common/status.h"
#include "exprs/function_context.h"
#include "fmt/core.h"
#include "jni.h"
#include "types/logical_type.h"
Expand Down Expand Up @@ -149,6 +150,11 @@ void JVMFunctionHelper::_init() {
_udf_helper_class, "batchUpdateIfNotNull",
"(Ljava/lang/Object;Ljava/lang/reflect/Method;Lcom/starrocks/udf/FunctionStates;[I[Ljava/lang/Object;)V");

_batch_create_bytebuf =
_env->GetStaticMethodID(_udf_helper_class, "batchCreateDirectBuffer", "(J[II)[Ljava/lang/Object;");

CHECK(_batch_create_bytebuf) << " not found method batchCreateDirectBuffer plz check jni-packages";

_int_batch_call = _env->GetStaticMethodID(_udf_helper_class, "batchCall",
"([Ljava/lang/Object;Ljava/lang/reflect/Method;I)[I");
_get_boxed_result =
Expand Down Expand Up @@ -283,6 +289,14 @@ jobject JVMFunctionHelper::create_object_array(jobject o, int num_rows) {
return res_arr;
}

jobject JVMFunctionHelper::batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end) {
int size = end - begin;
auto offsets = _env->NewIntArray(size + 1);
_env->SetIntArrayRegion(offsets, 0, size + 1, (const int32_t*)offset);
LOCAL_REF_GUARD(offsets);
return _env->CallStaticObjectMethod(_udf_helper_class, _batch_create_bytebuf, ptr, offsets, size);
}

void JVMFunctionHelper::batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows) {
auto obj = convert_handle_to_jobject(stub->ctx(), state);
LOCAL_REF_GUARD(obj);
Expand Down Expand Up @@ -476,15 +490,19 @@ StatusOr<JavaGlobalRef> JVMClass::newInstance() const {
}

UDAFStateList::UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get,
JavaGlobalRef&& add)
JavaGlobalRef&& add, JavaGlobalRef&& remove, JavaGlobalRef&& clear)
: _handle(std::move(handle)),
_get_method(std::move(get)),
_batch_get_method(std::move(batch_get)),
_add_method(std::move(add)) {
_add_method(std::move(add)),
_remove_method(std::move(remove)),
_clear_method(std::move(clear)) {
auto* env = JVMFunctionHelper::getInstance().getEnv();
_get_method_id = env->FromReflectedMethod(_get_method.handle());
_batch_get_method_id = env->FromReflectedMethod(_batch_get_method.handle());
_add_method_id = env->FromReflectedMethod(_add_method.handle());
_remove_method_id = env->FromReflectedMethod(_remove_method.handle());
_clear_method_id = env->FromReflectedMethod(_clear_method.handle());
}

jobject UDAFStateList::get_state(FunctionContext* ctx, JNIEnv* env, int state_handle) {
Expand All @@ -505,6 +523,16 @@ int UDAFStateList::add_state(FunctionContext* ctx, JNIEnv* env, jobject state) {
return res;
}

void UDAFStateList::remove(FunctionContext* ctx, JNIEnv* env, int state_handle) {
env->CallVoidMethod(_handle.handle(), _remove_method_id, state_handle);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

void UDAFStateList::clear(FunctionContext* ctx, JNIEnv* env) {
env->CallVoidMethod(_handle.handle(), _clear_method_id);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

ClassLoader::~ClassLoader() {
_handle.clear();
_clazz.clear();
Expand Down Expand Up @@ -804,6 +832,7 @@ void UDAFFunction::destroy(int state) {
// call destroy
env->CallVoidMethod(_udaf_handle, destory, obj);
CHECK_UDF_CALL_EXCEPTION(env, _function_context);
_ctx->states->remove(_function_context, env, state);
}

jvalue UDAFFunction::finalize(int state) {
Expand Down
16 changes: 15 additions & 1 deletion be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class JVMFunctionHelper {
jobject create_boxed_array(int type, int num_rows, bool nullable, DirectByteBuffer* buffs, int sz);
// create object array with the same elements
jobject create_object_array(jobject o, int num_rows);
jobject batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end);

// batch update single
void batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows);
Expand Down Expand Up @@ -177,6 +178,7 @@ class JVMFunctionHelper {
jmethodID _batch_update;
jmethodID _batch_update_if_not_null;
jmethodID _batch_update_state;
jmethodID _batch_create_bytebuf;
jmethodID _batch_call;
jmethodID _batch_call_no_args;
jmethodID _int_batch_call;
Expand Down Expand Up @@ -360,10 +362,12 @@ class BatchEvaluateStub {
// UDAF State Lists
// mapping a java object as a int index
// use get method to
// TODO: implement a Java binder to avoid using this class
class UDAFStateList {
public:
static inline const char* clazz_name = "com.starrocks.udf.FunctionStates";
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add);
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add,
JavaGlobalRef&& remove, JavaGlobalRef&& clear);

jobject handle() { return _handle.handle(); }

Expand All @@ -376,14 +380,24 @@ class UDAFStateList {
// add a state to StateList
int add_state(FunctionContext* ctx, JNIEnv* env, jobject state);

// remove a state from StateList
void remove(FunctionContext* ctx, JNIEnv* env, int state);

// clear all state in StateList
void clear(FunctionContext* ctx, JNIEnv* env);

private:
JavaGlobalRef _handle;
JavaGlobalRef _get_method;
JavaGlobalRef _batch_get_method;
JavaGlobalRef _add_method;
JavaGlobalRef _remove_method;
JavaGlobalRef _clear_method;
jmethodID _get_method_id;
jmethodID _batch_get_method_id;
jmethodID _add_method_id;
jmethodID _remove_method_id;
jmethodID _clear_method_id;
};

// For loading UDF Class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import java.util.List;

public class FunctionStates<T> {
private final int PROBE_MAX_TIMES = 16;
private int probeIndex = 0;
private int emptySlots = 0;
public List<T> states = new ArrayList<>();

public T get(int idx) {
Expand All @@ -26,18 +29,53 @@ public T get(int idx) {

Object[] batch_get(int[] idxs) {
Object[] res = new Object[idxs.length];
for(int i = 0;i < idxs.length; ++i) {
for (int i = 0; i < idxs.length; ++i) {
if (idxs[i] != -1) {
res[i] = states.get(idxs[i]);
}
}
return res;
}


public int add(T state) throws Exception {
// probe empty slot from
if (emptySlots > states.size() / 2) {
int probeTimes = 0;
do {
if (states.get(probeIndex) == null) {
states.set(probeIndex, state);
emptySlots--;
final int ret = probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
return ret;
}
probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
probeTimes++;
} while (probeTimes <= PROBE_MAX_TIMES);
}
states.add(state);
return states.size() - 1;
}

public void remove(int idx) {
emptySlots++;
states.set(idx, null);
}

public void clear() {
states.clear();
emptySlots = 0;
probeIndex = 0;
}

// used for test
public int size() {
return states.size();
}

}
Loading
Loading