diff --git a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala
index 7357d231..18c6b2b6 100644
--- a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala
+++ b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala
@@ -20,16 +20,19 @@ package org.apache.spark.streaming.pubsub
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport
import com.google.api.client.googleapis.json.GoogleJsonResponseException
import com.google.api.client.json.jackson2.JacksonFactory
import com.google.api.services.pubsub.Pubsub.Builder
-import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage, PullRequest}
-import com.google.api.services.pubsub.model.Subscription
+import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage, PullRequest, ReceivedMessage, Subscription}
import com.google.cloud.hadoop.util.RetryHttpInitializer
+import com.google.common.util.concurrent.RateLimiter
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.internal.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream.ReceiverInputDStream
@@ -51,11 +54,18 @@ class PubsubInputDStream(
val subscription: String,
val credential: SparkGCPCredentials,
val _storageLevel: StorageLevel,
- val autoAcknowledge: Boolean
+ val autoAcknowledge: Boolean,
+ val maxNoOfMessageInRequest: Int,
+ val rateMultiplierFactor: Double,
+ val endpoint: String,
+ conf: SparkConf
) extends ReceiverInputDStream[SparkPubsubMessage](_ssc) {
override def getReceiver(): Receiver[SparkPubsubMessage] = {
- new PubsubReceiver(project, topic, subscription, credential, _storageLevel, autoAcknowledge)
+ new PubsubReceiver(
+ project, topic, subscription, credential, _storageLevel, autoAcknowledge,
+ maxNoOfMessageInRequest, rateMultiplierFactor, endpoint, conf
+ )
}
}
@@ -214,7 +224,35 @@ object ConnectionUtils {
}
}
-
+/**
+ * Custom spark receiver to pull messages from Pubsub topic and push into reliable store.
+ * If backpressure is enabled,the message ingestion rate for this receiver will be managed by Spark.
+ *
+ * Following spark configurations can be used to control rates and block size
+ * spark.streaming.backpressure.initialRate
+ * spark.streaming.receiver.maxRate
+ * spark.streaming.blockQueueSize: Controlling block size
+ * spark.streaming.backpressure.pid.minRate
+ *
+ * See Spark streaming configurations doc
+ *
val sub: Subscription = new Subscription
sub.setTopic(s"$projectFullName/topics/$t")
+ sub.setAckDeadlineSeconds(30)
try {
client.projects().subscriptions().create(subscriptionFullName, sub).execute()
} catch {
@@ -262,6 +316,7 @@ class PubsubReceiver(
}
case None => // do nothing
}
+
new Thread() {
override def run() {
receive()
@@ -270,30 +325,31 @@ class PubsubReceiver(
}
def receive(): Unit = {
- val pullRequest = new PullRequest().setMaxMessages(MAX_MESSAGE).setReturnImmediately(false)
+ val pullRequest = new PullRequest()
+ .setMaxMessages(maxNoOfMessageInRequest).setReturnImmediately(false)
var backoff = INIT_BACKOFF
+
+ // To avoid the edge case when buffer is not full and no message pushed to store
+ latestAttemptToPushInStoreTime = System.currentTimeMillis()
+
while (!isStopped()) {
try {
+
val pullResponse =
client.projects().subscriptions().pull(subscriptionFullName, pullRequest).execute()
val receivedMessages = pullResponse.getReceivedMessages
+
+ // update rate limit if required
+ updateRateLimit()
+
+ // Put data into buffer
if (receivedMessages != null) {
- store(receivedMessages.asScala.toList
- .map(x => {
- val sm = new SparkPubsubMessage
- sm.message = x.getMessage
- sm.ackId = x.getAckId
- sm
- })
- .iterator)
-
- if (autoAcknowledge) {
- val ackRequest = new AcknowledgeRequest()
- ackRequest.setAckIds(receivedMessages.asScala.map(x => x.getAckId).asJava)
- client.projects().subscriptions().acknowledge(subscriptionFullName,
- ackRequest).execute()
- }
+ buffer.appendAll(receivedMessages.asScala)
}
+
+ // Push data from buffer to store
+ push()
+
backoff = INIT_BACKOFF
} catch {
case e: GoogleJsonResponseException =>
@@ -308,5 +364,109 @@ class PubsubReceiver(
}
}
+ def getInitialRateLimit: Long = {
+ math.min(
+ conf.getLong("spark.streaming.backpressure.initialRate", maxRateLimit),
+ maxRateLimit
+ )
+ }
+
+ /**
+ * Get the new recommended rate at which receiver should push data into store
+ * and update the rate limiter with new rate
+ */
+ def updateRateLimit(): Unit = {
+ val newRateLimit = rateMultiplierFactor * supervisor.getCurrentRateLimit.min(maxRateLimit)
+ if (rateLimiter.getRate != newRateLimit) {
+ rateLimiter.setRate(newRateLimit)
+ logInfo("New rateLimit:: " + newRateLimit)
+ }
+ }
+
+ /**
+ * Push data into store if
+ * 1. buffer size greater than equal to blockSize, or
+ * 2. blockInterval time is passed and buffer size is less than blockSize
+ *
+ * Before pushing the messages, first create iterator of complete block(s) and partial blocks
+ * and assigning new array to buffer.
+ *
+ * So during pushing data into store if any {@link org.apache.spark.SparkException} occur
+ * then all un-push messages or un-ack will be lost.
+ *
+ * To recover lost messages we are relying on pubsub
+ * (i.e after ack deadline passed then pubsub will again give that messages)
+ */
+ def push(): Unit = {
+
+ val diff = System.currentTimeMillis() - latestAttemptToPushInStoreTime
+ if (buffer.length >= blockSize || (buffer.length < blockSize && diff >= blockIntervalMs)) {
+
+ // grouping messages into complete and partial blocks (if any)
+ val (completeBlocks, partialBlock) = buffer.grouped(blockSize)
+ .partition(block => block.length == blockSize)
+
+ // If completeBlocks is empty it means within block interval time
+ // messages in buffer is less than blockSize. So will push partial block
+ val iterator = if (completeBlocks.nonEmpty) completeBlocks else partialBlock
+
+ // Will push partial block messages back to buffer if complete blocks formed
+ val partial = if (completeBlocks.nonEmpty && partialBlock.nonEmpty) {
+ partialBlock.next()
+ } else null
+
+ while (iterator.hasNext) {
+ try {
+ pushToStoreAndAck(iterator.next().toList)
+ } catch {
+ case e: SparkException => reportError(
+ "Failed to write messages into reliable store", e)
+ case NonFatal(e) => reportError(
+ "Failed to write messages in reliable store", e)
+ } finally {
+ latestAttemptToPushInStoreTime = System.currentTimeMillis()
+ }
+ }
+
+ // clear existing buffer messages
+ buffer.clear()
+
+ // Pushing partial block messages back to buffer if complete blocks formed
+ if (partial != null) buffer.appendAll(partial)
+ }
+ }
+
+ /**
+ * Push the list of received message into store and ack messages if auto ack is true
+ * @param receivedMessages
+ */
+ def pushToStoreAndAck(receivedMessages: List[ReceivedMessage]): Unit = {
+ val messages = receivedMessages
+ .map(x => {
+ val sm = new SparkPubsubMessage
+ sm.message = x.getMessage
+ sm.ackId = x.getAckId
+ sm})
+
+ rateLimiter.acquire(messages.size)
+ store(messages.toIterator)
+ if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId))
+ }
+
+ /**
+ * Acknowledge Message ackIds
+ * @param ackIds
+ */
+ def acknowledgeIds(ackIds: List[String]): Unit = {
+ val ackRequest = new AcknowledgeRequest()
+ ackRequest.setAckIds(ackIds.asJava)
+ client.projects().subscriptions()
+ .acknowledge(subscriptionFullName, ackRequest).execute()
+ }
+
+ private def createBufferArray(): ArrayBuffer[ReceivedMessage] = {
+ new ArrayBuffer[ReceivedMessage](2 * math.max(maxNoOfMessageInRequest, blockSize))
+ }
+
override def onStop(): Unit = {}
}
diff --git a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala
index 05214c34..c87cbcd9 100644
--- a/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala
+++ b/streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala
@@ -17,12 +17,15 @@
package org.apache.spark.streaming.pubsub
+import com.google.api.services.pubsub.Pubsub
+
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.apache.spark.streaming.dstream.ReceiverInputDStream
+
object PubsubUtils {
/**
@@ -50,7 +53,10 @@ object PubsubUtils {
subscription: String,
credentials: SparkGCPCredentials,
storageLevel: StorageLevel,
- autoAcknowledge: Boolean = true): ReceiverInputDStream[SparkPubsubMessage] = {
+ autoAcknowledge: Boolean = true,
+ maxNoOfMessageInRequest: Int = 1000,
+ rateMultiplierFactor: Double = 1.0,
+ endpoint: String = Pubsub.DEFAULT_ROOT_URL): ReceiverInputDStream[SparkPubsubMessage] = {
ssc.withNamedScope("pubsub stream") {
new PubsubInputDStream(
@@ -60,7 +66,12 @@ object PubsubUtils {
subscription,
credentials,
storageLevel,
- autoAcknowledge)
+ autoAcknowledge,
+ maxNoOfMessageInRequest,
+ rateMultiplierFactor,
+ endpoint,
+ ssc.conf
+ )
}
}
diff --git a/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala b/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala
index 8a67038d..64b3f0e7 100644
--- a/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala
+++ b/streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala
@@ -25,7 +25,7 @@ import scala.language.postfixOps
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
-import org.apache.spark.ConditionalSparkFunSuite
+import org.apache.spark.{ConditionalSparkFunSuite, SparkConf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
@@ -35,6 +35,8 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be
val batchDuration = Seconds(1)
+ val blockSize = 15
+
private val master: String = "local[2]"
private val appName: String = this.getClass.getSimpleName
@@ -70,6 +72,17 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be
}
}
+
+ def setSparkBackPressureConf(conf: SparkConf) : Unit = {
+ conf.set("spark.streaming.backpressure.enabled", "true")
+ conf.set("spark.streaming.backpressure.initialRate", "50")
+ conf.set("spark.streaming.receiver.maxRate", "100")
+ conf.set("spark.streaming.backpressure.pid.minRate", "10")
+ conf.set("spark.streaming.blockQueueSize", blockSize.toString)
+ conf.set("spark.streaming.blockInterval", "1000ms")
+ }
+
+
before {
ssc = new StreamingContext(master, appName, batchDuration)
}
@@ -113,6 +126,27 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be
sendReceiveMessages(receiveStream)
}
+ testIf("check block size", () => PubsubTestUtils.shouldRunTest()) {
+ setSparkBackPressureConf(ssc.sparkContext.conf)
+ val receiveStream = PubsubUtils.createStream(
+ ssc, PubsubTestUtils.projectId, Some(topicName), subForCreateName,
+ PubsubTestUtils.credential, StorageLevel.MEMORY_AND_DISK_SER_2, autoAcknowledge = true, 50)
+
+ @volatile var partitionSize: Set[Int] = Set[Int]()
+ receiveStream.foreachRDD(rdd => {
+ rdd.collectPartitions().foreach(partition => {
+ partitionSize += partition.length
+ })
+ })
+
+ ssc.start()
+
+ eventually(timeout(100000 milliseconds), interval(1000 milliseconds)) {
+ pubsubTestUtils.publishData(topicFullName, pubsubTestUtils.generatorMessages(100))
+ assert(partitionSize.max == blockSize)
+ }
+ }
+
private def sendReceiveMessages(receiveStream: ReceiverInputDStream[SparkPubsubMessage]): Unit = {
@volatile var receiveMessages: List[SparkPubsubMessage] = List()
receiveStream.foreachRDD { rdd =>