Skip to content

Commit

Permalink
Fix a casting bugs when it tries to laod more than 4G sized index file.
Browse files Browse the repository at this point in the history
Signed-off-by: Dooyong Kim <[email protected]>
  • Loading branch information
Dooyong Kim committed Oct 1, 2024
1 parent fe33151 commit da17102
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 20 deletions.
4 changes: 2 additions & 2 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class NativeEngineIndexInputMediator {
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int32_t nbytes, uint8_t *destination) {
void copyBytes(int64_t nbytes, uint8_t *destination) {
while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
const auto readBytes =
jni_interface->CallIntMethodInt(env, indexInput, copyBytesMethod, nbytes);
jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes);

// === Critical Section Start ===

Expand Down
4 changes: 2 additions & 2 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ namespace knn_jni {

virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0;

virtual jint CallIntMethodInt(JNIEnv * env, jobject obj, jmethodID methodID, int intArg) = 0;
virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0;

// --------------------------------------------------------------------------
};
Expand Down Expand Up @@ -194,7 +194,7 @@ namespace knn_jni {
jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final;
jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jint CallIntMethodInt(JNIEnv * env, jobject obj, jmethodID methodID, int intArg) final;
jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final;
void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final;
void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final;

Expand Down
14 changes: 14 additions & 0 deletions jni/include/nmslib_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//
// Created by Kim, Dooyong on 9/23/24.
//

#ifndef JNI_INCLUDE_NMSLIB_STREAM_SUPPORT_H_
#define JNI_INCLUDE_NMSLIB_STREAM_SUPPORT_H_

namespace knn_jni { namespace stream {



}}

#endif //JNI_INCLUDE_NMSLIB_STREAM_SUPPORT_H_
4 changes: 2 additions & 2 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ jfieldID knn_jni::JNIUtil::GetFieldID(JNIEnv * env, jclass clazz, const char *na
return env->GetFieldID(clazz, name, sig);
}

jint knn_jni::JNIUtil::CallIntMethodInt(JNIEnv * env, jobject obj, jmethodID methodID, int intArg) {
return env->CallIntMethod(obj, methodID, intArg);
jint knn_jni::JNIUtil::CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) {
return env->CallIntMethod(obj, methodID, longArg);
}

void * knn_jni::JNIUtil::GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) {
Expand Down
24 changes: 12 additions & 12 deletions jni/tests/faiss_stream_support_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ struct JavaIndexInputMock {
}

// This method is simulating `copyBytes` in IndexInputWithBuffer.
int32_t simulateCopyReads(int32_t readBytes) {
readBytes = std::min(readBytes, (int32_t) buffer.size());
readBytes = std::min(readBytes, (int32_t) (readTargetBytes.size() - nextReadIdx));
int32_t simulateCopyReads(int64_t readBytes) {
readBytes = std::min(readBytes, (int64_t) buffer.size());
readBytes = std::min(readBytes, (int64_t) (readTargetBytes.size() - nextReadIdx));
std::memcpy(buffer.data(), readTargetBytes.data() + nextReadIdx, readBytes);
nextReadIdx += readBytes;
return readBytes;
return (int32_t) readBytes;
}

static std::string makeRandomBytes(int32_t bytesSize) {
Expand Down Expand Up @@ -63,21 +63,21 @@ struct JavaIndexInputMock {
}

std::string readTargetBytes;
int32_t nextReadIdx;
int64_t nextReadIdx;
std::vector<char> buffer;
}; // struct JavaIndexInputMock

void setUpMockJNIUtil(JavaIndexInputMock& javaIndexInputMock, MockJNIUtil& mockJni) {
void setUpMockJNIUtil(JavaIndexInputMock &javaIndexInputMock, MockJNIUtil &mockJni) {
// Set up mocking values + mocking behavior in a method.
ON_CALL(mockJni, FindClassFromJNIEnv).WillByDefault(Return((jclass) 1));
ON_CALL(mockJni, GetMethodID).WillByDefault(Return((jmethodID) 1));
ON_CALL(mockJni, GetFieldID).WillByDefault(Return((jfieldID) 1));
ON_CALL(mockJni, GetObjectField).WillByDefault(Return((jobject) 1));
ON_CALL(mockJni, CallIntMethodInt).WillByDefault([&javaIndexInputMock](JNIEnv *env,
jobject obj,
jmethodID methodID,
int intArg) {
return javaIndexInputMock.simulateCopyReads(intArg);
ON_CALL(mockJni, CallIntMethodLong).WillByDefault([&javaIndexInputMock](JNIEnv *env,
jobject obj,
jmethodID methodID,
int64_t longArg) {
return javaIndexInputMock.simulateCopyReads(longArg);
});
ON_CALL(mockJni, GetPrimitiveArrayCritical).WillByDefault([&javaIndexInputMock](JNIEnv *env,
jarray array,
Expand All @@ -97,7 +97,7 @@ TEST(FaissStreamSupportTest, NativeEngineIndexInputMediatorCopyWhenEmpty) {

// Prepare copying
NativeEngineIndexInputMediator mediator{&mockJni, nullptr, nullptr};
std::string readBuffer (javaIndexInputMock.readTargetBytes.size(), '\0');
std::string readBuffer(javaIndexInputMock.readTargetBytes.size(), '\0');

// Call copyBytes
mediator.copyBytes((int32_t) javaIndexInputMock.readTargetBytes.size(), (uint8_t *) readBuffer.data());
Expand Down
2 changes: 1 addition & 1 deletion jni/tests/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ namespace test_util {
MOCK_METHOD(jclass, FindClassFromJNIEnv, (JNIEnv * env, const char *name));
MOCK_METHOD(jmethodID, GetMethodID, (JNIEnv * env, jclass clazz, const char *name, const char *sig));
MOCK_METHOD(jfieldID, GetFieldID, (JNIEnv * env, jclass clazz, const char *name, const char *sig));
MOCK_METHOD(jint, CallIntMethodInt, (JNIEnv * env, jobject obj, jmethodID methodID, int intArg));
MOCK_METHOD(jint, CallIntMethodLong, (JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg));
MOCK_METHOD(void *, GetPrimitiveArrayCritical, (JNIEnv * env, jarray array, jboolean *isCopy));
MOCK_METHOD(void, ReleasePrimitiveArrayCritical, (JNIEnv * env, jarray array, void *carray, jint mode));
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public IndexInputWithBuffer(@NonNull IndexInput indexInput) {
* @throws IOException
*/
private int copyBytes(long nbytes) throws IOException {
final int readBytes = Math.min(Math.toIntExact(nbytes), buffer.length);
final int readBytes = (int) Math.min(nbytes, buffer.length);
indexInput.readBytes(buffer, 0, readBytes);
return readBytes;
}
Expand Down

0 comments on commit da17102

Please sign in to comment.