diff --git a/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SerializerMessageConverter.java b/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SerializerMessageConverter.java index cfe8e923d5..1a84d9801f 100644 --- a/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SerializerMessageConverter.java +++ b/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SerializerMessageConverter.java @@ -19,8 +19,8 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.ObjectInputFilter; import java.io.ObjectInputStream; -import java.io.ObjectStreamClass; import java.io.UnsupportedEncodingException; import java.nio.charset.StandardCharsets; @@ -160,18 +160,16 @@ private Object asString(Message message, MessageProperties properties) { } private Object deserialize(ByteArrayInputStream inputStream) throws IOException { - try (ObjectInputStream objectInputStream = new ConfigurableObjectInputStream(inputStream, - this.defaultDeserializerClassLoader) { - - @Override - protected Class resolveClass(ObjectStreamClass classDesc) - throws IOException, ClassNotFoundException { - Class clazz = super.resolveClass(classDesc); - checkAllowedList(clazz); - return clazz; - } - - }) { + ObjectInputStream objectInputStream = + new ConfigurableObjectInputStream(inputStream, this.defaultDeserializerClassLoader); + objectInputStream.setObjectInputFilter( + ObjectInputFilter.allowFilter(aClass -> { + checkAllowedList(aClass); + return true; + }, + ObjectInputFilter.Status.REJECTED)); + + try (objectInputStream) { return objectInputStream.readObject(); } catch (ClassNotFoundException ex) { diff --git a/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SimpleMessageConverter.java b/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SimpleMessageConverter.java index 56897f80fb..3fa32cabab 100644 --- a/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SimpleMessageConverter.java +++ b/spring-amqp/src/main/java/org/springframework/amqp/support/converter/SimpleMessageConverter.java @@ -19,8 +19,8 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputFilter; import java.io.ObjectInputStream; -import java.io.ObjectStreamClass; import java.io.Serializable; import java.io.UnsupportedEncodingException; @@ -152,16 +152,14 @@ else if (object instanceof Serializable) { * @throws IOException if creation of the ObjectInputStream failed */ protected ObjectInputStream createObjectInputStream(InputStream is) throws IOException { - return new ConfigurableObjectInputStream(is, this.classLoader) { - - @Override - protected Class resolveClass(ObjectStreamClass classDesc) throws IOException, ClassNotFoundException { - Class clazz = super.resolveClass(classDesc); - checkAllowedList(clazz); - return clazz; - } - - }; + ObjectInputStream objectInputStream = new ConfigurableObjectInputStream(is, this.classLoader); + objectInputStream.setObjectInputFilter( + ObjectInputFilter.allowFilter(aClass -> { + checkAllowedList(aClass); + return true; + }, + ObjectInputFilter.Status.REJECTED)); + return objectInputStream; } } diff --git a/spring-amqp/src/main/java/org/springframework/amqp/utils/SerializationUtils.java b/spring-amqp/src/main/java/org/springframework/amqp/utils/SerializationUtils.java index 6e9cb8de8c..d8c7a69501 100644 --- a/spring-amqp/src/main/java/org/springframework/amqp/utils/SerializationUtils.java +++ b/spring-amqp/src/main/java/org/springframework/amqp/utils/SerializationUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2006-2023 the original author or authors. + * Copyright 2006-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -145,6 +145,7 @@ protected Class resolveClass(ObjectStreamClass classDesc) * Verify that the class is in the allowed list. * @param clazz the class. * @param patterns the patterns. + * @throws SecurityException if class to deserialized is not allowed * @since 2.1 */ public static void checkAllowedList(Class clazz, Set patterns) { diff --git a/spring-amqp/src/test/java/org/springframework/amqp/support/converter/AllowedListDeserializingMessageConverterTests.java b/spring-amqp/src/test/java/org/springframework/amqp/support/converter/AllowedListDeserializingMessageConverterTests.java index e56a568a16..b7412f5207 100644 --- a/spring-amqp/src/test/java/org/springframework/amqp/support/converter/AllowedListDeserializingMessageConverterTests.java +++ b/spring-amqp/src/test/java/org/springframework/amqp/support/converter/AllowedListDeserializingMessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2016-2023 the original author or authors. + * Copyright 2016-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,22 +30,29 @@ /** * @author Gary Russell + * @author Artem Bilan * @since 1.5.5 * */ public class AllowedListDeserializingMessageConverterTests { @Test - public void testAllowedList() throws Exception { + public void testAllowedList() { SerializerMessageConverter converter = new SerializerMessageConverter(); TestBean testBean = new TestBean("foo"); Message message = converter.toMessage(testBean, new MessageProperties()); - // when env var not set -// assertThatExceptionOfType(SecurityException.class).isThrownBy(() -> converter.fromMessage(message)); Object fromMessage; - // when env var set. - fromMessage = converter.fromMessage(message); - assertThat(fromMessage).isEqualTo(testBean); + // See build.gradle `tasks.withType(Test).all` + if ("true".equals(System.getenv("SPRING_AMQP_DESERIALIZATION_TRUST_ALL"))) { + fromMessage = converter.fromMessage(message); + assertThat(fromMessage).isEqualTo(testBean); + } + else { + assertThatExceptionOfType(MessageConversionException.class) + .isThrownBy(() -> converter.fromMessage(message)) + .withRootCauseInstanceOf(SecurityException.class) + .withStackTraceContaining("Attempt to deserialize unauthorized"); + } converter.setAllowedListPatterns(Collections.singletonList("*")); fromMessage = converter.fromMessage(message); @@ -59,7 +66,10 @@ public void testAllowedList() throws Exception { assertThat(fromMessage).isEqualTo(testBean); converter.setAllowedListPatterns(Collections.singletonList("foo.*")); - assertThatExceptionOfType(SecurityException.class).isThrownBy(() -> converter.fromMessage(message)); + assertThatExceptionOfType(MessageConversionException.class) + .isThrownBy(() -> converter.fromMessage(message)) + .withRootCauseInstanceOf(SecurityException.class) + .withStackTraceContaining("Attempt to deserialize unauthorized"); } @SuppressWarnings("serial") diff --git a/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SerializerMessageConverterTests.java b/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SerializerMessageConverterTests.java index 026fe7a308..39b728e2c3 100644 --- a/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SerializerMessageConverterTests.java +++ b/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SerializerMessageConverterTests.java @@ -25,6 +25,7 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.jupiter.api.Test; @@ -74,6 +75,7 @@ public void messageToBytes() { @Test public void messageToSerializedObject() throws Exception { SerializerMessageConverter converter = new SerializerMessageConverter(); + converter.setAllowedListPatterns(List.of("*")); MessageProperties properties = new MessageProperties(); properties.setContentType(MessageProperties.CONTENT_TYPE_SERIALIZED_OBJECT); ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); @@ -92,6 +94,7 @@ public void messageToSerializedObject() throws Exception { @Test public void messageToSerializedObjectNoContentType() throws Exception { SerializerMessageConverter converter = new SerializerMessageConverter(); + converter.setAllowedListPatterns(List.of(TestBean.class.getName())); converter.setIgnoreContentType(true); MessageProperties properties = new MessageProperties(); ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); diff --git a/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SimpleMessageConverterTests.java b/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SimpleMessageConverterTests.java index ec6e72d9fb..14e58f042d 100644 --- a/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SimpleMessageConverterTests.java +++ b/spring-amqp/src/test/java/org/springframework/amqp/support/converter/SimpleMessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.jupiter.api.Test; @@ -33,27 +35,28 @@ /** * @author Mark Fisher * @author Gary Russell + * @author Artem Bilan */ public class SimpleMessageConverterTests extends AllowedListDeserializingMessageConverterTests { @Test - public void bytesAsDefaultMessageBodyType() throws Exception { + public void bytesAsDefaultMessageBodyType() { SimpleMessageConverter converter = new SimpleMessageConverter(); Message message = new Message("test".getBytes(), new MessageProperties()); Object result = converter.fromMessage(message); assertThat(result.getClass()).isEqualTo(byte[].class); - assertThat(new String((byte[]) result, "UTF-8")).isEqualTo("test"); + assertThat(new String((byte[]) result, StandardCharsets.UTF_8)).isEqualTo("test"); } @Test - public void noMessageIdByDefault() throws Exception { + public void noMessageIdByDefault() { SimpleMessageConverter converter = new SimpleMessageConverter(); Message message = converter.toMessage("foo", null); assertThat(message.getMessageProperties().getMessageId()).isNull(); } @Test - public void optionalMessageId() throws Exception { + public void optionalMessageId() { SimpleMessageConverter converter = new SimpleMessageConverter(); converter.setCreateMessageIds(true); Message message = converter.toMessage("foo", null); @@ -87,6 +90,7 @@ public void messageToBytes() { @Test public void messageToSerializedObject() throws Exception { SimpleMessageConverter converter = new SimpleMessageConverter(); + converter.setAllowedListPatterns(List.of("*")); MessageProperties properties = new MessageProperties(); properties.setContentType(MessageProperties.CONTENT_TYPE_SERIALIZED_OBJECT); ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); @@ -114,7 +118,7 @@ public void stringToMessage() throws Exception { } @Test - public void bytesToMessage() throws Exception { + public void bytesToMessage() { SimpleMessageConverter converter = new SimpleMessageConverter(); Message message = converter.toMessage(new byte[] { 1, 2, 3 }, new MessageProperties()); String contentType = message.getMessageProperties().getContentType(); @@ -140,7 +144,7 @@ public void serializedObjectToMessage() throws Exception { } @Test - public void messageConversionExceptionForClassNotFound() throws Exception { + public void messageConversionExceptionForClassNotFound() { SimpleMessageConverter converter = new SimpleMessageConverter(); TestBean testBean = new TestBean("foo"); Message message = converter.toMessage(testBean, new MessageProperties()); @@ -163,7 +167,8 @@ class Foo { fail("Expected exception"); } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("SimpleMessageConverter only supports String, byte[] and Serializable payloads, received:"); + assertThat(e.getMessage()) + .contains("SimpleMessageConverter only supports String, byte[] and Serializable payloads, received:"); } }