From ff1c7d3a02196341dc4d47c299c5318cf27915c0 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Thu, 19 Jan 2023 12:39:35 +0200 Subject: [PATCH] Refactor string DMA (#263) * refactor string DMA * add append * add DerefMut & Deref support --- Cargo.toml | 4 ++ examples/hello.rs | 1 + examples/string.rs | 46 +++++++++++++++++ src/key.rs | 114 +++++++++++++++++++++++++++++++------------ src/lib.rs | 14 ------ src/raw.rs | 7 ++- tests/integration.rs | 20 ++++++++ 7 files changed, 161 insertions(+), 45 deletions(-) create mode 100644 examples/string.rs diff --git a/Cargo.toml b/Cargo.toml index abfbcd78..bc219f15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,10 @@ categories = ["database", "api-bindings"] name = "hello" crate-type = ["cdylib"] +[[example]] +name = "string" +crate-type = ["cdylib"] + [[example]] name = "keys_pos" crate-type = ["cdylib"] diff --git a/examples/hello.rs b/examples/hello.rs index 55db1667..e31d6faf 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -2,6 +2,7 @@ extern crate redis_module; use redis_module::{Context, RedisError, RedisResult, RedisString}; + fn hello_mul(_: &Context, args: Vec) -> RedisResult { if args.len() < 2 { return Err(RedisError::WrongArity); diff --git a/examples/string.rs b/examples/string.rs new file mode 100644 index 00000000..421c6c17 --- /dev/null +++ b/examples/string.rs @@ -0,0 +1,46 @@ +#[macro_use] +extern crate redis_module; + +use redis_module::{Context, NextArg, RedisError, RedisResult, RedisString, RedisValue}; + +fn string_set(ctx: &Context, args: Vec) -> RedisResult { + if args.len() < 3 { + return Err(RedisError::WrongArity); + } + + let mut args = args.into_iter().skip(1); + let key_name = args.next_arg()?; + let value = args.next_arg()?; + + let key = ctx.open_key_writable(&key_name); + let mut dma = key.as_string_dma()?; + dma.write(value.as_slice()) + .map(|_| RedisValue::SimpleStringStatic("OK")) +} + +fn string_get(ctx: &Context, args: Vec) -> RedisResult { + if args.len() < 2 { + return Err(RedisError::WrongArity); + } + + let mut args = args.into_iter().skip(1); + let key_name = args.next_arg()?; + + let key = ctx.open_key(&key_name); + let res = key + .read()? + .map_or(RedisValue::Null, |v| RedisValue::StringBuffer(Vec::from(v))); + Ok(res) +} + +////////////////////////////////////////////////////// + +redis_module! { + name: "string", + version: 1, + data_types: [], + commands: [ + ["string.set", string_set, "", 1, 1, 1], + ["string.get", string_get, "", 1, 1, 1], + ], +} diff --git a/src/key.rs b/src/key.rs index ef18f9a9..ee20e8d8 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,14 +1,14 @@ use std::convert::TryFrom; +use std::ops::Deref; +use std::ops::DerefMut; use std::os::raw::c_void; use std::ptr; -use std::str::Utf8Error; use std::time::Duration; use libc::size_t; use raw::KeyType; -use crate::from_byte_string; use crate::native_types::RedisType; use crate::raw; use crate::redismodule::REDIS_OK; @@ -75,13 +75,20 @@ impl RedisKey { self.key_inner == null_key } - pub fn read(&self) -> Result, RedisError> { - let val = if self.is_null() { - None + pub fn read(&self) -> Result, RedisError> { + if self.is_null() { + Ok(None) } else { - Some(read_key(self.key_inner)?) - }; - Ok(val) + let mut length: size_t = 0; + let dma = raw::string_dma(self.key_inner, &mut length, raw::KeyMode::READ); + if dma.is_null() { + Err(RedisError::Str("Could not read key")) + } else { + Ok(Some(unsafe { + std::slice::from_raw_parts(dma.cast::(), length) + })) + } + } } pub fn hash_get(&self, field: &str) -> Result, RedisError> { @@ -144,20 +151,15 @@ impl RedisKeyWritable { /// as you open the key in read mode, but when asking for write Redis /// returns a non-null pointer to allow us to write to even an empty key, /// so we have to check the key's value instead. - /* - fn is_empty_old(&self) -> Result { - match self.read()? { - Some(s) => match s.as_str() { - "" => Ok(true), - _ => Ok(false), - }, - _ => Ok(false), - } - } - */ - - pub fn read(&self) -> Result, RedisError> { - Ok(Some(read_key(self.key_inner)?)) + /// + /// ``` + /// fn is_empty_old(key: &RedisKeyWritable) -> Result { + /// let s = key.as_string_dma(); + /// s.write(b"new value")?; + /// } + /// ``` + pub fn as_string_dma(&self) -> Result { + StringDMA::new(self) } #[allow(clippy::must_use_candidate)] @@ -437,6 +439,66 @@ where } } +pub struct StringDMA<'a> { + key: &'a RedisKeyWritable, + buffer: &'a mut [u8], +} + +impl<'a> Deref for StringDMA<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.buffer + } +} + +impl<'a> DerefMut for StringDMA<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.buffer + } +} + +impl<'a> StringDMA<'a> { + fn new(key: &'a RedisKeyWritable) -> Result, RedisError> { + let mut length: size_t = 0; + let dma = raw::string_dma(key.key_inner, &mut length, raw::KeyMode::WRITE); + if dma.is_null() { + Err(RedisError::Str("Could not read key")) + } else { + let buffer = unsafe { std::slice::from_raw_parts_mut(dma.cast::(), length) }; + Ok(StringDMA { key, buffer }) + } + } + + pub fn write(&mut self, data: &[u8]) -> Result<&mut Self, RedisError> { + if self.buffer.len() != data.len() { + if raw::Status::Ok == raw::string_truncate(self.key.key_inner, data.len()) { + let mut length: size_t = 0; + let dma = raw::string_dma(self.key.key_inner, &mut length, raw::KeyMode::WRITE); + self.buffer = unsafe { std::slice::from_raw_parts_mut(dma.cast::(), length) }; + } else { + return Err(RedisError::Str("Failed to truncate string")); + } + } + self.buffer[..data.len()].copy_from_slice(data); + Ok(self) + } + + pub fn append(&mut self, data: &[u8]) -> Result<&mut Self, RedisError> { + let current_len = self.buffer.len(); + let new_len = current_len + data.len(); + if raw::Status::Ok == raw::string_truncate(self.key.key_inner, new_len) { + let mut length: size_t = 0; + let dma = raw::string_dma(self.key.key_inner, &mut length, raw::KeyMode::WRITE); + self.buffer = unsafe { std::slice::from_raw_parts_mut(dma.cast::(), length) }; + } else { + return Err(RedisError::Str("Failed to truncate string")); + } + self.buffer[current_len..new_len].copy_from_slice(data); + Ok(self) + } +} + impl From for Result<(), RedisError> { fn from(s: raw::Status) -> Self { match s { @@ -453,14 +515,6 @@ impl Drop for RedisKeyWritable { } } -fn read_key(key: *mut raw::RedisModuleKey) -> Result { - let mut length: size_t = 0; - from_byte_string( - raw::string_dma(key, &mut length, raw::KeyMode::READ), - length, - ) -} - /// Get an arbitrary number of hash fields from a key by batching calls /// to `raw::hash_get_multi`. fn hash_mget_key( diff --git a/src/lib.rs b/src/lib.rs index 61c89025..75893f9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,9 @@ //#![allow(dead_code)] pub use crate::context::InfoContext; -use std::os::raw::c_char; -use std::str::Utf8Error; use strum_macros::AsRefStr; extern crate num_traits; -use libc::size_t; - pub mod alloc; pub mod error; pub mod native_types; @@ -52,16 +48,6 @@ pub enum LogLevel { Warning, } -fn from_byte_string(byte_str: *const c_char, length: size_t) -> Result { - let mut vec_str: Vec = Vec::with_capacity(length as usize); - for j in 0..length { - let byte = unsafe { *byte_str.add(j) } as u8; - vec_str.insert(j, byte); - } - - String::from_utf8(vec_str).map_err(|e| e.utf8_error()) -} - pub fn base_info_func( ctx: &InfoContext, for_crash_report: bool, diff --git a/src/raw.rs b/src/raw.rs index ae122442..b7219e91 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -280,10 +280,15 @@ pub fn set_expire(key: *mut RedisModuleKey, expire: c_longlong) -> Status { } #[allow(clippy::not_unsafe_ptr_arg_deref)] -pub fn string_dma(key: *mut RedisModuleKey, len: *mut size_t, mode: KeyMode) -> *const c_char { +pub fn string_dma(key: *mut RedisModuleKey, len: *mut size_t, mode: KeyMode) -> *mut c_char { unsafe { RedisModule_StringDMA.unwrap()(key, len, mode.bits) } } +#[allow(clippy::not_unsafe_ptr_arg_deref)] +pub fn string_truncate(key: *mut RedisModuleKey, new_len: size_t) -> Status { + unsafe { RedisModule_StringTruncate.unwrap()(key, new_len).into() } +} + #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn hash_get_multi( key: *mut RedisModuleKey, diff --git a/tests/integration.rs b/tests/integration.rs index e6ad5404..8847625d 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -148,3 +148,23 @@ fn test_test_helper_err() -> Result<()> { Ok(()) } + +#[test] +fn test_string() -> Result<()> { + let port: u16 = 6485; + let _guards = vec![start_redis_server_with_module("string", port) + .with_context(|| "failed to start redis server")?]; + let mut con = + get_redis_connection(port).with_context(|| "failed to connect to redis server")?; + + redis::cmd("string.set") + .arg(&["key", "value"]) + .query(&mut con) + .with_context(|| "failed to run string.set")?; + + let res: String = redis::cmd("string.get").arg(&["key"]).query(&mut con)?; + + assert_eq!(&res, "value"); + + Ok(()) +}