diff --git a/components/context/src/main/java/datadog/context/ContextHelpers.java b/components/context/src/main/java/datadog/context/ContextHelpers.java
new file mode 100644
index 00000000000..62c83ecbff6
--- /dev/null
+++ b/components/context/src/main/java/datadog/context/ContextHelpers.java
@@ -0,0 +1,143 @@
+package datadog.context;
+
+import static java.lang.Math.max;
+import static java.util.Arrays.copyOfRange;
+import static java.util.Objects.requireNonNull;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.BinaryOperator;
+
+/**
+ * Static helpers to manipulate context collections.
+ *
+ *
Typical usages include:
+ *
+ *
{@code
+ * // Finding a context value from multiple sources:
+ * Span span = findFirst(spanKey, message, request, CURRENT)
+ * // Find all context values from different sources:
+ * List errors = findAll(errorKey, message, request, CURRENT)
+ * // Capture multiple contexts in a single one:
+ * Context aggregate = combine(message, request, CURRENT)
+ * // Combine multiple contexts into a single one using custom merge rules:
+ * Context combined = combine(
+ * (current, next) -> {
+ * var metric = current.get(metricKey);
+ * var nextMetric = next.get(metricKey);
+ * return current.with(metricKey, metric.add(nextMetric));
+ * }, message, request, CURRENT);
+ * }
+ *
+ * where {@link #CURRENT} denotes a carrier with the current context.
+ */
+public final class ContextHelpers {
+ /** A helper object carrying the {@link Context#current()} context. */
+ public static final Object CURRENT = new Object();
+
+ private ContextHelpers() {}
+
+ /**
+ * Find the first context value from given context carriers.
+ *
+ * @param key The key used to store the value.
+ * @param carriers The carrier to get context and value from.
+ * @param The type of the value to look for.
+ * @return The first context value found, {@code null} if not found.
+ */
+ public static T findFirst(ContextKey key, Object... carriers) {
+ requireNonNull(key, "key cannot be null");
+ for (Object carrier : carriers) {
+ requireNonNull(carrier, "carrier cannot be null");
+ Context context = carrier == CURRENT ? Context.current() : Context.from(carrier);
+ T value = context.get(key);
+ if (value != null) {
+ return value;
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Find all the context values from the given context carriers.
+ *
+ * @param key The key used to store the value.
+ * @param carriers The carriers to get context and value from.
+ * @param The type of the values to look for.
+ * @return A list of all values found, in context order.
+ */
+ public static List findAll(ContextKey key, Object... carriers) {
+ requireNonNull(key, "key cannot be null");
+ List values = new ArrayList<>(carriers.length);
+ for (Object carrier : carriers) {
+ requireNonNull(carrier, "carrier cannot be null");
+ Context context = carrier == CURRENT ? Context.current() : Context.from(carrier);
+ T value = context.get(key);
+ if (value != null) {
+ values.add(value);
+ }
+ }
+ return values;
+ }
+
+ /**
+ * Combine contexts and their values, keeping the first founds.
+ *
+ * @param contexts The contexts to combine.
+ * @return A context containing all the values from all the given context, keeping the first value
+ * found for a given key.
+ */
+ public static Context combine(Context... contexts) {
+ return combine(ContextHelpers::combineKeepingFirst, contexts);
+ }
+
+ /**
+ * Combine multiple contexts into a single one.
+ *
+ * @param combiner The context combiner, taking already combined context as first parameter, any
+ * following one as second parameter, and returning the combined context.
+ * @param contexts The contexts to combine.
+ * @return The combined context.
+ */
+ public static Context combine(BinaryOperator combiner, Context... contexts) {
+ requireNonNull(combiner, "combiner cannot be null");
+ Context result = new IndexedContext(new Object[0]);
+ for (Context context : contexts) {
+ requireNonNull(context, "context cannot be null");
+ result = combiner.apply(result, context);
+ }
+ return result;
+ }
+
+ private static Context combineKeepingFirst(Context current, Context next) {
+ if (!(current instanceof IndexedContext)) {
+ throw new IllegalStateException("Left context is supposed to be an IndexedContext");
+ }
+ IndexedContext currentIndexed = (IndexedContext) current;
+ if (next instanceof EmptyContext) {
+ return current;
+ } else if (next instanceof SingletonContext) {
+ SingletonContext nextSingleton = (SingletonContext) next;
+ Object[] store =
+ copyOfRange(
+ currentIndexed.store, 0, max(currentIndexed.store.length, nextSingleton.index + 1));
+ if (store[nextSingleton.index] == null) {
+ store[nextSingleton.index] = nextSingleton.value;
+ }
+ return new IndexedContext(store);
+ } else if (next instanceof IndexedContext) {
+ IndexedContext nextIndexed = (IndexedContext) next;
+ Object[] store =
+ copyOfRange(
+ currentIndexed.store, 0, max(currentIndexed.store.length, nextIndexed.store.length));
+ for (int i = 0; i < nextIndexed.store.length; i++) {
+ Object nextValue = nextIndexed.store[i];
+ if (nextValue != null && store[i] == null) {
+ store[i] = nextValue;
+ }
+ }
+ return new IndexedContext(store);
+ }
+ throw new IllegalStateException("Unsupported context type: " + next.getClass().getName());
+ }
+}
diff --git a/components/context/src/main/java/datadog/context/IndexedContext.java b/components/context/src/main/java/datadog/context/IndexedContext.java
index cadc481d707..d7aa1731b2a 100644
--- a/components/context/src/main/java/datadog/context/IndexedContext.java
+++ b/components/context/src/main/java/datadog/context/IndexedContext.java
@@ -11,7 +11,7 @@
/** {@link Context} containing many values. */
@ParametersAreNonnullByDefault
final class IndexedContext implements Context {
- private final Object[] store;
+ final Object[] store;
IndexedContext(Object[] store) {
this.store = store;
diff --git a/components/context/src/main/java/datadog/context/SingletonContext.java b/components/context/src/main/java/datadog/context/SingletonContext.java
index 7a8a4e98b6f..4aa5e3f04cf 100644
--- a/components/context/src/main/java/datadog/context/SingletonContext.java
+++ b/components/context/src/main/java/datadog/context/SingletonContext.java
@@ -10,8 +10,8 @@
/** {@link Context} containing a single value. */
@ParametersAreNonnullByDefault
final class SingletonContext implements Context {
- private final int index;
- private final Object value;
+ final int index;
+ final Object value;
SingletonContext(int index, Object value) {
this.index = index;
diff --git a/components/context/src/test/java/datadog/context/ContextHelpersTest.java b/components/context/src/test/java/datadog/context/ContextHelpersTest.java
new file mode 100644
index 00000000000..51e88774376
--- /dev/null
+++ b/components/context/src/test/java/datadog/context/ContextHelpersTest.java
@@ -0,0 +1,208 @@
+package datadog.context;
+
+import static datadog.context.Context.current;
+import static datadog.context.Context.root;
+import static datadog.context.ContextHelpers.CURRENT;
+import static datadog.context.ContextHelpers.combine;
+import static datadog.context.ContextHelpers.findAll;
+import static datadog.context.ContextHelpers.findFirst;
+import static datadog.context.ContextTest.BOOLEAN_KEY;
+import static datadog.context.ContextTest.FLOAT_KEY;
+import static datadog.context.ContextTest.STRING_KEY;
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.singleton;
+import static java.util.logging.Level.ALL;
+import static java.util.logging.Level.INFO;
+import static java.util.logging.Level.SEVERE;
+import static java.util.logging.Level.WARNING;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertIterableEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+import java.util.function.BinaryOperator;
+import java.util.logging.Level;
+import java.util.stream.Stream;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+class ContextHelpersTest {
+ private static final Object CARRIER_1 = new Object();
+ private static final Object CARRIER_2 = new Object();
+ private static final Object UNSET_CARRIER = new Object();
+ private static final Object NON_CARRIER = new Object();
+ private static final String VALUE_1 = "value1";
+ private static final String VALUE_2 = "value2";
+
+ @BeforeAll
+ static void init() {
+ Context context1 = root().with(STRING_KEY, VALUE_1);
+ context1.attachTo(CARRIER_1);
+
+ Context context2 = root().with(STRING_KEY, VALUE_2);
+ context2.attachTo(CARRIER_2);
+
+ root().attachTo(UNSET_CARRIER);
+ }
+
+ @ParameterizedTest
+ @MethodSource("findFirstArguments")
+ void testFindFirst(Object[] carriers, String expected) {
+ assertEquals(expected, findFirst(STRING_KEY, carriers), "Cannot find first value");
+ }
+
+ static Stream findFirstArguments() {
+ return Stream.of(
+ arguments(emptyArray(), null),
+ arguments(arrayOf(NON_CARRIER), null),
+ arguments(arrayOf(UNSET_CARRIER), null),
+ arguments(arrayOf(CARRIER_1), VALUE_1),
+ arguments(arrayOf(CARRIER_1, CARRIER_2), VALUE_1),
+ arguments(arrayOf(NON_CARRIER, CARRIER_1), VALUE_1),
+ arguments(arrayOf(UNSET_CARRIER, CARRIER_1), VALUE_1),
+ arguments(arrayOf(CARRIER_1, NON_CARRIER), VALUE_1),
+ arguments(arrayOf(CARRIER_1, UNSET_CARRIER), VALUE_1));
+ }
+
+ @ParameterizedTest
+ @MethodSource("findAllArguments")
+ void testFindAll(Object[] carriers, Iterable expected) {
+ assertIterableEquals(expected, findAll(STRING_KEY, carriers), "Cannot find all values");
+ }
+
+ static Stream findAllArguments() {
+ return Stream.of(
+ arguments(emptyArray(), emptyList()),
+ arguments(arrayOf(CARRIER_1), singleton(VALUE_1)),
+ arguments(arrayOf(CARRIER_1, CARRIER_2), asList(VALUE_1, VALUE_2)),
+ arguments(arrayOf(NON_CARRIER, CARRIER_1), singleton(VALUE_1)),
+ arguments(arrayOf(UNSET_CARRIER, CARRIER_1), singleton(VALUE_1)),
+ arguments(arrayOf(CARRIER_1, NON_CARRIER), singleton(VALUE_1)),
+ arguments(arrayOf(CARRIER_1, UNSET_CARRIER), singleton(VALUE_1)));
+ }
+
+ @Test
+ void testNullCarriers() {
+ assertThrows(
+ NullPointerException.class, () -> findFirst(null, CARRIER_1), "Should fail on null key");
+ assertThrows(
+ NullPointerException.class,
+ () -> findFirst(STRING_KEY, (Object) null),
+ "Should fail on null context");
+ assertThrows(
+ NullPointerException.class,
+ () -> findFirst(STRING_KEY, null, CARRIER_1),
+ "Should fail on null context");
+ assertThrows(
+ NullPointerException.class, () -> findAll(null, CARRIER_1), "Should fail on null key");
+ assertThrows(
+ NullPointerException.class,
+ () -> findAll(STRING_KEY, (Object) null),
+ "Should fail on null context");
+ assertThrows(
+ NullPointerException.class,
+ () -> findAll(STRING_KEY, null, CARRIER_1),
+ "Should fail on null context");
+ }
+
+ @Test
+ void testCurrent() {
+ assertEquals(root(), current(), "Current context is already set");
+ Context context = root().with(STRING_KEY, VALUE_1);
+ try (ContextScope ignored = context.attach()) {
+ assertEquals(
+ VALUE_1, findFirst(STRING_KEY, CURRENT), "Failed to get value from current context");
+ assertIterableEquals(
+ singleton(VALUE_1),
+ findAll(STRING_KEY, CURRENT),
+ "Failed to get value from current context");
+ }
+ assertEquals(root(), current(), "Current context stayed attached");
+ }
+
+ @Test
+ void testCombine() {
+ Context context1 = root().with(STRING_KEY, VALUE_1).with(BOOLEAN_KEY, true);
+ Context context2 = root().with(STRING_KEY, VALUE_2).with(FLOAT_KEY, 3.14F);
+ Context context3 = root();
+ Context context4 = root().with(FLOAT_KEY, 567F);
+
+ Context combined = combine(context1, context2, context3, context4);
+ assertEquals(VALUE_1, combined.get(STRING_KEY), "First duplicate value should be kept");
+ assertEquals(true, combined.get(BOOLEAN_KEY), "Values from first context should be kept");
+ assertEquals(3.14f, combined.get(FLOAT_KEY), "Values from second context should be kept");
+ }
+
+ @Test
+ void testCombiner() {
+ ContextKey errorKey = ContextKey.named("error");
+ Context context1 = root().with(errorKey, ErrorStats.from(INFO, 12)).with(STRING_KEY, VALUE_1);
+ Context context2 = root().with(errorKey, ErrorStats.from(SEVERE, 1)).with(FLOAT_KEY, 3.14F);
+ Context context3 = root().with(errorKey, ErrorStats.from(WARNING, 6)).with(BOOLEAN_KEY, true);
+
+ BinaryOperator errorStatsMerger =
+ (left, right) -> {
+ ErrorStats mergedStats = ErrorStats.merge(left.get(errorKey), right.get(errorKey));
+ return left.with(errorKey, mergedStats);
+ };
+ Context combined = combine(errorStatsMerger, context1, context2, context3);
+ ErrorStats combinedStats = combined.get(errorKey);
+ assertNotNull(combinedStats, "Failed to combined error stats");
+ assertEquals(19, combinedStats.errorCount, "Failed to combine error stats");
+ assertEquals(SEVERE, combinedStats.maxLevel, "Failed to combine error stats");
+ assertNull(combined.get(STRING_KEY), "Combiner should drop any other context values");
+ assertNull(combined.get(FLOAT_KEY), "Combiner should drop any other context values");
+ assertNull(combined.get(BOOLEAN_KEY), "Combiner should drop any other context values");
+ }
+
+ @Test
+ void testNullCombine() {
+ assertThrows(
+ NullPointerException.class,
+ () -> combine((BinaryOperator) null, root()),
+ "Should fail on null combiner");
+ assertThrows(
+ NullPointerException.class,
+ () -> combine((left, right) -> left, (Context) null),
+ "Should fail on null context");
+ }
+
+ private static class ErrorStats {
+ int errorCount;
+ Level maxLevel;
+
+ public ErrorStats() {
+ this.errorCount = 0;
+ this.maxLevel = ALL;
+ }
+
+ public static ErrorStats from(Level logLevel, int count) {
+ ErrorStats stats = new ErrorStats();
+ stats.errorCount = count;
+ stats.maxLevel = logLevel;
+ return stats;
+ }
+
+ public static ErrorStats merge(ErrorStats a, ErrorStats b) {
+ if (a == null) {
+ return b;
+ }
+ Level maxLevel = a.maxLevel.intValue() > b.maxLevel.intValue() ? a.maxLevel : b.maxLevel;
+ return from(maxLevel, a.errorCount + b.errorCount);
+ }
+ }
+
+ private static Object[] emptyArray() {
+ return new Object[0];
+ }
+
+ private static Object[] arrayOf(Object... objects) {
+ return objects;
+ }
+}