Skip to content

Commit

Permalink
ParamSerializer
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Sep 24, 2024
1 parent 70aa220 commit 7e3aa8f
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
request_body::RequestBody,
response::Response,
row::Row,
sql::{Bind, SqlBuilder},
sql::{Bind, SqlBuilder, ser},
Client,
};

Expand Down Expand Up @@ -196,9 +196,9 @@ impl Query {
self
}

pub fn with_param(self, name: &str, value: impl Bind + Serialize) -> Result<Self, String> {
pub fn with_param<T>(self, name: &str, value: T) -> Result<Self, String> where T: Serialize {
let mut param = String::from("");
Bind::write(&value, &mut param)?;
ser::write_param(&mut param, &value)?;
Ok(self.with_option(format!("param_{name}"), param))
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier};

mod bind;
pub(crate) mod escape;
mod ser;
pub(crate) mod ser;

#[derive(Debug, Clone)]
pub(crate) enum SqlBuilder {
Expand Down
192 changes: 176 additions & 16 deletions src/sql/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,32 @@ use thiserror::Error;

use super::escape;

// === SqlSerializerError ===
// === SerializerError ===

#[derive(Debug, Error)]
enum SqlSerializerError {
enum SerializerError {
#[error("{0} is unsupported")]
Unsupported(&'static str),
#[error("{0}")]
Custom(String),
}

impl ser::Error for SqlSerializerError {
impl ser::Error for SerializerError {
fn custom<T: fmt::Display>(msg: T) -> Self {
Self::Custom(msg.to_string())
}
}

impl From<fmt::Error> for SqlSerializerError {
impl From<fmt::Error> for SerializerError {
fn from(err: fmt::Error) -> Self {
Self::Custom(err.to_string())
}
}

// === SqlSerializer ===

type Result<T = (), E = SqlSerializerError> = std::result::Result<T, E>;
type Impossible = ser::Impossible<(), SqlSerializerError>;
type Result<T = (), E = SerializerError> = std::result::Result<T, E>;
type Impossible = ser::Impossible<(), SerializerError>;

struct SqlSerializer<'a, W> {
writer: &'a mut W,
Expand All @@ -43,7 +43,7 @@ macro_rules! unsupported {
($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => {
#[inline]
fn $ser_method(self, _v: $ty) -> $ret {
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
Err(SerializerError::Unsupported(stringify!($ser_method)))
}
unsupported!($($other)*);
};
Expand All @@ -53,7 +53,7 @@ macro_rules! unsupported {
($ser_method:ident, $($other:tt)*) => {
#[inline]
fn $ser_method(self) -> Result {
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
Err(SerializerError::Unsupported(stringify!($ser_method)))
}
unsupported!($($other)*);
};
Expand All @@ -73,7 +73,7 @@ macro_rules! forward_to_display {
}

impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = SqlListSerializer<'a, W>;
Expand Down Expand Up @@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_value: &T,
) -> Result {
Err(SqlSerializerError::Unsupported("serialize_newtype_variant"))
Err(SerializerError::Unsupported("serialize_newtype_variant"))
}

#[inline]
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
Err(SqlSerializerError::Unsupported("serialize_tuple_struct"))
Err(SerializerError::Unsupported("serialize_tuple_struct"))
}

#[inline]
Expand All @@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_len: usize,
) -> Result<Impossible> {
Err(SqlSerializerError::Unsupported("serialize_tuple_variant"))
Err(SerializerError::Unsupported("serialize_tuple_variant"))
}

#[inline]
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
Err(SqlSerializerError::Unsupported("serialize_struct"))
Err(SerializerError::Unsupported("serialize_struct"))
}

#[inline]
Expand All @@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(SqlSerializerError::Unsupported("serialize_struct_variant"))
Err(SerializerError::Unsupported("serialize_struct_variant"))
}

#[inline]
Expand All @@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> {
}

impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();

#[inline]
Expand All @@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
}

impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();

#[inline]
Expand All @@ -271,6 +271,160 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
}
}

// === ParamSerializer ===

struct ParamSerializer<'a, W> {
writer: &'a mut W,
}

impl<'a, W: Write> Serializer for ParamSerializer<'a, W> {
type Error = SerializerError;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = SqlListSerializer<'a, W>;
type SerializeStruct = Impossible;
type SerializeStructVariant = Impossible;
type SerializeTuple = SqlListSerializer<'a, W>;
type SerializeTupleStruct = Impossible;
type SerializeTupleVariant = Impossible;

unsupported!(
serialize_map(Option<usize>) -> Result<Impossible>,
serialize_bytes(&[u8]),
serialize_unit,
serialize_unit_struct(&'static str),
);

forward_to_display!(
serialize_i8(i8),
serialize_i16(i16),
serialize_i32(i32),
serialize_i64(i64),
serialize_i128(i128),
serialize_u8(u8),
serialize_u16(u16),
serialize_u32(u32),
serialize_u64(u64),
serialize_u128(u128),
serialize_f32(f32),
serialize_f64(f64),
serialize_bool(bool),
);

#[inline]
fn serialize_char(self, value: char) -> Result {
let mut tmp = [0u8; 4];
self.serialize_str(value.encode_utf8(&mut tmp))
}

#[inline]
fn serialize_str(self, value: &str) -> Result {
// ClickHouse expects strings in params to be unquoted until inside a nested type
// nested types go through serialize_seq which'll quote strings
self.writer.write_str(value)?;
Ok(())
}

#[inline]
fn serialize_seq(self, _len: Option<usize>) -> Result<SqlListSerializer<'a, W>> {
self.writer.write_char('[')?;
Ok(SqlListSerializer {
writer: self.writer,
has_items: false,
closing_char: ']',
})
}

#[inline]
fn serialize_tuple(self, _len: usize) -> Result<SqlListSerializer<'a, W>> {
self.writer.write_char('(')?;
Ok(SqlListSerializer {
writer: self.writer,
has_items: false,
closing_char: ')',
})
}

#[inline]
fn serialize_some<T: Serialize + ?Sized>(self, _value: &T) -> Result {
_value.serialize(self)
}

#[inline]
fn serialize_none(self) -> std::result::Result<Self::Ok, Self::Error> {
self.writer.write_str("NULL")?;
Ok(())
}

#[inline]
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
) -> Result {
escape::string(variant, self.writer)?;
Ok(())
}

#[inline]
fn serialize_newtype_struct<T: Serialize + ?Sized>(
self,
_name: &'static str,
value: &T,
) -> Result {
value.serialize(self)
}

#[inline]
fn serialize_newtype_variant<T: Serialize + ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result {
Err(SerializerError::Unsupported("serialize_newtype_variant"))
}

#[inline]
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
Err(SerializerError::Unsupported("serialize_tuple_struct"))
}

#[inline]
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Impossible> {
Err(SerializerError::Unsupported("serialize_tuple_variant"))
}

#[inline]
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
Err(SerializerError::Unsupported("serialize_struct"))
}

#[inline]
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(SerializerError::Unsupported("serialize_struct_variant"))
}

#[inline]
fn is_human_readable(&self) -> bool {
true
}
}

// === Public API ===

pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
Expand All @@ -279,6 +433,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu
.map_err(|err| err.to_string())
}

pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
value
.serialize(ParamSerializer { writer })
.map_err(|err| err.to_string())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 7e3aa8f

Please sign in to comment.