Skip to content

Commit

Permalink
Identify invalid IBDO data/array modifications.
Browse files Browse the repository at this point in the history
  • Loading branch information
James Cover jdcove2 committed Sep 20, 2023
1 parent 0f11c9e commit 9c1e6e6
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/main/config/emissary.core.SafeUsageChecker.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ENABLED = TRUE
37 changes: 31 additions & 6 deletions src/main/java/emissary/core/BaseDataObject.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ public class BaseDataObject implements Serializable, Cloneable, Remote, IBaseDat
*/
protected SeekableByteChannelFactory seekableByteChannelFactory;

final SafeUsageChecker safeUsageChecker = new SafeUsageChecker();

protected enum DataState {
NO_DATA, CHANNEL_ONLY, BYTE_ARRAY_ONLY, BYTE_ARRAY_AND_CHANNEL
}
Expand Down Expand Up @@ -221,6 +223,18 @@ protected DataState getDataState() {
}
}

@Override
/**
* {@inheritDoc}
*/
public void checkForUnsafeDataChanges() {
safeUsageChecker.checkForUnsafeDataChanges();

if (theData != null) {
safeUsageChecker.recordSnapshot(theData);
}
}

/**
* Create an empty BaseDataObject.
*/
Expand Down Expand Up @@ -327,6 +341,10 @@ public void setChannelFactory(final SeekableByteChannelFactory sbcf) {
Validate.notNull(sbcf, "Required: SeekableByteChannelFactory not null");
this.theData = null;
this.seekableByteChannelFactory = sbcf;

// calls to setData clear the unsafe state by definition
// reset the usage checker but don't capture a snapshot until someone requests the data in byte[] form
safeUsageChecker.reset();
}

/**
Expand Down Expand Up @@ -378,7 +396,12 @@ public byte[] data() {
return theData;
case CHANNEL_ONLY:
// Max size here is slightly less than the true max size to avoid memory issues
return SeekableByteChannelHelper.getByteArrayFromBdo(this, MAX_BYTE_ARRAY_SIZE);
final byte[] bytes = SeekableByteChannelHelper.getByteArrayFromBdo(this, MAX_BYTE_ARRAY_SIZE);

// capture a reference to the returned byte[] so we can test for unsafe modifications of its contents
safeUsageChecker.recordSnapshot(bytes);

return bytes;
case NO_DATA:
default:
return null; // NOSONAR maintains backwards compatibility
Expand All @@ -391,11 +414,10 @@ public byte[] data() {
@Override
public void setData(@Nullable final byte[] newData) {
this.seekableByteChannelFactory = null;
if (newData == null) {
this.theData = new byte[0];
} else {
this.theData = newData;
}
this.theData = newData == null ? new byte[0] : newData;

// calls to setData clear the unsafe state by definition, but we need to capture a new snapshot
safeUsageChecker.resetCacheThenRecordSnapshot(theData);
}

/**
Expand All @@ -422,6 +444,9 @@ public void setData(@Nullable final byte[] newData, final int offset, final int
this.theData = new byte[length];
System.arraycopy(newData, offset, this.theData, 0, length);
}

// calls to setData clear the unsafe state by definition, but we need to capture a new snapshot
safeUsageChecker.resetCacheThenRecordSnapshot(theData);
}

/**
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/emissary/core/IBaseDataObject.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ enum MergePolicy {
*/
String DEFAULT_PARAM_SEPARATOR = ";";

/**
* Checks to see if payload byte arrays visible to external classes have any changes not explicitly saved via a call to
* the {@link IBaseDataObject#setData(byte[]) setData(byte[])}, {@link IBaseDataObject#setData(byte[], int, int)
* setData(byte[], int, int)}, or {@link IBaseDataObject#setChannelFactory(SeekableByteChannelFactory)
* setChannelFactory(SeekableByteChannelFactory)} method.
*/
void checkForUnsafeDataChanges();

/**
* Return the data as a byte array. If using a channel to the data, calling this method will only return up to
* Integer.MAX_VALUE bytes of the original data.
Expand Down
105 changes: 105 additions & 0 deletions src/main/java/emissary/core/SafeUsageChecker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package emissary.core;

import emissary.config.ConfigUtil;
import emissary.config.Configurator;
import emissary.core.channels.SeekableByteChannelFactory;
import emissary.util.ByteUtil;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
* Utility for validating that Places safely interact with IBDO payloads in byte array form. Specifically, this class
* helps with validating that changes to a IBDO's payload are followed by a call to the
* {@link IBaseDataObject#setData(byte[]) setData(byte[])}, {@link IBaseDataObject#setData(byte[], int, int)
* setData(byte[], int, int)}, or {@link IBaseDataObject#setChannelFactory(SeekableByteChannelFactory)
* setChannelFactory(SeekableByteChannelFactory)} method.
*/
public class SafeUsageChecker {
protected static final Logger LOGGER = LoggerFactory.getLogger(SafeUsageChecker.class);

public static final String ENABLED_KEY = "ENABLED";
public static final boolean ENABLED_FROM_CONFIGURATION;
public static final String UNSAFE_MODIFICATION_DETECTED = "Detected unsafe changes to IBDO byte array contents";

// to minimize I/O, we only want to read the config file once regardless of the number of instances created
static {
boolean enabledFromConfiguration = false;

try {
Configurator configurator = ConfigUtil.getConfigInfo(SafeUsageChecker.class);

enabledFromConfiguration = configurator.findBooleanEntry(ENABLED_KEY, enabledFromConfiguration);
} catch (IOException e) {
LOGGER.warn("Could not get configuration!", e);
}

ENABLED_FROM_CONFIGURATION = enabledFromConfiguration;
}

/**
* Cache that records each {@literal byte[]} reference made available to IBDO clients, along with a sha256 hash of the
* array contents. Used for determining whether the clients modify the array contents without explicitly pushing those
* changes back to the IBDO
*/
private final Map<byte[], String> cache = new HashMap<>();
public final boolean enabled;

public SafeUsageChecker() {
enabled = ENABLED_FROM_CONFIGURATION;
}

public SafeUsageChecker(Configurator configurator) {
enabled = configurator.findBooleanEntry(ENABLED_KEY, ENABLED_FROM_CONFIGURATION);
}

/**
* Resets the snapshot cache
*/
public void reset() {
if (enabled) {
cache.clear();
}
}

/**
* Stores a new integrity snapshot
*
* @param bytes byte[] for which a snapshot should be captured
*/
public void recordSnapshot(final byte[] bytes) {
if (enabled) {
cache.put(bytes, ByteUtil.sha256Bytes(bytes));
}
}


/**
* Resets the cache and stores a new integrity snapshot
*
* @param bytes byte[] for which a snapshot should be captured
*/
public void resetCacheThenRecordSnapshot(final byte[] bytes) {
if (enabled) {
reset();
recordSnapshot(bytes);
}
}

/**
* Uses the snapshot cache to determine whether any of the byte arrays have unsaved changes
*/
public void checkForUnsafeDataChanges() {
if (enabled) {
boolean isUnsafe = cache.entrySet().stream().anyMatch(e -> !ByteUtil.sha256Bytes(e.getKey()).equals(e.getValue()));
if (isUnsafe) {
LOGGER.warn(UNSAFE_MODIFICATION_DETECTED);
}
reset();
}
}
}
1 change: 1 addition & 0 deletions src/main/java/emissary/place/ServiceProviderPlace.java
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ public List<IBaseDataObject> agentProcessHeavyDuty(IBaseDataObject payload) thro
MDC.put(MDCConstants.SERVICE_LOCATION, this.getKey());
try {
List<IBaseDataObject> l = processHeavyDuty(payload);
payload.checkForUnsafeDataChanges();
rehash(payload);
return l;
} catch (Exception e) {
Expand Down
122 changes: 122 additions & 0 deletions src/test/java/emissary/core/BaseDataObjectTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import emissary.core.channels.SeekableByteChannelHelper;
import emissary.directory.DirectoryEntry;
import emissary.pickup.Priority;
import emissary.test.core.junit5.LogbackTester;
import emissary.test.core.junit5.UnitTest;

import ch.qos.logback.classic.Level;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -25,6 +27,7 @@
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -38,6 +41,7 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;

import static emissary.core.SafeUsageChecker.UNSAFE_MODIFICATION_DETECTED;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -1319,4 +1323,122 @@ void testExtractedRecordClone() {
fail("Clone method should have been called", ex);
}
}

static final byte[] DATA_MODIFICATION_BYTES = "These are the test bytes!".getBytes(StandardCharsets.US_ASCII);
static final Level[] LEVELS_ONE_WARN = new Level[] {Level.WARN};
static final String[] ONE_UNSAFE_MODIFICATION_DETECTED = new String[] {UNSAFE_MODIFICATION_DETECTED};
static final boolean[] NO_THROWABLES = new boolean[] {false};

@Test
void testChannelFactoryInArrayOutNoSet() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES);
}
}

@Test
void shouldDetectUnsafeChangesIfArrayChangesNotFollowedByOneSet() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

byte[] data = ibdo.data();
data = ibdo.data();

Arrays.fill(data, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES);
}
}

@Test
void shouldDetectUnsafeChangesIfArrayChangesNotFollowedByBothSet() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data0 = ibdo.data();
final byte[] data1 = ibdo.data();

Arrays.fill(data0, (byte) 0);
Arrays.fill(data1, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES);
}
}

@Test
void shouldDetectNoUnsafeChangesImmediatelyAfterSetChannelFactory() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);
ibdo.setChannelFactory(InMemoryChannelFactory.create(data));

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(new byte[DATA_MODIFICATION_BYTES.length], ibdo.data());
logbackTester.checkLogList(new Level[0], new String[0], new boolean[0]);
}
}

@Test
void shouldDetectNoUnsafeChangesImmediatelyAfterSetData() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setChannelFactory(InMemoryChannelFactory.create(DATA_MODIFICATION_BYTES));

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);
ibdo.setData(data);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(new byte[DATA_MODIFICATION_BYTES.length], ibdo.data());
logbackTester.checkLogList(new Level[0], new String[0], new boolean[0]);
}
}

@Test
void testArrayInArrayOutNoSet() throws IOException {
try (LogbackTester logbackTester = new LogbackTester(SafeUsageChecker.class.getName())) {
final IBaseDataObject ibdo = new BaseDataObject();

ibdo.setData(DATA_MODIFICATION_BYTES);

final byte[] data = ibdo.data();

Arrays.fill(data, (byte) 0);

ibdo.checkForUnsafeDataChanges();

assertArrayEquals(DATA_MODIFICATION_BYTES, ibdo.data());
logbackTester.checkLogList(LEVELS_ONE_WARN, ONE_UNSAFE_MODIFICATION_DETECTED, NO_THROWABLES);
}
}
}
Loading

0 comments on commit 9c1e6e6

Please sign in to comment.