Skip to content
This repository has been archived by the owner on Feb 16, 2024. It is now read-only.

[BAHIR-295] Added backpressure & ratelimit support #101

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@ package org.apache.spark.streaming.pubsub
import java.io.{Externalizable, ObjectInput, ObjectOutput}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport
import com.google.api.client.googleapis.json.GoogleJsonResponseException
import com.google.api.client.json.jackson2.JacksonFactory
import com.google.api.services.pubsub.Pubsub.Builder
import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage, PullRequest}
import com.google.api.services.pubsub.model.Subscription
import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage, PullRequest, ReceivedMessage, Subscription}
import com.google.cloud.hadoop.util.RetryHttpInitializer
import com.google.common.util.concurrent.RateLimiter

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream.ReceiverInputDStream
Expand All @@ -51,11 +54,18 @@ class PubsubInputDStream(
val subscription: String,
val credential: SparkGCPCredentials,
val _storageLevel: StorageLevel,
val autoAcknowledge: Boolean
val autoAcknowledge: Boolean,
val maxNoOfMessageInRequest: Int,
val rateMultiplierFactor: Double,
val endpoint: String,
conf: SparkConf
) extends ReceiverInputDStream[SparkPubsubMessage](_ssc) {

override def getReceiver(): Receiver[SparkPubsubMessage] = {
new PubsubReceiver(project, topic, subscription, credential, _storageLevel, autoAcknowledge)
new PubsubReceiver(
project, topic, subscription, credential, _storageLevel, autoAcknowledge,
maxNoOfMessageInRequest, rateMultiplierFactor, endpoint, conf
)
}
}

Expand Down Expand Up @@ -214,31 +224,74 @@ object ConnectionUtils {
}
}


/**
* Custom spark receiver to pull messages from Pubsub topic and push into reliable store.
* If backpressure is enabled,the message ingestion rate for this receiver will be managed by Spark.
*
* Following spark configurations can be used to control rates and block size
* <i>spark.streaming.backpressure.initialRate</i>
* <i>spark.streaming.receiver.maxRate</i>
* <i>spark.streaming.blockQueueSize</i>: Controlling block size
* <i>spark.streaming.backpressure.pid.minRate</i>
*
* See Spark streaming configurations doc
* <a href="https://spark.apache.org/docs/latest/configuration.html#spark-streaming</a>
*
* NOTE: For given subscription assuming ackDeadlineSeconds is sufficient.
* So that messages will not expire if it is buffer for given blockIntervalMs
*
* @param project Google cloud project id
* @param topic Topic name for creating subscription if need
* @param subscription Pub/Sub subscription name
* @param credential Google cloud project credential to access Pub/Sub service
* @param storageLevel Storage level to be used
* @param autoAcknowledge Acknowledge pubsub message or not
* @param maxNoOfMessageInRequest Maximum number of message in a Pubsub pull request
* @param rateMultiplierFactor Increase the proposed rate estimated by PIDEstimator to take the
* advantage of dynamic allocation of executor.
* Default should be 1 if dynamic allocation is not enabled
* @param endpoint Pubsub service endpoint
* @param conf Spark config
*/
private[pubsub]
class PubsubReceiver(
project: String,
topic: Option[String],
subscription: String,
credential: SparkGCPCredentials,
storageLevel: StorageLevel,
autoAcknowledge: Boolean)
extends Receiver[SparkPubsubMessage](storageLevel) {
autoAcknowledge: Boolean,
maxNoOfMessageInRequest: Int,
rateMultiplierFactor: Double,
endpoint: String,
conf: SparkConf)
extends Receiver[SparkPubsubMessage](storageLevel) with Logging {

val APP_NAME = "sparkstreaming-pubsub-receiver"

val INIT_BACKOFF = 100 // 100ms

val MAX_BACKOFF = 10 * 1000 // 10s

val MAX_MESSAGE = 1000
val maxRateLimit: Long = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue)

val blockSize: Int = conf.getInt("spark.streaming.blockQueueSize", maxNoOfMessageInRequest)

val blockIntervalMs: Long = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")

var buffer: ArrayBuffer[ReceivedMessage] = createBufferArray()

var latestAttemptToPushInStoreTime: Long = -1

lazy val rateLimiter: RateLimiter = RateLimiter.create(getInitialRateLimit.toDouble)

lazy val client = new Builder(
ConnectionUtils.transport,
ConnectionUtils.jacksonFactory,
new RetryHttpInitializer(credential.provider, APP_NAME))
.setApplicationName(APP_NAME)
.build()
.setApplicationName(APP_NAME)
.setRootUrl(endpoint)
.build()

val projectFullName: String = s"projects/$project"
val subscriptionFullName: String = s"$projectFullName/subscriptions/$subscription"
Expand All @@ -248,6 +301,7 @@ class PubsubReceiver(
case Some(t) =>
val sub: Subscription = new Subscription
sub.setTopic(s"$projectFullName/topics/$t")
sub.setAckDeadlineSeconds(30)
try {
client.projects().subscriptions().create(subscriptionFullName, sub).execute()
} catch {
Expand All @@ -262,6 +316,7 @@ class PubsubReceiver(
}
case None => // do nothing
}

new Thread() {
override def run() {
receive()
Expand All @@ -270,30 +325,31 @@ class PubsubReceiver(
}

def receive(): Unit = {
val pullRequest = new PullRequest().setMaxMessages(MAX_MESSAGE).setReturnImmediately(false)
val pullRequest = new PullRequest()
.setMaxMessages(maxNoOfMessageInRequest).setReturnImmediately(false)
var backoff = INIT_BACKOFF

// To avoid the edge case when buffer is not full and no message pushed to store
latestAttemptToPushInStoreTime = System.currentTimeMillis()

while (!isStopped()) {
try {

val pullResponse =
client.projects().subscriptions().pull(subscriptionFullName, pullRequest).execute()
val receivedMessages = pullResponse.getReceivedMessages

// update rate limit if required
updateRateLimit()

// Put data into buffer
if (receivedMessages != null) {
store(receivedMessages.asScala.toList
.map(x => {
val sm = new SparkPubsubMessage
sm.message = x.getMessage
sm.ackId = x.getAckId
sm
})
.iterator)

if (autoAcknowledge) {
val ackRequest = new AcknowledgeRequest()
ackRequest.setAckIds(receivedMessages.asScala.map(x => x.getAckId).asJava)
client.projects().subscriptions().acknowledge(subscriptionFullName,
ackRequest).execute()
}
buffer.appendAll(receivedMessages.asScala)
}

// Push data from buffer to store
push()

backoff = INIT_BACKOFF
} catch {
case e: GoogleJsonResponseException =>
Expand All @@ -308,5 +364,109 @@ class PubsubReceiver(
}
}

def getInitialRateLimit: Long = {
math.min(
conf.getLong("spark.streaming.backpressure.initialRate", maxRateLimit),
maxRateLimit
)
}

/**
* Get the new recommended rate at which receiver should push data into store
* and update the rate limiter with new rate
*/
def updateRateLimit(): Unit = {
val newRateLimit = rateMultiplierFactor * supervisor.getCurrentRateLimit.min(maxRateLimit)
if (rateLimiter.getRate != newRateLimit) {
rateLimiter.setRate(newRateLimit)
logInfo("New rateLimit:: " + newRateLimit)
}
}

/**
* Push data into store if
* 1. buffer size greater than equal to blockSize, or
* 2. blockInterval time is passed and buffer size is less than blockSize
*
* Before pushing the messages, first create iterator of complete block(s) and partial blocks
* and assigning new array to buffer.
*
* So during pushing data into store if any {@link org.apache.spark.SparkException} occur
* then all un-push messages or un-ack will be lost.
*
* To recover lost messages we are relying on pubsub
* (i.e after ack deadline passed then pubsub will again give that messages)
*/
def push(): Unit = {

val diff = System.currentTimeMillis() - latestAttemptToPushInStoreTime
if (buffer.length >= blockSize || (buffer.length < blockSize && diff >= blockIntervalMs)) {

// grouping messages into complete and partial blocks (if any)
val (completeBlocks, partialBlock) = buffer.grouped(blockSize)
.partition(block => block.length == blockSize)

// If completeBlocks is empty it means within block interval time
// messages in buffer is less than blockSize. So will push partial block
val iterator = if (completeBlocks.nonEmpty) completeBlocks else partialBlock

// Will push partial block messages back to buffer if complete blocks formed
val partial = if (completeBlocks.nonEmpty && partialBlock.nonEmpty) {
partialBlock.next()
} else null

while (iterator.hasNext) {
try {
pushToStoreAndAck(iterator.next().toList)
} catch {
case e: SparkException => reportError(
"Failed to write messages into reliable store", e)
case NonFatal(e) => reportError(
"Failed to write messages in reliable store", e)
} finally {
latestAttemptToPushInStoreTime = System.currentTimeMillis()
}
}

// clear existing buffer messages
buffer.clear()

// Pushing partial block messages back to buffer if complete blocks formed
if (partial != null) buffer.appendAll(partial)
}
}

/**
* Push the list of received message into store and ack messages if auto ack is true
* @param receivedMessages
*/
def pushToStoreAndAck(receivedMessages: List[ReceivedMessage]): Unit = {
val messages = receivedMessages
.map(x => {
val sm = new SparkPubsubMessage
sm.message = x.getMessage
sm.ackId = x.getAckId
sm})

rateLimiter.acquire(messages.size)
store(messages.toIterator)
if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId))
}

/**
* Acknowledge Message ackIds
* @param ackIds
*/
def acknowledgeIds(ackIds: List[String]): Unit = {
val ackRequest = new AcknowledgeRequest()
ackRequest.setAckIds(ackIds.asJava)
client.projects().subscriptions()
.acknowledge(subscriptionFullName, ackRequest).execute()
}

private def createBufferArray(): ArrayBuffer[ReceivedMessage] = {
new ArrayBuffer[ReceivedMessage](2 * math.max(maxNoOfMessageInRequest, blockSize))
}

override def onStop(): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.streaming.pubsub

import com.google.api.services.pubsub.Pubsub

import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.apache.spark.streaming.dstream.ReceiverInputDStream


object PubsubUtils {

/**
Expand Down Expand Up @@ -50,7 +53,10 @@ object PubsubUtils {
subscription: String,
credentials: SparkGCPCredentials,
storageLevel: StorageLevel,
autoAcknowledge: Boolean = true): ReceiverInputDStream[SparkPubsubMessage] = {
autoAcknowledge: Boolean = true,
maxNoOfMessageInRequest: Int = 1000,
rateMultiplierFactor: Double = 1.0,
endpoint: String = Pubsub.DEFAULT_ROOT_URL): ReceiverInputDStream[SparkPubsubMessage] = {
ssc.withNamedScope("pubsub stream") {

new PubsubInputDStream(
Expand All @@ -60,7 +66,12 @@ object PubsubUtils {
subscription,
credentials,
storageLevel,
autoAcknowledge)
autoAcknowledge,
maxNoOfMessageInRequest,
rateMultiplierFactor,
endpoint,
ssc.conf
)
}
}

Expand Down
Loading