diff --git a/consumer/src/main/java/com/flipkart/varadhi/consumer/GroupedMessageSrc.java b/consumer/src/main/java/com/flipkart/varadhi/consumer/GroupedMessageSrc.java index fe44ad14..f677d7d3 100644 --- a/consumer/src/main/java/com/flipkart/varadhi/consumer/GroupedMessageSrc.java +++ b/consumer/src/main/java/com/flipkart/varadhi/consumer/GroupedMessageSrc.java @@ -7,6 +7,7 @@ import com.flipkart.varadhi.spi.services.PolledMessages; import lombok.AllArgsConstructor; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.mutable.MutableBoolean; @@ -14,28 +15,46 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; /** * Message source that maintains ordering among messages of the same groupId. */ @RequiredArgsConstructor +@Slf4j public class GroupedMessageSrc implements MessageSrc { private final ConcurrentHashMap allGroupedMessages = new ConcurrentHashMap<>(); + private final ConcurrentLinkedDeque freeGroups = new ConcurrentLinkedDeque<>(); + private final Consumer consumer; + + /** + * Used to limit the message buffering. Will be driven via consumer configuration. + */ + private final long maxUnAckedMessages; + /** * Maintains the count of total messages read from the consumer so far. * Required for watermark checks, for when this value runs low we can fetch more messages from the consumer. * Counter gets decremented when the message is committed/consumed. */ - private final AtomicLong totalInFlightMessages = new AtomicLong(0); + private final AtomicLong totalUnAckedMessages = new AtomicLong(0); - // Used for watermark checks against the totalInFlightMessages. Will be driven via consumer configuration. - private final long maxInFlightMessages = 100; // todo(aayush): make configurable + // Internal states to manage async state - private final Consumer consumer; + /** + * flag to indicate whether a task to fetch messages from consumer is ongoing. + */ + private final AtomicBoolean pendingAsyncFetch = new AtomicBoolean(false); + + /** + * holder to keep the incomplete future object while waiting for new messages or groups to get freed up. + */ + private final AtomicReference pendingRequest = new AtomicReference<>(); /** * Attempt to fill the message array with one message from each group. @@ -47,22 +66,65 @@ public class GroupedMessageSrc implements MessageSrc { */ @Override public CompletableFuture nextMessages(MessageTracker[] messages) { - if (!hasMaxInFlightMessages()) { - return replenishAvailableGroups().thenApply(v -> nextMessagesInternal(messages)); + int count = nextMessagesInternal(messages); + if (count > 0) { + return CompletableFuture.completedFuture(count); + } + + NextMsgsRequest request = new NextMsgsRequest(new CompletableFuture<>(), messages); + if (!pendingRequest.compareAndSet(null, request)) { + throw new IllegalStateException( + "nextMessages method is not supposed to be called concurrently. There seems to be a pending nextMessage call"); + } + + // incomplete result is saved. trigger new message fetch. + optionallyFetchNewMessages(); + + // double check, if any free group is available now. + if (isFreeGroupPresent()) { + tryCompletePendingRequest(); + } + + return request.result; + } + + private void tryCompletePendingRequest() { + NextMsgsRequest request; + if ((request = pendingRequest.getAndSet(null)) != null) { + request.result.complete(nextMessagesInternal(request.messages)); + } + } + + private void optionallyFetchNewMessages() { + if (!isMaxUnAckedMessagesBreached() && pendingAsyncFetch.compareAndSet(false, true)) { + // there is more room for new messages. We can initiate a new fetch request, as none is ongoing. + consumer.receiveAsync().whenComplete((polledMessages, ex) -> { + if (ex != null) { + replenishAvailableGroups(polledMessages); + pendingAsyncFetch.set(false); + } else { + log.error("Error while fetching messages from consumer", ex); + throw new IllegalStateException( + "should be unreachable. consumer.receiveAsync() should not throw exception."); + } + }); } - return CompletableFuture.completedFuture(nextMessagesInternal(messages)); } private int nextMessagesInternal(MessageTracker[] messages) { int i = 0; GroupTracker groupTracker; - while (i < messages.length && (groupTracker = getGroupTracker()) != null) { + while (i < messages.length && (groupTracker = pollFreeGroup()) != null) { messages[i++] = new GroupedMessageTracker(groupTracker.messages.getFirst().nextMessage()); } return i; } - private GroupTracker getGroupTracker() { + boolean isFreeGroupPresent() { + return !freeGroups.isEmpty(); + } + + private GroupTracker pollFreeGroup() { String freeGroup = freeGroups.poll(); if (freeGroup == null) { return null; @@ -77,13 +139,6 @@ private GroupTracker getGroupTracker() { return tracker; } - private CompletableFuture replenishAvailableGroups() { - return consumer.receiveAsync().thenApply(polledMessages -> { - replenishAvailableGroups(polledMessages); - return null; - }); - } - private void replenishAvailableGroups(PolledMessages polledMessages) { Map> groupedMessages = groupMessagesByGroupId(polledMessages); for (Map.Entry> group : groupedMessages.entrySet()) { @@ -97,11 +152,12 @@ private void replenishAvailableGroups(PolledMessages polledMessages) { tracker.messages.add(newBatch); return tracker; }); - totalInFlightMessages.addAndGet(newBatch.count()); + totalUnAckedMessages.addAndGet(newBatch.count()); if (isNewGroup.isTrue()) { freeGroups.add(group.getKey()); } } + tryCompletePendingRequest(); } private Map> groupMessagesByGroupId(PolledMessages polledMessages) { @@ -117,8 +173,8 @@ private Map> groupMessagesByGroupId(PolledMessages< return groups; } - private boolean hasMaxInFlightMessages() { - return totalInFlightMessages.get() >= maxInFlightMessages; + boolean isMaxUnAckedMessagesBreached() { + return totalUnAckedMessages.get() >= maxUnAckedMessages; } enum GroupStatus { @@ -158,7 +214,7 @@ private void free(String groupId, MessageConsumptionStatus status) { throw new IllegalStateException(String.format("Tried to free group %s: %s", gId, tracker)); } var messages = tracker.messages; - if (!messages.isEmpty() && messages.getFirst().remaining() == 0) { + while (!messages.isEmpty() && messages.getFirst().remaining() == 0) { messages.removeFirst(); } if (!messages.isEmpty()) { @@ -169,10 +225,14 @@ private void free(String groupId, MessageConsumptionStatus status) { return null; } }); - totalInFlightMessages.decrementAndGet(); + totalUnAckedMessages.decrementAndGet(); if (isRemaining.isTrue()) { freeGroups.addFirst(groupId); + tryCompletePendingRequest(); } } } + + record NextMsgsRequest(CompletableFuture result, MessageTracker[] messages) { + } } diff --git a/consumer/src/main/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrc.java b/consumer/src/main/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrc.java index 500720d1..0774659d 100644 --- a/consumer/src/main/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrc.java +++ b/consumer/src/main/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrc.java @@ -18,15 +18,19 @@ public class UnGroupedMessageSrc implements MessageSrc { private final Consumer consumer; - // flag to indicate whether a future is in progress to fetch messages from the consumer. - private final AtomicBoolean futureInProgress = new AtomicBoolean(false); + /** + * flag to indicate whether a task to fetch messages from consumer is ongoing. + */ + private final AtomicBoolean pendingAsyncFetch = new AtomicBoolean(false); - // Iterator into an ongoing consumer batch that has not been fully processed yet. - private Iterator> ongoingIterator = null; + /** + * Iterator into an ongoing consumer batch that has not been fully processed yet. + */ + private volatile Iterator> ongoingIterator = null; /** * Fetches the next batch of messages from the consumer. - * Prioritise immediate fetch and return over waiting for the consumer. + * Prioritises returning whatever messages are available. * * @param messages Array of message trackers to populate. * @@ -38,44 +42,53 @@ public CompletableFuture nextMessages(MessageTracker[] messages) { // Our first priority is to drain the iterator if it is set and return immediately. // We do not want to proceed with consumer receiveAsync if we have messages in the iterator, // as a slow or empty consumer might block the flow and cause the iterator contents to be stuck. - int offset = fetchFromIterator(ongoingIterator, messages, 0); - if (offset > 0) { - return CompletableFuture.completedFuture(offset); + int count = fetchFromIterator(consumer, messages, ongoingIterator); + if (count > 0) { + return CompletableFuture.completedFuture(count); } // If the iterator is not set, or is empty, then we try to fetch the message batch from the consumer. // However, multiple calls to nextMessages may fire multiple futures concurrently. // Leading to a race condition that overrides the iterator from a previous un-processed batch, causing a lost-update problem. // Therefore, we use the futureInProgress flag to limit the concurrency and ensure only one future is in progress at a time. - if (futureInProgress.compareAndSet(false, true)) { - return consumer.receiveAsync() - .thenApply(polledMessages -> processPolledMessages(polledMessages, messages, offset)) - .whenComplete((result, ex) -> futureInProgress.set( - false)); // any of the above stages can complete exceptionally, so this is to ensure the flag is reset. + ongoingIterator = null; + if (pendingAsyncFetch.compareAndSet(false, true)) { + return consumer.receiveAsync().whenComplete((result, ex) -> pendingAsyncFetch.set(false)) + .thenApply(polledMessages -> processPolledMessages(polledMessages, messages)); + } else { + throw new IllegalStateException( + "nextMessages method is not supposed to be called concurrently. There seems to be an ongoing consumer.receiveAsync() operation."); } - return CompletableFuture.completedFuture(0); } - private int processPolledMessages(PolledMessages polledMessages, MessageTracker[] messages, int startIndex) { - ongoingIterator = polledMessages.iterator(); - return fetchFromIterator(ongoingIterator, messages, startIndex); + private int processPolledMessages(PolledMessages polledMessages, MessageTracker[] messages) { + Iterator> polledMessagesIterator = polledMessages.iterator(); + ongoingIterator = polledMessagesIterator; + return fetchFromIterator(consumer, messages, polledMessagesIterator); } /** * Fetches messages from the iterator and populates the message array. * - * @param iterator Iterator of messages to fetch from. - * @param messages Array of message trackers to populate. - * @param startIndex Index into the messages array from where to start storing the messages. + * @param iterator Iterator of messages to fetch from. + * @param messages Array of message trackers to populate. * * @return Index into the messages array where the next message should be stored. (will be equal to the length if completely full) */ - private int fetchFromIterator( - Iterator> iterator, MessageTracker[] messages, int startIndex + + static int fetchFromIterator( + Consumer consumer, MessageTracker[] messages, Iterator> iterator ) { - while (iterator != null && iterator.hasNext() && startIndex < messages.length) { - messages[startIndex++] = new PolledMessageTracker<>(consumer, iterator.next()); + if (iterator == null || !iterator.hasNext()) { + return 0; + } + + int i = 0; + while (i < messages.length && iterator.hasNext()) { + PolledMessage polledMessage = iterator.next(); + MessageTracker messageTracker = new PolledMessageTracker<>(consumer, polledMessage); + messages[i++] = messageTracker; } - return startIndex; + return i; } } diff --git a/consumer/src/main/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottler.java b/consumer/src/main/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottler.java index 2cac8508..00d89c0e 100644 --- a/consumer/src/main/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottler.java +++ b/consumer/src/main/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottler.java @@ -1,7 +1,7 @@ package com.flipkart.varadhi.consumer.impl; -import com.flipkart.varadhi.consumer.ThresholdProvider; import com.flipkart.varadhi.consumer.InternalQueueType; +import com.flipkart.varadhi.consumer.ThresholdProvider; import com.flipkart.varadhi.consumer.Throttler; import com.google.common.base.Ticker; import lombok.RequiredArgsConstructor; @@ -18,7 +18,8 @@ * @param */ @Slf4j -public class SlidingWindowThrottler implements Throttler, ThresholdProvider.ThresholdChangeListener, AutoCloseable { +public class SlidingWindowThrottler + implements Throttler, ThresholdProvider.ThresholdChangeListener, AutoCloseable { /* Approach: @@ -172,7 +173,6 @@ private boolean moveWindow(long currentTick) { if (newWindowBeginTick == windowBeginTick) { return false; } - for (long i = windowBeginTick; i < newWindowBeginTick; ++i) { int beginIdx = (int) (i % totalTicks); int endIdx = (int) ((i + ticksInWindow) % totalTicks); diff --git a/consumer/src/test/java/com/flipkart/varadhi/consumer/ConcurrencyAndRLDemo.java b/consumer/src/test/java/com/flipkart/varadhi/consumer/ConcurrencyAndRLDemo.java index b89c4ffc..26a36522 100644 --- a/consumer/src/test/java/com/flipkart/varadhi/consumer/ConcurrencyAndRLDemo.java +++ b/consumer/src/test/java/com/flipkart/varadhi/consumer/ConcurrencyAndRLDemo.java @@ -1,7 +1,7 @@ package com.flipkart.varadhi.consumer; -import com.codahale.metrics.*; import com.codahale.metrics.Timer; +import com.codahale.metrics.*; import com.flipkart.varadhi.consumer.concurrent.Context; import com.flipkart.varadhi.consumer.concurrent.CustomThread; import com.flipkart.varadhi.consumer.concurrent.EventExecutor; @@ -103,9 +103,16 @@ public static void doSimulation( Meter loadGenMeter = registry.register("load.gen.rate", new Meter()); Meter errorExpMeter = registry.register("task.error.experienced.rate", new Meter()); - Timer completionLatency = registry.register("task.completion.latency", new Timer(new SlidingTimeWindowArrayReservoir(60, TimeUnit.SECONDS))); - Timer throttlerAcquireLatency = registry.register("throttler.acquire.latency", new Timer(new SlidingTimeWindowArrayReservoir(60, TimeUnit.SECONDS))); - Gauge errorThresholdGuage = registry.registerGauge("error.threshold.value", dynamicThreshold::getThreshold); + Timer completionLatency = registry.register( + "task.completion.latency", + new Timer(new SlidingTimeWindowArrayReservoir(60, TimeUnit.SECONDS)) + ); + Timer throttlerAcquireLatency = registry.register( + "throttler.acquire.latency", + new Timer(new SlidingTimeWindowArrayReservoir(60, TimeUnit.SECONDS)) + ); + Gauge errorThresholdGuage = + registry.registerGauge("error.threshold.value", dynamicThreshold::getThreshold); if (metricListener != null) { websocketScheduler.scheduleAtFixedRate(() -> { Map datapoints = new HashMap<>(); @@ -113,7 +120,7 @@ public static void doSimulation( datapoints.put("task.completion.rate", completionLatency.getOneMinuteRate()); datapoints.put("error.threshold.value", (double) errorThresholdGuage.getValue()); metricListener.accept(datapoints); - }, 1_000,2_000, TimeUnit.MILLISECONDS); + }, 1_000, 2_000, TimeUnit.MILLISECONDS); } reporter.start(2, TimeUnit.SECONDS); AtomicInteger throttlePending = new AtomicInteger(0); diff --git a/consumer/src/test/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrcTest.java b/consumer/src/test/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrcTest.java index 8f4c8b05..c27edc71 100644 --- a/consumer/src/test/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrcTest.java +++ b/consumer/src/test/java/com/flipkart/varadhi/consumer/UnGroupedMessageSrcTest.java @@ -3,13 +3,13 @@ import com.flipkart.varadhi.spi.services.DummyConsumer; import com.flipkart.varadhi.spi.services.DummyProducer.DummyOffset; import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; @Slf4j class UnGroupedMessageSrcTest { @@ -152,9 +152,15 @@ void testConcurrencyInConsumerFetchNotAllowed() { DummyConsumer.SlowConsumer consumer = new DummyConsumer.SlowConsumer(messages, 3); UnGroupedMessageSrc messageSrc = new UnGroupedMessageSrc<>(consumer); var f1 = messageSrc.nextMessages(messageTrackers); - var f2 = messageSrc.nextMessages(messageTrackers); - assertEquals(0, f2.join()); + try { + assertFalse(f1.isDone()); + var f2 = messageSrc.nextMessages(messageTrackers); + Assertions.fail("concurrent invocation is not expected."); + } catch (IllegalStateException e) { + // expected. + } + assertEquals(messageTrackers.length, f1.join()); // since f1 is completed now, next invocation should return remaining messages diff --git a/consumer/src/test/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottlerTest.java b/consumer/src/test/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottlerTest.java index 9b380857..b9b0e2da 100644 --- a/consumer/src/test/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottlerTest.java +++ b/consumer/src/test/java/com/flipkart/varadhi/consumer/impl/SlidingWindowThrottlerTest.java @@ -79,10 +79,14 @@ public void testExecutePendingTasksFollowsRateLimit() throws Exception { expectedCompleted += (qps / 2); assertions.accept(expectedCompleted); + // there is a gotcha here. advancing 5 more ms, will lead to tick change. so new tasks's permits will get + // assigned to new tick bucket. ticker.advance(5, TimeUnit.MILLISECONDS); expectedCompleted += (qps / 2); assertions.accept(expectedCompleted); + // now after 1 sec exact, the previous 5 tasks's permits will start getting released. + // so if we wait another 10 ms, all permits should be freed up. ticker.advance(1010, TimeUnit.MILLISECONDS); expectedCompleted += qps; assertions.accept(expectedCompleted);