Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

with_param #159

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#[macro_use]
extern crate static_assertions;

use self::{error::Result, http_client::HttpClient};
use self::{error::Result, http_client::HttpClient, sql::ser};
use ::serde::Serialize;
use std::{collections::HashMap, fmt::Display, sync::Arc};

pub use self::{compression::Compression, row::Row};
Expand Down Expand Up @@ -160,6 +161,12 @@ impl Client {
self
}

pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of Client::with_param? Parameters are specific for queries, not the whole client

Copy link
Member Author

@serprex serprex Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, but if someone wants to share parameter between queries queries they can put it on the client

For example, in a multi tenant app, tenantid could be put as a parameter in the function creating client

let mut param = String::from("");
ser::write_param(&mut param, &value)?;
Ok(self.with_option(format!("param_{name}"), param))
}

/// Used to specify a header that will be passed to all queries.
///
/// # Example
Expand Down
10 changes: 8 additions & 2 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use hyper::{header::CONTENT_LENGTH, Method, Request};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use url::Url;

Expand All @@ -10,7 +10,7 @@ use crate::{
request_body::RequestBody,
response::Response,
row::Row,
sql::{Bind, SqlBuilder},
sql::{ser, Bind, SqlBuilder},
Client,
};

Expand Down Expand Up @@ -195,6 +195,12 @@ impl Query {
self.client.add_option(name, value);
self
}

pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still trying to figure out the naming here. This one is consistent with with_option, but that name emphasizes the connection with Client::with_option.

For instance, bind mimics sqlx bind.
A simple param() would be more concise. Yep, it's inconsistent for now, but once Client::builder() is introduced, we can have fn option() instead of with_option.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Result<> can be very inconvenient. It can be used in a concise way only inside -> Result functions (q.with_param("a", 42)?.with_param("b", 43)?), but sometimes you want explicitly handle errors with matching inside the function that doesn't return Result, so ? (via the Try trait) cannot be used.

So, I suggest delayed errors in the same way as in the bind() method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method must be well documented. Also, we should specify the difference between param() and bind() (probably in a dedicated section on Query and provide links to that section in docs of both methods).

loyd marked this conversation as resolved.
Show resolved Hide resolved
let mut param = String::from("");
ser::write_param(&mut param, &value)?;
Ok(self.with_option(format!("param_{name}"), param))
}
}

/// A cursor that emits rows.
Expand Down
2 changes: 1 addition & 1 deletion src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier};

mod bind;
pub(crate) mod escape;
mod ser;
pub(crate) mod ser;

#[derive(Debug, Clone)]
pub(crate) enum SqlBuilder {
Expand Down
199 changes: 183 additions & 16 deletions src/sql/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,32 @@ use thiserror::Error;

use super::escape;

// === SqlSerializerError ===
// === SerializerError ===

#[derive(Debug, Error)]
enum SqlSerializerError {
enum SerializerError {
#[error("{0} is unsupported")]
Unsupported(&'static str),
#[error("{0}")]
Custom(String),
}

impl ser::Error for SqlSerializerError {
impl ser::Error for SerializerError {
fn custom<T: fmt::Display>(msg: T) -> Self {
Self::Custom(msg.to_string())
}
}

impl From<fmt::Error> for SqlSerializerError {
impl From<fmt::Error> for SerializerError {
fn from(err: fmt::Error) -> Self {
Self::Custom(err.to_string())
}
}

// === SqlSerializer ===

type Result<T = (), E = SqlSerializerError> = std::result::Result<T, E>;
type Impossible = ser::Impossible<(), SqlSerializerError>;
type Result<T = (), E = SerializerError> = std::result::Result<T, E>;
type Impossible = ser::Impossible<(), SerializerError>;

struct SqlSerializer<'a, W> {
writer: &'a mut W,
Expand All @@ -43,7 +43,7 @@ macro_rules! unsupported {
($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => {
#[inline]
fn $ser_method(self, _v: $ty) -> $ret {
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
Err(SerializerError::Unsupported(stringify!($ser_method)))
}
unsupported!($($other)*);
};
Expand All @@ -53,7 +53,7 @@ macro_rules! unsupported {
($ser_method:ident, $($other:tt)*) => {
#[inline]
fn $ser_method(self) -> Result {
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
Err(SerializerError::Unsupported(stringify!($ser_method)))
}
unsupported!($($other)*);
};
Expand All @@ -73,7 +73,7 @@ macro_rules! forward_to_display {
}

impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = SqlListSerializer<'a, W>;
Expand Down Expand Up @@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_value: &T,
) -> Result {
Err(SqlSerializerError::Unsupported("serialize_newtype_variant"))
Err(SerializerError::Unsupported("serialize_newtype_variant"))
}

#[inline]
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
Err(SqlSerializerError::Unsupported("serialize_tuple_struct"))
Err(SerializerError::Unsupported("serialize_tuple_struct"))
}

#[inline]
Expand All @@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_len: usize,
) -> Result<Impossible> {
Err(SqlSerializerError::Unsupported("serialize_tuple_variant"))
Err(SerializerError::Unsupported("serialize_tuple_variant"))
}

#[inline]
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
Err(SqlSerializerError::Unsupported("serialize_struct"))
Err(SerializerError::Unsupported("serialize_struct"))
}

#[inline]
Expand All @@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(SqlSerializerError::Unsupported("serialize_struct_variant"))
Err(SerializerError::Unsupported("serialize_struct_variant"))
}

#[inline]
Expand All @@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> {
}

impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();

#[inline]
Expand All @@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
}

impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();

#[inline]
Expand All @@ -271,6 +271,167 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
}
}

// === ParamSerializer ===

struct ParamSerializer<'a, W> {
writer: &'a mut W,
}

impl<'a, W: Write> Serializer for ParamSerializer<'a, W> {
type Error = SerializerError;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = SqlListSerializer<'a, W>;
type SerializeStruct = Impossible;
type SerializeStructVariant = Impossible;
type SerializeTuple = SqlListSerializer<'a, W>;
type SerializeTupleStruct = Impossible;
type SerializeTupleVariant = Impossible;

unsupported!(
serialize_map(Option<usize>) -> Result<Impossible>,
serialize_bytes(&[u8]),
serialize_unit,
serialize_unit_struct(&'static str),
);

forward_to_display!(
serialize_i8(i8),
serialize_i16(i16),
serialize_i32(i32),
serialize_i64(i64),
serialize_i128(i128),
serialize_u8(u8),
serialize_u16(u16),
serialize_u32(u32),
serialize_u64(u64),
serialize_u128(u128),
serialize_f32(f32),
serialize_f64(f64),
serialize_bool(bool),
);

#[inline]
fn serialize_char(self, value: char) -> Result {
let mut tmp = [0u8; 4];
self.serialize_str(value.encode_utf8(&mut tmp))
}

#[inline]
fn serialize_str(self, value: &str) -> Result {
// ClickHouse expects strings in params to be unquoted until inside a nested type
// nested types go through serialize_seq which'll quote strings
let mut rest = value;
while let Some(nextidx) = rest.find('\\') {
let (before, after) = rest.split_at(nextidx + 1);
rest = after;
self.writer.write_str(before)?;
self.writer.write_char('\\')?;
}
self.writer.write_str(rest)?;
Ok(())
}

#[inline]
fn serialize_seq(self, _len: Option<usize>) -> Result<SqlListSerializer<'a, W>> {
self.writer.write_char('[')?;
Ok(SqlListSerializer {
writer: self.writer,
has_items: false,
closing_char: ']',
})
}

#[inline]
fn serialize_tuple(self, _len: usize) -> Result<SqlListSerializer<'a, W>> {
self.writer.write_char('(')?;
Ok(SqlListSerializer {
writer: self.writer,
has_items: false,
closing_char: ')',
})
}

#[inline]
fn serialize_some<T: Serialize + ?Sized>(self, _value: &T) -> Result {
_value.serialize(self)
}

#[inline]
fn serialize_none(self) -> std::result::Result<Self::Ok, Self::Error> {
self.writer.write_str("NULL")?;
Ok(())
}

#[inline]
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
) -> Result {
escape::string(variant, self.writer)?;
Ok(())
}

#[inline]
fn serialize_newtype_struct<T: Serialize + ?Sized>(
self,
_name: &'static str,
value: &T,
) -> Result {
value.serialize(self)
}

#[inline]
fn serialize_newtype_variant<T: Serialize + ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result {
Err(SerializerError::Unsupported("serialize_newtype_variant"))
}

#[inline]
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
Err(SerializerError::Unsupported("serialize_tuple_struct"))
}

#[inline]
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Impossible> {
Err(SerializerError::Unsupported("serialize_tuple_variant"))
}

#[inline]
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
Err(SerializerError::Unsupported("serialize_struct"))
}

#[inline]
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(SerializerError::Unsupported("serialize_struct_variant"))
}

#[inline]
fn is_human_readable(&self) -> bool {
true
}
}

// === Public API ===

pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
Expand All @@ -279,6 +440,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu
.map_err(|err| err.to_string())
}

pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
value
.serialize(ParamSerializer { writer })
.map_err(|err| err.to_string())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading