Skip to content

Commit

Permalink
Identify invalid IBDO data/array modifications.
Browse files Browse the repository at this point in the history
  • Loading branch information
James Cover jdcove2 committed Aug 22, 2023
1 parent 4bae59a commit 95fc212
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 6 deletions.
37 changes: 31 additions & 6 deletions src/main/java/emissary/core/BaseDataObject.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ public class BaseDataObject implements Serializable, Cloneable, Remote, IBaseDat
*/
protected SeekableByteChannelFactory seekableByteChannelFactory;

final SafeUsageChecker safeUsageChecker = new SafeUsageChecker();

protected enum DataState {
NO_DATA, CHANNEL_ONLY, BYTE_ARRAY_ONLY, BYTE_ARRAY_AND_CHANNEL
}
Expand Down Expand Up @@ -221,6 +223,18 @@ protected DataState getDataState() {
}
}

@Override
/**
* {@inheritDoc}
*/
public void checkForUnsafeDataChanges() {
safeUsageChecker.checkForUnsafeDataChanges();

if (theData != null) {
safeUsageChecker.recordSnapshot(theData);
}
}

/**
* Create an empty BaseDataObject.
*/
Expand Down Expand Up @@ -327,6 +341,10 @@ public void setChannelFactory(final SeekableByteChannelFactory sbcf) {
Validate.notNull(sbcf, "Required: SeekableByteChannelFactory not null");
this.theData = null;
this.seekableByteChannelFactory = sbcf;

// calls to setData clear the unsafe state by definition
// reset the usage checker but don't capture a snapshot until someone requests the data in byte[] form
safeUsageChecker.reset();
}

/**
Expand Down Expand Up @@ -378,7 +396,12 @@ public byte[] data() {
return theData;
case CHANNEL_ONLY:
// Max size here is slightly less than the true max size to avoid memory issues
return SeekableByteChannelHelper.getByteArrayFromBdo(this, MAX_BYTE_ARRAY_SIZE);
final byte[] bytes = SeekableByteChannelHelper.getByteArrayFromBdo(this, MAX_BYTE_ARRAY_SIZE);

// capture a reference to the returned byte[] so we can test for unsafe modifications of its contents
safeUsageChecker.recordSnapshot(bytes);

return bytes;
case NO_DATA:
default:
return null; // NOSONAR maintains backwards compatibility
Expand All @@ -391,11 +414,10 @@ public byte[] data() {
@Override
public void setData(@Nullable final byte[] newData) {
this.seekableByteChannelFactory = null;
if (newData == null) {
this.theData = new byte[0];
} else {
this.theData = newData;
}
this.theData = newData == null ? new byte[0] : newData;

// calls to setData clear the unsafe state by definition, but we need to capture a new snapshot
safeUsageChecker.resetCacheThenRecordSnapshot(theData);
}

/**
Expand All @@ -422,6 +444,9 @@ public void setData(@Nullable final byte[] newData, final int offset, final int
this.theData = new byte[length];
System.arraycopy(newData, offset, this.theData, 0, length);
}

// calls to setData clear the unsafe state by definition, but we need to capture a new snapshot
safeUsageChecker.resetCacheThenRecordSnapshot(theData);
}

/**
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/emissary/core/IBaseDataObject.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ static enum MergePolicy {
*/
String DEFAULT_PARAM_SEPARATOR = ";";

/**
* Checks to see if payload byte arrays visible to external classes have any changes not explicitly saved via a call to
* the {@link IBaseDataObject#setData(byte[]) setData(byte[])}, {@link IBaseDataObject#setData(byte[], int, int)
* setData(byte[], int, int)}, or {@link IBaseDataObject#setChannelFactory(SeekableByteChannelFactory)
* setChannelFactory(SeekableByteChannelFactory)} method.
*/
void checkForUnsafeDataChanges();

/**
* Return the data as a byte array. If using a channel to the data, calling this method will only return up to
* Integer.MAX_VALUE bytes of the original data.
Expand Down
70 changes: 70 additions & 0 deletions src/main/java/emissary/core/SafeUsageChecker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package emissary.core;

import emissary.core.channels.SeekableByteChannelFactory;
import emissary.util.ByteUtil;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

/**
* Utility for validating that Places safely interact with IBDO payloads in byte array form. Specifically, this class
* helps with validating that changes to a IBDO's payload are followed by a call to the
* {@link IBaseDataObject#setData(byte[]) setData(byte[])}, {@link IBaseDataObject#setData(byte[], int, int)
* setData(byte[], int, int)}, or {@link IBaseDataObject#setChannelFactory(SeekableByteChannelFactory)
* setChannelFactory(SeekableByteChannelFactory)} method.
*/
public class SafeUsageChecker {
static final Logger logger = LoggerFactory.getLogger(SafeUsageChecker.class);
public static final String UNSAFE_MODIFICATION_DETECTED = "Detected unsafe changes to IBDO byte array contents";

/**
* Cache that records each {@literal byte[]} reference made available to IBDO clients, along with a sha256 hash of the
* array contents. Used for determining whether the clients modify the array contents without explicitly pushing those
* changes back to the IBDO
*/
private final Map<byte[], String> cache = new HashMap<>();

/**
* Resets the snapshot cache
*/
public void reset() {
cache.clear();
}

/**
* Stores a new integrity snapshot
*
* @param bytes byte[] for which a snapshot should be captured
*/
public void recordSnapshot(final byte[] bytes) {
cache.put(bytes, ByteUtil.sha256Bytes(bytes));
}


/**
* Resets the cache and stores a new integrity snapshot
*
* @param bytes byte[] for which a snapshot should be captured
*/
public void resetCacheThenRecordSnapshot(final byte[] bytes) {
reset();
recordSnapshot(bytes);
}

/**
* Uses the snapshot cache to determine whether any of the byte arrays have unsaved changes
*
* @return boolean indication of unsafe changes
*/
public boolean checkForUnsafeDataChanges() {
boolean isUnsafe = cache.entrySet().stream().anyMatch(e -> !ByteUtil.sha256Bytes(e.getKey()).equals(e.getValue()));
if (isUnsafe) {
logger.warn(UNSAFE_MODIFICATION_DETECTED);
}
reset();
return isUnsafe;
}
}
1 change: 1 addition & 0 deletions src/main/java/emissary/place/ServiceProviderPlace.java
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ public List<IBaseDataObject> agentProcessHeavyDuty(IBaseDataObject payload) thro
MDC.put(MDCConstants.SERVICE_LOCATION, this.getKey());
try {
List<IBaseDataObject> l = processHeavyDuty(payload);
payload.checkForUnsafeDataChanges();
rehash(payload);
return l;
} catch (Exception e) {
Expand Down
136 changes: 136 additions & 0 deletions src/test/java/emissary/core/BaseDataObjectTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import emissary.core.channels.SeekableByteChannelHelper;
import emissary.directory.DirectoryEntry;
import emissary.pickup.Priority;
import emissary.test.core.junit5.LogbackTester;
import emissary.test.core.junit5.UnitTest;

import ch.qos.logback.classic.Level;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -25,6 +27,7 @@
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -38,6 +41,7 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;

import static emissary.core.SafeUsageChecker.UNSAFE_MODIFICATION_DETECTED;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -1319,4 +1323,136 @@ void testExtractedRecordClone() {
fail("Clone method should have been called", ex);
}
}

@Test
void testChannelFactoryInArrayOutNoSet() throws IOException {
final byte[] bytes = "These are the test bytes!".getBytes(StandardCharsets.US_ASCII);
final Level[] levels = new Level[] {Level.WARN};
final String[] messages = new String[] {UNSAFE_MODIFICATION_DETECTED};
final boolean[] throwables = new boolean[] {false};

try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(bytes));

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(bytes, ibdo.data());
logbackTester.checkLogList(levels, messages, throwables);
}
}

final static byte[] DATA_MODIFICATION_BYTES = "These are the test bytes!".getBytes(StandardCharsets.US_ASCII);

@Test
void testChannelFactoryInTwoArrayOutNoSet1() throws IOException {
final Level[] levels = new Level[] {Level.WARN};
final String[] messages = new String[] {UNSAFE_MODIFICATION_DETECTED};
final boolean[] throwables = new boolean[] {false};

try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data0 = ibdo.data();
final byte[] data1 = ibdo.data();

Arrays.fill(data1, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(levels, messages, throwables);
}
}

@Test
void testChannelFactoryInTwoArrayOutNoSet2() throws IOException {
final Level[] levels = new Level[] {Level.WARN};
final String[] messages = new String[] {UNSAFE_MODIFICATION_DETECTED};
final boolean[] throwables = new boolean[] {false};

try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data0 = ibdo.data();
final byte[] data1 = ibdo.data();

Arrays.fill(data0, (byte) 0);
Arrays.fill(data1, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(levels, messages, throwables);
}
}

@Test
void testChannelFactoryInArrayOutSetChannelFactory() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);
ibdo.setChannelFactory(InMemoryChannelFactory.create(data));

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(new byte[DATA_MODIFICATION_BYTES.length], ibdo.data());
logbackTester.checkLogList(new Level[0], new String[0], new boolean[0]);
}
}

@Test
void testChannelFactoryInArrayOutSetArray() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);
ibdo.setData(data);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(new byte[DATA_MODIFICATION_BYTES.length], ibdo.data());
logbackTester.checkLogList(new Level[0], new String[0], new boolean[0]);
}
}

@Test
void testArrayInArrayOutNoSet() throws IOException {
final Level[] levels = new Level[] {Level.WARN};
final String[] messages = new String[] {UNSAFE_MODIFICATION_DETECTED};
final boolean[] throwables = new boolean[] {false};

try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setData(DATA_MODIFICATION_BYTES);

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(levels, messages, throwables);
}
}
}
55 changes: 55 additions & 0 deletions src/test/java/emissary/test/core/junit5/LogbackTester.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package emissary.test.core.junit5;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import org.apache.commons.lang3.Validate;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.io.IOException;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class LogbackTester implements Closeable {
public final String name;
public final Logger logger;
public final ListAppender<ILoggingEvent> appender;

public LogbackTester(final String name) {
Validate.notNull(name, "Required: name != null");

this.name = name;
logger = (Logger) LoggerFactory.getLogger(name);
appender = new ListAppender<>();

appender.setContext(logger.getLoggerContext());
appender.start();
logger.addAppender(appender);
logger.setAdditive(false);
}

public void checkLogList(final Level[] levels, final String[] messages, final boolean[] throwables) {
Validate.notNull(levels, "Required: levels != null");
Validate.notNull(messages, "Required: messages != null");
Validate.notNull(throwables, "Required: throwables != null");
Validate.isTrue(levels.length == messages.length, "Required: levels.length == messages.length");
Validate.isTrue(levels.length == throwables.length, "Required: levels.length == throwables.length");

assertEquals(levels.length, appender.list.size(), "Expected lengths do not match number of log messages");

for (int i = 0; i < appender.list.size(); i++) {
final ILoggingEvent item = appender.list.get(i);

assertEquals(levels[i], item.getLevel(), "Levels not equal for element " + i);
assertEquals(messages[i], item.getFormattedMessage(), "Messages not equal for element " + i);
assertEquals(throwables[i], item.getThrowableProxy() != null, "Throwables not equal for elmeent " + i);
}
}

@Override
public void close() throws IOException {
logger.detachAndStopAllAppenders();
}
}

0 comments on commit 95fc212

Please sign in to comment.