Skip to content

Commit

Permalink
Refactor string DMA (#263)
Browse files Browse the repository at this point in the history
* refactor string DMA
* add append
* add DerefMut & Deref support
  • Loading branch information
gkorland authored Jan 19, 2023
1 parent 3c5862f commit ff1c7d3
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 45 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ categories = ["database", "api-bindings"]
name = "hello"
crate-type = ["cdylib"]

[[example]]
name = "string"
crate-type = ["cdylib"]

[[example]]
name = "keys_pos"
crate-type = ["cdylib"]
Expand Down
1 change: 1 addition & 0 deletions examples/hello.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
extern crate redis_module;

use redis_module::{Context, RedisError, RedisResult, RedisString};

fn hello_mul(_: &Context, args: Vec<RedisString>) -> RedisResult {
if args.len() < 2 {
return Err(RedisError::WrongArity);
Expand Down
46 changes: 46 additions & 0 deletions examples/string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#[macro_use]
extern crate redis_module;

use redis_module::{Context, NextArg, RedisError, RedisResult, RedisString, RedisValue};

fn string_set(ctx: &Context, args: Vec<RedisString>) -> RedisResult {
if args.len() < 3 {
return Err(RedisError::WrongArity);
}

let mut args = args.into_iter().skip(1);
let key_name = args.next_arg()?;
let value = args.next_arg()?;

let key = ctx.open_key_writable(&key_name);
let mut dma = key.as_string_dma()?;
dma.write(value.as_slice())
.map(|_| RedisValue::SimpleStringStatic("OK"))
}

fn string_get(ctx: &Context, args: Vec<RedisString>) -> RedisResult {
if args.len() < 2 {
return Err(RedisError::WrongArity);
}

let mut args = args.into_iter().skip(1);
let key_name = args.next_arg()?;

let key = ctx.open_key(&key_name);
let res = key
.read()?
.map_or(RedisValue::Null, |v| RedisValue::StringBuffer(Vec::from(v)));
Ok(res)
}

//////////////////////////////////////////////////////

redis_module! {
name: "string",
version: 1,
data_types: [],
commands: [
["string.set", string_set, "", 1, 1, 1],
["string.get", string_get, "", 1, 1, 1],
],
}
114 changes: 84 additions & 30 deletions src/key.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::convert::TryFrom;
use std::ops::Deref;
use std::ops::DerefMut;
use std::os::raw::c_void;
use std::ptr;
use std::str::Utf8Error;
use std::time::Duration;

use libc::size_t;

use raw::KeyType;

use crate::from_byte_string;
use crate::native_types::RedisType;
use crate::raw;
use crate::redismodule::REDIS_OK;
Expand Down Expand Up @@ -75,13 +75,20 @@ impl RedisKey {
self.key_inner == null_key
}

pub fn read(&self) -> Result<Option<String>, RedisError> {
let val = if self.is_null() {
None
pub fn read(&self) -> Result<Option<&[u8]>, RedisError> {
if self.is_null() {
Ok(None)
} else {
Some(read_key(self.key_inner)?)
};
Ok(val)
let mut length: size_t = 0;
let dma = raw::string_dma(self.key_inner, &mut length, raw::KeyMode::READ);
if dma.is_null() {
Err(RedisError::Str("Could not read key"))
} else {
Ok(Some(unsafe {
std::slice::from_raw_parts(dma.cast::<u8>(), length)
}))
}
}
}

pub fn hash_get(&self, field: &str) -> Result<Option<RedisString>, RedisError> {
Expand Down Expand Up @@ -144,20 +151,15 @@ impl RedisKeyWritable {
/// as you open the key in read mode, but when asking for write Redis
/// returns a non-null pointer to allow us to write to even an empty key,
/// so we have to check the key's value instead.
/*
fn is_empty_old(&self) -> Result<bool, Error> {
match self.read()? {
Some(s) => match s.as_str() {
"" => Ok(true),
_ => Ok(false),
},
_ => Ok(false),
}
}
*/

pub fn read(&self) -> Result<Option<String>, RedisError> {
Ok(Some(read_key(self.key_inner)?))
///
/// ```
/// fn is_empty_old(key: &RedisKeyWritable) -> Result<bool, Error> {
/// let s = key.as_string_dma();
/// s.write(b"new value")?;
/// }
/// ```
pub fn as_string_dma(&self) -> Result<StringDMA, RedisError> {
StringDMA::new(self)
}

#[allow(clippy::must_use_candidate)]
Expand Down Expand Up @@ -437,6 +439,66 @@ where
}
}

pub struct StringDMA<'a> {
key: &'a RedisKeyWritable,
buffer: &'a mut [u8],
}

impl<'a> Deref for StringDMA<'a> {
type Target = [u8];

fn deref(&self) -> &Self::Target {
self.buffer
}
}

impl<'a> DerefMut for StringDMA<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer
}
}

impl<'a> StringDMA<'a> {
fn new(key: &'a RedisKeyWritable) -> Result<StringDMA<'a>, RedisError> {
let mut length: size_t = 0;
let dma = raw::string_dma(key.key_inner, &mut length, raw::KeyMode::WRITE);
if dma.is_null() {
Err(RedisError::Str("Could not read key"))
} else {
let buffer = unsafe { std::slice::from_raw_parts_mut(dma.cast::<u8>(), length) };
Ok(StringDMA { key, buffer })
}
}

pub fn write(&mut self, data: &[u8]) -> Result<&mut Self, RedisError> {
if self.buffer.len() != data.len() {
if raw::Status::Ok == raw::string_truncate(self.key.key_inner, data.len()) {
let mut length: size_t = 0;
let dma = raw::string_dma(self.key.key_inner, &mut length, raw::KeyMode::WRITE);
self.buffer = unsafe { std::slice::from_raw_parts_mut(dma.cast::<u8>(), length) };
} else {
return Err(RedisError::Str("Failed to truncate string"));
}
}
self.buffer[..data.len()].copy_from_slice(data);
Ok(self)
}

pub fn append(&mut self, data: &[u8]) -> Result<&mut Self, RedisError> {
let current_len = self.buffer.len();
let new_len = current_len + data.len();
if raw::Status::Ok == raw::string_truncate(self.key.key_inner, new_len) {
let mut length: size_t = 0;
let dma = raw::string_dma(self.key.key_inner, &mut length, raw::KeyMode::WRITE);
self.buffer = unsafe { std::slice::from_raw_parts_mut(dma.cast::<u8>(), length) };
} else {
return Err(RedisError::Str("Failed to truncate string"));
}
self.buffer[current_len..new_len].copy_from_slice(data);
Ok(self)
}
}

impl From<raw::Status> for Result<(), RedisError> {
fn from(s: raw::Status) -> Self {
match s {
Expand All @@ -453,14 +515,6 @@ impl Drop for RedisKeyWritable {
}
}

fn read_key(key: *mut raw::RedisModuleKey) -> Result<String, Utf8Error> {
let mut length: size_t = 0;
from_byte_string(
raw::string_dma(key, &mut length, raw::KeyMode::READ),
length,
)
}

/// Get an arbitrary number of hash fields from a key by batching calls
/// to `raw::hash_get_multi`.
fn hash_mget_key<T>(
Expand Down
14 changes: 0 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
//#![allow(dead_code)]

pub use crate::context::InfoContext;
use std::os::raw::c_char;
use std::str::Utf8Error;
use strum_macros::AsRefStr;
extern crate num_traits;

use libc::size_t;

pub mod alloc;
pub mod error;
pub mod native_types;
Expand Down Expand Up @@ -52,16 +48,6 @@ pub enum LogLevel {
Warning,
}

fn from_byte_string(byte_str: *const c_char, length: size_t) -> Result<String, Utf8Error> {
let mut vec_str: Vec<u8> = Vec::with_capacity(length as usize);
for j in 0..length {
let byte = unsafe { *byte_str.add(j) } as u8;
vec_str.insert(j, byte);
}

String::from_utf8(vec_str).map_err(|e| e.utf8_error())
}

pub fn base_info_func(
ctx: &InfoContext,
for_crash_report: bool,
Expand Down
7 changes: 6 additions & 1 deletion src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,15 @@ pub fn set_expire(key: *mut RedisModuleKey, expire: c_longlong) -> Status {
}

#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn string_dma(key: *mut RedisModuleKey, len: *mut size_t, mode: KeyMode) -> *const c_char {
pub fn string_dma(key: *mut RedisModuleKey, len: *mut size_t, mode: KeyMode) -> *mut c_char {
unsafe { RedisModule_StringDMA.unwrap()(key, len, mode.bits) }
}

#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn string_truncate(key: *mut RedisModuleKey, new_len: size_t) -> Status {
unsafe { RedisModule_StringTruncate.unwrap()(key, new_len).into() }
}

#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn hash_get_multi<T>(
key: *mut RedisModuleKey,
Expand Down
20 changes: 20 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,23 @@ fn test_test_helper_err() -> Result<()> {

Ok(())
}

#[test]
fn test_string() -> Result<()> {
let port: u16 = 6485;
let _guards = vec![start_redis_server_with_module("string", port)
.with_context(|| "failed to start redis server")?];
let mut con =
get_redis_connection(port).with_context(|| "failed to connect to redis server")?;

redis::cmd("string.set")
.arg(&["key", "value"])
.query(&mut con)
.with_context(|| "failed to run string.set")?;

let res: String = redis::cmd("string.get").arg(&["key"]).query(&mut con)?;

assert_eq!(&res, "value");

Ok(())
}

0 comments on commit ff1c7d3

Please sign in to comment.