Skip to content

Commit

Permalink
Feature/241 deduplicate save memory 2 (#242)
Browse files Browse the repository at this point in the history
* works

* Fix test

* Fix code style

* Fix scalastyle
  • Loading branch information
kevinwallimann authored Oct 12, 2021
1 parent bab72cc commit eec8181
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ import java.time.Duration
import java.util
import java.util.UUID.randomUUID
import java.util.{Collections, Properties}

import org.apache.avro.Schema.Parser
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.kafka.clients.admin.{AdminClient, AdminClientConfig, NewTopic}
import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.consumer.{ConsumerRecord, KafkaConsumer}
import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import za.co.absa.abris.avro.read.confluent.SchemaManagerFactory
import za.co.absa.commons.io.TempDirectory
import za.co.absa.commons.spark.SparkTestBase
import za.co.absa.abris.avro.registry.SchemaSubject
import za.co.absa.hyperdrive.ingestor.implementation.transformer.deduplicate.kafka.PrunedConsumerRecord
import za.co.absa.hyperdrive.ingestor.implementation.utils.KafkaUtil
import za.co.absa.hyperdrive.shared.exceptions.IngestionException

Expand Down Expand Up @@ -61,6 +61,12 @@ class KafkaToKafkaDeduplicationAfterRetryDockerTest extends FlatSpec with Matche

private val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration)
private var baseDir: TempDirectory = _
private val pruningFn = (r: ConsumerRecord[GenericRecord, GenericRecord]) => PrunedConsumerRecord(
r.topic(),
r.partition(),
r.offset(),
Seq(r.value().get("record_id"))
)

behavior of "CommandLineIngestionDriver"

Expand All @@ -79,10 +85,11 @@ class KafkaToKafkaDeduplicationAfterRetryDockerTest extends FlatSpec with Matche
executeTestCase(deduplicatorConfig, recordIdsV1, recordIdsV2, kafkaSchemaRegistryWrapper, destinationTopic)

val consumer = createConsumer(kafkaSchemaRegistryWrapper)
val records = getAllMessages(consumer, destinationTopic)
val valueFieldNames = records.head.value().getSchema.getFields.asScala.map(_.name())
val valueFieldNames = getValueSchema(consumer, destinationTopic).getFields.asScala.map(_.name())
val consumer2 = createConsumer(kafkaSchemaRegistryWrapper)
val records = getAllMessages(consumer2, destinationTopic, pruningFn)
val actualRecordIds = records.flatMap(_.data.map(_.asInstanceOf[Int]))
valueFieldNames should contain theSameElementsAs List("record_id", "value_field", "hyperdrive_id")
val actualRecordIds = records.map(_.value().get("record_id"))
actualRecordIds.distinct.size shouldBe actualRecordIds.size
actualRecordIds should contain theSameElementsAs recordIdsV1 ++ recordIdsV2
}
Expand All @@ -96,10 +103,11 @@ class KafkaToKafkaDeduplicationAfterRetryDockerTest extends FlatSpec with Matche
executeTestCase(Map(), recordIdsV1, recordIdsV2, kafkaSchemaRegistryWrapper, destinationTopic)

val consumer = createConsumer(kafkaSchemaRegistryWrapper)
val records = getAllMessages(consumer, destinationTopic)
val valueFieldNames = records.head.value().getSchema.getFields.asScala.map(_.name())
val valueFieldNames = getValueSchema(consumer, destinationTopic).getFields.asScala.map(_.name())
val consumer2 = createConsumer(kafkaSchemaRegistryWrapper)
val records = getAllMessages(consumer2, destinationTopic, pruningFn)
val actualRecordIds = records.flatMap(_.data)
valueFieldNames should contain theSameElementsAs List("record_id", "value_field", "hyperdrive_id")
val actualRecordIds = records.map(_.value().get("record_id"))
actualRecordIds.distinct.size should be < actualRecordIds.size
}

Expand Down Expand Up @@ -265,10 +273,15 @@ class KafkaToKafkaDeduplicationAfterRetryDockerTest extends FlatSpec with Matche
kafkaSchemaRegistryWrapper.createConsumer(props)
}

private def getAllMessages[K, V](consumer: KafkaConsumer[K, V], topic: String) = {
private def getValueSchema(consumer: KafkaConsumer[GenericRecord, GenericRecord], topic: String) = {
consumer.subscribe(Seq(topic).asJava)
consumer.poll(Duration.ofSeconds(10L)).asScala.head.value().getSchema
}

private def getAllMessages[K, V](consumer: KafkaConsumer[K, V], topic: String, pruningFn: ConsumerRecord[K, V] => PrunedConsumerRecord) = {
val topicPartitions = KafkaUtil.getTopicPartitions(consumer, topic)
val offsets = consumer.endOffsets(topicPartitions.asJava)
implicit val kafkaConsumerTimeout: Duration = Duration.ofSeconds(10L)
KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets.asScala.mapValues(Long2long).toMap)
KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets.asScala.mapValues(Long2long).toMap, pruningFn)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ private[transformer] class DeduplicateKafkaSinkTransformer(
logOffsets(latestOffsetsOpt)

val sourceRecords = latestOffsetsOpt.map(latestOffset => consumeAndClose(sourceConsumer,
consumer => KafkaUtil.getMessagesAtLeastToOffset(consumer, latestOffset))).getOrElse(Seq())
val sourceIds = sourceRecords.map(extractIdFieldsFromRecord(_, sourceIdColumnNames))
consumer => KafkaUtil.getMessagesAtLeastToOffset(consumer, latestOffset, pruneRecord(sourceIdColumnNames)))).getOrElse(Seq())
val sourceIds = sourceRecords.map(_.data)

val sinkConsumer = createConsumer(writerBrokers, writerExtraOptions, encoderSchemaRegistryConfig)
val sinkTopicPartitions = KafkaUtil.getTopicPartitions(sinkConsumer, writerTopic)
val recordsPerPartition = sinkTopicPartitions.map(p => p -> sourceRecords.size.toLong).toMap
val latestSinkRecords = consumeAndClose(sinkConsumer, consumer =>
KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, recordsPerPartition))
KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, recordsPerPartition, pruneRecord(destinationIdColumnNames)))
logConsumedSinkRecords(latestSinkRecords)

val publishedIds = latestSinkRecords.map(extractIdFieldsFromRecord(_, destinationIdColumnNames))
val publishedIds = latestSinkRecords.map(_.data)
val duplicatedIds = sourceIds.intersect(publishedIds)
logDuplicatedIds(duplicatedIds)
val duplicatedIdsLit = duplicatedIds.map(duplicatedId => struct(duplicatedId.map(lit): _*))
Expand Down Expand Up @@ -123,8 +123,8 @@ private[transformer] class DeduplicateKafkaSinkTransformer(
logger.info(s"Reset source offsets by partition to { ${currentPositions} }")
}

private def logConsumedSinkRecords(latestSinkRecords: Seq[ConsumerRecord[GenericRecord, GenericRecord]]): Unit = {
val offsetsByPartition = latestSinkRecords.map(r => r.partition() -> r.offset())
private def logConsumedSinkRecords(latestSinkRecords: Seq[PrunedConsumerRecord]): Unit = {
val offsetsByPartition = latestSinkRecords.map(r => r.partition -> r.offset)
.groupBy(_._1)
.mapValues(_.map(_._2))
.toSeq
Expand All @@ -138,11 +138,18 @@ private[transformer] class DeduplicateKafkaSinkTransformer(
logger.info(s"Found ${duplicatedIds.size} duplicated ids. First three: ${duplicatedIds.take(3)}.")
}

private def extractIdFieldsFromRecord(record: ConsumerRecord[GenericRecord, GenericRecord], idColumnNames: Seq[String]): Seq[Any] = {
idColumnNames.map(idColumnName =>
AvroUtil.getFromConsumerRecord(record, idColumnName)
.getOrElse(throw new IllegalArgumentException(s"Could not find value for field $idColumnName"))
)
private def pruneRecord(idColumnNames: Seq[String]): ConsumerRecord[GenericRecord, GenericRecord] => PrunedConsumerRecord = {
record: ConsumerRecord[GenericRecord, GenericRecord] =>
val prunedPayload = idColumnNames.map(idColumnName =>
AvroUtil.getFromConsumerRecord(record, idColumnName)
.getOrElse(throw new IllegalArgumentException(s"Could not find value for field $idColumnName"))
)
PrunedConsumerRecord(
record.topic(),
record.partition(),
record.offset(),
prunedPayload
)
}

private def consumeAndClose[T](consumer: KafkaConsumer[GenericRecord, GenericRecord], consume: KafkaConsumer[GenericRecord, GenericRecord] => T) = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.hyperdrive.ingestor.implementation.transformer.deduplicate.kafka

case class PrunedConsumerRecord(
topic: String,
partition: Int,
offset: Long,
data: Seq[Any]
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,28 @@ import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog}
import org.apache.spark.sql.kafka010.KafkaSourceOffsetProxy
import za.co.absa.hyperdrive.compatibility.provider.CompatibleOffsetProvider
import za.co.absa.hyperdrive.ingestor.implementation.transformer.deduplicate.kafka.PrunedConsumerRecord

import scala.collection.JavaConverters._
import scala.collection.mutable

private[hyperdrive] object KafkaUtil {
private val logger = LogManager.getLogger

def getAtLeastNLatestRecordsFromPartition[K, V](consumer: KafkaConsumer[K, V], numberOfRecords: Map[TopicPartition, Long])
(implicit kafkaConsumerTimeout: Duration): Seq[ConsumerRecord[K, V]] = {
def getAtLeastNLatestRecordsFromPartition[K, V](consumer: KafkaConsumer[K, V], numberOfRecords: Map[TopicPartition, Long],
pruningFn: ConsumerRecord[K, V] => PrunedConsumerRecord)
(implicit kafkaConsumerTimeout: Duration): Seq[PrunedConsumerRecord] = {
consumer.assign(numberOfRecords.keySet.asJava)
val endOffsets = consumer.endOffsets(numberOfRecords.keySet.asJava).asScala.mapValues(Long2long)
val topicPartitions = endOffsets.keySet

var records: Seq[ConsumerRecord[K, V]] = Seq()
var records: Seq[PrunedConsumerRecord] = Seq()
val offsetLowerBounds = mutable.Map(endOffsets.toSeq: _*)
import scala.util.control.Breaks._
breakable {
while (true) {
val recordSizes = records
.groupBy(r => new TopicPartition(r.topic(), r.partition()))
.groupBy(r => new TopicPartition(r.topic, r.partition))
.mapValues(records => records.size)
val unfinishedPartitions = topicPartitions.filter(p => recordSizes.getOrElse(p, 0) < numberOfRecords(p) && offsetLowerBounds(p) != 0)
if (unfinishedPartitions.isEmpty) {
Expand All @@ -54,15 +56,16 @@ private[hyperdrive] object KafkaUtil {
offsetLowerBounds.foreach {
case (partition, offset) => consumer.seek(partition, offset)
}
records = getMessagesAtLeastToOffset(consumer, endOffsets.toMap)
records = getMessagesAtLeastToOffset(consumer, endOffsets.toMap, pruningFn)
}
}

records
}

def getMessagesAtLeastToOffset[K, V](consumer: KafkaConsumer[K, V], toOffsets: Map[TopicPartition, Long])
(implicit kafkaConsumerTimeout: Duration): Seq[ConsumerRecord[K, V]] = {
def getMessagesAtLeastToOffset[K, V](consumer: KafkaConsumer[K, V], toOffsets: Map[TopicPartition, Long],
pruningFn: ConsumerRecord[K, V] => PrunedConsumerRecord)
(implicit kafkaConsumerTimeout: Duration): Seq[PrunedConsumerRecord] = {
consumer.assign(toOffsets.keySet.asJava)
val endOffsets = consumer.endOffsets(toOffsets.keys.toSeq.asJava).asScala
endOffsets.foreach { case (topicPartition, offset) =>
Expand All @@ -74,11 +77,11 @@ private[hyperdrive] object KafkaUtil {
}

import scala.util.control.Breaks._
var records: Seq[ConsumerRecord[K, V]] = mutable.Seq()
var records: Seq[PrunedConsumerRecord] = mutable.Seq()
breakable {
while (true) {
val newRecords = consumer.poll(kafkaConsumerTimeout).asScala.toSeq
records ++= newRecords
records ++= newRecords.map(pruningFn)
if (newRecords.isEmpty || offsetsHaveBeenReached(consumer, toOffsets)) {
break()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import java.time.Duration
import java.util
import java.util.UUID.randomUUID
import java.util.{Collections, Properties}

import org.apache.kafka.clients.admin.{AdminClient, AdminClientConfig, NewTopic}
import org.apache.kafka.clients.consumer.{ConsumerRecord, KafkaConsumer}
import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord}
Expand All @@ -28,6 +27,7 @@ import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializ
import org.scalatest.{AppendedClues, BeforeAndAfter, FlatSpec, Matchers}
import org.testcontainers.containers.KafkaContainer
import org.testcontainers.utility.DockerImageName
import za.co.absa.hyperdrive.ingestor.implementation.transformer.deduplicate.kafka.PrunedConsumerRecord

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -39,6 +39,12 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter
private val kafkaInsufficientTimeout = Duration.ofMillis(1L)
private val topic = "test-topic"
private val maxPollRecords = 10
private val pruningFn = (r: ConsumerRecord[String, String]) => PrunedConsumerRecord(
r.topic(),
r.partition(),
r.offset(),
Seq(r.value())
)

before{
kafka.start()
Expand All @@ -62,10 +68,10 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter

// when
implicit val kafkaConsumerTimeout: Duration = kafkaSufficientTimeout
val records = KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets)
val records = KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets, pruningFn)

// then
val actualMessages = records.map(_.value()).toList.sorted
val actualMessages = records.map(_.data.head.asInstanceOf[String]).toList.sorted
actualMessages should contain theSameElementsAs messages
}

Expand Down Expand Up @@ -99,10 +105,10 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter

// when
implicit val kafkaConsumerTimeout: Duration = kafkaSufficientTimeout
val records = KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets)
val records = KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets, pruningFn)

// then
val actualMessages = records.map(_.value()).toList.sorted
val actualMessages = records.map(_.data.head.asInstanceOf[String]).toList.sorted
actualMessages should contain allElementsOf messages

// cleanup
Expand All @@ -118,7 +124,7 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter

// when
implicit val kafkaConsumerTimeout: Duration = kafkaInsufficientTimeout
val exception = the[Exception] thrownBy KafkaUtil.getMessagesAtLeastToOffset(consumer, Map(new TopicPartition(topic, 0) -> 0))
val exception = the[Exception] thrownBy KafkaUtil.getMessagesAtLeastToOffset(consumer, Map(new TopicPartition(topic, 0) -> 0), pruningFn)

// then
exception.getMessage should include ("Subscription to topics, partitions and pattern are mutually exclusive")
Expand All @@ -140,7 +146,7 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter

// when
implicit val kafkaConsumerTimeout: Duration = kafkaInsufficientTimeout
val exception = the[Exception] thrownBy KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets)
val exception = the[Exception] thrownBy KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets, pruningFn)

// then
exception.getMessage should include ("Not all expected messages were consumed")
Expand All @@ -160,7 +166,7 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter

// when
implicit val kafkaConsumerTimeout: Duration = kafkaInsufficientTimeout
val exception = the[Exception] thrownBy KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets)
val exception = the[Exception] thrownBy KafkaUtil.getMessagesAtLeastToOffset(consumer, offsets, pruningFn)

// then
exception.getMessage should include ("Requested consumption")
Expand Down Expand Up @@ -209,8 +215,8 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter
implicit val kafkaConsumerTimeout: Duration = kafkaSufficientTimeout
val topicPartitions = KafkaUtil.getTopicPartitions(consumer, topic)
val recordsPerPartition = topicPartitions.map(p => p -> 4L).toMap
val actualRecords = KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, recordsPerPartition)
val values = actualRecords.map(_.value())
val actualRecords = KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, recordsPerPartition, pruningFn)
val values = actualRecords.map(_.data.head.asInstanceOf[String])

values.size should be >= 12
values should contain allElementsOf Seq("msg_103", "msg_102", "msg_101", "msg_100", "msg_99", "msg_97", "msg_95",
Expand All @@ -231,10 +237,10 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter
// when
implicit val kafkaConsumerTimeout: Duration = kafkaSufficientTimeout
val recordsPerPartition = topicPartitions.map(t => t -> 1000L).toMap
val records = KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, recordsPerPartition)
val records = KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, recordsPerPartition, pruningFn)

// then
val actualMessages = records.map(_.value()).toList.sorted
val actualMessages = records.map(_.data.head.asInstanceOf[String]).toList.sorted
actualMessages should contain theSameElementsAs messages
}

Expand All @@ -248,7 +254,8 @@ class TestKafkaUtilDockerTest extends FlatSpec with Matchers with BeforeAndAfter

val consumer = createConsumer(kafka)
implicit val kafkaConsumerTimeout: Duration = kafkaInsufficientTimeout
val result = the[Exception] thrownBy KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer, Map(new TopicPartition(topic, 0) -> 10))
val result = the[Exception] thrownBy KafkaUtil.getAtLeastNLatestRecordsFromPartition(consumer,
Map(new TopicPartition(topic, 0) -> 10), pruningFn)
result.getMessage should include("increasing the consumer timeout")
}

Expand Down

0 comments on commit eec8181

Please sign in to comment.