Skip to content

Commit

Permalink
Issue #205: toRedisByteLIST() (#216)
Browse files Browse the repository at this point in the history
* issue #205: a prototype of toRedisByteLIST() function

* issue #205: added doc
  • Loading branch information
fe2s authored Mar 8, 2020
1 parent 6a19452 commit ecd0c15
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 4 deletions.
9 changes: 9 additions & 0 deletions doc/rdd.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ sc.toRedisFixedLIST(listRDD, listName, listSize)
The `listRDD` is an RDD that contains all of the list's string elements in order, and `listName` is the list's key name.
`listSize` is an integer which specifies the size of the Redis list; it is optional, and will default to an unlimited size.

Use the following to store an RDD of binary values in a Redis List:

```scala
sc.toRedisByteLIST(byteListRDD)
```

The `byteListRDD` is an RDD of tuples (`list name`, `list values`) represented as byte arrays.


#### Sets
For storing data in a Redis Set, use `toRedisSET` as follows:

Expand Down
32 changes: 28 additions & 4 deletions src/main/scala/com/redislabs/provider/redis/RedisConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,25 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable {
}

/**
* @param key
* *IMPORTANT* Please remember to close after using
* @return jedis who is a connection for a given key
*
* @param key
* @return jedis that is a connection for a given key
*/
def connectionForKey(key: String): Jedis = {
getHost(key).connect()
}

/**
* *IMPORTANT* Please remember to close after using
*
* @param key
* @return jedis is a connection for a given key
*/
def connectionForKey(key: Array[Byte]): Jedis = {
getHost(key).connect()
}

/**
* @param initialHost any redis endpoint of a cluster or a single server
* @return true if the target server is in cluster mode
Expand All @@ -195,9 +206,22 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable {
*/
def getHost(key: String): RedisNode = {
val slot = JedisClusterCRC16.getSlot(key)
hosts.filter(host => {
getHostBySlot(slot)
}

/**
* @param key
* @return host whose slots should involve key
*/
def getHost(key: Array[Byte]): RedisNode = {
val slot = JedisClusterCRC16.getSlot(key)
getHostBySlot(slot)
}

private def getHostBySlot(slot: Int): RedisNode = {
hosts.filter { host =>
host.startSlot <= slot && host.endSlot >= slot
})(0)
}(0)
}


Expand Down
38 changes: 38 additions & 0 deletions src/main/scala/com/redislabs/provider/redis/redisFunctions.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.redislabs.provider.redis

import com.redislabs.provider.redis.rdd._
import com.redislabs.provider.redis.util.ConnectionUtils.withConnection
import com.redislabs.provider.redis.util.PipelineUtils._
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -299,6 +300,19 @@ class RedisContext(@transient val sc: SparkContext) extends Serializable {
vs.foreachPartition(partition => setList(listName, partition, ttl, redisConfig, readWriteConfig))
}

/**
* Write RDD of binary values to Redis List.
*
* @param rdd RDD of tuples (list name, list values)
* @param ttl time to live
*/
def toRedisByteLIST(rdd: RDD[(Array[Byte], Seq[Array[Byte]])], ttl: Int = 0)
(implicit
redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf),
readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) {
rdd.foreachPartition(partition => setList(partition, ttl, redisConfig, readWriteConfig))
}

/**
* @param vs RDD of values
* @param listName target list's name which hold all the vs
Expand Down Expand Up @@ -415,6 +429,30 @@ object RedisContext extends Serializable {
conn.close()
}


def setList(keyValues: Iterator[(Array[Byte], Seq[Array[Byte]])],
ttl: Int,
redisConfig: RedisConfig,
readWriteConfig: ReadWriteConfig) {
implicit val rwConf: ReadWriteConfig = readWriteConfig

keyValues
.map { case (key, listValues) =>
(redisConfig.getHost(key), (key, listValues))
}
.toArray
.groupBy(_._1)
.foreach { case (node, arr) =>
withConnection(node.endpoint.connect()) { conn =>
foreachWithPipeline(conn, arr) { (pipeline, a) =>
val (key, listVals) = a._2
pipeline.rpush(key, listVals: _*)
if (ttl > 0) pipeline.expire(key, ttl)
}
}
}
}

/**
* @param key
* @param listSize
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.redislabs.provider.redis.rdd

import com.redislabs.provider.redis.util.ConnectionUtils.withConnection
import com.redislabs.provider.redis.{RedisConfig, SparkRedisSuite, toRedisContext}
import org.scalatest.Matchers
import scala.collection.JavaConverters._

import scala.io.Source.fromInputStream

Expand Down Expand Up @@ -109,6 +111,27 @@ trait RedisRddSuite extends SparkRedisSuite with Keys with Matchers {
setContents should be(ws)
}

test("toRedisLIST, byte array") {
val list1 = Seq("a1", "b1", "c1")
val list2 = Seq("a2", "b2", "c2")
val keyValues = Seq(
("list1", list1),
("list2", list2)
)
val keyValueBytes = keyValues.map {case (k, list) => (k.getBytes, list.map(_.getBytes())) }
val rdd = sc.parallelize(keyValueBytes)
sc.toRedisByteLIST(rdd)

def verify(list: String, vals: Seq[String]): Unit = {
withConnection(redisConfig.getHost(list).endpoint.connect()) { conn =>
conn.lrange(list, 0, vals.size).asScala should be(vals.toList)
}
}

verify("list1", list1)
verify("list2", list2)
}

test("Expire") {
val expireTime = 1
val prefix = s"#expire in $expireTime#:"
Expand Down

0 comments on commit ecd0c15

Please sign in to comment.