diff --git a/src/main/config/emissary.core.SafeUsageChecker.cfg b/src/main/config/emissary.core.SafeUsageChecker.cfg new file mode 100644 index 0000000000..654e7805a2 --- /dev/null +++ b/src/main/config/emissary.core.SafeUsageChecker.cfg @@ -0,0 +1 @@ +ENABLED = TRUE \ No newline at end of file diff --git a/src/main/java/emissary/core/BaseDataObject.java b/src/main/java/emissary/core/BaseDataObject.java index f6c4141738..cef9434604 100755 --- a/src/main/java/emissary/core/BaseDataObject.java +++ b/src/main/java/emissary/core/BaseDataObject.java @@ -182,6 +182,8 @@ public class BaseDataObject implements Serializable, Cloneable, Remote, IBaseDat */ protected SeekableByteChannelFactory seekableByteChannelFactory; + final SafeUsageChecker safeUsageChecker = new SafeUsageChecker(); + protected enum DataState { NO_DATA, CHANNEL_ONLY, BYTE_ARRAY_ONLY, BYTE_ARRAY_AND_CHANNEL } @@ -221,6 +223,18 @@ protected DataState getDataState() { } } + @Override + /** + * {@inheritDoc} + */ + public void checkForUnsafeDataChanges() { + safeUsageChecker.checkForUnsafeDataChanges(); + + if (theData != null) { + safeUsageChecker.recordSnapshot(theData); + } + } + /** * Create an empty BaseDataObject. */ @@ -327,6 +341,10 @@ public void setChannelFactory(final SeekableByteChannelFactory sbcf) { Validate.notNull(sbcf, "Required: SeekableByteChannelFactory not null"); this.theData = null; this.seekableByteChannelFactory = sbcf; + + // calls to setData clear the unsafe state by definition + // reset the usage checker but don't capture a snapshot until someone requests the data in byte[] form + safeUsageChecker.reset(); } /** @@ -378,7 +396,12 @@ public byte[] data() { return theData; case CHANNEL_ONLY: // Max size here is slightly less than the true max size to avoid memory issues - return SeekableByteChannelHelper.getByteArrayFromBdo(this, MAX_BYTE_ARRAY_SIZE); + final byte[] bytes = SeekableByteChannelHelper.getByteArrayFromBdo(this, MAX_BYTE_ARRAY_SIZE); + + // capture a reference to the returned byte[] so we can test for unsafe modifications of its contents + safeUsageChecker.recordSnapshot(bytes); + + return bytes; case NO_DATA: default: return null; // NOSONAR maintains backwards compatibility @@ -391,11 +414,10 @@ public byte[] data() { @Override public void setData(@Nullable final byte[] newData) { this.seekableByteChannelFactory = null; - if (newData == null) { - this.theData = new byte[0]; - } else { - this.theData = newData; - } + this.theData = newData == null ? new byte[0] : newData; + + // calls to setData clear the unsafe state by definition, but we need to capture a new snapshot + safeUsageChecker.resetCacheThenRecordSnapshot(theData); } /** @@ -422,6 +444,9 @@ public void setData(@Nullable final byte[] newData, final int offset, final int this.theData = new byte[length]; System.arraycopy(newData, offset, this.theData, 0, length); } + + // calls to setData clear the unsafe state by definition, but we need to capture a new snapshot + safeUsageChecker.resetCacheThenRecordSnapshot(theData); } /** diff --git a/src/main/java/emissary/core/IBaseDataObject.java b/src/main/java/emissary/core/IBaseDataObject.java index 3f0122c50e..fca8cca582 100755 --- a/src/main/java/emissary/core/IBaseDataObject.java +++ b/src/main/java/emissary/core/IBaseDataObject.java @@ -26,6 +26,14 @@ enum MergePolicy { */ String DEFAULT_PARAM_SEPARATOR = ";"; + /** + * Checks to see if payload byte arrays visible to external classes have any changes not explicitly saved via a call to + * the {@link IBaseDataObject#setData(byte[]) setData(byte[])}, {@link IBaseDataObject#setData(byte[], int, int) + * setData(byte[], int, int)}, or {@link IBaseDataObject#setChannelFactory(SeekableByteChannelFactory) + * setChannelFactory(SeekableByteChannelFactory)} method. + */ + void checkForUnsafeDataChanges(); + /** * Return the data as a byte array. If using a channel to the data, calling this method will only return up to * Integer.MAX_VALUE bytes of the original data. diff --git a/src/main/java/emissary/core/SafeUsageChecker.java b/src/main/java/emissary/core/SafeUsageChecker.java new file mode 100644 index 0000000000..31a3d6055a --- /dev/null +++ b/src/main/java/emissary/core/SafeUsageChecker.java @@ -0,0 +1,100 @@ +package emissary.core; + +import emissary.config.ConfigUtil; +import emissary.config.Configurator; +import emissary.core.channels.SeekableByteChannelFactory; +import emissary.util.ByteUtil; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Utility for validating that Places safely interact with IBDO payloads in byte array form. Specifically, this class + * helps with validating that changes to a IBDO's payload are followed by a call to the + * {@link IBaseDataObject#setData(byte[]) setData(byte[])}, {@link IBaseDataObject#setData(byte[], int, int) + * setData(byte[], int, int)}, or {@link IBaseDataObject#setChannelFactory(SeekableByteChannelFactory) + * setChannelFactory(SeekableByteChannelFactory)} method. + */ +public class SafeUsageChecker { + protected static final Logger LOGGER = LoggerFactory.getLogger(SafeUsageChecker.class); + + public static final boolean ENABLED; + public static final String UNSAFE_MODIFICATION_DETECTED = "Detected unsafe changes to IBDO byte array contents"; + + static { + boolean enabled = false; + + try { + Configurator configurator = ConfigUtil.getConfigInfo(SafeUsageChecker.class); + + enabled = configurator.findBooleanEntry("ENABLED", enabled); + } catch (IOException e) { + LOGGER.info("Could not get configuration!", e); + } + + ENABLED = enabled; + } + + /** + * Cache that records each {@literal byte[]} reference made available to IBDO clients, along with a sha256 hash of the + * array contents. Used for determining whether the clients modify the array contents without explicitly pushing those + * changes back to the IBDO + */ + private final Map cache = new HashMap<>(); + + /** + * Resets the snapshot cache + */ + public void reset() { + if (ENABLED) { + cache.clear(); + } + } + + /** + * Stores a new integrity snapshot + * + * @param bytes byte[] for which a snapshot should be captured + */ + public void recordSnapshot(final byte[] bytes) { + if (ENABLED) { + cache.put(bytes, ByteUtil.sha256Bytes(bytes)); + } + } + + + /** + * Resets the cache and stores a new integrity snapshot + * + * @param bytes byte[] for which a snapshot should be captured + */ + public void resetCacheThenRecordSnapshot(final byte[] bytes) { + if (ENABLED) { + reset(); + recordSnapshot(bytes); + } + } + + /** + * Uses the snapshot cache to determine whether any of the byte arrays have unsaved changes + * + * @return boolean indication of unsafe changes + */ + public boolean checkForUnsafeDataChanges() { + if (ENABLED) { + boolean isUnsafe = cache.entrySet().stream().anyMatch(e -> !ByteUtil.sha256Bytes(e.getKey()).equals(e.getValue())); + if (isUnsafe) { + LOGGER.warn(UNSAFE_MODIFICATION_DETECTED); + } + reset(); + + return isUnsafe; + } else { + return false; + } + } +} diff --git a/src/main/java/emissary/place/ServiceProviderPlace.java b/src/main/java/emissary/place/ServiceProviderPlace.java index 62bb39aed1..021fecf8d3 100755 --- a/src/main/java/emissary/place/ServiceProviderPlace.java +++ b/src/main/java/emissary/place/ServiceProviderPlace.java @@ -581,6 +581,7 @@ public List agentProcessHeavyDuty(IBaseDataObject payload) thro MDC.put(MDCConstants.SERVICE_LOCATION, this.getKey()); try { List l = processHeavyDuty(payload); + payload.checkForUnsafeDataChanges(); rehash(payload); return l; } catch (Exception e) { diff --git a/src/test/java/emissary/core/BaseDataObjectTest.java b/src/test/java/emissary/core/BaseDataObjectTest.java index 3faf288c17..c7aa8c029b 100755 --- a/src/test/java/emissary/core/BaseDataObjectTest.java +++ b/src/test/java/emissary/core/BaseDataObjectTest.java @@ -8,8 +8,10 @@ import emissary.core.channels.SeekableByteChannelHelper; import emissary.directory.DirectoryEntry; import emissary.pickup.Priority; +import emissary.test.core.junit5.LogbackTester; import emissary.test.core.junit5.UnitTest; +import ch.qos.logback.classic.Level; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; import org.junit.jupiter.api.AfterEach; @@ -25,6 +27,7 @@ import java.lang.reflect.Field; import java.nio.ByteBuffer; import java.nio.channels.SeekableByteChannel; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -38,6 +41,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Stream; +import static emissary.core.SafeUsageChecker.UNSAFE_MODIFICATION_DETECTED; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -1319,4 +1323,122 @@ void testExtractedRecordClone() { fail("Clone method should have been called", ex); } } + + static final byte[] DATA_MODIFICATION_BYTES = "These are the test bytes!".getBytes(StandardCharsets.US_ASCII); + static final Level[] LEVELS_ONE_WARN = new Level[] {Level.WARN}; + static final String[] ONE_UNSAFE_MODIFICATION_DETECTED = new String[] {UNSAFE_MODIFICATION_DETECTED}; + static final boolean[] NO_THROWABLES = new boolean[] {false}; + + @Test + void testChannelFactoryInArrayOutNoSet() throws IOException { + try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) { + final IBaseDataObject ibdo = new BaseDataObject(); + + ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES)); + + final byte[] data = ibdo.data(); + + Arrays.fill(data, (byte) 0); + + ibdo.checkForUnsafeDataChanges(); + + assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data()); + logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES); + } + } + + @Test + void shouldDetectUnsafeChangesIfArrayChangesNotFollowedByOneSet() throws IOException { + try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) { + final IBaseDataObject ibdo = new BaseDataObject(); + + ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES)); + + byte[] data = ibdo.data(); + data = ibdo.data(); + + Arrays.fill(data, (byte) 0); + + ibdo.checkForUnsafeDataChanges(); + + assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data()); + logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES); + } + } + + @Test + void shouldDetectUnsafeChangesIfArrayChangesNotFollowedByBothSet() throws IOException { + try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) { + final IBaseDataObject ibdo = new BaseDataObject(); + + ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES)); + + final byte[] data0 = ibdo.data(); + final byte[] data1 = ibdo.data(); + + Arrays.fill(data0, (byte) 0); + Arrays.fill(data1, (byte) 0); + + ibdo.checkForUnsafeDataChanges(); + + assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data()); + logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES); + } + } + + @Test + void shouldDetectNoUnsafeChangesImmediatelyAfterSetChannelFactory() throws IOException { + try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) { + final IBaseDataObject ibdo = new BaseDataObject(); + + ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES)); + + final byte[] data = ibdo.data(); + + Arrays.fill(data, (byte) 0); + ibdo.setChannelFactory(InMemoryChannelFactory.create(data)); + + ibdo.checkForUnsafeDataChanges(); + + assertArrayEquals(new byte[DATA_MODIFICATION_BYTES.length], ibdo.data()); + logbackTester.checkLogList(new Level[0], new String[0], new boolean[0]); + } + } + + @Test + void shouldDetectNoUnsafeChangesImmediatelyAfterSetData() throws IOException { + try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) { + final IBaseDataObject ibdo = new BaseDataObject(); + + ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES)); + + final byte[] data = ibdo.data(); + + Arrays.fill(data, (byte) 0); + ibdo.setData(data); + + ibdo.checkForUnsafeDataChanges(); + + assertArrayEquals(new byte[DATA_MODIFICATION_BYTES.length], ibdo.data()); + logbackTester.checkLogList(new Level[0], new String[0], new boolean[0]); + } + } + + @Test + void testArrayInArrayOutNoSet() throws IOException { + try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) { + final IBaseDataObject ibdo = new BaseDataObject(); + + ibdo.setData(DATA_MODIFICATION_BYTES); + + final byte[] data = ibdo.data(); + + Arrays.fill(data, (byte) 0); + + ibdo.checkForUnsafeDataChanges(); + + assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data()); + logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES); + } + } } diff --git a/src/test/java/emissary/test/core/junit5/LogbackTester.java b/src/test/java/emissary/test/core/junit5/LogbackTester.java new file mode 100644 index 0000000000..d0cef883f8 --- /dev/null +++ b/src/test/java/emissary/test/core/junit5/LogbackTester.java @@ -0,0 +1,55 @@ +package emissary.test.core.junit5; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import org.apache.commons.lang3.Validate; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class LogbackTester implements Closeable { + public final String name; + public final Logger logger; + public final ListAppender appender; + + public LogbackTester(final String name) { + Validate.notNull(name, "Required: name != null"); + + this.name = name; + logger = (Logger) LoggerFactory.getLogger(name); + appender = new ListAppender<>(); + + appender.setContext(logger.getLoggerContext()); + appender.start(); + logger.addAppender(appender); + logger.setAdditive(false); + } + + public void checkLogList(final Level[] levels, final String[] messages, final boolean[] throwables) { + Validate.notNull(levels, "Required: levels != null"); + Validate.notNull(messages, "Required: messages != null"); + Validate.notNull(throwables, "Required: throwables != null"); + Validate.isTrue(levels.length == messages.length, "Required: levels.length == messages.length"); + Validate.isTrue(levels.length == throwables.length, "Required: levels.length == throwables.length"); + + assertEquals(levels.length, appender.list.size(), "Expected lengths do not match number of log messages"); + + for (int i = 0; i < appender.list.size(); i++) { + final ILoggingEvent item = appender.list.get(i); + + assertEquals(levels[i], item.getLevel(), "Levels not equal for element " + i); + assertEquals(messages[i], item.getFormattedMessage(), "Messages not equal for element " + i); + assertEquals(throwables[i], item.getThrowableProxy() != null, "Throwables not equal for elmeent " + i); + } + } + + @Override + public void close() throws IOException { + logger.detachAndStopAllAppenders(); + } +}