diff --git a/src/lib.rs b/src/lib.rs index 6df4a50..fda89e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,8 @@ #[macro_use] extern crate static_assertions; -use self::{error::Result, http_client::HttpClient}; +use self::{error::Result, http_client::HttpClient, sql::ser}; +use ::serde::Serialize; use std::{collections::HashMap, fmt::Display, sync::Arc}; pub use self::{compression::Compression, row::Row}; @@ -160,6 +161,12 @@ impl Client { self } + pub fn with_param(self, name: &str, value: impl Serialize) -> Result { + let mut param = String::from(""); + ser::write_param(&mut param, &value)?; + Ok(self.with_option(format!("param_{name}"), param)) + } + /// Used to specify a header that will be passed to all queries. /// /// # Example diff --git a/src/query.rs b/src/query.rs index 926fdb5..f18a60d 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,5 +1,5 @@ use hyper::{header::CONTENT_LENGTH, Method, Request}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use url::Url; @@ -10,7 +10,7 @@ use crate::{ request_body::RequestBody, response::Response, row::Row, - sql::{Bind, SqlBuilder}, + sql::{ser, Bind, SqlBuilder}, Client, }; @@ -195,6 +195,12 @@ impl Query { self.client.add_option(name, value); self } + + pub fn with_param(self, name: &str, value: impl Serialize) -> Result { + let mut param = String::from(""); + ser::write_param(&mut param, &value)?; + Ok(self.with_option(format!("param_{name}"), param)) + } } /// A cursor that emits rows. diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 7417be7..66330f6 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier}; mod bind; pub(crate) mod escape; -mod ser; +pub(crate) mod ser; #[derive(Debug, Clone)] pub(crate) enum SqlBuilder { diff --git a/src/sql/ser.rs b/src/sql/ser.rs index 00ea606..5715c12 100644 --- a/src/sql/ser.rs +++ b/src/sql/ser.rs @@ -8,23 +8,23 @@ use thiserror::Error; use super::escape; -// === SqlSerializerError === +// === SerializerError === #[derive(Debug, Error)] -enum SqlSerializerError { +enum SerializerError { #[error("{0} is unsupported")] Unsupported(&'static str), #[error("{0}")] Custom(String), } -impl ser::Error for SqlSerializerError { +impl ser::Error for SerializerError { fn custom(msg: T) -> Self { Self::Custom(msg.to_string()) } } -impl From for SqlSerializerError { +impl From for SerializerError { fn from(err: fmt::Error) -> Self { Self::Custom(err.to_string()) } @@ -32,8 +32,8 @@ impl From for SqlSerializerError { // === SqlSerializer === -type Result = std::result::Result; -type Impossible = ser::Impossible<(), SqlSerializerError>; +type Result = std::result::Result; +type Impossible = ser::Impossible<(), SerializerError>; struct SqlSerializer<'a, W> { writer: &'a mut W, @@ -43,7 +43,7 @@ macro_rules! unsupported { ($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => { #[inline] fn $ser_method(self, _v: $ty) -> $ret { - Err(SqlSerializerError::Unsupported(stringify!($ser_method))) + Err(SerializerError::Unsupported(stringify!($ser_method))) } unsupported!($($other)*); }; @@ -53,7 +53,7 @@ macro_rules! unsupported { ($ser_method:ident, $($other:tt)*) => { #[inline] fn $ser_method(self) -> Result { - Err(SqlSerializerError::Unsupported(stringify!($ser_method))) + Err(SerializerError::Unsupported(stringify!($ser_method))) } unsupported!($($other)*); }; @@ -73,7 +73,7 @@ macro_rules! forward_to_display { } impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); type SerializeMap = Impossible; type SerializeSeq = SqlListSerializer<'a, W>; @@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _value: &T, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_newtype_variant")) + Err(SerializerError::Unsupported("serialize_newtype_variant")) } #[inline] fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { - Err(SqlSerializerError::Unsupported("serialize_tuple_struct")) + Err(SerializerError::Unsupported("serialize_tuple_struct")) } #[inline] @@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _len: usize, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_tuple_variant")) + Err(SerializerError::Unsupported("serialize_tuple_variant")) } #[inline] fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - Err(SqlSerializerError::Unsupported("serialize_struct")) + Err(SerializerError::Unsupported("serialize_struct")) } #[inline] @@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _len: usize, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_struct_variant")) + Err(SerializerError::Unsupported("serialize_struct_variant")) } #[inline] @@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> { } impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); #[inline] @@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> { } impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); #[inline] @@ -271,6 +271,167 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> { } } +// === ParamSerializer === + +struct ParamSerializer<'a, W> { + writer: &'a mut W, +} + +impl<'a, W: Write> Serializer for ParamSerializer<'a, W> { + type Error = SerializerError; + type Ok = (); + type SerializeMap = Impossible; + type SerializeSeq = SqlListSerializer<'a, W>; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + type SerializeTuple = SqlListSerializer<'a, W>; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + unsupported!( + serialize_map(Option) -> Result, + serialize_bytes(&[u8]), + serialize_unit, + serialize_unit_struct(&'static str), + ); + + forward_to_display!( + serialize_i8(i8), + serialize_i16(i16), + serialize_i32(i32), + serialize_i64(i64), + serialize_i128(i128), + serialize_u8(u8), + serialize_u16(u16), + serialize_u32(u32), + serialize_u64(u64), + serialize_u128(u128), + serialize_f32(f32), + serialize_f64(f64), + serialize_bool(bool), + ); + + #[inline] + fn serialize_char(self, value: char) -> Result { + let mut tmp = [0u8; 4]; + self.serialize_str(value.encode_utf8(&mut tmp)) + } + + #[inline] + fn serialize_str(self, value: &str) -> Result { + // ClickHouse expects strings in params to be unquoted until inside a nested type + // nested types go through serialize_seq which'll quote strings + let mut rest = value; + while let Some(nextidx) = rest.find('\\') { + let (before, after) = rest.split_at(nextidx + 1); + rest = after; + self.writer.write_str(before)?; + self.writer.write_char('\\')?; + } + self.writer.write_str(rest)?; + Ok(()) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result> { + self.writer.write_char('[')?; + Ok(SqlListSerializer { + writer: self.writer, + has_items: false, + closing_char: ']', + }) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result> { + self.writer.write_char('(')?; + Ok(SqlListSerializer { + writer: self.writer, + has_items: false, + closing_char: ')', + }) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result { + _value.serialize(self) + } + + #[inline] + fn serialize_none(self) -> std::result::Result { + self.writer.write_str("NULL")?; + Ok(()) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + escape::string(variant, self.writer)?; + Ok(()) + } + + #[inline] + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result { + Err(SerializerError::Unsupported("serialize_newtype_variant")) + } + + #[inline] + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(SerializerError::Unsupported("serialize_tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(SerializerError::Unsupported("serialize_tuple_variant")) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(SerializerError::Unsupported("serialize_struct")) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(SerializerError::Unsupported("serialize_struct_variant")) + } + + #[inline] + fn is_human_readable(&self) -> bool { + true + } +} + // === Public API === pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> { @@ -279,6 +440,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu .map_err(|err| err.to_string()) } +pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> { + value + .serialize(ParamSerializer { writer }) + .map_err(|err| err.to_string()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/it/query.rs b/tests/it/query.rs index 80e0158..8d1e2cd 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -85,6 +85,49 @@ async fn fetch_one_and_optional() { assert_eq!(got_string, "bar"); } +#[tokio::test] +async fn server_side_param() { + let client = prepare_database!() + .with_param("val1", 42) + .expect("failed to bind 42"); + + let result = client + .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") + .with_param("val2", 144) + .expect("failed to bind 144") + .fetch_one::() + .await + .expect("failed to fetch u64"); + assert_eq!(result, 186); + + let result = client + .query("SELECT {val1: String} AS result") + .with_param("val1", "string") + .expect("failed to bind \"string\"") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "string"); + + let result = client + .query("SELECT {val1: String} AS result") + .with_param("val1", "\x01\x02\x03\\ \"\'") + .expect("failed to bind weird string") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "\x01\x02\x03\\ \"\'"); + + let result = client + .query("SELECT {val1: Array(String)} AS result") + .with_param("val1", vec!["a", "bc"]) + .expect("failed to bind string array") + .fetch_one::>() + .await + .expect("failed to fetch string"); + assert_eq!(result, &["a", "bc"]); +} + // See #19. #[tokio::test] async fn long_query() {