diff --git a/jni/include/faiss_stream_support.h b/jni/include/faiss_stream_support.h index a04ff9d5d..65f1631d4 100644 --- a/jni/include/faiss_stream_support.h +++ b/jni/include/faiss_stream_support.h @@ -42,11 +42,11 @@ class NativeEngineIndexInputMediator { copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) { } - void copyBytes(int32_t nbytes, uint8_t *destination) { + void copyBytes(int64_t nbytes, uint8_t *destination) { while (nbytes > 0) { // Call `copyBytes` to read bytes as many as possible. const auto readBytes = - jni_interface->CallIntMethodInt(env, indexInput, copyBytesMethod, nbytes); + jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes); // === Critical Section Start === diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 49b1c0c1b..6b1b926e7 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -138,7 +138,7 @@ namespace knn_jni { virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0; - virtual jint CallIntMethodInt(JNIEnv * env, jobject obj, jmethodID methodID, int intArg) = 0; + virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0; // -------------------------------------------------------------------------- }; @@ -194,7 +194,7 @@ namespace knn_jni { jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final; jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; - jint CallIntMethodInt(JNIEnv * env, jobject obj, jmethodID methodID, int intArg) final; + jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final; void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final; void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final; diff --git a/jni/include/nmslib_stream_support.h b/jni/include/nmslib_stream_support.h new file mode 100644 index 000000000..b84a496d3 --- /dev/null +++ b/jni/include/nmslib_stream_support.h @@ -0,0 +1,14 @@ +// +// Created by Kim, Dooyong on 9/23/24. +// + +#ifndef JNI_INCLUDE_NMSLIB_STREAM_SUPPORT_H_ +#define JNI_INCLUDE_NMSLIB_STREAM_SUPPORT_H_ + +namespace knn_jni { namespace stream { + + + +}} + +#endif //JNI_INCLUDE_NMSLIB_STREAM_SUPPORT_H_ diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index 1358fddb8..3eaf3b0a1 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -563,8 +563,8 @@ jfieldID knn_jni::JNIUtil::GetFieldID(JNIEnv * env, jclass clazz, const char *na return env->GetFieldID(clazz, name, sig); } -jint knn_jni::JNIUtil::CallIntMethodInt(JNIEnv * env, jobject obj, jmethodID methodID, int intArg) { - return env->CallIntMethod(obj, methodID, intArg); +jint knn_jni::JNIUtil::CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) { + return env->CallIntMethod(obj, methodID, longArg); } void * knn_jni::JNIUtil::GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) { diff --git a/jni/tests/faiss_stream_support_test.cpp b/jni/tests/faiss_stream_support_test.cpp index cceec8ef2..4045985bb 100644 --- a/jni/tests/faiss_stream_support_test.cpp +++ b/jni/tests/faiss_stream_support_test.cpp @@ -29,12 +29,12 @@ struct JavaIndexInputMock { } // This method is simulating `copyBytes` in IndexInputWithBuffer. - int32_t simulateCopyReads(int32_t readBytes) { - readBytes = std::min(readBytes, (int32_t) buffer.size()); - readBytes = std::min(readBytes, (int32_t) (readTargetBytes.size() - nextReadIdx)); + int32_t simulateCopyReads(int64_t readBytes) { + readBytes = std::min(readBytes, (int64_t) buffer.size()); + readBytes = std::min(readBytes, (int64_t) (readTargetBytes.size() - nextReadIdx)); std::memcpy(buffer.data(), readTargetBytes.data() + nextReadIdx, readBytes); nextReadIdx += readBytes; - return readBytes; + return (int32_t) readBytes; } static std::string makeRandomBytes(int32_t bytesSize) { @@ -63,21 +63,21 @@ struct JavaIndexInputMock { } std::string readTargetBytes; - int32_t nextReadIdx; + int64_t nextReadIdx; std::vector buffer; }; // struct JavaIndexInputMock -void setUpMockJNIUtil(JavaIndexInputMock& javaIndexInputMock, MockJNIUtil& mockJni) { +void setUpMockJNIUtil(JavaIndexInputMock &javaIndexInputMock, MockJNIUtil &mockJni) { // Set up mocking values + mocking behavior in a method. ON_CALL(mockJni, FindClassFromJNIEnv).WillByDefault(Return((jclass) 1)); ON_CALL(mockJni, GetMethodID).WillByDefault(Return((jmethodID) 1)); ON_CALL(mockJni, GetFieldID).WillByDefault(Return((jfieldID) 1)); ON_CALL(mockJni, GetObjectField).WillByDefault(Return((jobject) 1)); - ON_CALL(mockJni, CallIntMethodInt).WillByDefault([&javaIndexInputMock](JNIEnv *env, - jobject obj, - jmethodID methodID, - int intArg) { - return javaIndexInputMock.simulateCopyReads(intArg); + ON_CALL(mockJni, CallIntMethodLong).WillByDefault([&javaIndexInputMock](JNIEnv *env, + jobject obj, + jmethodID methodID, + int64_t longArg) { + return javaIndexInputMock.simulateCopyReads(longArg); }); ON_CALL(mockJni, GetPrimitiveArrayCritical).WillByDefault([&javaIndexInputMock](JNIEnv *env, jarray array, @@ -97,7 +97,7 @@ TEST(FaissStreamSupportTest, NativeEngineIndexInputMediatorCopyWhenEmpty) { // Prepare copying NativeEngineIndexInputMediator mediator{&mockJni, nullptr, nullptr}; - std::string readBuffer (javaIndexInputMock.readTargetBytes.size(), '\0'); + std::string readBuffer(javaIndexInputMock.readTargetBytes.size(), '\0'); // Call copyBytes mediator.copyBytes((int32_t) javaIndexInputMock.readTargetBytes.size(), (uint8_t *) readBuffer.data()); diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 4ef1de90a..286000c08 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -111,7 +111,7 @@ namespace test_util { MOCK_METHOD(jclass, FindClassFromJNIEnv, (JNIEnv * env, const char *name)); MOCK_METHOD(jmethodID, GetMethodID, (JNIEnv * env, jclass clazz, const char *name, const char *sig)); MOCK_METHOD(jfieldID, GetFieldID, (JNIEnv * env, jclass clazz, const char *name, const char *sig)); - MOCK_METHOD(jint, CallIntMethodInt, (JNIEnv * env, jobject obj, jmethodID methodID, int intArg)); + MOCK_METHOD(jint, CallIntMethodLong, (JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg)); MOCK_METHOD(void *, GetPrimitiveArrayCritical, (JNIEnv * env, jarray array, jboolean *isCopy)); MOCK_METHOD(void, ReleasePrimitiveArrayCritical, (JNIEnv * env, jarray array, void *carray, jint mode)); }; diff --git a/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java b/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java index 426232252..273a4deac 100644 --- a/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java +++ b/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java @@ -34,7 +34,7 @@ public IndexInputWithBuffer(@NonNull IndexInput indexInput) { * @throws IOException */ private int copyBytes(long nbytes) throws IOException { - final int readBytes = Math.min(Math.toIntExact(nbytes), buffer.length); + final int readBytes = (int) Math.min(nbytes, buffer.length); indexInput.readBytes(buffer, 0, readBytes); return readBytes; }