Skip to content

Commit

Permalink
[BugFix] Fix Java UDAF OOM when spill
Browse files Browse the repository at this point in the history
In past implementations, Function States would not be released.

For sorted streaming agg or spill aggregates, they will not be released even if destroy is called.
This PR adds a remove state interface to Function States.

After calling destroy, it will call function states-> remove.

Signed-off-by: stdpain <[email protected]>
  • Loading branch information
stdpain committed Jul 19, 2024
1 parent 6132193 commit 3613a87
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 22 deletions.
5 changes: 4 additions & 1 deletion be/src/exprs/agg/java_udaf_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ Status init_udaf_context(int64_t id, const std::string& url, const std::string&
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
44 changes: 30 additions & 14 deletions be/src/exprs/agg/java_udaf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <cstring>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
Expand Down Expand Up @@ -255,8 +256,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {

// This is only used to get portion of the entire binary column
template <class StatesProvider, class MergeCaller>
void _merge_batch_process(StatesProvider&& states_provider, MergeCaller&& caller, const Column* column,
size_t start, size_t size) const {
void _merge_batch_process(FunctionContext* ctx, StatesProvider&& states_provider, MergeCaller&& caller,
const Column* column, size_t start, size_t size, bool need_multi_buffer) const {
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();
// get state lists
Expand All @@ -269,17 +270,32 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
ColumnHelper::get_binary_column(const_cast<Column*>(ColumnHelper::get_data_column(column)));

auto& serialized_bytes = serialized_column->get_bytes();
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
const auto& offsets = serialized_column->get_offset();

auto buffer =
std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset, end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
if (serialized_bytes.size() > std::numeric_limits<int>::max()) {
ctx->set_error("serialized column size is too large");
return;
}

// batch call merge
caller(state_array, buffer_array);
if (!need_multi_buffer) {
size_t start_offset = serialized_column->get_offset()[start];
size_t end_offset = serialized_column->get_offset()[start + size];
// create one buffer will be ok
auto buffer = std::make_unique<DirectByteBuffer>(serialized_bytes.data() + start_offset,
end_offset - start_offset);
auto buffer_array = helper.create_object_array(buffer->handle(), size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
} else {
auto buffer_array =
helper.batch_create_bytebuf(serialized_bytes.data(), offsets.data(), start, start + size);
RETURN_IF_UNLIKELY_NULL(buffer_array, (void)0);
LOCAL_REF_GUARD_ENV(env, buffer_array);
// batch call merge
caller(state_array, buffer_array);
}
}

void merge_batch(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -300,7 +316,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, false);
}

void merge_batch_selectively(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column* column,
Expand All @@ -318,7 +334,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_if_not_null(ctx, ctx->udaf_ctxs()->handle.handle(),
ctx->udaf_ctxs()->merge->method.handle(), state_array, state_and_buffer, 1);
};
_merge_batch_process(std::move(provider), std::move(merger), column, 0, batch_size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, 0, batch_size, true);
}

void merge_batch_single_state(FunctionContext* ctx, AggDataPtr __restrict state, const Column* column, size_t start,
Expand All @@ -336,7 +352,7 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
helper.batch_update_state(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->merge->method.handle(),
state_and_buffer, 2);
};
_merge_batch_process(std::move(provider), std::move(merger), column, start, size);
_merge_batch_process(ctx, std::move(provider), std::move(merger), column, start, size, false);
}

void batch_serialize(FunctionContext* ctx, size_t batch_size, const Buffer<AggDataPtr>& agg_states,
Expand Down
6 changes: 4 additions & 2 deletions be/src/exprs/agg/java_window_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ Status window_init_jvm_context(int64_t fid, const std::string& url, const std::s
ASSIGN_OR_RETURN(auto get_func, analyzer->get_method_object(state_clazz.clazz(), "get"));
ASSIGN_OR_RETURN(auto batch_get_func, analyzer->get_method_object(state_clazz.clazz(), "batch_get"));
ASSIGN_OR_RETURN(auto add_func, analyzer->get_method_object(state_clazz.clazz(), "add"));

udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func);
ASSIGN_OR_RETURN(auto remove_func, analyzer->get_method_object(state_clazz.clazz(), "remove"));
ASSIGN_OR_RETURN(auto clear_func, analyzer->get_method_object(state_clazz.clazz(), "clear"));
udaf_ctx->states = std::make_unique<UDAFStateList>(std::move(instance), get_func, batch_get_func, add_func,
remove_func, clear_func);
udaf_ctx->_func = std::make_unique<UDAFFunction>(udaf_ctx->handle.handle(), context, udaf_ctx);

return Status::OK();
Expand Down
33 changes: 31 additions & 2 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "column/binary_column.h"
#include "column/column.h"
#include "common/status.h"
#include "exprs/function_context.h"
#include "fmt/core.h"
#include "jni.h"
#include "types/logical_type.h"
Expand Down Expand Up @@ -149,6 +150,11 @@ void JVMFunctionHelper::_init() {
_udf_helper_class, "batchUpdateIfNotNull",
"(Ljava/lang/Object;Ljava/lang/reflect/Method;Lcom/starrocks/udf/FunctionStates;[I[Ljava/lang/Object;)V");

_batch_create_bytebuf =
_env->GetStaticMethodID(_udf_helper_class, "batchCreateDirectBuffer", "(J[II)[Ljava/lang/Object;");

CHECK(_batch_create_bytebuf) << " not found method batchCreateDirectBuffer plz check jni-packages";

_int_batch_call = _env->GetStaticMethodID(_udf_helper_class, "batchCall",
"([Ljava/lang/Object;Ljava/lang/reflect/Method;I)[I");
_get_boxed_result =
Expand Down Expand Up @@ -283,6 +289,14 @@ jobject JVMFunctionHelper::create_object_array(jobject o, int num_rows) {
return res_arr;
}

jobject JVMFunctionHelper::batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end) {
int size = end - begin;
auto offsets = _env->NewIntArray(size + 1);
_env->SetIntArrayRegion(offsets, 0, size + 1, (const int32_t*)offset);
LOCAL_REF_GUARD(offsets);
return _env->CallStaticObjectMethod(_udf_helper_class, _batch_create_bytebuf, ptr, offsets, size);
}

void JVMFunctionHelper::batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows) {
auto obj = convert_handle_to_jobject(stub->ctx(), state);
LOCAL_REF_GUARD(obj);
Expand Down Expand Up @@ -476,15 +490,19 @@ StatusOr<JavaGlobalRef> JVMClass::newInstance() const {
}

UDAFStateList::UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get,
JavaGlobalRef&& add)
JavaGlobalRef&& add, JavaGlobalRef&& remove, JavaGlobalRef&& clear)
: _handle(std::move(handle)),
_get_method(std::move(get)),
_batch_get_method(std::move(batch_get)),
_add_method(std::move(add)) {
_add_method(std::move(add)),
_remove_method(std::move(remove)),
_clear_method(std::move(clear)) {
auto* env = JVMFunctionHelper::getInstance().getEnv();
_get_method_id = env->FromReflectedMethod(_get_method.handle());
_batch_get_method_id = env->FromReflectedMethod(_batch_get_method.handle());
_add_method_id = env->FromReflectedMethod(_add_method.handle());
_remove_method_id = env->FromReflectedMethod(_remove_method.handle());
_clear_method_id = env->FromReflectedMethod(_clear_method.handle());
}

jobject UDAFStateList::get_state(FunctionContext* ctx, JNIEnv* env, int state_handle) {
Expand All @@ -505,6 +523,16 @@ int UDAFStateList::add_state(FunctionContext* ctx, JNIEnv* env, jobject state) {
return res;
}

void UDAFStateList::remove(FunctionContext* ctx, JNIEnv* env, int state_handle) {
env->CallVoidMethod(_handle.handle(), _remove_method_id, state_handle);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

void UDAFStateList::clear(FunctionContext* ctx, JNIEnv* env) {
env->CallVoidMethod(_handle.handle(), _clear_method_id);
CHECK_UDF_CALL_EXCEPTION(env, ctx);
}

ClassLoader::~ClassLoader() {
_handle.clear();
_clazz.clear();
Expand Down Expand Up @@ -804,6 +832,7 @@ void UDAFFunction::destroy(int state) {
// call destroy
env->CallVoidMethod(_udaf_handle, destory, obj);
CHECK_UDF_CALL_EXCEPTION(env, _function_context);
_ctx->states->remove(_function_context, env, state);
}

jvalue UDAFFunction::finalize(int state) {
Expand Down
16 changes: 15 additions & 1 deletion be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class JVMFunctionHelper {
jobject create_boxed_array(int type, int num_rows, bool nullable, DirectByteBuffer* buffs, int sz);
// create object array with the same elements
jobject create_object_array(jobject o, int num_rows);
jobject batch_create_bytebuf(unsigned char* ptr, const uint32_t* offset, int begin, int end);

// batch update single
void batch_update_single(AggBatchCallStub* stub, int state, jobject* input, int cols, int rows);
Expand Down Expand Up @@ -177,6 +178,7 @@ class JVMFunctionHelper {
jmethodID _batch_update;
jmethodID _batch_update_if_not_null;
jmethodID _batch_update_state;
jmethodID _batch_create_bytebuf;
jmethodID _batch_call;
jmethodID _batch_call_no_args;
jmethodID _int_batch_call;
Expand Down Expand Up @@ -360,10 +362,12 @@ class BatchEvaluateStub {
// UDAF State Lists
// mapping a java object as a int index
// use get method to
// TODO: implement a Java binder to avoid using this class
class UDAFStateList {
public:
static inline const char* clazz_name = "com.starrocks.udf.FunctionStates";
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add);
UDAFStateList(JavaGlobalRef&& handle, JavaGlobalRef&& get, JavaGlobalRef&& batch_get, JavaGlobalRef&& add,
JavaGlobalRef&& remove, JavaGlobalRef&& clear);

jobject handle() { return _handle.handle(); }

Expand All @@ -376,14 +380,24 @@ class UDAFStateList {
// add a state to StateList
int add_state(FunctionContext* ctx, JNIEnv* env, jobject state);

// remove a state from StateList
void remove(FunctionContext* ctx, JNIEnv* env, int state);

// clear all state in StateList
void clear(FunctionContext* ctx, JNIEnv* env);

private:
JavaGlobalRef _handle;
JavaGlobalRef _get_method;
JavaGlobalRef _batch_get_method;
JavaGlobalRef _add_method;
JavaGlobalRef _remove_method;
JavaGlobalRef _clear_method;
jmethodID _get_method_id;
jmethodID _batch_get_method_id;
jmethodID _add_method_id;
jmethodID _remove_method_id;
jmethodID _clear_method_id;
};

// For loading UDF Class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import java.util.List;

public class FunctionStates<T> {
private final int PROBE_MAX_TIMES = 16;
private int probeIndex = 0;
private int emptySlots = 0;
public List<T> states = new ArrayList<>();

public T get(int idx) {
Expand All @@ -26,18 +29,53 @@ public T get(int idx) {

Object[] batch_get(int[] idxs) {
Object[] res = new Object[idxs.length];
for(int i = 0;i < idxs.length; ++i) {
for (int i = 0; i < idxs.length; ++i) {
if (idxs[i] != -1) {
res[i] = states.get(idxs[i]);
}
}
return res;
}


public int add(T state) throws Exception {
// probe empty slot from
if (emptySlots > states.size() / 2) {
int probeTimes = 0;
do {
if (states.get(probeIndex) == null) {
states.set(probeIndex, state);
emptySlots--;
final int ret = probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
return ret;
}
probeIndex++;
if (probeIndex == states.size()) {
probeIndex = 0;
}
probeTimes++;
} while (probeTimes <= PROBE_MAX_TIMES);
}
states.add(state);
return states.size() - 1;
}

public void remove(int idx) {
emptySlots++;
states.set(idx, null);
}

public void clear() {
states.clear();
emptySlots = 0;
probeIndex = 0;
}

// used for test
public int size() {
return states.size();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.starrocks.utils.Platform;

import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
Expand Down Expand Up @@ -724,6 +725,21 @@ public static void batchUpdateIfNotNull(Object o, Method method, FunctionStates
}
}

public static Object[] batchCreateDirectBuffer(long data, int[] offsets, int size) throws Exception {
Class<?> directByteBufferClass = Class.forName("java.nio.DirectByteBuffer");
Constructor<?> constructor = directByteBufferClass.getDeclaredConstructor(long.class, int.class);
constructor.setAccessible(true);

Object[] res = new Object[size];
int nums = 0;
for (int i = 0;i < size; i++) {
long address = data + offsets[i];
int length = offsets[i + 1] - offsets[i];
res[nums++] = constructor.newInstance(address, length);
}
return res;
}

// batch call Object(Object...)
public static Object[] batchCall(Object o, Method method, int batchSize, Object[] column)
throws Throwable {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.udf;

import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

public class FunctionStatesTest {
private static class State {

}
// Test Agg batch Call Single
@Test
public void testAddFunctionStates() throws Exception {
FunctionStates<State> states = new FunctionStates<>();
List<Integer> lst = new ArrayList<>();
for (int i = 0; i < 4096; i++) {
lst.add(states.add(new State()));
}
for (int j = 0; j < 100; j++) {
for (int i = 0; i < 4095; i++) {
states.remove(lst.get(i));
}
lst.clear();
for (int i = 0; i < 4095; i++) {
lst.add(states.add(new State()));
}
}

Assert.assertEquals(states.size(), 8191);
}
}
Loading

0 comments on commit 3613a87

Please sign in to comment.