Skip to content

Commit

Permalink
Use predictable serialization logic for transport headers (#4264)
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Nied <[email protected]>
  • Loading branch information
peternied authored Apr 26, 2024
1 parent 6b9762b commit b3bfbd4
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.security.ssl.util.ExceptionUtils;
import org.opensearch.security.ssl.util.SSLRequestHelper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.SerializationFormat;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
Expand Down Expand Up @@ -92,7 +93,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
SerializationFormat.determineFormat(channel.getVersion()) == SerializationFormat.JDK
);

if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.opensearch.Version;
import org.opensearch.common.settings.Settings;
import org.opensearch.security.auditlog.impl.AuditCategory;

Expand Down Expand Up @@ -334,7 +333,6 @@ public enum RolesMappingResolution {
public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = "";

public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization";
public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0;

// On-behalf-of endpoints settings
// CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.support;

import org.opensearch.Version;

public enum SerializationFormat {
/** Uses Java's native serialization system */
JDK,
/** Uses a custom serializer built ontop of OpenSearch 2.11 */
CustomSerializer_2_11;

private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0;
private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0;

/**
* Determines the format of serialization that should be used from a version identifier
*/
public static SerializationFormat determineFormat(final Version version) {
if (version.onOrAfter(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
&& version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) {
return SerializationFormat.CustomSerializer_2_11;
}
return SerializationFormat.JDK;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.opensearch.security.support.Base64Helper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.support.SerializationFormat;
import org.opensearch.security.user.User;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport.Connection;
Expand Down Expand Up @@ -150,7 +151,8 @@ public <T extends TransportResponse> void sendRequestDecorate(
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean isDebugEnabled = log.isDebugEnabled();
final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);

final var serializationFormat = SerializationFormat.determineFormat(connection.getVersion());
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
Expand Down Expand Up @@ -228,25 +230,28 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
);
}

if (useJDKSerialization) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
try {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
}
getThreadContext().putHeader(headerMap);
} catch (IllegalArgumentException iae) {
log.debug("Failed to add headers information onto on thread context", iae);
}

getThreadContext().putHeader(headerMap);

ensureCorrectHeaders(
remoteAddress0,
user0,
origin0,
injectedUserString,
injectedRolesString,
isSameNodeRequest,
useJDKSerialization
serializationFormat
);

if (actionTraceEnabled.get()) {
Expand Down Expand Up @@ -275,7 +280,7 @@ private void ensureCorrectHeaders(
final String injectedUserString,
final String injectedRolesString,
final boolean isSameNodeRequest,
final boolean useJDKSerialization
final SerializationFormat format
) {
// keep original address

Expand Down Expand Up @@ -313,6 +318,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserString);
}
} else {
final var useJDKSerialization = format == SerializationFormat.JDK;
if (transportAddress != null) {
getThreadContext().putHeader(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
package org.opensearch.security.support;

import java.io.Serializable;
import java.util.HashMap;
import java.util.stream.IntStream;

import org.junit.Assert;
import org.junit.Test;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.opensearch.security.support.Base64Helper.deserializeObject;
import static org.opensearch.security.support.Base64Helper.serializeObject;
import static org.junit.Assert.assertThat;

public class Base64HelperTest {

Expand Down Expand Up @@ -48,4 +53,22 @@ public void testEnsureJDKSerialized() {
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));
}

@Test
public void testDuplicatedItemSizes() {
var largeObject = new HashMap<String, Object>();
var hm = new HashMap<>();
IntStream.range(0, 100).forEach(i -> { hm.put("c" + i, "cvalue" + i); });
IntStream.range(0, 100).forEach(i -> { largeObject.put("b" + i, hm); });

final var jdkSerialized = Base64Helper.serializeObject(largeObject, true);
final var customSerialized = Base64Helper.serializeObject(largeObject, false);
final var customSerializedOnlyHashMap = Base64Helper.serializeObject(hm, false);

assertThat(jdkSerialized.length(), equalTo(3832));
// The custom serializer is ~50x larger than the jdk serialized version
assertThat(customSerialized.length(), equalTo(184792));
// Show that the majority of the size of the custom serialized large object is the map duplicated ~100 times
assertThat((double) customSerializedOnlyHashMap.length(), closeTo(customSerialized.length() / 100, 70d));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -120,7 +121,7 @@ public class SecurityInterceptorTests {

private AsyncSender sender;
private AsyncSender serializedSender;
private AsyncSender nullSender;
private AtomicReference<CountDownLatch> senderLatch = new AtomicReference<>(new CountDownLatch(1));

@Before
public void setup() {
Expand Down Expand Up @@ -208,6 +209,7 @@ public <T extends TransportResponse> void sendRequest(
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, true));
senderLatch.get().countDown();
}
};

Expand All @@ -222,6 +224,7 @@ public <T extends TransportResponse> void sendRequest(
) {
User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
assertEquals(transientUser, user);
senderLatch.get().countDown();
}
};

Expand Down Expand Up @@ -249,17 +252,16 @@ final void completableRequestDecorate(
TransportResponseHandler handler,
DiscoveryNode localNode
) {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
verifyOriginalContext(user);
try {
senderLatch.get().await(1, TimeUnit.SECONDS);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}

ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor();

singleThreadExecutor.execute(() -> {
try {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
verifyOriginalContext(user);
} finally {
singleThreadExecutor.shutdown();
}
});
// Reset the latch so another request can be processed
senderLatch.set(new CountDownLatch(1));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,19 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand All @@ -118,9 +128,19 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand Down

0 comments on commit b3bfbd4

Please sign in to comment.