Skip to content

Commit

Permalink
[BugFix] Fix Java UDAF OOM in spill or sorted streaming agg (#48618)
Browse files Browse the repository at this point in the history
Signed-off-by: stdpain <[email protected]>
(cherry picked from commit 36b3aae)
  • Loading branch information
stdpain authored and mergify[bot] committed Jul 22, 2024
1 parent 5f9a143 commit 882ceee
Show file tree
Hide file tree
Showing 13 changed files with 292 additions and 22 deletions.
14 changes: 14 additions & 0 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,13 @@ Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
} else if (!_is_only_group_by_columns) {
_release_agg_memory();
}

for (int i = 0; i < _agg_functions.size(); i++) {
if (_agg_fn_ctxs[i]) {
_agg_fn_ctxs[i]->release_mems();
}
}

_mem_pool->free_all();

if (_group_by_expr_ctxs.empty()) {
Expand All @@ -546,6 +553,7 @@ Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
} else {
TRY_CATCH_BAD_ALLOC(_init_agg_hash_variant(_hash_map_variant));
}

// _state_allocator holds the entries of the hash_map/hash_set, when iterating a hash_map/set, the _state_allocator
// is used to access these entries, so we must reset the _state_allocator along with the hash_map/hash_set.
_state_allocator.reset();
Expand Down Expand Up @@ -604,6 +612,12 @@ void Aggregator::close(RuntimeState* state) {
_mem_pool->free_all();
}

for (int i = 0; i < _agg_functions.size(); i++) {
if (_agg_fn_ctxs[i]) {
_agg_fn_ctxs[i]->release_mems();
}
}

if (_is_only_group_by_columns) {
_hash_set_variant.reset();
} else {
Expand Down
5 changes: 4 additions & 1 deletion be/src/exprs/agg/java_udaf_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ Status init_udaf_context(int64_t id, const std::string& url, const std::string&
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
44 changes: 30 additions & 14 deletions be/src/exprs/agg/java_udaf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <cstring>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
Expand Down Expand Up @@ -255,8 +256,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {

// This is only used to get portion of the entire binary column
template <class StatesProvider, class MergeCaller>
void _merge_batch_process(StatesProvider&& states_provider, MergeCaller&& caller, const Column* column,
size_t start, size_t size) const {
void _merge_batch_process(FunctionContext* ctx, StatesProvider&& states_provider, MergeCaller&& caller,
const Column* column, size_t start, size_t size, bool need_multi_buffer) const {
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();
// get state lists
Expand All @@ -269,17 +270,32 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
ColumnHelper::get_binary_column(const_cast<Column*>(ColumnHelper::get_data_column(column)));

auto& serialized_bytes = serialized_column->get_bytes();
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
const auto& offsets = serialized_column->get_offset();

auto buffer =
std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset, end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
if (serialized_bytes.size() > std::numeric_limits<int>::max()) {
ctx->set_error("serialized column size is too large");
return;
}

// batch call merge
caller(state_array, buffer_array);
if (!need_multi_buffer) {
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
// create one buffer will be ok
auto buffer = std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset,
end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
} else {
auto buffer_array =
helper.batch_create_bytebuf(serialized_bytes.data(), offsets.data(), start, start + size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
}
}

void merge_batch(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -300,7 +316,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, false);
}

void merge_batch_selectively(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -318,7 +334,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_if_not_null(ctx, ctx->udaf_ctxs()->handle.handle(),
ctx->udaf_ctxs()->merge->method.handle(), state_array, state_and_buffer, 1);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, true);
}

void merge_batch_single_state(FunctionContext* ctx, AggDataPtr __restrict state, const Column* column, size_t start,
Expand All @@ -336,7 +352,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, start, size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, start, size, false);
}

void batch_serialize(FunctionContext* ctx, size_t batch_size, const Buffer<AggDataPtr>& agg_states,
Expand Down
6 changes: 4 additions & 2 deletions be/src/exprs/agg/java_window_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ Status window_init_jvm_context(int64_t fid, const std::string& url, const std::s
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));

udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
8 changes: 8 additions & 0 deletions be/src/exprs/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "exprs/agg/java_udaf_function.h"
#include "runtime/runtime_state.h"
#include "types/logical_type_infra.h"
#include "udf/java/java_udf.h"

namespace starrocks {

Expand Down Expand Up @@ -132,6 +133,13 @@ void* FunctionContext::get_function_state(FunctionStateScope scope) const {
}
}

void FunctionContext::release_mems() {
if (_jvm_udaf_ctxs != nullptr && _jvm_udaf_ctxs->states) {
auto env = JVMFunctionHelper::getInstance().getEnv();
_jvm_udaf_ctxs->states->clear(this, env);
}
}

void FunctionContext::set_error(const char* error_msg, const bool is_udf) {
std::lock_guard<std::mutex> lock(_error_msg_mutex);
if (_error_msg.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class FunctionContext {

JavaUDAFContext* udaf_ctxs() { return _jvm_udaf_ctxs.get(); }

void release_mems();

ssize_t get_group_concat_max_len() { return group_concat_max_len; }
// min value is 4, default is 1024
void set_group_concat_max_len(ssize_t len) { group_concat_max_len = len < 4 ? 4 : len; }
Expand Down
33 changes: 31 additions & 2 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "column/binary_column.h"
#include "column/column.h"
#include "common/status.h"
#include "exprs/function_context.h"
#include "fmt/core.h"
#include "jni.h"
#include "types/logical_type.h"
Expand Down Expand Up @@ -149,6 +150,11 @@ void JVMFunctionHelper::_init() {
_udf_helper_class, "batchUpdateIfNotNull",
"(Ljava/lang/Object;Ljava/lang/reflect/Method;Lcom/starrocks/udf/FunctionStates;[I[Ljava/lang/Object;)V");

_batch_create_bytebuf =
_env->GetStaticMethodID(_udf_helper_class, "batchCreateDirectBuffer", "(J[II)[Ljava/lang/Object;");

CHECK(_batch_create_bytebuf) << " not found method batchCreateDirectBuffer plz check jni-packages";

_int_batch_call = _env->GetStaticMethodID(_udf_helper_class, "batchCall",
"([Ljava/lang/Object;Ljava/lang/reflect/Method;I)[I");
_get_boxed_result =
Expand Down Expand Up @@ -290,6 +296,14 @@ jobject JVMFunctionHelper::create_object_array(jobject o, int num_rows) {
return res_arr;
}

jobject JVMFunctionHelper::batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end) {
int size = end - begin;
auto offsets = _env->NewIntArray(size + 1);
_env->SetIntArrayRegion(offsets, 0, size + 1, (const int32_t*)offset);
LOCAL_REF_GUARD(offsets);
return _env->CallStaticObjectMethod(_udf_helper_class, _batch_create_bytebuf, ptr, offsets, size);
}

void JVMFunctionHelper::batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows) {
auto obj = convert_handle_to_jobject(stub->ctx(), state);
LOCAL_REF_GUARD(obj);
Expand Down Expand Up @@ -483,15 +497,19 @@ StatusOr<JavaGlobalRef> JVMClass::newInstance() const {
}

UDAFStateList::UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get,
JavaGlobalRef&& add)
JavaGlobalRef&& add, JavaGlobalRef&& remove, JavaGlobalRef&& clear)
: _handle(std::move(handle)),
_get_method(std::move(get)),
_batch_get_method(std::move(batch_get)),
_add_method(std::move(add)) {
_add_method(std::move(add)),
_remove_method(std::move(remove)),
_clear_method(std::move(clear)) {
auto* env = JVMFunctionHelper::getInstance().getEnv();
_get_method_id = env->FromReflectedMethod(_get_method.handle());
_batch_get_method_id = env->FromReflectedMethod(_batch_get_method.handle());
_add_method_id = env->FromReflectedMethod(_add_method.handle());
_remove_method_id = env->FromReflectedMethod(_remove_method.handle());
_clear_method_id = env->FromReflectedMethod(_clear_method.handle());
}

jobject UDAFStateList::get_state(FunctionContext* ctx, JNIEnv* env, int state_handle) {
Expand All @@ -512,6 +530,16 @@ int UDAFStateList::add_state(FunctionContext* ctx, JNIEnv* env, jobject state) {
return res;
}

void UDAFStateList::remove(FunctionContext* ctx, JNIEnv* env, int state_handle) {
env->CallVoidMethod(_handle.handle(), _remove_method_id, state_handle);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

void UDAFStateList::clear(FunctionContext* ctx, JNIEnv* env) {
env->CallVoidMethod(_handle.handle(), _clear_method_id);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

ClassLoader::~ClassLoader() {
_handle.clear();
_clazz.clear();
Expand Down Expand Up @@ -811,6 +839,7 @@ void UDAFFunction::destroy(int state) {
// call destroy
env->CallVoidMethod(_udaf_handle, destory, obj);
CHECK_UDF_CALL_EXCEPTION(env, _function_context);
_ctx->states->remove(_function_context, env, state);
}

jvalue UDAFFunction::finalize(int state) {
Expand Down
16 changes: 15 additions & 1 deletion be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class JVMFunctionHelper {
jobject create_boxed_array(int type, int num_rows, bool nullable, DirectByteBuffer* buffs, int sz);
// create object array with the same elements
jobject create_object_array(jobject o, int num_rows);
jobject batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end);

// batch update single
void batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows);
Expand Down Expand Up @@ -176,6 +177,7 @@ class JVMFunctionHelper {
jmethodID _batch_update;
jmethodID _batch_update_if_not_null;
jmethodID _batch_update_state;
jmethodID _batch_create_bytebuf;
jmethodID _batch_call;
jmethodID _batch_call_no_args;
jmethodID _int_batch_call;
Expand Down Expand Up @@ -359,10 +361,12 @@ class BatchEvaluateStub {
// UDAF State Lists
// mapping a java object as a int index
// use get method to
// TODO: implement a Java binder to avoid using this class
class UDAFStateList {
public:
static inline const char* clazz_name = "com.starrocks.udf.FunctionStates";
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add);
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add,
JavaGlobalRef&& remove, JavaGlobalRef&& clear);

jobject handle() { return _handle.handle(); }

Expand All @@ -375,14 +379,24 @@ class UDAFStateList {
// add a state to StateList
int add_state(FunctionContext* ctx, JNIEnv* env, jobject state);

// remove a state from StateList
void remove(FunctionContext* ctx, JNIEnv* env, int state);

// clear all state in StateList
void clear(FunctionContext* ctx, JNIEnv* env);

private:
JavaGlobalRef _handle;
JavaGlobalRef _get_method;
JavaGlobalRef _batch_get_method;
JavaGlobalRef _add_method;
JavaGlobalRef _remove_method;
JavaGlobalRef _clear_method;
jmethodID _get_method_id;
jmethodID _batch_get_method_id;
jmethodID _add_method_id;
jmethodID _remove_method_id;
jmethodID _clear_method_id;
};

// For loading UDF Class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import java.util.List;

public class FunctionStates<T> {
private final int PROBE_MAX_TIMES = 16;
private int probeIndex = 0;
private int emptySlots = 0;
public List<T> states = new ArrayList<>();

public T get(int idx) {
Expand All @@ -26,18 +29,53 @@ public T get(int idx) {

Object[] batch_get(int[] idxs) {
Object[] res = new Object[idxs.length];
for(int i = 0;i < idxs.length; ++i) {
for (int i = 0; i < idxs.length; ++i) {
if (idxs[i] != -1) {
res[i] = states.get(idxs[i]);
}
}
return res;
}


public int add(T state) throws Exception {
// probe empty slot from
if (emptySlots > states.size() / 2) {
int probeTimes = 0;
do {
if (states.get(probeIndex) == null) {
states.set(probeIndex, state);
emptySlots--;
final int ret = probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
return ret;
}
probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
probeTimes++;
} while (probeTimes <= PROBE_MAX_TIMES);
}
states.add(state);
return states.size() - 1;
}

public void remove(int idx) {
emptySlots++;
states.set(idx, null);
}

public void clear() {
states.clear();
emptySlots = 0;
probeIndex = 0;
}

// used for test
public int size() {
return states.size();
}

}
Loading

0 comments on commit 882ceee

Please sign in to comment.