diff --git a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala index 7357d231..18c6b2b6 100644 --- a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala +++ b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala @@ -20,16 +20,19 @@ package org.apache.spark.streaming.pubsub import java.io.{Externalizable, ObjectInput, ObjectOutput} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport import com.google.api.client.googleapis.json.GoogleJsonResponseException import com.google.api.client.json.jackson2.JacksonFactory import com.google.api.services.pubsub.Pubsub.Builder -import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage, PullRequest} -import com.google.api.services.pubsub.model.Subscription +import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage, PullRequest, ReceivedMessage, Subscription} import com.google.cloud.hadoop.util.RetryHttpInitializer +import com.google.common.util.concurrent.RateLimiter +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -51,11 +54,18 @@ class PubsubInputDStream( val subscription: String, val credential: SparkGCPCredentials, val _storageLevel: StorageLevel, - val autoAcknowledge: Boolean + val autoAcknowledge: Boolean, + val maxNoOfMessageInRequest: Int, + val rateMultiplierFactor: Double, + val endpoint: String, + conf: SparkConf ) extends ReceiverInputDStream[SparkPubsubMessage](_ssc) { override def getReceiver(): Receiver[SparkPubsubMessage] = { - new PubsubReceiver(project, topic, subscription, credential, _storageLevel, autoAcknowledge) + new PubsubReceiver( + project, topic, subscription, credential, _storageLevel, autoAcknowledge, + maxNoOfMessageInRequest, rateMultiplierFactor, endpoint, conf + ) } } @@ -214,7 +224,35 @@ object ConnectionUtils { } } - +/** + * Custom spark receiver to pull messages from Pubsub topic and push into reliable store. + * If backpressure is enabled,the message ingestion rate for this receiver will be managed by Spark. + * + * Following spark configurations can be used to control rates and block size + * spark.streaming.backpressure.initialRate + * spark.streaming.receiver.maxRate + * spark.streaming.blockQueueSize: Controlling block size + * spark.streaming.backpressure.pid.minRate + * + * See Spark streaming configurations doc + * val sub: Subscription = new Subscription sub.setTopic(s"$projectFullName/topics/$t") + sub.setAckDeadlineSeconds(30) try { client.projects().subscriptions().create(subscriptionFullName, sub).execute() } catch { @@ -262,6 +316,7 @@ class PubsubReceiver( } case None => // do nothing } + new Thread() { override def run() { receive() @@ -270,30 +325,31 @@ class PubsubReceiver( } def receive(): Unit = { - val pullRequest = new PullRequest().setMaxMessages(MAX_MESSAGE).setReturnImmediately(false) + val pullRequest = new PullRequest() + .setMaxMessages(maxNoOfMessageInRequest).setReturnImmediately(false) var backoff = INIT_BACKOFF + + // To avoid the edge case when buffer is not full and no message pushed to store + latestAttemptToPushInStoreTime = System.currentTimeMillis() + while (!isStopped()) { try { + val pullResponse = client.projects().subscriptions().pull(subscriptionFullName, pullRequest).execute() val receivedMessages = pullResponse.getReceivedMessages + + // update rate limit if required + updateRateLimit() + + // Put data into buffer if (receivedMessages != null) { - store(receivedMessages.asScala.toList - .map(x => { - val sm = new SparkPubsubMessage - sm.message = x.getMessage - sm.ackId = x.getAckId - sm - }) - .iterator) - - if (autoAcknowledge) { - val ackRequest = new AcknowledgeRequest() - ackRequest.setAckIds(receivedMessages.asScala.map(x => x.getAckId).asJava) - client.projects().subscriptions().acknowledge(subscriptionFullName, - ackRequest).execute() - } + buffer.appendAll(receivedMessages.asScala) } + + // Push data from buffer to store + push() + backoff = INIT_BACKOFF } catch { case e: GoogleJsonResponseException => @@ -308,5 +364,109 @@ class PubsubReceiver( } } + def getInitialRateLimit: Long = { + math.min( + conf.getLong("spark.streaming.backpressure.initialRate", maxRateLimit), + maxRateLimit + ) + } + + /** + * Get the new recommended rate at which receiver should push data into store + * and update the rate limiter with new rate + */ + def updateRateLimit(): Unit = { + val newRateLimit = rateMultiplierFactor * supervisor.getCurrentRateLimit.min(maxRateLimit) + if (rateLimiter.getRate != newRateLimit) { + rateLimiter.setRate(newRateLimit) + logInfo("New rateLimit:: " + newRateLimit) + } + } + + /** + * Push data into store if + * 1. buffer size greater than equal to blockSize, or + * 2. blockInterval time is passed and buffer size is less than blockSize + * + * Before pushing the messages, first create iterator of complete block(s) and partial blocks + * and assigning new array to buffer. + * + * So during pushing data into store if any {@link org.apache.spark.SparkException} occur + * then all un-push messages or un-ack will be lost. + * + * To recover lost messages we are relying on pubsub + * (i.e after ack deadline passed then pubsub will again give that messages) + */ + def push(): Unit = { + + val diff = System.currentTimeMillis() - latestAttemptToPushInStoreTime + if (buffer.length >= blockSize || (buffer.length < blockSize && diff >= blockIntervalMs)) { + + // grouping messages into complete and partial blocks (if any) + val (completeBlocks, partialBlock) = buffer.grouped(blockSize) + .partition(block => block.length == blockSize) + + // If completeBlocks is empty it means within block interval time + // messages in buffer is less than blockSize. So will push partial block + val iterator = if (completeBlocks.nonEmpty) completeBlocks else partialBlock + + // Will push partial block messages back to buffer if complete blocks formed + val partial = if (completeBlocks.nonEmpty && partialBlock.nonEmpty) { + partialBlock.next() + } else null + + while (iterator.hasNext) { + try { + pushToStoreAndAck(iterator.next().toList) + } catch { + case e: SparkException => reportError( + "Failed to write messages into reliable store", e) + case NonFatal(e) => reportError( + "Failed to write messages in reliable store", e) + } finally { + latestAttemptToPushInStoreTime = System.currentTimeMillis() + } + } + + // clear existing buffer messages + buffer.clear() + + // Pushing partial block messages back to buffer if complete blocks formed + if (partial != null) buffer.appendAll(partial) + } + } + + /** + * Push the list of received message into store and ack messages if auto ack is true + * @param receivedMessages + */ + def pushToStoreAndAck(receivedMessages: List[ReceivedMessage]): Unit = { + val messages = receivedMessages + .map(x => { + val sm = new SparkPubsubMessage + sm.message = x.getMessage + sm.ackId = x.getAckId + sm}) + + rateLimiter.acquire(messages.size) + store(messages.toIterator) + if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId)) + } + + /** + * Acknowledge Message ackIds + * @param ackIds + */ + def acknowledgeIds(ackIds: List[String]): Unit = { + val ackRequest = new AcknowledgeRequest() + ackRequest.setAckIds(ackIds.asJava) + client.projects().subscriptions() + .acknowledge(subscriptionFullName, ackRequest).execute() + } + + private def createBufferArray(): ArrayBuffer[ReceivedMessage] = { + new ArrayBuffer[ReceivedMessage](2 * math.max(maxNoOfMessageInRequest, blockSize)) + } + override def onStop(): Unit = {} } diff --git a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala index 05214c34..c87cbcd9 100644 --- a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala +++ b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala @@ -17,12 +17,15 @@ package org.apache.spark.streaming.pubsub +import com.google.api.services.pubsub.Pubsub + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.JavaReceiverInputDStream import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream + object PubsubUtils { /** @@ -50,7 +53,10 @@ object PubsubUtils { subscription: String, credentials: SparkGCPCredentials, storageLevel: StorageLevel, - autoAcknowledge: Boolean = true): ReceiverInputDStream[SparkPubsubMessage] = { + autoAcknowledge: Boolean = true, + maxNoOfMessageInRequest: Int = 1000, + rateMultiplierFactor: Double = 1.0, + endpoint: String = Pubsub.DEFAULT_ROOT_URL): ReceiverInputDStream[SparkPubsubMessage] = { ssc.withNamedScope("pubsub stream") { new PubsubInputDStream( @@ -60,7 +66,12 @@ object PubsubUtils { subscription, credentials, storageLevel, - autoAcknowledge) + autoAcknowledge, + maxNoOfMessageInRequest, + rateMultiplierFactor, + endpoint, + ssc.conf + ) } } diff --git a/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala b/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala index 8a67038d..64b3f0e7 100644 --- a/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala +++ b/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.ConditionalSparkFunSuite +import org.apache.spark.{ConditionalSparkFunSuite, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Seconds import org.apache.spark.streaming.StreamingContext @@ -35,6 +35,8 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be val batchDuration = Seconds(1) + val blockSize = 15 + private val master: String = "local[2]" private val appName: String = this.getClass.getSimpleName @@ -70,6 +72,17 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be } } + + def setSparkBackPressureConf(conf: SparkConf) : Unit = { + conf.set("spark.streaming.backpressure.enabled", "true") + conf.set("spark.streaming.backpressure.initialRate", "50") + conf.set("spark.streaming.receiver.maxRate", "100") + conf.set("spark.streaming.backpressure.pid.minRate", "10") + conf.set("spark.streaming.blockQueueSize", blockSize.toString) + conf.set("spark.streaming.blockInterval", "1000ms") + } + + before { ssc = new StreamingContext(master, appName, batchDuration) } @@ -113,6 +126,27 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be sendReceiveMessages(receiveStream) } + testIf("check block size", () => PubsubTestUtils.shouldRunTest()) { + setSparkBackPressureConf(ssc.sparkContext.conf) + val receiveStream = PubsubUtils.createStream( + ssc, PubsubTestUtils.projectId, Some(topicName), subForCreateName, + PubsubTestUtils.credential, StorageLevel.MEMORY_AND_DISK_SER_2, autoAcknowledge = true, 50) + + @volatile var partitionSize: Set[Int] = Set[Int]() + receiveStream.foreachRDD(rdd => { + rdd.collectPartitions().foreach(partition => { + partitionSize += partition.length + }) + }) + + ssc.start() + + eventually(timeout(100000 milliseconds), interval(1000 milliseconds)) { + pubsubTestUtils.publishData(topicFullName, pubsubTestUtils.generatorMessages(100)) + assert(partitionSize.max == blockSize) + } + } + private def sendReceiveMessages(receiveStream: ReceiverInputDStream[SparkPubsubMessage]): Unit = { @volatile var receiveMessages: List[SparkPubsubMessage] = List() receiveStream.foreachRDD { rdd =>