Skip to content

Commit

Permalink
[BugFix] Fix Java UDAF OOM when spill
Browse files Browse the repository at this point in the history
In past implementations, Function States would not be released.

For sorted streaming agg or spill aggregates, they will not be released even if destroy is called.
This PR adds a remove state interface to Function States.

After calling destroy, it will call function states-> remove.

Signed-off-by: stdpain <[email protected]>
  • Loading branch information
stdpain committed Jul 20, 2024
1 parent 6132193 commit be7d803
Show file tree
Hide file tree
Showing 13 changed files with 292 additions and 22 deletions.
14 changes: 14 additions & 0 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,13 @@ Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
} else if (!_is_only_group_by_columns) {
_release_agg_memory();
}

for (int i = 0; i < _agg_functions.size(); i++) {
if (_agg_fn_ctxs[i]) {
_agg_fn_ctxs[i]->release_mems();
}
}

_mem_pool->free_all();
_agg_state_mem_usage = 0;

Expand All @@ -566,6 +573,7 @@ Status Aggregator::_reset_state(RuntimeState* state, bool reset_sink_complete) {
} else {
TRY_CATCH_BAD_ALLOC(_init_agg_hash_variant(_hash_map_variant));
}

// _state_allocator holds the entries of the hash_map/hash_set, when iterating a hash_map/set, the _state_allocator
// is used to access these entries, so we must reset the _state_allocator along with the hash_map/hash_set.
_state_allocator.reset();
Expand Down Expand Up @@ -623,6 +631,12 @@ void Aggregator::close(RuntimeState* state) {
_mem_pool->free_all();
}

for (int i = 0; i < _agg_functions.size(); i++) {
if (_agg_fn_ctxs[i]) {
_agg_fn_ctxs[i]->release_mems();
}
}

if (_is_only_group_by_columns) {
_hash_set_variant.reset();
} else {
Expand Down
5 changes: 4 additions & 1 deletion be/src/exprs/agg/java_udaf_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ Status init_udaf_context(int64_t id, const std::string& url, const std::string&
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
44 changes: 30 additions & 14 deletions be/src/exprs/agg/java_udaf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <cstring>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
Expand Down Expand Up @@ -255,8 +256,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {

// This is only used to get portion of the entire binary column
template <class StatesProvider, class MergeCaller>
void _merge_batch_process(StatesProvider&& states_provider, MergeCaller&& caller, const Column* column,
size_t start, size_t size) const {
void _merge_batch_process(FunctionContext* ctx, StatesProvider&& states_provider, MergeCaller&& caller,
const Column* column, size_t start, size_t size, bool need_multi_buffer) const {
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();
// get state lists
Expand All @@ -269,17 +270,32 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
ColumnHelper::get_binary_column(const_cast<Column*>(ColumnHelper::get_data_column(column)));

auto& serialized_bytes = serialized_column->get_bytes();
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
const auto& offsets = serialized_column->get_offset();

auto buffer =
std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset, end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
if (serialized_bytes.size() > std::numeric_limits<int>::max()) {
ctx->set_error("serialized column size is too large");
return;
}

// batch call merge
caller(state_array, buffer_array);
if (!need_multi_buffer) {
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
// create one buffer will be ok
auto buffer = std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset,
end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
} else {
auto buffer_array =
helper.batch_create_bytebuf(serialized_bytes.data(), offsets.data(), start, start + size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
}
}

void merge_batch(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -300,7 +316,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, false);
}

void merge_batch_selectively(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -318,7 +334,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_if_not_null(ctx, ctx->udaf_ctxs()->handle.handle(),
ctx->udaf_ctxs()->merge->method.handle(), state_array, state_and_buffer, 1);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, true);
}

void merge_batch_single_state(FunctionContext* ctx, AggDataPtr __restrict state, const Column* column, size_t start,
Expand All @@ -336,7 +352,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, start, size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, start, size, false);
}

void batch_serialize(FunctionContext* ctx, size_t batch_size, const Buffer<AggDataPtr>& agg_states,
Expand Down
6 changes: 4 additions & 2 deletions be/src/exprs/agg/java_window_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ Status window_init_jvm_context(int64_t fid, const std::string& url, const std::s
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));

udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
8 changes: 8 additions & 0 deletions be/src/exprs/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "runtime/runtime_state.h"
#include "storage/rowset/bloom_filter.h"
#include "types/logical_type_infra.h"
#include "udf/java/java_udf.h"

namespace starrocks {

Expand Down Expand Up @@ -131,6 +132,13 @@ void* FunctionContext::get_function_state(FunctionStateScope scope) const {
}
}

void FunctionContext::release_mems() {
if (_jvm_udaf_ctxs != nullptr && _jvm_udaf_ctxs->states) {
auto env = JVMFunctionHelper::getInstance().getEnv();
_jvm_udaf_ctxs->states->clear(this, env);
}
}

void FunctionContext::set_error(const char* error_msg, const bool is_udf) {
std::lock_guard<std::mutex> lock(_error_msg_mutex);
if (_error_msg.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class FunctionContext {

JavaUDAFContext* udaf_ctxs() { return _jvm_udaf_ctxs.get(); }

void release_mems();

ssize_t get_group_concat_max_len() { return group_concat_max_len; }
// min value is 4, default is 1024
void set_group_concat_max_len(ssize_t len) { group_concat_max_len = len < 4 ? 4 : len; }
Expand Down
33 changes: 31 additions & 2 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "column/binary_column.h"
#include "column/column.h"
#include "common/status.h"
#include "exprs/function_context.h"
#include "fmt/core.h"
#include "jni.h"
#include "types/logical_type.h"
Expand Down Expand Up @@ -149,6 +150,11 @@ void JVMFunctionHelper::_init() {
_udf_helper_class, "batchUpdateIfNotNull",
"(Ljava/lang/Object;Ljava/lang/reflect/Method;Lcom/starrocks/udf/FunctionStates;[I[Ljava/lang/Object;)V");

_batch_create_bytebuf =
_env->GetStaticMethodID(_udf_helper_class, "batchCreateDirectBuffer", "(J[II)[Ljava/lang/Object;");

CHECK(_batch_create_bytebuf) << " not found method batchCreateDirectBuffer plz check jni-packages";

_int_batch_call = _env->GetStaticMethodID(_udf_helper_class, "batchCall",
"([Ljava/lang/Object;Ljava/lang/reflect/Method;I)[I");
_get_boxed_result =
Expand Down Expand Up @@ -283,6 +289,14 @@ jobject JVMFunctionHelper::create_object_array(jobject o, int num_rows) {
return res_arr;
}

jobject JVMFunctionHelper::batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end) {
int size = end - begin;
auto offsets = _env->NewIntArray(size + 1);
_env->SetIntArrayRegion(offsets, 0, size + 1, (const int32_t*)offset);
LOCAL_REF_GUARD(offsets);
return _env->CallStaticObjectMethod(_udf_helper_class, _batch_create_bytebuf, ptr, offsets, size);
}

void JVMFunctionHelper::batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows) {
auto obj = convert_handle_to_jobject(stub->ctx(), state);
LOCAL_REF_GUARD(obj);
Expand Down Expand Up @@ -476,15 +490,19 @@ StatusOr<JavaGlobalRef> JVMClass::newInstance() const {
}

UDAFStateList::UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get,
JavaGlobalRef&& add)
JavaGlobalRef&& add, JavaGlobalRef&& remove, JavaGlobalRef&& clear)
: _handle(std::move(handle)),
_get_method(std::move(get)),
_batch_get_method(std::move(batch_get)),
_add_method(std::move(add)) {
_add_method(std::move(add)),
_remove_method(std::move(remove)),
_clear_method(std::move(clear)) {
auto* env = JVMFunctionHelper::getInstance().getEnv();
_get_method_id = env->FromReflectedMethod(_get_method.handle());
_batch_get_method_id = env->FromReflectedMethod(_batch_get_method.handle());
_add_method_id = env->FromReflectedMethod(_add_method.handle());
_remove_method_id = env->FromReflectedMethod(_remove_method.handle());
_clear_method_id = env->FromReflectedMethod(_clear_method.handle());
}

jobject UDAFStateList::get_state(FunctionContext* ctx, JNIEnv* env, int state_handle) {
Expand All @@ -505,6 +523,16 @@ int UDAFStateList::add_state(FunctionContext* ctx, JNIEnv* env, jobject state) {
return res;
}

void UDAFStateList::remove(FunctionContext* ctx, JNIEnv* env, int state_handle) {
env->CallVoidMethod(_handle.handle(), _remove_method_id, state_handle);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

void UDAFStateList::clear(FunctionContext* ctx, JNIEnv* env) {
env->CallVoidMethod(_handle.handle(), _clear_method_id);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

ClassLoader::~ClassLoader() {
_handle.clear();
_clazz.clear();
Expand Down Expand Up @@ -804,6 +832,7 @@ void UDAFFunction::destroy(int state) {
// call destroy
env->CallVoidMethod(_udaf_handle, destory, obj);
CHECK_UDF_CALL_EXCEPTION(env, _function_context);
_ctx->states->remove(_function_context, env, state);
}

jvalue UDAFFunction::finalize(int state) {
Expand Down
16 changes: 15 additions & 1 deletion be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class JVMFunctionHelper {
jobject create_boxed_array(int type, int num_rows, bool nullable, DirectByteBuffer* buffs, int sz);
// create object array with the same elements
jobject create_object_array(jobject o, int num_rows);
jobject batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end);

// batch update single
void batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows);
Expand Down Expand Up @@ -177,6 +178,7 @@ class JVMFunctionHelper {
jmethodID _batch_update;
jmethodID _batch_update_if_not_null;
jmethodID _batch_update_state;
jmethodID _batch_create_bytebuf;
jmethodID _batch_call;
jmethodID _batch_call_no_args;
jmethodID _int_batch_call;
Expand Down Expand Up @@ -360,10 +362,12 @@ class BatchEvaluateStub {
// UDAF State Lists
// mapping a java object as a int index
// use get method to
// TODO: implement a Java binder to avoid using this class
class UDAFStateList {
public:
static inline const char* clazz_name = "com.starrocks.udf.FunctionStates";
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add);
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add,
JavaGlobalRef&& remove, JavaGlobalRef&& clear);

jobject handle() { return _handle.handle(); }

Expand All @@ -376,14 +380,24 @@ class UDAFStateList {
// add a state to StateList
int add_state(FunctionContext* ctx, JNIEnv* env, jobject state);

// remove a state from StateList
void remove(FunctionContext* ctx, JNIEnv* env, int state);

// clear all state in StateList
void clear(FunctionContext* ctx, JNIEnv* env);

private:
JavaGlobalRef _handle;
JavaGlobalRef _get_method;
JavaGlobalRef _batch_get_method;
JavaGlobalRef _add_method;
JavaGlobalRef _remove_method;
JavaGlobalRef _clear_method;
jmethodID _get_method_id;
jmethodID _batch_get_method_id;
jmethodID _add_method_id;
jmethodID _remove_method_id;
jmethodID _clear_method_id;
};

// For loading UDF Class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import java.util.List;

public class FunctionStates<T> {
private final int PROBE_MAX_TIMES = 16;
private int probeIndex = 0;
private int emptySlots = 0;
public List<T> states = new ArrayList<>();

public T get(int idx) {
Expand All @@ -26,18 +29,53 @@ public T get(int idx) {

Object[] batch_get(int[] idxs) {
Object[] res = new Object[idxs.length];
for(int i = 0;i < idxs.length; ++i) {
for (int i = 0; i < idxs.length; ++i) {
if (idxs[i] != -1) {
res[i] = states.get(idxs[i]);
}
}
return res;
}


public int add(T state) throws Exception {
// probe empty slot from
if (emptySlots > states.size() / 2) {
int probeTimes = 0;
do {
if (states.get(probeIndex) == null) {
states.set(probeIndex, state);
emptySlots--;
final int ret = probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
return ret;
}
probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
probeTimes++;
} while (probeTimes <= PROBE_MAX_TIMES);
}
states.add(state);
return states.size() - 1;
}

public void remove(int idx) {
emptySlots++;
states.set(idx, null);
}

public void clear() {
states.clear();
emptySlots = 0;
probeIndex = 0;
}

// used for test
public int size() {
return states.size();
}

}
Loading

0 comments on commit be7d803

Please sign in to comment.