From 4c9e73d92f25d5ead71fe06c8e7d1fec9a2f93b8 Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Fri, 22 Mar 2024 10:55:31 -0700 Subject: [PATCH] try unwrapping memory in compress/decompress --- .../jna/JnaCloseableByteBuffer.java | 2 +- .../nativeaccess/jna/JnaZstdLibrary.java | 24 ++++++++++++----- .../org/elasticsearch/nativeaccess/Zstd.java | 12 ++++----- .../nativeaccess/lib/ZstdLibrary.java | 6 ++--- .../jdk/JdkCloseableByteBuffer.java | 5 +++- .../nativeaccess/jdk/JdkZstdLibrary.java | 26 ++++++++++++------- .../elasticsearch/nativeaccess/ZstdTests.java | 26 +++++++++---------- .../codec/zstd/Zstd814StoredFieldsFormat.java | 4 +-- 8 files changed, 64 insertions(+), 41 deletions(-) diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaCloseableByteBuffer.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaCloseableByteBuffer.java index e47b17e234705..e987f8042691b 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaCloseableByteBuffer.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaCloseableByteBuffer.java @@ -15,7 +15,7 @@ import java.nio.ByteBuffer; class JnaCloseableByteBuffer implements CloseableByteBuffer { - private final Memory memory; + final Memory memory; private final ByteBuffer bufferView; JnaCloseableByteBuffer(int len) { diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaZstdLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaZstdLibrary.java index f0581633ea969..8573203a71697 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaZstdLibrary.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaZstdLibrary.java @@ -9,8 +9,12 @@ package org.elasticsearch.nativeaccess.jna; import com.sun.jna.Library; +import com.sun.jna.Memory; import com.sun.jna.Native; +import com.sun.jna.Pointer; + +import org.elasticsearch.nativeaccess.CloseableByteBuffer; import org.elasticsearch.nativeaccess.lib.ZstdLibrary; import java.nio.ByteBuffer; @@ -20,13 +24,13 @@ class JnaZstdLibrary implements ZstdLibrary { private interface NativeFunctions extends Library { long ZSTD_compressBound(int scrLen); - long ZSTD_compress(ByteBuffer dst, int dstLen, ByteBuffer src, int srcLen, int compressionLevel); + long ZSTD_compress(Pointer dst, int dstLen, Pointer src, int srcLen, int compressionLevel); boolean ZSTD_isError(long code); String ZSTD_getErrorName(long code); - long ZSTD_decompress(ByteBuffer dst, int dstLen, ByteBuffer src, int srcLen); + long ZSTD_decompress(Pointer dst, int dstLen, Pointer src, int srcLen); } private final NativeFunctions functions; @@ -41,8 +45,12 @@ public long compressBound(int scrLen) { } @Override - public long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel) { - return functions.ZSTD_compress(dst, dst.remaining(), src, src.remaining(), compressionLevel); + public long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel) { + assert dst instanceof JnaCloseableByteBuffer; + assert src instanceof JnaCloseableByteBuffer; + var nativeDst = (JnaCloseableByteBuffer) dst; + var nativeSrc = (JnaCloseableByteBuffer) src; + return functions.ZSTD_compress(nativeDst.memory.share(dst.buffer().position()), dst.buffer().remaining(), nativeSrc.memory.share(src.buffer().position()), src.buffer().remaining(), compressionLevel); } @Override @@ -56,7 +64,11 @@ public String getErrorName(long code) { } @Override - public long decompress(ByteBuffer dst, ByteBuffer src) { - return functions.ZSTD_decompress(dst, dst.remaining(), src, src.remaining()); + public long decompress(CloseableByteBuffer dst, CloseableByteBuffer src) { + assert dst instanceof JnaCloseableByteBuffer; + assert src instanceof JnaCloseableByteBuffer; + var nativeDst = (JnaCloseableByteBuffer) dst; + var nativeSrc = (JnaCloseableByteBuffer) src; + return functions.ZSTD_decompress(nativeDst.memory.share(dst.buffer().position()), dst.buffer().remaining(), nativeSrc.memory.share(src.buffer().position()), src.buffer().remaining()); } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/Zstd.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/Zstd.java index 6a0d348d5251b..061525f781369 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/Zstd.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/Zstd.java @@ -25,13 +25,13 @@ public final class Zstd { * Compress the content of {@code src} into {@code dst} at compression level {@code level}, and return the number of compressed bytes. * {@link ByteBuffer#position()} and {@link ByteBuffer#limit()} of both {@link ByteBuffer}s are left unmodified. */ - public int compress(ByteBuffer dst, ByteBuffer src, int level) { + public int compress(CloseableByteBuffer dst, CloseableByteBuffer src, int level) { Objects.requireNonNull(dst, "Null destination buffer"); Objects.requireNonNull(src, "Null source buffer"); - assert dst.isDirect(); + /*assert dst.isDirect(); assert dst.isReadOnly() == false; assert src.isDirect(); - assert src.isReadOnly() == false; + assert src.isReadOnly() == false;*/ long ret = zstdLib.compress(dst, src, level); if (zstdLib.isError(ret)) { throw new IllegalArgumentException(zstdLib.getErrorName(ret)); @@ -45,13 +45,13 @@ public int compress(ByteBuffer dst, ByteBuffer src, int level) { * Compress the content of {@code src} into {@code dst}, and return the number of decompressed bytes. {@link ByteBuffer#position()} and * {@link ByteBuffer#limit()} of both {@link ByteBuffer}s are left unmodified. */ - public int decompress(ByteBuffer dst, ByteBuffer src) { + public int decompress(CloseableByteBuffer dst, CloseableByteBuffer src) { Objects.requireNonNull(dst, "Null destination buffer"); Objects.requireNonNull(src, "Null source buffer"); - assert dst.isDirect(); + /*assert dst.isDirect(); assert dst.isReadOnly() == false; assert src.isDirect(); - assert src.isReadOnly() == false; + assert src.isReadOnly() == false;*/ long ret = zstdLib.decompress(dst, src); if (zstdLib.isError(ret)) { throw new IllegalArgumentException(zstdLib.getErrorName(ret)); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/ZstdLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/ZstdLibrary.java index feb1dbe8e3d61..ea4c8efa5318a 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/ZstdLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/ZstdLibrary.java @@ -8,17 +8,17 @@ package org.elasticsearch.nativeaccess.lib; -import java.nio.ByteBuffer; +import org.elasticsearch.nativeaccess.CloseableByteBuffer; public non-sealed interface ZstdLibrary extends NativeLibrary { long compressBound(int scrLen); - long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel); + long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel); boolean isError(long code); String getErrorName(long code); - long decompress(ByteBuffer dst, ByteBuffer src); + long decompress(CloseableByteBuffer dst, CloseableByteBuffer src); } diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkCloseableByteBuffer.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkCloseableByteBuffer.java index d802fd8be7a67..409839726e839 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkCloseableByteBuffer.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkCloseableByteBuffer.java @@ -11,15 +11,18 @@ import org.elasticsearch.nativeaccess.CloseableByteBuffer; import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; import java.nio.ByteBuffer; class JdkCloseableByteBuffer implements CloseableByteBuffer { private final Arena arena; + final MemorySegment segment; private final ByteBuffer bufferView; JdkCloseableByteBuffer(int len) { this.arena = Arena.ofShared(); - this.bufferView = this.arena.allocate(len).asByteBuffer(); + this.segment = arena.allocate(len); + this.bufferView = segment.asByteBuffer(); } @Override diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkZstdLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkZstdLibrary.java index 632240a844255..789e2c1ef39c0 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkZstdLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkZstdLibrary.java @@ -8,12 +8,12 @@ package org.elasticsearch.nativeaccess.jdk; +import org.elasticsearch.nativeaccess.CloseableByteBuffer; import org.elasticsearch.nativeaccess.lib.ZstdLibrary; import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.MemorySegment; import java.lang.invoke.MethodHandle; -import java.nio.ByteBuffer; import static java.lang.foreign.ValueLayout.ADDRESS; import static java.lang.foreign.ValueLayout.JAVA_BOOLEAN; @@ -49,11 +49,15 @@ public long compressBound(int srcLen) { } @Override - public long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel) { - var nativeDst = MemorySegment.ofBuffer(dst); - var nativeSrc = MemorySegment.ofBuffer(src); + public long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel) { + assert dst instanceof JdkCloseableByteBuffer; + assert src instanceof JdkCloseableByteBuffer; + var nativeDst = (JdkCloseableByteBuffer) dst; + var nativeSrc = (JdkCloseableByteBuffer) src; + var segmentDst = nativeDst.segment.asSlice(dst.buffer().position(), dst.buffer().remaining()); + var segmentSrc = nativeSrc.segment.asSlice(src.buffer().position(), src.buffer().remaining()); try { - return (long) compress$mh.invokeExact(nativeDst, dst.remaining(), nativeSrc, src.remaining(), compressionLevel); + return (long) compress$mh.invokeExact(segmentDst, segmentDst.byteSize(), segmentSrc, segmentSrc.byteSize(), compressionLevel); } catch (Throwable t) { throw new AssertionError(t); } @@ -79,11 +83,15 @@ public String getErrorName(long code) { } @Override - public long decompress(ByteBuffer dst, ByteBuffer src) { - var nativeDst = MemorySegment.ofBuffer(dst); - var nativeSrc = MemorySegment.ofBuffer(src); + public long decompress(CloseableByteBuffer dst, CloseableByteBuffer src) { + assert dst instanceof JdkCloseableByteBuffer; + assert src instanceof JdkCloseableByteBuffer; + var nativeDst = (JdkCloseableByteBuffer) dst; + var nativeSrc = (JdkCloseableByteBuffer) src; + var segmentDst = nativeDst.segment.asSlice(dst.buffer().position(), dst.buffer().remaining()); + var segmentSrc = nativeSrc.segment.asSlice(src.buffer().position(), src.buffer().remaining()); try { - return (long) decompress$mh.invokeExact(nativeDst, dst.remaining(), nativeSrc, src.remaining()); + return (long) decompress$mh.invokeExact(segmentDst, segmentDst.byteSize(), segmentSrc, segmentSrc.byteSize()); } catch (Throwable t) { throw new AssertionError(t); } diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/ZstdTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/ZstdTests.java index d051961b06c5f..1282b1fee9206 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/ZstdTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/ZstdTests.java @@ -41,16 +41,16 @@ public void testCompressValidation() { var srcBuf = src.buffer(); var dstBuf = dst.buffer(); - var npe1 = expectThrows(NullPointerException.class, () -> zstd.compress(null, srcBuf, 0)); + var npe1 = expectThrows(NullPointerException.class, () -> zstd.compress(null, src, 0)); assertThat(npe1.getMessage(), equalTo("Null destination buffer")); - var npe2 = expectThrows(NullPointerException.class, () -> zstd.compress(dstBuf, null, 0)); + var npe2 = expectThrows(NullPointerException.class, () -> zstd.compress(dst, null, 0)); assertThat(npe2.getMessage(), equalTo("Null source buffer")); // dst capacity too low for (int i = 0; i < srcBuf.remaining(); ++i) { srcBuf.put(i, randomByte()); } - var e = expectThrows(IllegalArgumentException.class, () -> zstd.compress(dstBuf, srcBuf, 0)); + var e = expectThrows(IllegalArgumentException.class, () -> zstd.compress(dst, src, 0)); assertThat(e.getMessage(), equalTo("Destination buffer is too small")); } } @@ -64,21 +64,21 @@ public void testDecompressValidation() { var originalBuf = original.buffer(); var compressedBuf = compressed.buffer(); - var npe1 = expectThrows(NullPointerException.class, () -> zstd.decompress(null, originalBuf)); + var npe1 = expectThrows(NullPointerException.class, () -> zstd.decompress(null, original)); assertThat(npe1.getMessage(), equalTo("Null destination buffer")); - var npe2 = expectThrows(NullPointerException.class, () -> zstd.decompress(compressedBuf, null)); + var npe2 = expectThrows(NullPointerException.class, () -> zstd.decompress(compressed, null)); assertThat(npe2.getMessage(), equalTo("Null source buffer")); // Invalid compressed format for (int i = 0; i < originalBuf.remaining(); ++i) { originalBuf.put(i, (byte) i); } - var e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(compressedBuf, originalBuf)); + var e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(compressed, original)); assertThat(e.getMessage(), equalTo("Unknown frame descriptor")); - int compressedLength = zstd.compress(compressedBuf, originalBuf, 0); + int compressedLength = zstd.compress(compressed, original, 0); compressedBuf.limit(compressedLength); - e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(restored.buffer(), compressedBuf)); + e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(restored, compressed)); assertThat(e.getMessage(), equalTo("Destination buffer is too small")); } @@ -109,9 +109,9 @@ private void doTestRoundtrip(byte[] data) { var restored = nativeAccess.newBuffer(data.length) ) { original.buffer().put(0, data); - int compressedLength = zstd.compress(compressed.buffer(), original.buffer(), randomIntBetween(-3, 9)); + int compressedLength = zstd.compress(compressed, original, randomIntBetween(-3, 9)); compressed.buffer().limit(compressedLength); - int decompressedLength = zstd.decompress(restored.buffer(), compressed.buffer()); + int decompressedLength = zstd.decompress(restored, compressed); assertThat(restored.buffer(), equalTo(original.buffer())); assertThat(decompressedLength, equalTo(data.length)); } @@ -127,15 +127,15 @@ private void doTestRoundtrip(byte[] data) { original.buffer().put(decompressedOffset, data); original.buffer().position(decompressedOffset); compressed.buffer().position(compressedOffset); - int compressedLength = zstd.compress(compressed.buffer(), original.buffer(), randomIntBetween(-3, 9)); + int compressedLength = zstd.compress(compressed, original, randomIntBetween(-3, 9)); compressed.buffer().limit(compressedOffset + compressedLength); restored.buffer().position(decompressedOffset); - int decompressedLength = zstd.decompress(restored.buffer(), compressed.buffer()); + int decompressedLength = zstd.decompress(restored, compressed); + assertThat(decompressedLength, equalTo(data.length)); assertThat( restored.buffer().slice(decompressedOffset, data.length), equalTo(original.buffer().slice(decompressedOffset, data.length)) ); - assertThat(decompressedLength, equalTo(data.length)); } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/zstd/Zstd814StoredFieldsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/zstd/Zstd814StoredFieldsFormat.java index 635bd7ffdff4d..e52397581ae08 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/zstd/Zstd814StoredFieldsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/zstd/Zstd814StoredFieldsFormat.java @@ -132,7 +132,7 @@ public void decompress(DataInput in, int originalLength, int offset, int length, } src.buffer().flip(); - final int decompressedLen = zstd.decompress(dest.buffer(), src.buffer()); + final int decompressedLen = zstd.decompress(dest, src); if (decompressedLen != originalLength) { throw new CorruptIndexException("Expected " + originalLength + " decompressed bytes, got " + decompressedLen, in); } @@ -183,7 +183,7 @@ public void compress(ByteBuffersDataInput buffersInput, DataOutput out) throws I } src.buffer().flip(); - final int compressedLen = zstd.compress(dest.buffer(), src.buffer(), level); + final int compressedLen = zstd.compress(dest, src, level); out.writeVInt(compressedLen); for (int written = 0; written < compressedLen;) {