Skip to content

Commit

Permalink
spring-projectsGH-2687: Use JDK ObjectInputFilter for serialization…
Browse files Browse the repository at this point in the history
… security

Fixes: spring-projects#2687
  • Loading branch information
artembilan committed Jul 12, 2024
1 parent 682aefe commit 16c9e8f
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;

Expand Down Expand Up @@ -160,18 +160,16 @@ private Object asString(Message message, MessageProperties properties) {
}

private Object deserialize(ByteArrayInputStream inputStream) throws IOException {
try (ObjectInputStream objectInputStream = new ConfigurableObjectInputStream(inputStream,
this.defaultDeserializerClassLoader) {

@Override
protected Class<?> resolveClass(ObjectStreamClass classDesc)
throws IOException, ClassNotFoundException {
Class<?> clazz = super.resolveClass(classDesc);
checkAllowedList(clazz);
return clazz;
}

}) {
ObjectInputStream objectInputStream =
new ConfigurableObjectInputStream(inputStream, this.defaultDeserializerClassLoader);
objectInputStream.setObjectInputFilter(
ObjectInputFilter.allowFilter(aClass -> {
checkAllowedList(aClass);
return true;
},
ObjectInputFilter.Status.REJECTED));

try (objectInputStream) {
return objectInputStream.readObject();
}
catch (ClassNotFoundException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;

Expand Down Expand Up @@ -152,16 +152,14 @@ else if (object instanceof Serializable) {
* @throws IOException if creation of the ObjectInputStream failed
*/
protected ObjectInputStream createObjectInputStream(InputStream is) throws IOException {
return new ConfigurableObjectInputStream(is, this.classLoader) {

@Override
protected Class<?> resolveClass(ObjectStreamClass classDesc) throws IOException, ClassNotFoundException {
Class<?> clazz = super.resolveClass(classDesc);
checkAllowedList(clazz);
return clazz;
}

};
ObjectInputStream objectInputStream = new ConfigurableObjectInputStream(is, this.classLoader);
objectInputStream.setObjectInputFilter(
ObjectInputFilter.allowFilter(aClass -> {
checkAllowedList(aClass);
return true;
},
ObjectInputFilter.Status.REJECTED));
return objectInputStream;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2006-2023 the original author or authors.
* Copyright 2006-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -145,6 +145,7 @@ protected Class<?> resolveClass(ObjectStreamClass classDesc)
* Verify that the class is in the allowed list.
* @param clazz the class.
* @param patterns the patterns.
* @throws SecurityException if class to deserialized is not allowed
* @since 2.1
*/
public static void checkAllowedList(Class<?> clazz, Set<String> patterns) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2023 the original author or authors.
* Copyright 2016-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,22 +30,29 @@

/**
* @author Gary Russell
* @author Artem Bilan
* @since 1.5.5
*
*/
public class AllowedListDeserializingMessageConverterTests {

@Test
public void testAllowedList() throws Exception {
public void testAllowedList() {
SerializerMessageConverter converter = new SerializerMessageConverter();
TestBean testBean = new TestBean("foo");
Message message = converter.toMessage(testBean, new MessageProperties());
// when env var not set
// assertThatExceptionOfType(SecurityException.class).isThrownBy(() -> converter.fromMessage(message));
Object fromMessage;
// when env var set.
fromMessage = converter.fromMessage(message);
assertThat(fromMessage).isEqualTo(testBean);
// See build.gradle `tasks.withType(Test).all`
if ("true".equals(System.getenv("SPRING_AMQP_DESERIALIZATION_TRUST_ALL"))) {
fromMessage = converter.fromMessage(message);
assertThat(fromMessage).isEqualTo(testBean);
}
else {
assertThatExceptionOfType(MessageConversionException.class)
.isThrownBy(() -> converter.fromMessage(message))
.withRootCauseInstanceOf(SecurityException.class)
.withStackTraceContaining("Attempt to deserialize unauthorized");
}

converter.setAllowedListPatterns(Collections.singletonList("*"));
fromMessage = converter.fromMessage(message);
Expand All @@ -59,7 +66,10 @@ public void testAllowedList() throws Exception {
assertThat(fromMessage).isEqualTo(testBean);

converter.setAllowedListPatterns(Collections.singletonList("foo.*"));
assertThatExceptionOfType(SecurityException.class).isThrownBy(() -> converter.fromMessage(message));
assertThatExceptionOfType(MessageConversionException.class)
.isThrownBy(() -> converter.fromMessage(message))
.withRootCauseInstanceOf(SecurityException.class)
.withStackTraceContaining("Attempt to deserialize unauthorized");
}

@SuppressWarnings("serial")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;

import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -74,6 +75,7 @@ public void messageToBytes() {
@Test
public void messageToSerializedObject() throws Exception {
SerializerMessageConverter converter = new SerializerMessageConverter();
converter.setAllowedListPatterns(List.of("*"));
MessageProperties properties = new MessageProperties();
properties.setContentType(MessageProperties.CONTENT_TYPE_SERIALIZED_OBJECT);
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
Expand All @@ -92,6 +94,7 @@ public void messageToSerializedObject() throws Exception {
@Test
public void messageToSerializedObjectNoContentType() throws Exception {
SerializerMessageConverter converter = new SerializerMessageConverter();
converter.setAllowedListPatterns(List.of(TestBean.class.getName()));
converter.setIgnoreContentType(true);
MessageProperties properties = new MessageProperties();
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,8 @@
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;

import org.junit.jupiter.api.Test;

Expand All @@ -33,27 +35,28 @@
/**
* @author Mark Fisher
* @author Gary Russell
* @author Artem Bilan
*/
public class SimpleMessageConverterTests extends AllowedListDeserializingMessageConverterTests {

@Test
public void bytesAsDefaultMessageBodyType() throws Exception {
public void bytesAsDefaultMessageBodyType() {
SimpleMessageConverter converter = new SimpleMessageConverter();
Message message = new Message("test".getBytes(), new MessageProperties());
Object result = converter.fromMessage(message);
assertThat(result.getClass()).isEqualTo(byte[].class);
assertThat(new String((byte[]) result, "UTF-8")).isEqualTo("test");
assertThat(new String((byte[]) result, StandardCharsets.UTF_8)).isEqualTo("test");
}

@Test
public void noMessageIdByDefault() throws Exception {
public void noMessageIdByDefault() {
SimpleMessageConverter converter = new SimpleMessageConverter();
Message message = converter.toMessage("foo", null);
assertThat(message.getMessageProperties().getMessageId()).isNull();
}

@Test
public void optionalMessageId() throws Exception {
public void optionalMessageId() {
SimpleMessageConverter converter = new SimpleMessageConverter();
converter.setCreateMessageIds(true);
Message message = converter.toMessage("foo", null);
Expand Down Expand Up @@ -87,6 +90,7 @@ public void messageToBytes() {
@Test
public void messageToSerializedObject() throws Exception {
SimpleMessageConverter converter = new SimpleMessageConverter();
converter.setAllowedListPatterns(List.of("*"));
MessageProperties properties = new MessageProperties();
properties.setContentType(MessageProperties.CONTENT_TYPE_SERIALIZED_OBJECT);
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
Expand Down Expand Up @@ -114,7 +118,7 @@ public void stringToMessage() throws Exception {
}

@Test
public void bytesToMessage() throws Exception {
public void bytesToMessage() {
SimpleMessageConverter converter = new SimpleMessageConverter();
Message message = converter.toMessage(new byte[] { 1, 2, 3 }, new MessageProperties());
String contentType = message.getMessageProperties().getContentType();
Expand All @@ -140,7 +144,7 @@ public void serializedObjectToMessage() throws Exception {
}

@Test
public void messageConversionExceptionForClassNotFound() throws Exception {
public void messageConversionExceptionForClassNotFound() {
SimpleMessageConverter converter = new SimpleMessageConverter();
TestBean testBean = new TestBean("foo");
Message message = converter.toMessage(testBean, new MessageProperties());
Expand All @@ -163,7 +167,8 @@ class Foo {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertThat(e.getMessage()).contains("SimpleMessageConverter only supports String, byte[] and Serializable payloads, received:");
assertThat(e.getMessage())
.contains("SimpleMessageConverter only supports String, byte[] and Serializable payloads, received:");
}
}

Expand Down

0 comments on commit 16c9e8f

Please sign in to comment.