diff --git a/debezium-server-kinesis/src/main/java/io/debezium/server/kinesis/KinesisChangeConsumer.java b/debezium-server-kinesis/src/main/java/io/debezium/server/kinesis/KinesisChangeConsumer.java index d5ac963c..c8132008 100644 --- a/debezium-server-kinesis/src/main/java/io/debezium/server/kinesis/KinesisChangeConsumer.java +++ b/debezium-server-kinesis/src/main/java/io/debezium/server/kinesis/KinesisChangeConsumer.java @@ -7,8 +7,11 @@ import java.net.URI; import java.time.Duration; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import jakarta.annotation.PostConstruct; import jakarta.annotation.PreDestroy; @@ -38,7 +41,10 @@ import software.amazon.awssdk.services.kinesis.KinesisClient; import software.amazon.awssdk.services.kinesis.KinesisClientBuilder; import software.amazon.awssdk.services.kinesis.model.KinesisException; -import software.amazon.awssdk.services.kinesis.model.PutRecordRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; +import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry; /** * Implementation of the consumer that delivers the messages into Amazon Kinesis destination. @@ -56,12 +62,18 @@ public class KinesisChangeConsumer extends BaseChangeConsumer implements Debeziu private static final String PROP_REGION_NAME = PROP_PREFIX + "region"; private static final String PROP_ENDPOINT_NAME = PROP_PREFIX + "endpoint"; private static final String PROP_CREDENTIALS_PROFILE = PROP_PREFIX + "credentials.profile"; + private static final String PROP_BATCH_SIZE = PROP_PREFIX + "batch.size"; + private static final String PROP_RETRIES = PROP_PREFIX + "default.retries"; + + private static final int DEFAULT_RETRY_COUNT = 5; + private static final int MAX_BATCH_SIZE = 500; + private static final Duration RETRY_INTERVAL = Duration.ofSeconds(1); private String region; private Optional endpointOverride; private Optional credentialsProfile; - private static final int DEFAULT_RETRIES = 5; - private static final Duration RETRY_INTERVAL = Duration.ofSeconds(1); + private Integer batchSize; + private Integer maxRetries; @ConfigProperty(name = PROP_PREFIX + "null.key", defaultValue = "default") String nullKey; @@ -74,13 +86,23 @@ public class KinesisChangeConsumer extends BaseChangeConsumer implements Debeziu @PostConstruct void connect() { + final Config config = ConfigProvider.getConfig(); + batchSize = config.getOptionalValue(PROP_BATCH_SIZE, Integer.class).orElse(MAX_BATCH_SIZE); + maxRetries = config.getOptionalValue(PROP_RETRIES, Integer.class).orElse(DEFAULT_RETRY_COUNT); + + if (batchSize <= 0) { + throw new DebeziumException("Batch size must be greater than 0"); + } + else if (batchSize > MAX_BATCH_SIZE) { + throw new DebeziumException("Batch size must be less than or equal to MAX_BATCH_SIZE"); + } + if (customClient.isResolvable()) { client = customClient.get(); LOGGER.info("Obtained custom configured KinesisClient '{}'", client); return; } - final Config config = ConfigProvider.getConfig(); region = config.getValue(PROP_REGION_NAME, String.class); endpointOverride = config.getOptionalValue(PROP_ENDPOINT_NAME, String.class); credentialsProfile = config.getOptionalValue(PROP_CREDENTIALS_PROFILE, String.class); @@ -106,41 +128,103 @@ void close() { @Override public void handleBatch(List> records, RecordCommitter> committer) throws InterruptedException { - for (ChangeEvent record : records) { - LOGGER.trace("Received event '{}'", record); - - int attempts = 0; - while (!recordSent(record)) { - attempts++; - if (attempts >= DEFAULT_RETRIES) { - throw new DebeziumException("Exceeded maximum number of attempts to publish event " + record); + + // Guard if records are empty + if (records.isEmpty()) { + committer.markBatchFinished(); + return; + } + + String streamName; + List> batch = new ArrayList<>(); + + // Group the records by destination + Map>> segmentedBatches = records.stream().collect(Collectors.groupingBy(record -> record.destination())); + + // Iterate over the segmentedBatches + for (List> segmentedBatch : segmentedBatches.values()) { + // Iterate over the batch + + for (int i = 0; i < segmentedBatch.size(); i += batchSize) { + + // Create a sublist of the batch given the batchSize + batch = segmentedBatch.subList(i, Math.min(i + batchSize, segmentedBatch.size())); + List putRecordsRequestEntryList = new ArrayList<>(); + streamName = batch.get(0).destination(); + + for (ChangeEvent record : batch) { + + Object rv = record.value(); + if (rv == null) { + rv = ""; + } + PutRecordsRequestEntry putRecordsRequestEntry = PutRecordsRequestEntry.builder() + .partitionKey((record.key() != null) ? getString(record.key()) : nullKey) + .data(SdkBytes.fromByteArray(getBytes(rv))).build(); + putRecordsRequestEntryList.add(putRecordsRequestEntry); + } + + // Handle Error + boolean notSuccesful = true; + int attempts = 0; + List batchRequest = putRecordsRequestEntryList; + + while (notSuccesful) { + + if (attempts >= maxRetries) { + throw new DebeziumException("Exceeded maximum number of attempts to publish event"); + } + + try { + PutRecordsResponse response = recordsSent(batchRequest, streamName); + attempts++; + if (response.failedRecordCount() > 0) { + LOGGER.warn("Failed to send {} number of records, retrying", response.failedRecordCount()); + Metronome.sleeper(RETRY_INTERVAL, Clock.SYSTEM).pause(); + + final List putRecordsResults = response.records(); + List failedRecordsList = new ArrayList<>(); + + for (int index = 0; index < putRecordsResults.size(); index++) { + PutRecordsResultEntry entryResult = putRecordsResults.get(index); + if (entryResult.errorCode() != null) { + failedRecordsList.add(putRecordsRequestEntryList.get(index)); + } + } + batchRequest = failedRecordsList; + + } + else { + notSuccesful = false; + attempts = 0; + } + + } + catch (KinesisException exception) { + LOGGER.warn("Failed to send record to {}", streamName, exception); + attempts++; + Metronome.sleeper(RETRY_INTERVAL, Clock.SYSTEM).pause(); + } + } + + for (ChangeEvent record : batch) { + committer.markProcessed(record); } - Metronome.sleeper(RETRY_INTERVAL, Clock.SYSTEM).pause(); } - committer.markProcessed(record); } + + // Mark Batch Finished committer.markBatchFinished(); } - private boolean recordSent(ChangeEvent record) { - Object rv = record.value(); - if (rv == null) { - rv = ""; - } + private PutRecordsResponse recordsSent(List putRecordsRequestEntryList, String streamName) { - final PutRecordRequest putRecord = PutRecordRequest.builder() - .partitionKey((record.key() != null) ? getString(record.key()) : nullKey) - .streamName(streamNameMapper.map(record.destination())) - .data(SdkBytes.fromByteArray(getBytes(rv))) - .build(); + // Create a PutRecordsRequest + PutRecordsRequest putRecordsRequest = PutRecordsRequest.builder().streamName(streamNameMapper.map(streamName)).records(putRecordsRequestEntryList).build(); - try { - client.putRecord(putRecord); - return true; - } - catch (KinesisException exception) { - LOGGER.warn("Failed to send record to {}", record.destination(), exception); - return false; - } + // Send Request + PutRecordsResponse putRecordsResponse = client.putRecords(putRecordsRequest); + LOGGER.trace("Response Receieved: " + putRecordsResponse); + return putRecordsResponse; } } diff --git a/debezium-server-kinesis/src/test/java/io/debezium/server/kinesis/KinesisUnitTest.java b/debezium-server-kinesis/src/test/java/io/debezium/server/kinesis/KinesisUnitTest.java new file mode 100644 index 00000000..573890f3 --- /dev/null +++ b/debezium-server-kinesis/src/test/java/io/debezium/server/kinesis/KinesisUnitTest.java @@ -0,0 +1,333 @@ +/* + * Copyright Debezium Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.debezium.server.kinesis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import jakarta.enterprise.inject.Instance; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.debezium.engine.ChangeEvent; +import io.debezium.engine.DebeziumEngine.RecordCommitter; +import io.debezium.engine.Header; +import io.debezium.testing.testcontainers.PostgresTestResourceLifecycleManager; +import io.quarkus.test.common.QuarkusTestResource; +import io.quarkus.test.junit.QuarkusTest; + +import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.KinesisException; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; +import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry; + +@QuarkusTest +@QuarkusTestResource(PostgresTestResourceLifecycleManager.class) +public class KinesisUnitTest { + + private KinesisChangeConsumer kinesisChangeConsumer; + private KinesisClient spyClient; + private AtomicInteger counter; + private AtomicBoolean threwException; + List> changeEvents; + RecordCommitter> committer; + + @BeforeEach + public void setup() { + counter = new AtomicInteger(0); + threwException = new AtomicBoolean(false); + changeEvents = createChangeEvents(500, "key", "destination"); + committer = RecordCommitter(); + spyClient = spy(KinesisClient.builder().region(Region.of(KinesisTestConfigSource.KINESIS_REGION)) + .credentialsProvider(ProfileCredentialsProvider.create("default")).build()); + + Instance mockInstance = mock(Instance.class); + when(mockInstance.isResolvable()).thenReturn(true); + when(mockInstance.get()).thenReturn(spyClient); + + kinesisChangeConsumer = new KinesisChangeConsumer(); + kinesisChangeConsumer.customClient = mockInstance; + } + + @AfterEach + public void tearDown() { + reset(spyClient); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static List> createChangeEvents(int size, String key, String destination) { + List> changeEvents = new ArrayList<>(); + for (int i = 0; i < size; i++) { + ChangeEvent result = mock(ChangeEvent.class); + when(result.key()).thenReturn(key); + when(result.value()).thenReturn(Integer.toString(i)); + when(result.destination()).thenReturn(destination); + Header header = mock(Header.class); + when(header.getKey()).thenReturn(key); + when(header.getValue()).thenReturn(Integer.toString(i)); + when(result.headers()).thenReturn(List.of(header)); + changeEvents.add(result); + } + return changeEvents; + } + + @SuppressWarnings({ "unchecked" }) + private static RecordCommitter> RecordCommitter() { + RecordCommitter> result = mock(RecordCommitter.class); + return result; + } + + // 1. Test that continous sending of Kinesis response containing error yields exception after 5 attempts + @Test + public void testValidResponseWithErrorCode() throws Exception { + // Arrange + doAnswer(invocation -> { + PutRecordsRequest request = invocation.getArgument(0); + List records = request.records(); + counter.incrementAndGet(); + List failedEntries = records.stream().map(record -> PutRecordsResultEntry.builder().errorCode("ProvisionedThroughputExceededException") + .errorMessage("The request rate for the stream is too high").build()).collect(Collectors.toList()); + + return PutRecordsResponse.builder().failedRecordCount(records.size()).records(failedEntries).build(); + }).when(spyClient).putRecords(any(PutRecordsRequest.class)); + + // Act + try { + kinesisChangeConsumer.connect(); + kinesisChangeConsumer.handleBatch(changeEvents, RecordCommitter()); + } + catch (Exception e) { + threwException.getAndSet(true); + } + + // Assert + assertTrue(threwException.get()); + // DEFAULT_RETRIES is 5 times + assertEquals(5, counter.get()); + } + + // 2. Test that continous return of exception yields Debezium exception after 5 attempts + @Test + public void testExceptionWhileWritingData() throws Exception { + // Arrange + doAnswer(invocation -> { + counter.incrementAndGet(); + throw KinesisException.builder().message("Kinesis Exception").build(); + }).when(spyClient).putRecords(any(PutRecordsRequest.class)); + + // Act + try { + kinesisChangeConsumer.connect(); + kinesisChangeConsumer.handleBatch(changeEvents, committer); + } + catch (Exception e) { + threwException.getAndSet(true); + } + + // Assert + assertTrue(threwException.get()); + // DEFAULT_RETRIES is 5 times + assertEquals(5, counter.get()); + } + + // 3. Test that only failed records are re-sent + @Test + public void testResendFailedRecords() throws Exception { + // Arrange + AtomicBoolean firstCall = new AtomicBoolean(true); + List failedRecordsFromFirstCall = new ArrayList<>(); + List recordsFromSecondCall = new ArrayList<>(); + doAnswer(invocation -> { + List response = new ArrayList<>(); + PutRecordsRequest request = invocation.getArgument(0); + List records = request.records(); + counter.incrementAndGet(); + + if (firstCall.get()) { + int failedEntries = 100; + for (int i = 0; i < records.size(); i++) { + PutRecordsResultEntry recordResult; + if (i < failedEntries) { + recordResult = PutRecordsResultEntry.builder().errorCode("ProvisionedThroughputExceededException") + .errorMessage("The request rate for the stream is too high").build(); + + failedRecordsFromFirstCall.add(records.get(i)); + } + else { + recordResult = PutRecordsResultEntry.builder().shardId("shardId").sequenceNumber("sequenceNumber").build(); + } + response.add(recordResult); + } + firstCall.getAndSet(false); + return PutRecordsResponse.builder().failedRecordCount(failedEntries).records(response).build(); + } + else { + for (PutRecordsRequestEntry record : records) { + recordsFromSecondCall.add(record); + PutRecordsResultEntry recordResult = PutRecordsResultEntry.builder().shardId("shardId").sequenceNumber("sequenceNumber").build(); + response.add(recordResult); + } + return PutRecordsResponse.builder().failedRecordCount(0).records(response).build(); + } + }).when(spyClient).putRecords(any(PutRecordsRequest.class)); + + // Act + try { + kinesisChangeConsumer.connect(); + kinesisChangeConsumer.handleBatch(changeEvents, committer); + } + catch (Exception e) { + threwException.getAndSet(true); + } + + // Assert + assertFalse(threwException.get()); + assertEquals(2, counter.get()); + assertEquals(recordsFromSecondCall.size(), failedRecordsFromFirstCall.size()); + for (int i = 0; i < recordsFromSecondCall.size(); i++) { + assertEquals(failedRecordsFromFirstCall.get(i).data(), recordsFromSecondCall.get(i).data()); + } + } + + // 4. Create 600 ChangeEvents to destination 1 and 600 to destination 2 and test that they are correctly batched + @Test + public void testBatchesAreCorrect() throws Exception { + // Arrange + List> changeEvents = new ArrayList<>(); + String destinationOne = "dest1"; + String destinationTwo = "dest2"; + + // call createEvents with 600 records for destination 1 and 600 records for destination 2 + changeEvents = createChangeEvents(600, destinationOne, destinationOne); + changeEvents.addAll(createChangeEvents(600, destinationTwo, destinationTwo)); + + AtomicInteger numRecordsDestinationOne = new AtomicInteger(0); + AtomicInteger numRrecordsDestinationTwo = new AtomicInteger(0); + AtomicInteger numBatches = new AtomicInteger(0); + + doAnswer(invocation -> { + List response = new ArrayList<>(); + PutRecordsRequest request = invocation.getArgument(0); + List records = request.records(); + for (PutRecordsRequestEntry record : records) { + if (record.partitionKey().equals(destinationOne)) { + numRecordsDestinationOne.incrementAndGet(); + } + else if (record.partitionKey().equals(destinationTwo)) { + numRrecordsDestinationTwo.incrementAndGet(); + } + PutRecordsResultEntry recordResult = PutRecordsResultEntry.builder().shardId("shardId").sequenceNumber("sequenceNumber").build(); + response.add(recordResult); + } + numBatches.incrementAndGet(); + return PutRecordsResponse.builder().failedRecordCount(0).records(response).build(); + }).when(spyClient).putRecords(any(PutRecordsRequest.class)); + + // Act + try { + kinesisChangeConsumer.connect(); + kinesisChangeConsumer.handleBatch(changeEvents, committer); + } + catch (Exception e) { + threwException.getAndSet(true); + } + + // Assert + // No exception should be thrown + assertFalse(threwException.get()); + // 2 destinations, 600 records each + assertEquals(600, numRecordsDestinationOne.get()); + assertEquals(600, numRrecordsDestinationTwo.get()); + // 2 destinations, 2 batches each + assertEquals(4, numBatches.get()); + } + + // 5. Test that empty records are handled correctly + @Test + public void testEmptyRecords() throws Exception { + // Arrange + List> changeEvents = new ArrayList<>(); + + // Act + try { + kinesisChangeConsumer.connect(); + kinesisChangeConsumer.handleBatch(changeEvents, committer); + } + catch (Exception e) { + threwException.getAndSet(true); + } + + // Assert + assertFalse(threwException.get()); + } + + // 6. Test that a batch of 1000 records is correctly split into 2 batches of 500 records + @Test + public void testBatchSplitting() throws Exception { + // Arrange + List> changeEvents = createChangeEvents(1000, "key", "destination"); + + AtomicInteger numBatches = new AtomicInteger(0); + AtomicInteger numRecordsBatchOne = new AtomicInteger(0); + AtomicInteger numRecordsBatchTwo = new AtomicInteger(0); + AtomicBoolean firstBatch = new AtomicBoolean(true); + + doAnswer(invocation -> { + List response = new ArrayList<>(); + PutRecordsRequest request = invocation.getArgument(0); + List records = request.records(); + + for (PutRecordsRequestEntry record : records) { + if (firstBatch.get()) { + numRecordsBatchOne.incrementAndGet(); + } + else { + numRecordsBatchTwo.incrementAndGet(); + } + PutRecordsResultEntry recordResult = PutRecordsResultEntry.builder().shardId("shardId").sequenceNumber("sequenceNumber").build(); + response.add(recordResult); + } + numBatches.incrementAndGet(); + firstBatch.getAndSet(false); + return PutRecordsResponse.builder().failedRecordCount(0).records(response).build(); + }).when(spyClient).putRecords(any(PutRecordsRequest.class)); + + // Act + try { + kinesisChangeConsumer.connect(); + kinesisChangeConsumer.handleBatch(changeEvents, committer); + } + catch (Exception e) { + threwException.getAndSet(true); + } + + // Assert + assertFalse(threwException.get()); + assertEquals(2, numBatches.get()); + assertEquals(500, numRecordsBatchOne.get()); + assertEquals(500, numRecordsBatchTwo.get()); + } + +}